diff --git a/pkg/frontend/mysql_cmd_executor.go b/pkg/frontend/mysql_cmd_executor.go index e944c32091def0fe8ae830168b2172226e87dbe3..dc370de3ca60c49853084480c266c7610228db93 100644 --- a/pkg/frontend/mysql_cmd_executor.go +++ b/pkg/frontend/mysql_cmd_executor.go @@ -1397,6 +1397,12 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(stmt *tree.ExplainStmt) error { // handlePrepareStmt func (mce *MysqlCmdExecutor) handlePrepareStmt(st *tree.PrepareStmt) error { + switch st.Stmt.(type) { + case *tree.Update: + mce.ses.GetTxnCompileCtx().SetQueryType(TXN_UPDATE) + case *tree.Delete: + mce.ses.GetTxnCompileCtx().SetQueryType(TXN_DELETE) + } preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st) if err != nil { return err @@ -1411,15 +1417,21 @@ func (mce *MysqlCmdExecutor) handlePrepareStmt(st *tree.PrepareStmt) error { // handlePrepareString func (mce *MysqlCmdExecutor) handlePrepareString(st *tree.PrepareString) error { - preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st) + stmts, err := mysql.Parse(st.Sql) if err != nil { return err } - stmts, err := mysql.Parse(st.Sql) + switch stmts[0].(type) { + case *tree.Update: + mce.ses.GetTxnCompileCtx().SetQueryType(TXN_UPDATE) + case *tree.Delete: + mce.ses.GetTxnCompileCtx().SetQueryType(TXN_DELETE) + } + + preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st) if err != nil { return err } - return mce.ses.SetPrepareStmt(preparePlan.GetDcl().GetPrepare().GetName(), &PrepareStmt{ Name: preparePlan.GetDcl().GetPrepare().GetName(), PreparePlan: preparePlan, @@ -1922,6 +1934,7 @@ func (mce *MysqlCmdExecutor) doComQuery(requestCtx context.Context, sql string) if ret, err = cw.Compile(requestCtx, ses, ses.outputCallback); err != nil { goto handleFailed } + stmt = cw.GetAst() runner = ret.(ComputationRunner) if ses.Pu.SV.GetRecordTimeElapsedOfSqlRequest() { diff --git a/pkg/sql/plan/deepcopy.go b/pkg/sql/plan/deepcopy.go index 4f40b896ebfaffd1bb781ebd9e3a9d7f0f224cf8..85d7613ea622ed662007ebe18cdf22acc95ae76c 100644 --- a/pkg/sql/plan/deepcopy.go +++ b/pkg/sql/plan/deepcopy.go @@ -113,11 +113,7 @@ func DeepCopyNode(node *plan.Node) *plan.Node { Size: col.Typ.Size, Scale: col.Typ.Scale, }, - Default: &plan.Default{ - NullAbility: col.Default.NullAbility, - Expr: DeepCopyExpr(col.Default.Expr), - OriginString: col.Default.String(), - }, + Default: DeepCopyDefault(col.Default), Primary: col.Primary, Pkidx: col.Pkidx, } @@ -189,6 +185,17 @@ func DeepCopyNode(node *plan.Node) *plan.Node { return newNode } +func DeepCopyDefault(def *plan.Default) *plan.Default { + if def == nil { + return nil + } + return &plan.Default{ + NullAbility: def.NullAbility, + Expr: DeepCopyExpr(def.Expr), + OriginString: def.OriginString, + } +} + func DeepCopyTableDef(table *plan.TableDef) *plan.TableDef { newTable := &plan.TableDef{ Name: table.Name, @@ -207,11 +214,7 @@ func DeepCopyTableDef(table *plan.TableDef) *plan.TableDef { Size: col.Typ.Size, Scale: col.Typ.Scale, }, - Default: &plan.Default{ - NullAbility: col.Default.NullAbility, - Expr: DeepCopyExpr(col.Default.Expr), - OriginString: col.Default.OriginString, - }, + Default: DeepCopyDefault(col.Default), Primary: col.Primary, Pkidx: col.Pkidx, } diff --git a/test/cases/prepare/prepare.test b/test/cases/prepare/prepare.test index e3065ad1ae8c50ebf090ec0919ab290ae35c24ea..60c309cfc160dd5b42648d04d0841d4d067b2273 100644 --- a/test/cases/prepare/prepare.test +++ b/test/cases/prepare/prepare.test @@ -14,6 +14,18 @@ delete from t1 where a > 3; execute stmt1 using @a_var; deallocate prepare stmt1; execute stmt1 using @a_var; + +prepare stmt1 from 'update t1 set a=999 where b = ?'; +set @b_var = 33; +execute stmt1 using @b_var; +select * from t1; +deallocate prepare stmt1; + +prepare stmt1 from 'delete from t1 where b = ?'; +execute stmt1 using @b_var; +select * from t1; +deallocate prepare stmt1; + drop table t1; prepare stmt2 from 'select @var_t1'; diff --git a/test/result/prepare/prepare.result b/test/result/prepare/prepare.result index f1622f58f7e64452c7d1cc73e37adf6974322338..0cf15e2d94bd9433f57730ab0a2f5e535deacddf 100644 --- a/test/result/prepare/prepare.result +++ b/test/result/prepare/prepare.result @@ -23,6 +23,25 @@ a b deallocate prepare stmt1; execute stmt1 using @a_var; prepare statement 'stmt1' does not exist + +prepare stmt1 from 'update t1 set a=999 where b = ?'; +set @b_var = 33; +execute stmt1 using @b_var; +select * from t1; +a b +1 11 +2 22 +999 33 +deallocate prepare stmt1; + +prepare stmt1 from 'delete from t1 where b = ?'; +execute stmt1 using @b_var; +select * from t1; +a b +1 11 +2 22 +deallocate prepare stmt1; + drop table t1; prepare stmt2 from 'select @var_t1';