From 76a720f0746b79911a4bd986f297aa3481b2945d Mon Sep 17 00:00:00 2001
From: ou yuanning <45346669+ouyuanning@users.noreply.github.com>
Date: Tue, 16 Aug 2022 13:59:35 +0800
Subject: [PATCH] Fix prepare/execute stmt bug (#4517)

Fix prepare/execute stmt bug

Approved by: @daviszhen, @aressu1985, @aunjgr
---
 pkg/frontend/mysql_cmd_executor.go | 19 ++++++++++++++++---
 pkg/sql/plan/deepcopy.go           | 23 +++++++++++++----------
 test/cases/prepare/prepare.test    | 12 ++++++++++++
 test/result/prepare/prepare.result | 19 +++++++++++++++++++
 4 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/pkg/frontend/mysql_cmd_executor.go b/pkg/frontend/mysql_cmd_executor.go
index e944c3209..dc370de3c 100644
--- a/pkg/frontend/mysql_cmd_executor.go
+++ b/pkg/frontend/mysql_cmd_executor.go
@@ -1397,6 +1397,12 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(stmt *tree.ExplainStmt) error {
 
 // handlePrepareStmt
 func (mce *MysqlCmdExecutor) handlePrepareStmt(st *tree.PrepareStmt) error {
+	switch st.Stmt.(type) {
+	case *tree.Update:
+		mce.ses.GetTxnCompileCtx().SetQueryType(TXN_UPDATE)
+	case *tree.Delete:
+		mce.ses.GetTxnCompileCtx().SetQueryType(TXN_DELETE)
+	}
 	preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st)
 	if err != nil {
 		return err
@@ -1411,15 +1417,21 @@ func (mce *MysqlCmdExecutor) handlePrepareStmt(st *tree.PrepareStmt) error {
 
 // handlePrepareString
 func (mce *MysqlCmdExecutor) handlePrepareString(st *tree.PrepareString) error {
-	preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st)
+	stmts, err := mysql.Parse(st.Sql)
 	if err != nil {
 		return err
 	}
-	stmts, err := mysql.Parse(st.Sql)
+	switch stmts[0].(type) {
+	case *tree.Update:
+		mce.ses.GetTxnCompileCtx().SetQueryType(TXN_UPDATE)
+	case *tree.Delete:
+		mce.ses.GetTxnCompileCtx().SetQueryType(TXN_DELETE)
+	}
+
+	preparePlan, err := buildPlan(mce.ses.txnCompileCtx, st)
 	if err != nil {
 		return err
 	}
-
 	return mce.ses.SetPrepareStmt(preparePlan.GetDcl().GetPrepare().GetName(), &PrepareStmt{
 		Name:        preparePlan.GetDcl().GetPrepare().GetName(),
 		PreparePlan: preparePlan,
@@ -1922,6 +1934,7 @@ func (mce *MysqlCmdExecutor) doComQuery(requestCtx context.Context, sql string)
 		if ret, err = cw.Compile(requestCtx, ses, ses.outputCallback); err != nil {
 			goto handleFailed
 		}
+		stmt = cw.GetAst()
 
 		runner = ret.(ComputationRunner)
 		if ses.Pu.SV.GetRecordTimeElapsedOfSqlRequest() {
diff --git a/pkg/sql/plan/deepcopy.go b/pkg/sql/plan/deepcopy.go
index 4f40b896e..85d7613ea 100644
--- a/pkg/sql/plan/deepcopy.go
+++ b/pkg/sql/plan/deepcopy.go
@@ -113,11 +113,7 @@ func DeepCopyNode(node *plan.Node) *plan.Node {
 					Size:      col.Typ.Size,
 					Scale:     col.Typ.Scale,
 				},
-				Default: &plan.Default{
-					NullAbility:  col.Default.NullAbility,
-					Expr:         DeepCopyExpr(col.Default.Expr),
-					OriginString: col.Default.String(),
-				},
+				Default: DeepCopyDefault(col.Default),
 				Primary: col.Primary,
 				Pkidx:   col.Pkidx,
 			}
@@ -189,6 +185,17 @@ func DeepCopyNode(node *plan.Node) *plan.Node {
 	return newNode
 }
 
+func DeepCopyDefault(def *plan.Default) *plan.Default {
+	if def == nil {
+		return nil
+	}
+	return &plan.Default{
+		NullAbility:  def.NullAbility,
+		Expr:         DeepCopyExpr(def.Expr),
+		OriginString: def.OriginString,
+	}
+}
+
 func DeepCopyTableDef(table *plan.TableDef) *plan.TableDef {
 	newTable := &plan.TableDef{
 		Name: table.Name,
@@ -207,11 +214,7 @@ func DeepCopyTableDef(table *plan.TableDef) *plan.TableDef {
 				Size:      col.Typ.Size,
 				Scale:     col.Typ.Scale,
 			},
-			Default: &plan.Default{
-				NullAbility:  col.Default.NullAbility,
-				Expr:         DeepCopyExpr(col.Default.Expr),
-				OriginString: col.Default.OriginString,
-			},
+			Default: DeepCopyDefault(col.Default),
 			Primary: col.Primary,
 			Pkidx:   col.Pkidx,
 		}
diff --git a/test/cases/prepare/prepare.test b/test/cases/prepare/prepare.test
index e3065ad1a..60c309cfc 100644
--- a/test/cases/prepare/prepare.test
+++ b/test/cases/prepare/prepare.test
@@ -14,6 +14,18 @@ delete from t1 where a > 3;
 execute stmt1 using @a_var;
 deallocate prepare stmt1;
 execute stmt1 using @a_var;
+
+prepare stmt1 from 'update t1 set a=999 where b = ?';
+set @b_var = 33;
+execute stmt1 using @b_var;
+select * from t1;
+deallocate prepare stmt1;
+
+prepare stmt1 from 'delete from t1 where b = ?';
+execute stmt1 using @b_var;
+select * from t1;
+deallocate prepare stmt1;
+
 drop table t1;
 
 prepare stmt2 from 'select @var_t1';
diff --git a/test/result/prepare/prepare.result b/test/result/prepare/prepare.result
index f1622f58f..0cf15e2d9 100644
--- a/test/result/prepare/prepare.result
+++ b/test/result/prepare/prepare.result
@@ -23,6 +23,25 @@ a	b
 deallocate prepare stmt1;
 execute stmt1 using @a_var;
 prepare statement 'stmt1' does not exist
+
+prepare stmt1 from 'update t1 set a=999 where b = ?';
+set @b_var = 33;
+execute stmt1 using @b_var;
+select * from t1;
+a	b
+1	11
+2	22
+999	33
+deallocate prepare stmt1;
+
+prepare stmt1 from 'delete from t1 where b = ?';
+execute stmt1 using @b_var;
+select * from t1;
+a	b
+1	11
+2	22
+deallocate prepare stmt1;
+
 drop table t1;
 
 prepare stmt2 from 'select @var_t1';
-- 
GitLab