Skip to content
Snippets Groups Projects
Unverified Commit 19d6c553 authored by daviszhen's avatar daviszhen Committed by GitHub
Browse files

connect the context.Context in modules (frontend,metric and trace) (#4494)

In previous design, there is not context.Context going through the frontend and the computation engine.
In 0.6, the computation engine introduces the context.Context.
So, I add the context.Context in the frontend and try to connect context.Context with modules -- metric and trace.
I am not sure the usage of the conext.Context in these modules is right. So I propose the PR.

The hierarchy of the context.Context:
```
rootCtx(context.Background)
|
|
WithCancel()
|
|
\|/
cancelMoServerCtx
|
|
WithValue(configuration)
|
|
moServerCtx
|
|
----------------------------> context in metric, trace
|
|
\|/
cancelRoutineCtx (every network connection has an independent context)
|
|
------> WithCancel----->AuthenticateCtx (login authentication -- user, password, role)
|
|
cancelRequestCtx (every request from the client has a new context)
|
|
-------> WithCancel ------> LoadBatchCtx ( every batch has a independent context)
```

Approved by: @fengttt, @yingfeng, @aptend
parent eb4884a2
No related branches found
No related tags found
No related merge requests found
Showing
with 322 additions and 208 deletions
......@@ -61,13 +61,15 @@ var (
MoVersion = ""
)
func createMOServer() {
func createMOServer(inputCtx context.Context) {
address := fmt.Sprintf("%s:%d", config.GlobalSystemVariables.GetHost(), config.GlobalSystemVariables.GetPort())
pu := config.NewParameterUnit(&config.GlobalSystemVariables, config.HostMmu, config.Mempool, config.StorageEngine, config.ClusterNodes)
mo = frontend.NewMOServer(address, pu)
moServerCtx := context.WithValue(inputCtx, config.ParameterUnitKey, pu)
mo = frontend.NewMOServer(moServerCtx, address, pu)
{
// init trace/log/error framework
if _, err := trace.Init(context.Background(),
if _, err := trace.Init(moServerCtx,
trace.WithMOVersion(MoVersion),
trace.WithNode(0, trace.NodeTypeNode),
trace.EnableTracer(config.GlobalSystemVariables.GetEnableTrace()),
......@@ -80,11 +82,12 @@ func createMOServer() {
panic(err)
}
}
if config.GlobalSystemVariables.GetEnableMetric() {
ieFactory := func() ie.InternalExecutor {
return frontend.NewInternalExecutor(pu)
}
metric.InitMetric(ieFactory, pu, 0, metric.ALL_IN_ONE_MODE)
metric.InitMetric(moServerCtx, ieFactory, pu, 0, metric.ALL_IN_ONE_MODE)
}
frontend.InitServerVersion(MoVersion)
}
......@@ -199,11 +202,14 @@ func main() {
logutil.SetupMOLogger(&logConf)
rootCtx := context.Background()
cancelMoServerCtx, cancelMoServerFunc := context.WithCancel(rootCtx)
//just initialize the tae after configuration has been loaded
if len(args) == 2 && args[1] == "initdb" {
fmt.Println("Initialize the TAE engine ...")
taeWrapper := initTae()
err := frontend.InitDB(taeWrapper.eng)
err := frontend.InitDB(cancelMoServerCtx, taeWrapper.eng)
if err != nil {
logutil.Infof("Initialize catalog failed. error:%v", err)
os.Exit(InitCatalogExit)
......@@ -233,7 +239,7 @@ func main() {
if engineName == "tae" {
fmt.Println("Initialize the TAE engine ...")
tae = initTae()
err := frontend.InitDB(tae.eng)
err := frontend.InitDB(cancelMoServerCtx, tae.eng)
if err != nil {
logutil.Infof("Initialize catalog failed. error:%v", err)
os.Exit(InitCatalogExit)
......@@ -249,7 +255,7 @@ func main() {
os.Exit(StartMOExit)
}
createMOServer()
createMOServer(cancelMoServerCtx)
err := runMOServer()
if err != nil {
......@@ -266,6 +272,9 @@ func main() {
agent.Close()
//cancel mo server
cancelMoServerFunc()
cleanup()
if engineName == "tae" {
......
......@@ -24,7 +24,7 @@ require (
github.com/lni/dragonboat/v4 v4.0.0-20220803152440-a83f853de8b1
github.com/lni/goutils v1.3.1-0.20220604063047-388d67b4dbc4
github.com/matrixorigin/matrixcube v0.3.1-0.20220606032431-c944d801f1e5
github.com/matrixorigin/simdcsv v0.0.0-20210926114300-591bf748a770
github.com/matrixorigin/simdcsv v0.0.0-20220815110641-d4a2c7888f52
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/panjf2000/ants/v2 v2.4.6
github.com/pierrec/lz4/v4 v4.1.14
......
......@@ -15,11 +15,18 @@
package config
import (
"context"
"github.com/matrixorigin/matrixone/pkg/vm/engine"
"github.com/matrixorigin/matrixone/pkg/vm/mempool"
"github.com/matrixorigin/matrixone/pkg/vm/mmu/host"
)
type ConfigurationKeyType int
const (
ParameterUnitKey ConfigurationKeyType = 1
)
var GlobalSystemVariables SystemVariables
// HostMmu host memory
......@@ -59,3 +66,12 @@ func NewParameterUnit(sv *SystemVariables, hostMmu *host.Mmu, mempool *mempool.M
ClusterNodes: clusterNodes,
}
}
// GetParameterUnit gets the configuration from the context.
func GetParameterUnit(ctx context.Context) *ParameterUnit {
pu := ctx.Value(ParameterUnitKey).(*ParameterUnit)
if pu == nil {
panic("parameter unit is invalid")
}
return pu
}
......@@ -14,12 +14,14 @@
package frontend
import "context"
// CmdExecutor handle the command from the client
type CmdExecutor interface {
PrepareSessionBeforeExecRequest(*Session)
// ExecRequest execute the request and get the response
ExecRequest(req *Request) (*Response, error)
ExecRequest(context.Context, *Request) (*Response, error)
Close()
}
......
......@@ -604,7 +604,7 @@ func FillInitialDataForMoUser() *batch.Batch {
}
// InitDB setups the initial catalog tables in tae
func InitDB(tae engine.Engine) error {
func InitDB(ctx context.Context, tae engine.Engine) error {
taeEngine, ok := tae.(moengine.TxnEngine)
if !ok {
return errorIsNotTaeEngine
......@@ -632,7 +632,6 @@ func InitDB(tae engine.Engine) error {
// return err
//}
ctx := context.TODO()
catalogDB, err := tae.Database(ctx, catalogDbName, engine.Snapshot(txnCtx.GetCtx()))
if err != nil {
logutil.Infof("get database %v failed.error:%v", catalogDbName, err)
......@@ -756,11 +755,11 @@ func InitDB(tae engine.Engine) error {
return err
}
return sanityCheck(tae)
return sanityCheck(ctx, tae)
}
// sanityCheck checks the catalog is ready or not
func sanityCheck(tae engine.Engine) error {
func sanityCheck(ctx context.Context, tae engine.Engine) error {
taeEngine, ok := tae.(moengine.TxnEngine)
if !ok {
return errorIsNotTaeEngine
......@@ -770,7 +769,6 @@ func sanityCheck(tae engine.Engine) error {
if err != nil {
return err
}
ctx := context.TODO()
// databases: mo_catalog,information_schema
dbs, err := tae.Databases(ctx, engine.Snapshot(txnCtx.GetCtx()))
if err != nil {
......@@ -792,7 +790,7 @@ func sanityCheck(tae engine.Engine) error {
DefineSchemaForMoUser(),
}
catalogDbName := "mo_catalog"
err = isWantedDatabase(taeEngine, txnCtx, catalogDbName, wantTablesOfMoCatalog, wantSchemasOfCatalog)
err = isWantedDatabase(ctx, taeEngine, txnCtx, catalogDbName, wantTablesOfMoCatalog, wantSchemasOfCatalog)
if err != nil {
return err
}
......@@ -821,9 +819,7 @@ func isWanted(want, actual []string) bool {
}
// isWantedDatabase checks the database has the right tables
func isWantedDatabase(taeEngine moengine.TxnEngine, txnCtx moengine.Txn,
dbName string, tables []string, schemas []*CatalogSchema) error {
ctx := context.TODO()
func isWantedDatabase(ctx context.Context, taeEngine moengine.TxnEngine, txnCtx moengine.Txn, dbName string, tables []string, schemas []*CatalogSchema) error {
db, err := taeEngine.Database(ctx, dbName, engine.Snapshot(txnCtx.GetCtx()))
if err != nil {
logutil.Infof("get database %v failed.error:%v", dbName, err)
......@@ -846,7 +842,7 @@ func isWantedDatabase(taeEngine moengine.TxnEngine, txnCtx moengine.Txn,
//TODO:fix it after tae is ready
//check table attributes
for i, tableName := range tables {
err = isWantedTable(db, txnCtx, tableName, schemas[i])
err = isWantedTable(ctx, db, txnCtx, tableName, schemas[i])
if err != nil {
return err
}
......@@ -856,9 +852,7 @@ func isWantedDatabase(taeEngine moengine.TxnEngine, txnCtx moengine.Txn,
}
// isWantedTable checks the table has the right attributes
func isWantedTable(db engine.Database, txnCtx moengine.Txn,
tableName string, schema *CatalogSchema) error {
ctx := context.TODO()
func isWantedTable(ctx context.Context, db engine.Database, txnCtx moengine.Txn, tableName string, schema *CatalogSchema) error {
table, err := db.Relation(ctx, tableName)
if err != nil {
logutil.Infof("get table %v failed.error:%v", tableName, err)
......
......@@ -15,6 +15,7 @@
package frontend
import (
"context"
"sync"
"github.com/matrixorigin/matrixone/pkg/config"
......@@ -37,7 +38,7 @@ func applyOverride(sess *Session, opts ie.SessionOverrideOptions) {
}
type internalMiniExec interface {
doComQuery(string) error
doComQuery(requestCtx context.Context, sql string) error
PrepareSessionBeforeExecRequest(*Session)
}
......@@ -113,22 +114,22 @@ func (res *internalExecResult) StringValueByName(ridx uint64, col string) (strin
}
}
func (ie *internalExecutor) Exec(sql string, opts ie.SessionOverrideOptions) (err error) {
func (ie *internalExecutor) Exec(ctx context.Context, sql string, opts ie.SessionOverrideOptions) (err error) {
ie.Lock()
defer ie.Unlock()
sess := ie.newCmdSession(opts)
ie.executor.PrepareSessionBeforeExecRequest(sess)
ie.proto.stashResult = false
return ie.executor.doComQuery(sql)
return ie.executor.doComQuery(ctx, sql)
}
func (ie *internalExecutor) Query(sql string, opts ie.SessionOverrideOptions) ie.InternalExecResult {
func (ie *internalExecutor) Query(ctx context.Context, sql string, opts ie.SessionOverrideOptions) ie.InternalExecResult {
ie.Lock()
defer ie.Unlock()
sess := ie.newCmdSession(opts)
ie.executor.PrepareSessionBeforeExecRequest(sess)
ie.proto.stashResult = true
err := ie.executor.doComQuery(sql)
err := ie.executor.doComQuery(ctx, sql)
res := ie.proto.swapOutResult()
res.err = err
return res
......
......@@ -16,6 +16,7 @@ package frontend
import (
"bytes"
"context"
"errors"
"testing"
......@@ -41,7 +42,7 @@ type miniExec struct {
sess *Session
}
func (e *miniExec) doComQuery(string) error {
func (e *miniExec) doComQuery(context.Context, string) error {
_ = e.sess.GetMysqlProtocol()
return nil
}
......@@ -56,9 +57,9 @@ func TestIe(t *testing.T) {
sess := executor.newCmdSession(ie.NewOptsBuilder().Database("mo_catalog").Internal(true).Finish())
assert.Equal(t, "dump", sess.GetMysqlProtocol().GetUserName())
err := executor.Exec("whatever", ie.NewOptsBuilder().Finish())
err := executor.Exec(context.TODO(), "whatever", ie.NewOptsBuilder().Finish())
assert.NoError(t, err)
res := executor.Query("whatever", ie.NewOptsBuilder().Finish())
res := executor.Query(context.TODO(), "whatever", ie.NewOptsBuilder().Finish())
assert.NoError(t, err)
assert.Equal(t, uint64(0), res.RowCount())
}
......
......@@ -114,6 +114,8 @@ type SharePart struct {
//result of load
result *LoadResult
loadCtx context.Context
}
type notifyEventType int
......@@ -173,9 +175,9 @@ type ParseLineHandler struct {
SharePart
DebugTime
threadInfo map[int]*ThreadInfo
simdCsvReader *simdcsv.Reader
closeOnceGetParsedLinesChan sync.Once
threadInfo map[int]*ThreadInfo
simdCsvReader *simdcsv.Reader
//closeOnceGetParsedLinesChan sync.Once
//csv read put lines into the channel
simdCsvGetParsedLinesChan atomic.Value // chan simdcsv.LineOut
//the count of writing routine
......@@ -185,8 +187,6 @@ type ParseLineHandler struct {
simdCsvBatchPool chan *PoolElement
simdCsvNotiyEventChan chan *notifyEvent
closeOnce sync.Once
closeRef *CloseLoadData
}
type WriteBatchHandler struct {
......@@ -198,8 +198,6 @@ type WriteBatchHandler struct {
pl *PoolElement
batchFilled int
simdCsvErr error
closeRef *CloseLoadData
}
type CloseLoadData struct {
......@@ -226,6 +224,50 @@ func getLineOutChan(v atomic.Value) chan simdcsv.LineOut {
return v.Load().(chan simdcsv.LineOut)
}
func (plh *ParseLineHandler) getLineOutCallback(lineOut simdcsv.LineOut) error {
wait_a := time.Now()
defer func() {
AtomicAddDuration(plh.asyncChan, time.Since(wait_a))
}()
wait_d := time.Now()
if lineOut.Line == nil && lineOut.Lines == nil {
return nil
}
if lineOut.Line != nil {
//step 1 : skip dropped lines
if plh.lineCount < plh.load.IgnoredLines {
plh.lineCount++
return nil
}
wait_b := time.Now()
//step 2 : append line into line array
plh.simdCsvLineArray[plh.lineIdx] = lineOut.Line
plh.lineIdx++
plh.lineCount++
plh.maxFieldCnt = Max(plh.maxFieldCnt, len(lineOut.Line))
AtomicAddDuration(plh.csvLineArray1, time.Since(wait_b))
if plh.lineIdx == plh.batchSize {
//logutil.Infof("+++++ batch bytes %v B %v MB",plh.bytes,plh.bytes / 1024.0 / 1024.0)
err := saveLinesToStorage(plh, false)
if err != nil {
return err
}
plh.lineIdx = 0
plh.maxFieldCnt = 0
plh.bytes = 0
}
}
AtomicAddDuration(plh.asyncChanLoop, time.Since(wait_d))
return nil
}
func (plh *ParseLineHandler) getLineOutFromSimdCsvRoutine() error {
wait_a := time.Now()
defer func() {
......@@ -233,20 +275,28 @@ func (plh *ParseLineHandler) getLineOutFromSimdCsvRoutine() error {
}()
var lineOut simdcsv.LineOut
var status bool
for {
fmt.Println("11111111")
quit := false
select {
case <-plh.closeRef.stopLoadData:
case <-plh.loadCtx.Done():
logutil.Infof("----- get stop in getLineOutFromSimdCsvRoutine")
quit = true
case lineOut = <-getLineOutChan(plh.simdCsvGetParsedLinesChan):
case lineOut, status = <-getLineOutChan(plh.simdCsvGetParsedLinesChan):
if !status {
quit = true
}
}
fmt.Println("xxxxxx")
fmt.Println(lineOut)
if quit {
break
}
wait_d := time.Now()
if lineOut.Line == nil && lineOut.Lines == nil {
fmt.Println("tttttttttt")
break
}
if lineOut.Line != nil {
......@@ -308,15 +358,14 @@ func AtomicAddDuration(v atomic.Value, t interface{}) {
}
func (plh *ParseLineHandler) close() {
plh.closeOnceGetParsedLinesChan.Do(func() {
close(getLineOutChan(plh.simdCsvGetParsedLinesChan))
})
//plh.closeOnceGetParsedLinesChan.Do(func() {
// close(getLineOutChan(plh.simdCsvGetParsedLinesChan))
//})
plh.closeOnce.Do(func() {
close(plh.simdCsvBatchPool)
close(plh.simdCsvNotiyEventChan)
plh.simdCsvReader.Close()
})
plh.closeRef.Close()
}
/*
......@@ -388,13 +437,12 @@ func makeBatch(handler *ParseLineHandler, id int) *PoolElement {
/*
Init ParseLineHandler
*/
func initParseLineHandler(handler *ParseLineHandler) error {
func initParseLineHandler(requestCtx context.Context, handler *ParseLineHandler) error {
relation := handler.tableHandler
load := handler.load
var cols []*engine.AttributeDef = nil
ctx := context.TODO()
defs, err := relation.TableDefs(ctx)
defs, err := relation.TableDefs(requestCtx)
if err != nil {
return err
}
......@@ -496,9 +544,9 @@ func initWriteBatchHandler(handler *ParseLineHandler, wHandler *WriteBatchHandle
wHandler.oneTxnPerBatch = handler.oneTxnPerBatch
wHandler.timestamp = handler.timestamp
wHandler.result = &LoadResult{}
wHandler.closeRef = handler.closeRef
wHandler.lineCount = handler.lineCount
wHandler.skipWriteBatch = handler.skipWriteBatch
wHandler.loadCtx = handler.loadCtx
wHandler.pl = allocBatch(handler)
wHandler.ThreadInfo = handler.threadInfo[wHandler.pl.id]
......@@ -1662,7 +1710,7 @@ when force is true, batchsize will be changed.
func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
var err error = nil
ctx := context.TODO()
ctx := handler.loadCtx
if handler.batchFilled == handler.batchSize {
//batchBytes := 0
//for _, vec := range handler.batchData.Vecs {
......@@ -1686,7 +1734,8 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
var txnHandler *TxnHandler
tableHandler := handler.tableHandler
initSes := handler.ses
tmpSes := NewSession(initSes.GetMysqlProtocol(), initSes.GuestMmu, initSes.Mempool, initSes.Pu, gSysVariables)
tmpSes := NewBackgroundSession(ctx, initSes.GuestMmu, initSes.Mempool, initSes.Pu, gSysVariables)
defer tmpSes.Close()
if !handler.skipWriteBatch {
if handler.oneTxnPerBatch {
txnHandler = tmpSes.GetTxnHandler()
......@@ -1832,7 +1881,8 @@ func writeBatchToStorage(handler *WriteBatchHandler, force bool) error {
tableHandler := handler.tableHandler
// dbHandler := handler.dbHandler
initSes := handler.ses
tmpSes := NewSession(initSes.GetMysqlProtocol(), initSes.GuestMmu, initSes.Mempool, initSes.Pu, gSysVariables)
tmpSes := NewBackgroundSession(ctx, initSes.GuestMmu, initSes.Mempool, initSes.Pu, gSysVariables)
defer tmpSes.Close()
var dbHandler engine.Database
if !handler.skipWriteBatch {
if handler.oneTxnPerBatch {
......@@ -1947,10 +1997,9 @@ func PrintThreadInfo(handler *ParseLineHandler, close *CloseFlag, a time.Duratio
/*
LoadLoop reads data from stream, extracts the fields, and saves into the table
*/
func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database, tableHandler engine.Relation, dbName string) (*LoadResult, error) {
func (mce *MysqlCmdExecutor) LoadLoop(requestCtx context.Context, load *tree.Load, dbHandler engine.Database, tableHandler engine.Relation, dbName string) (*LoadResult, error) {
ses := mce.GetSession()
var m sync.Mutex
//begin:= time.Now()
//defer func() {
// logutil.Infof("-----load loop exit %s",time.Since(begin))
......@@ -1996,6 +2045,7 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
batchSize: curBatchSize,
result: result,
skipWriteBatch: ses.Pu.SV.GetLoadDataSkipWritingBatch(),
loadCtx: requestCtx,
},
threadInfo: make(map[int]*ThreadInfo),
simdCsvGetParsedLinesChan: atomic.Value{},
......@@ -2026,14 +2076,6 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
notifyChanSize := handler.simdCsvConcurrencyCountOfWriteBatch * 2
notifyChanSize = Max(100, notifyChanSize)
/*
make close reference
*/
handler.closeRef = NewCloseLoadData()
//put closeRef into the executor
mce.loadDataClose = handler.closeRef
handler.simdCsvReader = simdcsv.NewReaderWithOptions(dataFile,
rune(load.Fields.Terminated[0]),
'#',
......@@ -2048,7 +2090,7 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
//release resources of handler
defer handler.close()
err = initParseLineHandler(handler)
err = initParseLineHandler(requestCtx, handler)
if err != nil {
return nil, err
}
......@@ -2067,15 +2109,18 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
read from the output channel of the simdcsv parser, make a batch,
deliver it to async routine writing batch
*/
wg.Add(1)
go func() {
defer wg.Done()
err := handler.getLineOutFromSimdCsvRoutine()
if err != nil {
logutil.Errorf("get line from simdcsv failed. err:%v", err)
handler.simdCsvNotiyEventChan <- newNotifyEvent(NOTIFY_EVENT_OUTPUT_SIMDCSV_ERROR, err, nil)
}
}()
//wg.Add(1)
//go func() {
// defer wg.Done()
// fmt.Println("ccccccccc")
// //TODO:remove it
// err := handler.getLineOutFromSimdCsvRoutine()
// fmt.Println("ddddddddd")
// if err != nil {
// logutil.Errorf("get line from simdcsv failed. err:%v", err)
// handler.simdCsvNotiyEventChan <- newNotifyEvent(NOTIFY_EVENT_OUTPUT_SIMDCSV_ERROR, err, nil)
// }
//}()
/*
get lines from simdcsv, deliver them to the output channel.
......@@ -2084,10 +2129,17 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
go func() {
defer wg.Done()
wait_b := time.Now()
m.Lock()
defer m.Unlock()
err := handler.simdCsvReader.ReadLoop(getLineOutChan(handler.simdCsvGetParsedLinesChan))
fmt.Println("aaaaaaaaaa")
//TODO: add a output callback
//TODO: remove the channel
err = handler.simdCsvReader.ReadLoop(requestCtx, nil, handler.getLineOutCallback)
//last batch
err = saveLinesToStorage(handler, true)
if err != nil {
logutil.Errorf("get line from simdcsv failed. err:%v", err)
handler.simdCsvNotiyEventChan <- newNotifyEvent(NOTIFY_EVENT_OUTPUT_SIMDCSV_ERROR, err, nil)
}
fmt.Println("bbbbbbbbbb")
if err != nil {
handler.simdCsvNotiyEventChan <- newNotifyEvent(NOTIFY_EVENT_READ_SIMDCSV_ERROR, err, nil)
}
......@@ -2108,12 +2160,10 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
for {
quit := false
select {
case <-handler.closeRef.stopLoadData:
//get obvious cancel
case <-requestCtx.Done():
logutil.Info("cancel the load")
retErr = NewMysqlError(ER_QUERY_INTERRUPTED)
quit = true
//logutil.Infof("----- get stop in load ")
case ne = <-handler.simdCsvNotiyEventChan:
switch ne.neType {
case NOTIFY_EVENT_WRITE_BATCH_RESULT:
......@@ -2136,13 +2186,8 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
}
if quit {
//
handler.simdCsvReader.Close()
handler.closeOnceGetParsedLinesChan.Do(func() {
m.Lock()
defer m.Unlock()
close(getLineOutChan(handler.simdCsvGetParsedLinesChan))
})
go func() {
for closechannel.IsOpened() {
select {
......@@ -2179,6 +2224,5 @@ func (mce *MysqlCmdExecutor) LoadLoop(load *tree.Load, dbHandler engine.Database
statsWg.Wait()
close.Close()
closechannel.Close()
return result, retErr
}
......@@ -18,6 +18,11 @@ import (
"context"
"errors"
"fmt"
"github.com/matrixorigin/matrixone/pkg/config"
"github.com/matrixorigin/matrixone/pkg/sql/parsers"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
"github.com/matrixorigin/matrixone/pkg/vm/engine"
"github.com/matrixorigin/matrixone/pkg/vm/mmu/guest"
"os"
"sync/atomic"
"testing"
......@@ -25,17 +30,12 @@ import (
"github.com/fagongzi/goetty/v2/buf"
"github.com/golang/mock/gomock"
"github.com/matrixorigin/matrixone/pkg/config"
"github.com/matrixorigin/matrixone/pkg/container/batch"
"github.com/matrixorigin/matrixone/pkg/container/nulls"
"github.com/matrixorigin/matrixone/pkg/container/types"
"github.com/matrixorigin/matrixone/pkg/container/vector"
mock_frontend "github.com/matrixorigin/matrixone/pkg/frontend/test"
"github.com/matrixorigin/matrixone/pkg/sql/parsers"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
"github.com/matrixorigin/matrixone/pkg/vm/engine"
"github.com/matrixorigin/matrixone/pkg/vm/mmu/guest"
"github.com/matrixorigin/simdcsv"
"github.com/prashantv/gostub"
"github.com/smartystreets/goconvey/convey"
......@@ -230,7 +230,7 @@ func Test_load(t *testing.T) {
data: []byte("test anywhere"),
}
resp, err := mce.ExecRequest(req)
resp, err := mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldBeNil)
})
......@@ -434,7 +434,7 @@ func Test_load(t *testing.T) {
row2col = gostub.Stub(&row2colChoose, false)
}
_, err := mce.LoadLoop(cws[i], db, rel, "T")
_, err := mce.LoadLoop(context.TODO(), cws[i], db, rel, "T")
if kases[i].fail {
convey.So(err, convey.ShouldBeError)
} else {
......@@ -458,20 +458,23 @@ func getParsedLinesChan(simdCsvGetParsedLinesChan chan simdcsv.LineOut) {
func Test_getLineOutFromSimdCsvRoutine(t *testing.T) {
convey.Convey("getLineOutFromSimdCsvRoutine succ", t, func() {
handler := &ParseLineHandler{
closeRef: &CloseLoadData{stopLoadData: make(chan interface{}, 1)},
simdCsvGetParsedLinesChan: atomic.Value{},
SharePart: SharePart{
load: &tree.Load{IgnoredLines: 1},
simdCsvLineArray: make([][]string, 100)},
}
handler.simdCsvGetParsedLinesChan.Store(make(chan simdcsv.LineOut, 100))
handler.closeRef.stopLoadData <- 1
gostub.StubFunc(&saveLinesToStorage, nil)
var cancel context.CancelFunc
handler.loadCtx, cancel = context.WithTimeout(context.TODO(), time.Second)
convey.So(handler.getLineOutFromSimdCsvRoutine(), convey.ShouldBeNil)
cancel()
handler.closeRef.stopLoadData <- 1
gostub.StubFunc(&saveLinesToStorage, errors.New("1"))
handler.loadCtx, cancel = context.WithTimeout(context.TODO(), time.Second)
convey.So(handler.getLineOutFromSimdCsvRoutine(), convey.ShouldNotBeNil)
cancel()
getParsedLinesChan(getLineOutChan(handler.simdCsvGetParsedLinesChan))
stubs := gostub.StubFunc(&saveLinesToStorage, nil)
......
......@@ -24,7 +24,6 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/matrixorigin/matrixone/pkg/common/moerr"
......@@ -89,22 +88,14 @@ type MysqlCmdExecutor struct {
//the count of sql has been processed
sqlCount uint64
//for load data closing
loadDataClose *CloseLoadData
//for export data closing
exportDataClose *CloseExportData
ses *Session
sessionRWLock sync.RWMutex
routineMgr *RoutineManager
cancelRequestFunc context.CancelFunc
}
func (mce *MysqlCmdExecutor) PrepareSessionBeforeExecRequest(ses *Session) {
mce.sessionRWLock.Lock()
defer mce.sessionRWLock.Unlock()
mce.ses = ses
}
......@@ -490,7 +481,7 @@ func getDataFromPipeline(obj interface{}, bat *batch.Batch) error {
for j := 0; j < n; j++ { //row index
if oq.ep.Outfile {
select {
case <-ses.closeRef.stopExportData:
case <-ses.requestCtx.Done():
{
return nil
}
......@@ -849,13 +840,12 @@ func extractRowFromVector(vec *vector.Vector, i int, row []interface{}, rowIndex
return nil
}
func (mce *MysqlCmdExecutor) handleChangeDB(db string) error {
func (mce *MysqlCmdExecutor) handleChangeDB(requestCtx context.Context, db string) error {
ses := mce.GetSession()
txnHandler := ses.GetTxnHandler()
txnCtx := txnHandler.GetTxn().GetCtx()
//TODO: check meta data
ctx := context.TODO()
if _, err := ses.Pu.StorageEngine.Database(ctx, db, engine.Snapshot(txnCtx)); err != nil {
if _, err := ses.Pu.StorageEngine.Database(requestCtx, db, engine.Snapshot(txnCtx)); err != nil {
//echo client. no such database
return NewMysqlError(ER_BAD_DB_ERROR, db)
}
......@@ -918,7 +908,7 @@ func (mce *MysqlCmdExecutor) handleSelectVariables(ve *tree.VarExpr) error {
/*
handle Load DataSource statement
*/
func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
func (mce *MysqlCmdExecutor) handleLoadData(requestCtx context.Context, load *tree.Load) error {
var err error
ses := mce.GetSession()
proto := ses.protocol
......@@ -965,12 +955,11 @@ func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
loadDb = ses.protocol.GetDatabaseName()
}
ctx := context.TODO()
txnHandler := ses.GetTxnHandler()
if ses.InMultiStmtTransactionMode() {
return fmt.Errorf("do not support the Load in a transaction started by BEGIN/START TRANSACTION statement")
}
dbHandler, err := ses.GetStorage().Database(ctx, loadDb, engine.Snapshot(txnHandler.GetTxn().GetCtx()))
dbHandler, err := ses.GetStorage().Database(requestCtx, loadDb, engine.Snapshot(txnHandler.GetTxn().GetCtx()))
if err != nil {
//echo client. no such database
return NewMysqlError(ER_BAD_DB_ERROR, loadDb)
......@@ -986,7 +975,7 @@ func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
/*
check table
*/
tableHandler, err := dbHandler.Relation(ctx, loadTable)
tableHandler, err := dbHandler.Relation(requestCtx, loadTable)
if err != nil {
//echo client. no such table
return NewMysqlError(ER_NO_SUCH_TABLE, loadDb, loadTable)
......@@ -995,7 +984,7 @@ func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
/*
execute load data
*/
result, err := mce.LoadLoop(load, dbHandler, tableHandler, loadDb)
result, err := mce.LoadLoop(requestCtx, load, dbHandler, tableHandler, loadDb)
if err != nil {
return err
}
......@@ -1014,14 +1003,12 @@ func (mce *MysqlCmdExecutor) handleLoadData(load *tree.Load) error {
/*
handle cmd CMD_FIELD_LIST
*/
func (mce *MysqlCmdExecutor) handleCmdFieldList(icfl *InternalCmdFieldList) error {
func (mce *MysqlCmdExecutor) handleCmdFieldList(requestCtx context.Context, icfl *InternalCmdFieldList) error {
var err error
ses := mce.GetSession()
proto := ses.GetMysqlProtocol()
tableName := icfl.tableName
ctx := context.TODO()
dbName := ses.GetDatabaseName()
if dbName == "" {
return NewMysqlError(ER_NO_DB_ERROR)
......@@ -1035,22 +1022,22 @@ func (mce *MysqlCmdExecutor) handleCmdFieldList(icfl *InternalCmdFieldList) erro
if mce.tableInfos == nil || mce.db != dbName {
txnHandler := ses.GetTxnHandler()
eng := ses.GetStorage()
db, err := eng.Database(ctx, dbName, engine.Snapshot(txnHandler.GetTxn().GetCtx()))
db, err := eng.Database(requestCtx, dbName, engine.Snapshot(txnHandler.GetTxn().GetCtx()))
if err != nil {
return err
}
names, err := db.Relations(ctx)
names, err := db.Relations(requestCtx)
if err != nil {
return err
}
for _, name := range names {
table, err := db.Relation(ctx, name)
table, err := db.Relation(requestCtx, name)
if err != nil {
return err
}
defs, err := table.TableDefs(ctx)
defs, err := table.TableDefs(requestCtx)
if err != nil {
return err
}
......@@ -1270,7 +1257,7 @@ func (mce *MysqlCmdExecutor) handleShowVariables(sv *tree.ShowVariables) error {
return err
}
func (mce *MysqlCmdExecutor) handleAnalyzeStmt(stmt *tree.AnalyzeStmt) error {
func (mce *MysqlCmdExecutor) handleAnalyzeStmt(requestCtx context.Context, stmt *tree.AnalyzeStmt) error {
// rewrite analyzeStmt to `select approx_count_distinct(col), .. from tbl`
// IMO, this approach is simple and future-proof
// Although this rewriting processing could have been handled in rewrite module,
......@@ -1288,7 +1275,7 @@ func (mce *MysqlCmdExecutor) handleAnalyzeStmt(stmt *tree.AnalyzeStmt) error {
ctx.WriteString(" from ")
stmt.Table.Format(ctx)
sql := ctx.String()
return mce.doComQuery(sql)
return mce.doComQuery(requestCtx, sql)
}
// Note: for pass the compile quickly. We will remove the comments in the future.
......@@ -1552,7 +1539,7 @@ func (cwft *TxnComputationWrapper) GetAffectedRows() uint64 {
return cwft.compile.GetAffectedRows()
}
func (cwft *TxnComputationWrapper) Compile(u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
func (cwft *TxnComputationWrapper) Compile(requestCtx context.Context, u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
var err error
cwft.plan, err = buildPlan(cwft.ses.GetTxnCompilerContext(), cwft.stmt)
if err != nil {
......@@ -1609,7 +1596,7 @@ func (cwft *TxnComputationWrapper) Compile(u interface{}, fill func(interface{},
cwft.proc.UnixTime = time.Now().UnixNano()
txnHandler := cwft.ses.GetTxnHandler()
cwft.proc.Snapshot = txnHandler.GetTxn().GetCtx()
cwft.compile = compile.New(cwft.ses.GetDatabaseName(), cwft.ses.GetSql(), cwft.ses.GetUserName(), context.TODO(), cwft.ses.GetStorage(), cwft.proc, cwft.stmt)
cwft.compile = compile.New(cwft.ses.GetDatabaseName(), cwft.ses.GetSql(), cwft.ses.GetUserName(), requestCtx, cwft.ses.GetStorage(), cwft.proc, cwft.stmt)
err = cwft.compile.Compile(cwft.plan, cwft.ses, fill)
if err != nil {
return nil, err
......@@ -1718,7 +1705,7 @@ func (mce *MysqlCmdExecutor) afterRun(stmt tree.Statement, beginInstant time.Tim
}
// execute query
func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
func (mce *MysqlCmdExecutor) doComQuery(requestCtx context.Context, sql string) (retErr error) {
beginInstant := time.Now()
ses := mce.GetSession()
ses.showStmtType = NotShowStatement
......@@ -1822,14 +1809,11 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
switch st := stmt.(type) {
case *tree.Select:
if st.Ep != nil {
mce.exportDataClose = NewCloseExportData()
ses.ep = st.Ep
ses.closeRef = mce.exportDataClose
}
}
selfHandle = false
ses.GetTxnCompileCtx().SetQueryType(TXN_DEFAULT)
switch st := stmt.(type) {
......@@ -1841,7 +1825,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
}
case *tree.Use:
selfHandle = true
err = mce.handleChangeDB(st.Name)
err = mce.handleChangeDB(requestCtx, st.Name)
if err != nil {
goto handleFailed
}
......@@ -1857,7 +1841,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
case *tree.Load:
fromLoadData = true
selfHandle = true
err = mce.handleLoadData(st)
err = mce.handleLoadData(requestCtx, st)
if err != nil {
goto handleFailed
}
......@@ -1895,7 +1879,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
}
case *tree.AnalyzeStmt:
selfHandle = true
if err = mce.handleAnalyzeStmt(st); err != nil {
if err = mce.handleAnalyzeStmt(requestCtx, st); err != nil {
goto handleFailed
}
case *tree.ExplainStmt:
......@@ -1921,7 +1905,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
ses.GetTxnCompileCtx().SetQueryType(TXN_UPDATE)
case *InternalCmdFieldList:
selfHandle = true
if err = mce.handleCmdFieldList(st); err != nil {
if err = mce.handleCmdFieldList(requestCtx, st); err != nil {
goto handleFailed
}
}
......@@ -1935,7 +1919,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
cmpBegin = time.Now()
if ret, err = cw.Compile(ses, ses.outputCallback); err != nil {
if ret, err = cw.Compile(requestCtx, ses, ses.outputCallback); err != nil {
goto handleFailed
}
......@@ -2112,7 +2096,7 @@ func (mce *MysqlCmdExecutor) doComQuery(sql string) (retErr error) {
}
// ExecRequest the server execute the commands from the client following the mysql's routine
func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err error) {
func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, req *Request) (resp *Response, err error) {
defer func() {
if e := recover(); e != nil {
err = moerr.NewPanicError(e)
......@@ -2144,6 +2128,14 @@ func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err erro
if strings.ToLower(seps[0]) == "kill" {
//last one is processID
/*
The 'kill query xxx' is processed in an independent connection.
When a 'Ctrl+C' is received from the user in mysql client shell,
an independent connection is established and the 'kill query xxx'
is sent to the server. The server cancels the 'query xxx' after it
receives the 'kill query xxx'. The server responses the OK.
Then, the client quit this connection.
*/
procIdStr := seps[len(seps)-1]
procID, err := strconv.ParseUint(procIdStr, 10, 64)
if err != nil {
......@@ -2159,7 +2151,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err erro
return resp, nil
}
err := mce.doComQuery(query)
err := mce.doComQuery(requestCtx, query)
if err != nil {
resp = NewGeneralErrorResponse(COM_QUERY, err)
}
......@@ -2168,7 +2160,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err erro
var dbname = string(req.GetData().([]byte))
mce.addSqlCount(1)
query := "use `" + dbname + "`"
err := mce.doComQuery(query)
err := mce.doComQuery(requestCtx, query)
if err != nil {
resp = NewGeneralErrorResponse(COM_INIT_DB, err)
}
......@@ -2178,7 +2170,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err erro
var payload = string(req.GetData().([]byte))
mce.addSqlCount(1)
query := makeCmdFieldListSql(payload)
err := mce.doComQuery(query)
err := mce.doComQuery(requestCtx, query)
if err != nil {
resp = NewGeneralErrorResponse(COM_FIELD_LIST, err)
}
......@@ -2195,18 +2187,23 @@ func (mce *MysqlCmdExecutor) ExecRequest(req *Request) (resp *Response, err erro
return resp, nil
}
func (mce *MysqlCmdExecutor) setCancelRequestFunc(cancelFunc context.CancelFunc) {
mce.cancelRequestFunc = cancelFunc
}
func (mce *MysqlCmdExecutor) getCancelRequestFunc() context.CancelFunc {
return mce.cancelRequestFunc
}
func (mce *MysqlCmdExecutor) Close() {
//logutil.Infof("close executor")
if mce.loadDataClose != nil {
//logutil.Infof("close process load data")
mce.loadDataClose.Close()
cancelRequestFunc := mce.getCancelRequestFunc()
if cancelRequestFunc != nil {
cancelRequestFunc()
}
if mce.exportDataClose != nil {
mce.exportDataClose.Close()
}
mce.sessionRWLock.Lock()
defer mce.sessionRWLock.Unlock()
err := mce.ses.TxnRollback()
fmt.Println("----close mce")
ses := mce.GetSession()
err := ses.TxnRollback()
if err != nil {
logutil.Errorf("rollback txn in mce.Close failed.error:%v", err)
}
......
......@@ -207,7 +207,7 @@ func Test_mce(t *testing.T) {
data: []byte("test anywhere"),
}
resp, err := mce.ExecRequest(req)
resp, err := mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldBeNil)
......@@ -215,7 +215,7 @@ func Test_mce(t *testing.T) {
cmd: int(COM_QUERY),
data: []byte("kill"),
}
resp, err = mce.ExecRequest(req)
resp, err = mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldNotBeNil)
......@@ -224,7 +224,7 @@ func Test_mce(t *testing.T) {
data: []byte("kill 10"),
}
mce.SetRoutineManager(&RoutineManager{})
resp, err = mce.ExecRequest(req)
resp, err = mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldNotBeNil)
......@@ -233,7 +233,7 @@ func Test_mce(t *testing.T) {
data: []byte("test anywhere"),
}
_, err = mce.ExecRequest(req)
_, err = mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
//COM_INIT_DB replaced by changeDB()
//convey.So(resp.category, convey.ShouldEqual, OkResponse)
......@@ -243,7 +243,7 @@ func Test_mce(t *testing.T) {
data: []byte("test anywhere"),
}
resp, err = mce.ExecRequest(req)
resp, err = mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp.category, convey.ShouldEqual, OkResponse)
......@@ -252,7 +252,7 @@ func Test_mce(t *testing.T) {
data: []byte("test anywhere"),
}
resp, err = mce.ExecRequest(req)
resp, err = mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldBeNil)
......@@ -314,11 +314,11 @@ func Test_mce_selfhandle(t *testing.T) {
mce := NewMysqlCmdExecutor()
mce.PrepareSessionBeforeExecRequest(ses)
err = mce.handleChangeDB("T")
err = mce.handleChangeDB(ctx, "T")
convey.So(err, convey.ShouldBeNil)
convey.So(ses.protocol.GetDatabaseName(), convey.ShouldEqual, "T")
err = mce.handleChangeDB("T")
err = mce.handleChangeDB(ctx, "T")
convey.So(err, convey.ShouldBeError)
})
......@@ -390,7 +390,7 @@ func Test_mce_selfhandle(t *testing.T) {
query := string(queryData)
cflStmt, err := parseCmdFieldList(makeCmdFieldListSql(query))
convey.So(err, convey.ShouldBeNil)
err = mce.handleCmdFieldList(cflStmt)
err = mce.handleCmdFieldList(ctx, cflStmt)
convey.So(err, convey.ShouldBeError)
ses.Mrs = &MysqlResultSet{}
......@@ -401,11 +401,11 @@ func Test_mce_selfhandle(t *testing.T) {
typ: types.Type{Oid: types.T_varchar},
}}
err = mce.handleCmdFieldList(cflStmt)
err = mce.handleCmdFieldList(ctx, cflStmt)
convey.So(err, convey.ShouldBeNil)
mce.db = ses.protocol.GetDatabaseName()
err = mce.handleCmdFieldList(cflStmt)
err = mce.handleCmdFieldList(ctx, cflStmt)
convey.So(err, convey.ShouldBeNil)
set := "set @@tx_isolation=`READ-COMMITTED`"
......@@ -420,7 +420,7 @@ func Test_mce_selfhandle(t *testing.T) {
data: []byte{'A', 0},
}
resp, err := mce.ExecRequest(req)
resp, err := mce.ExecRequest(ctx, req)
convey.So(err, convey.ShouldBeNil)
convey.So(resp, convey.ShouldBeNil)
})
......@@ -1004,7 +1004,7 @@ func Test_CMD_FIELD_LIST(t *testing.T) {
mce := &MysqlCmdExecutor{}
mce.PrepareSessionBeforeExecRequest(ses)
err = mce.doComQuery(cmdFieldListQuery)
err = mce.doComQuery(ctx, cmdFieldListQuery)
convey.So(err, convey.ShouldBeNil)
})
}
......
......@@ -16,6 +16,7 @@ package frontend
import (
"bytes"
"context"
"database/sql"
"encoding/binary"
"errors"
......@@ -53,7 +54,7 @@ func (tRM *TestRoutineManager) Created(rs goetty.IOSession) {
pro := NewMysqlClientProtocol(nextConnectionID(), rs, 1024, tRM.pu.SV)
pro.SetSkipCheckUser(true)
exe := NewMysqlCmdExecutor()
routine := NewRoutine(pro, exe, tRM.pu)
routine := NewRoutine(context.TODO(), pro, exe, tRM.pu)
hsV10pkt := pro.makeHandshakeV10Payload()
err := pro.writePackets(hsV10pkt)
......@@ -271,8 +272,8 @@ func TestMysqlClientProtocol_Handshake(t *testing.T) {
config.HostMmu = host.New(config.GlobalSystemVariables.GetHostMmuLimitation())
config.Mempool = mempool.New( /*int(config.GlobalSystemVariables.GetMempoolMaxSize()), int(config.GlobalSystemVariables.GetMempoolFactor())*/ )
pu := config.NewParameterUnit(&config.GlobalSystemVariables, config.HostMmu, config.Mempool, config.StorageEngine, config.ClusterNodes)
rm := NewRoutineManager(pu)
ctx := context.WithValue(context.TODO(), config.ParameterUnitKey, pu)
rm := NewRoutineManager(ctx, pu)
rm.SetSkipCheckUser(true)
wg := sync.WaitGroup{}
......
......@@ -15,11 +15,11 @@
package frontend
import (
"context"
"github.com/matrixorigin/matrixone/pkg/config"
"github.com/matrixorigin/matrixone/pkg/logutil"
"github.com/matrixorigin/matrixone/pkg/vm/mempool"
"github.com/matrixorigin/matrixone/pkg/vm/mmu/guest"
"sync"
"time"
)
......@@ -40,10 +40,8 @@ type Routine struct {
//channel of request
requestChan chan *Request
//channel of notify
notifyChan chan interface{}
onceCloseNotifyChan sync.Once
cancelRoutineCtx context.Context
cancelRoutineFunc context.CancelFunc
routineMgr *RoutineManager
......@@ -81,7 +79,7 @@ func (routine *Routine) GetSession() *Session {
/*
After the handshake with the client is done, the routine goes into processing loop.
*/
func (routine *Routine) Loop() {
func (routine *Routine) Loop(routineCtx context.Context) {
var req *Request = nil
var err error
var resp *Response
......@@ -90,8 +88,8 @@ func (routine *Routine) Loop() {
for {
quit := false
select {
case <-routine.notifyChan:
logutil.Infof("-----routine quit")
case <-routineCtx.Done():
logutil.Infof("-----cancel routine")
quit = true
case req = <-routine.requestChan:
}
......@@ -107,9 +105,13 @@ func (routine *Routine) Loop() {
mpi := routine.protocol.(*MysqlProtocolImpl)
mpi.sequenceId = req.seq
cancelRequestCtx, cancelRequestFunc := context.WithCancel(routineCtx)
routine.executor.(*MysqlCmdExecutor).setCancelRequestFunc(cancelRequestFunc)
ses := routine.GetSession()
ses.SetRequestContext(cancelRequestCtx)
routine.executor.PrepareSessionBeforeExecRequest(routine.GetSession())
if resp, err = routine.executor.ExecRequest(req); err != nil {
if resp, err = routine.executor.ExecRequest(cancelRequestCtx, req); err != nil {
logutil.Errorf("routine execute request failed. error:%v \n", err)
}
......@@ -122,6 +124,8 @@ func (routine *Routine) Loop() {
if mgr.getParameterUnit().SV.GetRecordTimeElapsedOfSqlRequest() {
logutil.Infof("connection id %d , the time of handling the request %s", routine.getConnID(), time.Since(reqBegin).String())
}
cancelRequestFunc()
}
}
......@@ -131,10 +135,9 @@ When the io is closed, the Quit will be called.
func (routine *Routine) Quit() {
routine.notifyClose()
routine.onceCloseNotifyChan.Do(func() {
//logutil.Infof("---------notify close")
close(routine.notifyChan)
})
if routine.cancelRoutineFunc != nil {
routine.cancelRoutineFunc()
}
if routine.protocol != nil {
routine.protocol.Quit()
......@@ -150,18 +153,20 @@ func (routine *Routine) notifyClose() {
}
}
func NewRoutine(protocol MysqlProtocol, executor CmdExecutor, pu *config.ParameterUnit) *Routine {
func NewRoutine(ctx context.Context, protocol MysqlProtocol, executor CmdExecutor, pu *config.ParameterUnit) *Routine {
cancelRoutineCtx, cancelRoutineFunc := context.WithCancel(ctx)
ri := &Routine{
protocol: protocol,
executor: executor,
requestChan: make(chan *Request, 1),
notifyChan: make(chan interface{}),
guestMmu: guest.New(pu.SV.GetGuestMmuLimitation(), pu.HostMmu),
mempool: pu.Mempool,
protocol: protocol,
executor: executor,
requestChan: make(chan *Request, 1),
guestMmu: guest.New(pu.SV.GetGuestMmuLimitation(), pu.HostMmu),
mempool: pu.Mempool,
cancelRoutineCtx: cancelRoutineCtx,
cancelRoutineFunc: cancelRoutineFunc,
}
//async process request
go ri.Loop()
go ri.Loop(cancelRoutineCtx)
return ri
}
......@@ -15,6 +15,7 @@
package frontend
import (
"context"
"errors"
"sync"
......@@ -25,6 +26,7 @@ import (
type RoutineManager struct {
rwlock sync.RWMutex
ctx context.Context
clients map[goetty.IOSession]*Routine
pu *config.ParameterUnit
skipCheckUser bool
......@@ -52,9 +54,10 @@ func (rm *RoutineManager) Created(rs goetty.IOSession) {
exe := NewMysqlCmdExecutor()
exe.SetRoutineManager(rm)
routine := NewRoutine(pro, exe, rm.pu)
routine := NewRoutine(rm.ctx, pro, exe, rm.pu)
routine.SetRoutineMgr(rm)
ses := NewSession(routine.protocol, routine.guestMmu, routine.mempool, rm.pu, gSysVariables)
ses.SetRequestContext(routine.cancelRoutineCtx)
routine.SetSession(ses)
pro.SetSession(ses)
......@@ -171,8 +174,9 @@ func (rm *RoutineManager) Handler(rs goetty.IOSession, msg interface{}, received
return nil
}
func NewRoutineManager(pu *config.ParameterUnit) *RoutineManager {
func NewRoutineManager(ctx context.Context, pu *config.ParameterUnit) *RoutineManager {
rm := &RoutineManager{
ctx: ctx,
clients: make(map[goetty.IOSession]*Routine),
pu: pu,
}
......
......@@ -15,6 +15,7 @@
package frontend
import (
"context"
"fmt"
"github.com/matrixorigin/matrixone/pkg/config"
"github.com/matrixorigin/matrixone/pkg/vm/mempool"
......@@ -43,7 +44,8 @@ func create_test_server() *MOServer {
pu := config.NewParameterUnit(&config.GlobalSystemVariables, config.HostMmu, config.Mempool, config.StorageEngine, config.ClusterNodes)
address := fmt.Sprintf("%s:%d", config.GlobalSystemVariables.GetHost(), config.GlobalSystemVariables.GetPort())
return NewMOServer(address, pu)
moServerCtx := context.WithValue(context.TODO(), config.ParameterUnitKey, pu)
return NewMOServer(moServerCtx, address, pu)
}
func Test_Closed(t *testing.T) {
......
......@@ -15,6 +15,7 @@
package frontend
import (
"context"
"fmt"
"sync/atomic"
......@@ -58,9 +59,9 @@ func nextConnectionID() uint32 {
return atomic.AddUint32(&initConnectionID, 1)
}
func NewMOServer(addr string, pu *config.ParameterUnit) *MOServer {
func NewMOServer(ctx context.Context, addr string, pu *config.ParameterUnit) *MOServer {
codec := NewSqlCodec()
rm := NewRoutineManager(pu)
rm := NewRoutineManager(ctx, pu)
// TODO asyncFlushBatch
app, err := goetty.NewApplication(addr, rm.Handler,
goetty.WithAppLogger(logutil.GetGlobalLogger()),
......
......@@ -77,7 +77,6 @@ type Session struct {
ep *tree.ExportParam
showStmtType ShowStatementType
closeRef *CloseExportData
txnHandler *TxnHandler
txnCompileCtx *TxnCompilerContext
storage engine.Engine
......@@ -95,6 +94,8 @@ type Session struct {
prepareStmts map[string]*PrepareStmt
requestCtx context.Context
//it gets the result set from the pipeline and send it to the client
outputCallback func(interface{}, *batch.Batch) error
......@@ -136,6 +137,39 @@ func NewSession(proto Protocol, gm *guest.Mmu, mp *mempool.Mempool, PU *config.P
return ses
}
// BackgroundSession executing the sql in background
type BackgroundSession struct {
*Session
cancel context.CancelFunc
}
// NewBackgroundSession generates an independent background session executing the sql
func NewBackgroundSession(ctx context.Context, gm *guest.Mmu, mp *mempool.Mempool, PU *config.ParameterUnit, gSysVars *GlobalSystemVariables) *BackgroundSession {
ses := NewSession(&FakeProtocol{}, gm, mp, PU, gSysVars)
ses.SetOutputCallback(fakeDataSetFetcher)
cancelBackgroundCtx, cancelBackgroundFunc := context.WithCancel(ctx)
ses.SetRequestContext(cancelBackgroundCtx)
backSes := &BackgroundSession{
Session: ses,
cancel: cancelBackgroundFunc,
}
return backSes
}
func (bgs *BackgroundSession) Close() {
if bgs.cancel != nil {
bgs.cancel()
}
}
func (ses *Session) SetRequestContext(reqCtx context.Context) {
ses.requestCtx = reqCtx
}
func (ses *Session) GetRequestContext() context.Context {
return ses.requestCtx
}
func (ses *Session) SetMysqlResultSet(mrs *MysqlResultSet) {
ses.Mrs = mrs
}
......@@ -522,7 +556,7 @@ func (ses *Session) AuthenticateUser(userInput string) error {
ses.SetTenantInfo(tenant)
//Get the password of the user in an independent session
err = executeSQLInBackgroundSession(ses.GuestMmu, ses.Mempool, ses.Pu, "use mo_catalog; select * from mo_database;")
err = executeSQLInBackgroundSession(ses.requestCtx, ses.GuestMmu, ses.Mempool, ses.Pu, "use mo_catalog; select * from mo_database;")
return err
}
......@@ -640,8 +674,7 @@ func (tcc *TxnCompilerContext) GetRootSql() string {
func (tcc *TxnCompilerContext) DatabaseExists(name string) bool {
var err error
//open database
ctx := context.TODO()
_, err = tcc.txnHandler.GetStorage().Database(ctx, name, engine.Snapshot(tcc.txnHandler.GetTxn().GetCtx()))
_, err = tcc.txnHandler.GetStorage().Database(tcc.ses.GetRequestContext(), name, engine.Snapshot(tcc.txnHandler.GetTxn().GetCtx()))
if err != nil {
logutil.Errorf("get database %v failed. error %v", name, err)
return false
......@@ -656,7 +689,7 @@ func (tcc *TxnCompilerContext) getRelation(dbName string, tableName string) (eng
return nil, err
}
ctx := context.TODO()
ctx := tcc.ses.GetRequestContext()
//open database
db, err := tcc.txnHandler.GetStorage().Database(ctx, dbName, engine.Snapshot(tcc.txnHandler.GetTxn().GetCtx()))
if err != nil {
......@@ -698,7 +731,7 @@ func (tcc *TxnCompilerContext) Resolve(dbName string, tableName string) (*plan2.
if err != nil {
return nil, nil
}
ctx := context.TODO()
ctx := tcc.ses.GetRequestContext()
engineDefs, err := table.TableDefs(ctx)
if err != nil {
return nil, nil
......@@ -790,7 +823,7 @@ func (tcc *TxnCompilerContext) ResolveVariable(varName string, isSystemVar, isGl
}
func (tcc *TxnCompilerContext) GetPrimaryKeyDef(dbName string, tableName string) []*plan2.ColDef {
ctx := context.TODO()
ctx := tcc.ses.GetRequestContext()
dbName, err := tcc.ensureDatabaseIsNotEmpty(dbName)
if err != nil {
return nil
......@@ -826,7 +859,7 @@ func (tcc *TxnCompilerContext) GetPrimaryKeyDef(dbName string, tableName string)
}
func (tcc *TxnCompilerContext) GetHideKeyDef(dbName string, tableName string) *plan2.ColDef {
ctx := context.TODO()
ctx := tcc.ses.GetRequestContext()
dbName, err := tcc.ensureDatabaseIsNotEmpty(dbName)
if err != nil {
return nil
......@@ -903,13 +936,13 @@ func fakeDataSetFetcher(handle interface{}, dataSet *batch.Batch) error {
// executeSQLInBackgroundSession executes the sql in an independent session and transaction.
// It sends nothing to the client.
func executeSQLInBackgroundSession(gm *guest.Mmu, mp *mempool.Mempool, pu *config.ParameterUnit, sql string) error {
func executeSQLInBackgroundSession(ctx context.Context, gm *guest.Mmu, mp *mempool.Mempool, pu *config.ParameterUnit, sql string) error {
mce := NewMysqlCmdExecutor()
defer mce.Close()
ses := NewSession(&FakeProtocol{}, gm, mp, pu, gSysVariables)
ses.SetOutputCallback(fakeDataSetFetcher)
mce.PrepareSessionBeforeExecRequest(ses)
err := mce.doComQuery(sql)
backSess := NewBackgroundSession(ctx, gm, mp, pu, gSysVariables)
mce.PrepareSessionBeforeExecRequest(backSess.Session)
defer backSess.Close()
err := mce.doComQuery(backSess.GetRequestContext(), sql)
if err != nil {
return err
}
......
......@@ -9,9 +9,9 @@ import (
reflect "reflect"
time "time"
goetty "github.com/fagongzi/goetty/v2"
buf "github.com/fagongzi/goetty/v2/buf"
gomock "github.com/golang/mock/gomock"
goetty "github.com/fagongzi/goetty/v2"
)
// MockIOSession is a mock of IOSession interface.
......
......@@ -5,6 +5,7 @@
package mock_frontend
import (
"context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
......@@ -73,7 +74,7 @@ func (m *MockComputationWrapper) EXPECT() *MockComputationWrapperMockRecorder {
}
// Compile mocks base method.
func (m *MockComputationWrapper) Compile(u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
func (m *MockComputationWrapper) Compile(requestCtx context.Context, u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Compile", u, fill)
ret0, _ := ret[0].(interface{})
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment