Skip to content
Snippets Groups Projects
Unverified Commit 83eb08f8 authored by reusee's avatar reusee Committed by GitHub
Browse files

frontend: set session request contexts (#4646)

frontend: set session request contexts

Approved by: @daviszhen
parent 05add466
No related branches found
No related tags found
No related merge requests found
......@@ -117,7 +117,7 @@ func (res *internalExecResult) StringValueByName(ridx uint64, col string) (strin
func (ie *internalExecutor) Exec(ctx context.Context, sql string, opts ie.SessionOverrideOptions) (err error) {
ie.Lock()
defer ie.Unlock()
sess := ie.newCmdSession(opts)
sess := ie.newCmdSession(ctx, opts)
ie.executor.PrepareSessionBeforeExecRequest(sess)
ie.proto.stashResult = false
return ie.executor.doComQuery(ctx, sql)
......@@ -126,7 +126,7 @@ func (ie *internalExecutor) Exec(ctx context.Context, sql string, opts ie.Sessio
func (ie *internalExecutor) Query(ctx context.Context, sql string, opts ie.SessionOverrideOptions) ie.InternalExecResult {
ie.Lock()
defer ie.Unlock()
sess := ie.newCmdSession(opts)
sess := ie.newCmdSession(ctx, opts)
ie.executor.PrepareSessionBeforeExecRequest(sess)
ie.proto.stashResult = true
err := ie.executor.doComQuery(ctx, sql)
......@@ -135,8 +135,9 @@ func (ie *internalExecutor) Query(ctx context.Context, sql string, opts ie.Sessi
return res
}
func (ie *internalExecutor) newCmdSession(opts ie.SessionOverrideOptions) *Session {
func (ie *internalExecutor) newCmdSession(ctx context.Context, opts ie.SessionOverrideOptions) *Session {
sess := NewSession(ie.proto, guest.New(ie.pu.SV.GuestMmuLimitation, ie.pu.HostMmu), ie.pu.Mempool, ie.pu, gSysVariables)
sess.SetRequestContext(ctx)
applyOverride(sess, ie.baseSessOpts)
applyOverride(sess, opts)
return sess
......
......@@ -51,15 +51,16 @@ func (e *miniExec) PrepareSessionBeforeExecRequest(sess *Session) {
}
func TestIe(t *testing.T) {
ctx := context.TODO()
pu := config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, nil, nil)
executor := newIe(pu, &miniExec{})
executor.ApplySessionOverride(ie.NewOptsBuilder().Username("dump").Finish())
sess := executor.newCmdSession(ie.NewOptsBuilder().Database("mo_catalog").Internal(true).Finish())
sess := executor.newCmdSession(ctx, ie.NewOptsBuilder().Database("mo_catalog").Internal(true).Finish())
assert.Equal(t, "dump", sess.GetMysqlProtocol().GetUserName())
err := executor.Exec(context.TODO(), "whatever", ie.NewOptsBuilder().Finish())
err := executor.Exec(ctx, "whatever", ie.NewOptsBuilder().Finish())
assert.NoError(t, err)
res := executor.Query(context.TODO(), "whatever", ie.NewOptsBuilder().Finish())
res := executor.Query(ctx, "whatever", ie.NewOptsBuilder().Finish())
assert.NoError(t, err)
assert.Equal(t, uint64(0), res.RowCount())
}
......
......@@ -133,6 +133,7 @@ func Test_readTextFile(t *testing.T) {
config.StorageEngine = nil
}()
ses := NewSession(proto, guestMmu, pu.Mempool, pu, gSysVariables)
ses.SetRequestContext(ctx)
mce := NewMysqlCmdExecutor()
......@@ -317,6 +318,7 @@ func Test_readTextFile(t *testing.T) {
guestMmu := guest.New(pu.SV.GuestMmuLimitation, pu.HostMmu)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, gSysVariables)
ses.SetRequestContext(ctx)
mce := NewMysqlCmdExecutor()
......@@ -515,6 +517,7 @@ func Test_readTextFile(t *testing.T) {
guestMmu := guest.New(pu.SV.GuestMmuLimitation, pu.HostMmu)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, gSysVariables)
ses.SetRequestContext(ctx)
mce := NewMysqlCmdExecutor()
......
......@@ -52,8 +52,8 @@ func Test_mce(t *testing.T) {
txnOperator := mock_frontend.NewMockTxnOperator(ctrl)
eng.EXPECT().Database(ctx, gomock.Any(), txnOperator).Return(nil, nil).AnyTimes()
txnOperator.EXPECT().Commit(nil).Return(nil).AnyTimes()
txnOperator.EXPECT().Rollback(nil).Return(nil).AnyTimes()
txnOperator.EXPECT().Commit(gomock.Any()).Return(nil).AnyTimes()
txnOperator.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes()
txnClient := mock_frontend.NewMockTxnClient(ctrl)
txnClient.EXPECT().New().Return(txnOperator, nil).AnyTimes()
......@@ -194,6 +194,7 @@ func Test_mce(t *testing.T) {
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
mce := NewMysqlCmdExecutor()
......@@ -296,6 +297,7 @@ func Test_mce_selfhandle(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
mce := NewMysqlCmdExecutor()
mce.PrepareSessionBeforeExecRequest(ses)
......@@ -332,6 +334,7 @@ func Test_mce_selfhandle(t *testing.T) {
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
mce := NewMysqlCmdExecutor()
......@@ -434,6 +437,7 @@ func Test_getDataFromPipeline(t *testing.T) {
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
// mce := NewMysqlCmdExecutor()
......@@ -505,6 +509,7 @@ func Test_getDataFromPipeline(t *testing.T) {
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
convey.So(getDataFromPipeline(ses, nil), convey.ShouldBeNil)
......@@ -696,6 +701,7 @@ func Test_handleSelectVariables(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, nil, nil, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
mce := &MysqlCmdExecutor{}
mce.PrepareSessionBeforeExecRequest(ses)
......@@ -735,6 +741,7 @@ func Test_handleShowVariables(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, nil, nil, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
mce := &MysqlCmdExecutor{}
mce.PrepareSessionBeforeExecRequest(ses)
......@@ -766,6 +773,7 @@ func Test_GetComputationWrapper(t *testing.T) {
}
func Test_handleShowCreateTable(t *testing.T) {
ctx := context.TODO()
convey.Convey("handleShowCreateTable succ", t, func() {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
......@@ -784,6 +792,7 @@ func Test_handleShowCreateTable(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Data = make([][]interface{}, 1)
ses.Data[0] = make([]interface{}, showCreateTableAttrCount)
ses.Data[0][tableNamePos] = []byte("tableName")
......@@ -802,6 +811,7 @@ func Test_handleShowCreateTable(t *testing.T) {
}
func Test_handleShowCreateDatabase(t *testing.T) {
ctx := context.TODO()
convey.Convey("handleShowCreateDatabase succ", t, func() {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
......@@ -820,6 +830,7 @@ func Test_handleShowCreateDatabase(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
ses.Mrs.Name2Index = make(map[string]uint64)
......@@ -834,6 +845,7 @@ func Test_handleShowCreateDatabase(t *testing.T) {
}
func Test_handleShowColumns(t *testing.T) {
ctx := context.TODO()
convey.Convey("handleShowColumns succ", t, func() {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
......@@ -852,6 +864,7 @@ func Test_handleShowColumns(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, guestMmu, pu.Mempool, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Data = make([][]interface{}, 1)
ses.Data[0] = make([]interface{}, primaryKeyPos+1)
ses.Data[0][0] = []byte("col1")
......@@ -889,6 +902,7 @@ func runTestHandle(funName string, t *testing.T, handleFun func(*MysqlCmdExecuto
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, nil, nil, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
mce := &MysqlCmdExecutor{}
mce.PrepareSessionBeforeExecRequest(ses)
......@@ -973,6 +987,7 @@ func Test_CMD_FIELD_LIST(t *testing.T) {
var gSys GlobalSystemVariables
InitGlobalSystemVariables(&gSys)
ses := NewSession(proto, nil, nil, pu, &gSys)
ses.SetRequestContext(ctx)
ses.Mrs = &MysqlResultSet{}
ses.SetDatabaseName("t")
mce := &MysqlCmdExecutor{}
......
......@@ -637,11 +637,9 @@ func (th *TxnHandler) CommitTxn() error {
if !th.IsValidTxn() {
return nil
}
var ctx context.Context
if th.ses != nil {
ctx = th.ses.GetRequestContext()
} else {
ctx = context.Background()
ctx := th.ses.GetRequestContext()
if ctx == nil {
panic("context should not be nil")
}
err := th.txn.Commit(ctx)
th.SetInvalid()
......@@ -652,11 +650,9 @@ func (th *TxnHandler) RollbackTxn() error {
if !th.IsValidTxn() {
return nil
}
var ctx context.Context
if th.ses != nil {
ctx = th.ses.GetRequestContext()
} else {
ctx = context.Background()
ctx := th.ses.GetRequestContext()
if ctx == nil {
panic("context should not be nil")
}
err := th.txn.Rollback(ctx)
th.SetInvalid()
......
......@@ -36,7 +36,7 @@ func TestTxnHandler_NewTxn(t *testing.T) {
ctx := context.TODO()
txnOperator := mock_frontend.NewMockTxnOperator(ctrl)
txnOperator.EXPECT().Commit(ctx).Return(nil).AnyTimes()
txnOperator.EXPECT().Commit(gomock.Any()).Return(nil).AnyTimes()
txnClient := mock_frontend.NewMockTxnClient(ctrl)
cnt := 0
txnClient.EXPECT().New().DoAndReturn(
......@@ -51,6 +51,9 @@ func TestTxnHandler_NewTxn(t *testing.T) {
eng := mock_frontend.NewMockEngine(ctrl)
txn := InitTxnHandler(eng, txnClient)
txn.ses = &Session{
requestCtx: ctx,
}
err := txn.NewTxn()
convey.So(err, convey.ShouldBeNil)
err = txn.NewTxn()
......@@ -68,7 +71,7 @@ func TestTxnHandler_CommitTxn(t *testing.T) {
ctx := context.TODO()
txnOperator := mock_frontend.NewMockTxnOperator(ctrl)
cnt := 0
txnOperator.EXPECT().Commit(ctx).DoAndReturn(
txnOperator.EXPECT().Commit(gomock.Any()).DoAndReturn(
func(context.Context) error {
cnt++
if cnt%2 != 0 {
......@@ -84,6 +87,9 @@ func TestTxnHandler_CommitTxn(t *testing.T) {
txnClient.EXPECT().New().Return(txnOperator, nil).AnyTimes()
txn := InitTxnHandler(eng, txnClient)
txn.ses = &Session{
requestCtx: ctx,
}
err := txn.NewTxn()
convey.So(err, convey.ShouldBeNil)
err = txn.CommitTxn()
......@@ -103,7 +109,7 @@ func TestTxnHandler_RollbackTxn(t *testing.T) {
ctx := context.TODO()
txnOperator := mock_frontend.NewMockTxnOperator(ctrl)
cnt := 0
txnOperator.EXPECT().Rollback(ctx).DoAndReturn(
txnOperator.EXPECT().Rollback(gomock.Any()).DoAndReturn(
func(ctc context.Context) error {
cnt++
if cnt%2 != 0 {
......@@ -119,6 +125,9 @@ func TestTxnHandler_RollbackTxn(t *testing.T) {
txnClient.EXPECT().New().Return(txnOperator, nil).AnyTimes()
txn := InitTxnHandler(eng, txnClient)
txn.ses = &Session{
requestCtx: ctx,
}
err := txn.NewTxn()
convey.So(err, convey.ShouldBeNil)
err = txn.RollbackTxn()
......@@ -142,7 +151,9 @@ func TestSession_TxnBegin(t *testing.T) {
proto := NewMysqlClientProtocol(0, ioses, 1024, sv)
txnClient := mock_frontend.NewMockTxnClient(ctrl)
txnClient.EXPECT().New().AnyTimes()
return NewSession(proto, nil, nil, config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, txnClient, nil), gSysVars)
session := NewSession(proto, nil, nil, config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, txnClient, nil), gSysVars)
session.SetRequestContext(context.Background())
return session
}
convey.Convey("new session", t, func() {
ctrl := gomock.NewController(t)
......@@ -184,7 +195,9 @@ func TestVariables(t *testing.T) {
proto := NewMysqlClientProtocol(0, ioses, 1024, sv)
txnClient := mock_frontend.NewMockTxnClient(ctrl)
txnClient.EXPECT().New().AnyTimes()
return NewSession(proto, nil, nil, config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, txnClient, nil), gSysVars)
session := NewSession(proto, nil, nil, config.NewParameterUnit(&config.FrontendParameters{}, nil, nil, nil, txnClient, nil), gSysVars)
session.SetRequestContext(context.Background())
return session
}
checkWant := func(ses, existSes, newSesAfterSession *Session,
......@@ -456,7 +469,9 @@ func TestSession_TxnCompilerContext(t *testing.T) {
t.Error(err)
}
proto := NewMysqlClientProtocol(0, ioses, 1024, sv)
return NewSession(proto, nil, nil, pu, gSysVars)
session := NewSession(proto, nil, nil, pu, gSysVars)
session.SetRequestContext(context.Background())
return session
}
convey.Convey("test", t, func() {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment