diff --git a/pkg/vm/engine/tae/catalog/base.go b/pkg/vm/engine/tae/catalog/base.go index e1a533a9af356aa66a7b8b9f439f02c5c95c6761..c2001eed61d870f5031a6472d403accbe24f224a 100644 --- a/pkg/vm/engine/tae/catalog/base.go +++ b/pkg/vm/engine/tae/catalog/base.go @@ -38,7 +38,7 @@ func CompareUint64(left, right uint64) int { type BaseEntry struct { *sync.RWMutex - MVCC *common.SortedDList + MVCC *common.GenericSortedDList[*UpdateNode] length uint64 // meta * ID uint64 @@ -47,7 +47,7 @@ type BaseEntry struct { func NewReplayBaseEntry() *BaseEntry { be := &BaseEntry{ RWMutex: &sync.RWMutex{}, - MVCC: new(common.SortedDList), + MVCC: common.NewGenericSortedDList[*UpdateNode](compareUpdateNode), } return be } @@ -55,7 +55,7 @@ func NewReplayBaseEntry() *BaseEntry { func NewBaseEntry(id uint64) *BaseEntry { return &BaseEntry{ ID: id, - MVCC: new(common.SortedDList), + MVCC: common.NewGenericSortedDList[*UpdateNode](compareUpdateNode), RWMutex: &sync.RWMutex{}, } } @@ -63,9 +63,9 @@ func (be *BaseEntry) StringLocked() string { var w bytes.Buffer _, _ = w.WriteString(fmt.Sprintf("[%d %p]", be.ID, be.RWMutex)) - it := common.NewSortedDListIt(nil, be.MVCC, false) + it := common.NewGenericSortedDListIt(nil, be.MVCC, false) for it.Valid() { - version := it.Get().GetPayload().(*UpdateNode) + version := it.Get().GetPayload() _, _ = w.WriteString(" -> ") _, _ = w.WriteString(version.String()) it.Next() @@ -104,8 +104,8 @@ func (be *BaseEntry) GetID() uint64 { return be.ID } func (be *BaseEntry) GetIndexes() []*wal.Index { ret := make([]*wal.Index, 0) - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() ret = append(ret, un.LogIndex...) return true }, true) @@ -134,8 +134,8 @@ func (be *BaseEntry) CreateWithTxn(txn txnif.AsyncTxn) { be.InsertNode(node) } func (be *BaseEntry) ExistUpdate(minTs, MaxTs types.TS) (exist bool) { - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() if un.End.IsEmpty() { return true } @@ -153,7 +153,7 @@ func (be *BaseEntry) ExistUpdate(minTs, MaxTs types.TS) (exist bool) { // TODO update create func (be *BaseEntry) DeleteLocked(txn txnif.TxnReader, impl INode) (node INode, err error) { - entry := be.MVCC.GetHead().GetPayload().(*UpdateNode) + entry := be.MVCC.GetHead().GetPayload() if entry.Txn == nil || entry.IsSameTxn(txn.GetStartTS()) { if be.HasDropped() { err = ErrNotFound @@ -185,15 +185,15 @@ func (be *BaseEntry) GetUpdateNodeLocked() *UpdateNode { if payload == nil { return nil } - entry := payload.(*UpdateNode) + entry := payload return entry } // GetCommittedNode gets the latest committed UpdateNode. // It's useful when check whether the catalog/metadata entry is deleted. func (be *BaseEntry) GetCommittedNode() (node *UpdateNode) { - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() if !un.IsActive() { node = un return false @@ -207,8 +207,8 @@ func (be *BaseEntry) GetCommittedNode() (node *UpdateNode) { // It returns the UpdateNode in the same txn as the read txn // or returns the latest UpdateNode with commitTS less than the timestamp. func (be *BaseEntry) GetNodeToRead(startts types.TS) (node *UpdateNode) { - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() if un.IsSameTxn(startts) { node = un return false @@ -236,8 +236,8 @@ func (be *BaseEntry) DeleteBefore(ts types.TS) bool { // GetExactUpdateNode gets the exact UpdateNode with the startTs. // It's only used in replay func (be *BaseEntry) GetExactUpdateNode(startts types.TS) (node *UpdateNode) { - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() if un.Start == startts { node = un return false @@ -330,9 +330,9 @@ func (be *BaseEntry) WriteAllTo(w io.Writer) (n int64, err error) { return } n += 8 - be.MVCC.Loop(func(node *common.DLNode) bool { + be.MVCC.Loop(func(node *common.GenericDLNode[*UpdateNode]) bool { var n2 int64 - n2, err = node.GetPayload().(*UpdateNode).WriteTo(w) + n2, err = node.GetPayload().WriteTo(w) if err != nil { return false } @@ -449,7 +449,7 @@ func (be *BaseEntry) MetaTxnCanRead(txn txnif.AsyncTxn, mu *sync.RWMutex) (canRe } func (be *BaseEntry) CloneCreateEntry() *BaseEntry { cloned := &BaseEntry{ - MVCC: &common.SortedDList{}, + MVCC: common.NewGenericSortedDList[*UpdateNode](compareUpdateNode), RWMutex: &sync.RWMutex{}, ID: be.ID, } @@ -562,8 +562,8 @@ func (be *BaseEntry) IsCommitted() bool { } func (be *BaseEntry) CloneCommittedInRange(start, end types.TS) (ret *BaseEntry) { - be.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + be.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() if un.IsActive() { return true } @@ -608,3 +608,24 @@ func (be *BaseEntry) GetDeleteAt() types.TS { } return un.DeletedAt } + +func (be *BaseEntry) TxnCanGet(ts types.TS) (can, dropped bool) { + be.RLock() + defer be.RUnlock() + needWait, txnToWait := be.NeedWaitCommitting(ts) + if needWait { + be.RUnlock() + txnToWait.GetTxnState(true) + be.RLock() + } + un := be.GetNodeToRead(ts) + if un == nil { + return + } + if un.HasDropped() { + can, dropped = true, true + return + } + can, dropped = true, false + return +} diff --git a/pkg/vm/engine/tae/catalog/block.go b/pkg/vm/engine/tae/catalog/block.go index 94ace5db7addd7c70770615ed823f3bcab9eb962..d78020f502feb7386fdf1f7f89ab1079415d208e 100644 --- a/pkg/vm/engine/tae/catalog/block.go +++ b/pkg/vm/engine/tae/catalog/block.go @@ -28,6 +28,10 @@ import ( type BlockDataFactory = func(meta *BlockEntry) data.Block +func compareBlockFn(a, b *BlockEntry) int { + return a.BaseEntry.DoCompre(b.BaseEntry) +} + type BlockEntry struct { *BaseEntry segment *SegmentEntry @@ -92,11 +96,6 @@ func (entry *BlockEntry) MakeCommand(id uint32) (cmd txnif.TxnCmd, err error) { return newBlockCmd(id, cmdType, entry), nil } -func (entry *BlockEntry) Compare(o common.NodePayload) int { - oe := o.(*BlockEntry).BaseEntry - return entry.DoCompre(oe) -} - func (entry *BlockEntry) PPString(level common.PPLevel, depth int, prefix string) string { s := fmt.Sprintf("%s%s%s", common.RepeatStr("\t", depth), prefix, entry.StringLocked()) return s diff --git a/pkg/vm/engine/tae/catalog/catalog.go b/pkg/vm/engine/tae/catalog/catalog.go index f93d9d201c2cfa03455bbb071491555f05cfbe0d..9b18b8052bd619ccda77628a5c8b592736fc4363 100644 --- a/pkg/vm/engine/tae/catalog/catalog.go +++ b/pkg/vm/engine/tae/catalog/catalog.go @@ -54,9 +54,9 @@ type Catalog struct { ckpmu sync.RWMutex checkpoints []*Checkpoint - entries map[uint64]*common.DLNode - nameNodes map[string]*nodeList - link *common.SortedDList + entries map[uint64]*common.GenericDLNode[*DBEntry] + nameNodes map[string]*nodeList[*DBEntry] + link *common.GenericSortedDList[*DBEntry] nodesMu sync.RWMutex @@ -71,15 +71,19 @@ func genDBFullName(tenantID uint32, name string) string { return fmt.Sprintf("%d-%s", tenantID, name) } +func compareDBFn(a, b *DBEntry) int { + return a.BaseEntry.DoCompre(b.BaseEntry) +} + func MockCatalog(dir, name string, cfg *batchstoredriver.StoreCfg, scheduler tasks.TaskScheduler) *Catalog { driver := store.NewStoreWithBatchStoreDriver(dir, name, cfg) catalog := &Catalog{ RWMutex: new(sync.RWMutex), IDAlloctor: NewIDAllocator(), store: driver, - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*DBEntry]), + nameNodes: make(map[string]*nodeList[*DBEntry]), + link: common.NewGenericSortedDList(compareDBFn), checkpoints: make([]*Checkpoint, 0), scheduler: scheduler, } @@ -93,9 +97,9 @@ func OpenCatalog(dir, name string, cfg *batchstoredriver.StoreCfg, scheduler tas RWMutex: new(sync.RWMutex), IDAlloctor: NewIDAllocator(), store: driver, - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*DBEntry]), + nameNodes: make(map[string]*nodeList[*DBEntry]), + link: common.NewGenericSortedDList(compareDBFn), checkpoints: make([]*Checkpoint, 0), scheduler: scheduler, } @@ -221,8 +225,8 @@ func (catalog *Catalog) onReplayDatabase(cmd *EntryCommand) { return } - cmd.DB.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + cmd.DB.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() dbun := db.GetExactUpdateNode(un.Start) if dbun == nil { db.InsertNode(un) //TODO isvalid @@ -289,8 +293,8 @@ func (catalog *Catalog) onReplayTable(cmd *EntryCommand, dataFactory DataFactory panic(err) } } else { - cmd.Table.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + cmd.Table.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() node := rel.GetExactUpdateNode(un.Start) if node == nil { rel.InsertNode(un) @@ -354,8 +358,8 @@ func (catalog *Catalog) onReplaySegment(cmd *EntryCommand, dataFactory DataFacto cmd.Segment.table = rel rel.AddEntryLocked(cmd.Segment) } else { - cmd.Segment.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + cmd.Segment.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() segun := seg.GetExactUpdateNode(un.Start) if segun != nil { segun.UpdateNode(un) @@ -435,8 +439,8 @@ func (catalog *Catalog) onReplayBlock(cmd *EntryCommand, dataFactory DataFactory cmd.Block.segment = seg seg.AddEntryLocked(cmd.Block) } else { - cmd.Block.MVCC.Loop(func(n *common.DLNode) bool { - un := n.GetPayload().(*UpdateNode) + cmd.Block.MVCC.Loop(func(n *common.GenericDLNode[*UpdateNode]) bool { + un := n.GetPayload() blkun := blk.GetExactUpdateNode(un.Start) if blkun != nil { blkun.UpdateNode(un) @@ -510,6 +514,10 @@ func (catalog *Catalog) AddColumnCnt(cnt int) { } } +func (catalog *Catalog) GetItemNodeByIDLocked(id uint64) *common.GenericDLNode[*DBEntry] { + return catalog.entries[id] +} + func (catalog *Catalog) GetScheduler() tasks.TaskScheduler { return catalog.scheduler } func (catalog *Catalog) GetDatabaseByID(id uint64) (db *DBEntry, err error) { catalog.RLock() @@ -519,7 +527,7 @@ func (catalog *Catalog) GetDatabaseByID(id uint64) (db *DBEntry, err error) { err = ErrNotFound return } - db = node.GetPayload().(*DBEntry) + db = node.GetPayload() return } @@ -529,13 +537,16 @@ func (catalog *Catalog) AddEntryLocked(database *DBEntry, txn txnif.TxnReader) e n := catalog.link.Insert(database) catalog.entries[database.GetID()] = n - nn := newNodeList(catalog, &catalog.nodesMu, database.name) + nn := newNodeList[*DBEntry](catalog.GetItemNodeByIDLocked, + databaseTxnCanGetFn[*DBEntry], + &catalog.nodesMu, + database.name) catalog.nameNodes[database.GetFullName()] = nn nn.CreateNode(database.GetID()) } else { - node := nn.GetDBNode() - record := node.GetPayload().(*DBEntry) + node := nn.GetNode() + record := node.GetPayload() err := record.PrepareAdd(txn) if err != nil { return err @@ -547,10 +558,10 @@ func (catalog *Catalog) AddEntryLocked(database *DBEntry, txn txnif.TxnReader) e return nil } -func (catalog *Catalog) MakeDBIt(reverse bool) *common.SortedDListIt { +func (catalog *Catalog) MakeDBIt(reverse bool) *common.GenericSortedDListIt[*DBEntry] { catalog.RLock() defer catalog.RUnlock() - return common.NewSortedDListIt(catalog.RWMutex, catalog.link, reverse) + return common.NewGenericSortedDListIt[*DBEntry](catalog.RWMutex, catalog.link, reverse) } func (catalog *Catalog) SimplePPString(level common.PPLevel) string { @@ -563,7 +574,7 @@ func (catalog *Catalog) PPString(level common.PPLevel, depth int, prefix string) it := catalog.MakeDBIt(true) for it.Valid() { cnt++ - entry := it.Get().GetPayload().(*DBEntry) + entry := it.Get().GetPayload() _ = w.WriteByte('\n') _, _ = w.WriteString(entry.PPString(level, depth+1, "")) it.Next() @@ -604,7 +615,7 @@ func (catalog *Catalog) RemoveEntry(database *DBEntry) error { return nil } -func (catalog *Catalog) txnGetNodeByNameLocked(name string, txnCtx txnif.AsyncTxn) (*common.DLNode, error) { +func (catalog *Catalog) txnGetNodeByNameLocked(name string, txnCtx txnif.AsyncTxn) (*common.GenericDLNode[*DBEntry], error) { catalog.RLock() defer catalog.RUnlock() fullName := genDBFullName(txnCtx.GetTenantID(), name) @@ -612,7 +623,7 @@ func (catalog *Catalog) txnGetNodeByNameLocked(name string, txnCtx txnif.AsyncTx if node == nil { return nil, ErrNotFound } - return node.TxnGetDBNodeLocked(txnCtx) + return node.TxnGetNodeLocked(txnCtx) } func (catalog *Catalog) GetDBEntry(name string, txnCtx txnif.AsyncTxn) (*DBEntry, error) { @@ -620,7 +631,7 @@ func (catalog *Catalog) GetDBEntry(name string, txnCtx txnif.AsyncTxn) (*DBEntry if err != nil { return nil, err } - return n.GetPayload().(*DBEntry), nil + return n.GetPayload(), nil } func (catalog *Catalog) DropDBEntry(name string, txnCtx txnif.AsyncTxn) (deleted *DBEntry, err error) { @@ -632,7 +643,7 @@ func (catalog *Catalog) DropDBEntry(name string, txnCtx txnif.AsyncTxn) (deleted if err != nil { return } - entry := dn.GetPayload().(*DBEntry) + entry := dn.GetPayload() entry.Lock() defer entry.Unlock() err = entry.DropEntryLocked(txnCtx) @@ -661,7 +672,7 @@ func (catalog *Catalog) CreateDBEntryByTS(name string, ts types.TS) (*DBEntry, e func (catalog *Catalog) RecurLoop(processor Processor) (err error) { dbIt := catalog.MakeDBIt(true) for dbIt.Valid() { - dbEntry := dbIt.Get().GetPayload().(*DBEntry) + dbEntry := dbIt.Get().GetPayload() if err = processor.OnDatabase(dbEntry); err != nil { if err == ErrStopCurrRecur { err = nil diff --git a/pkg/vm/engine/tae/catalog/catalog_test.go b/pkg/vm/engine/tae/catalog/catalog_test.go index d796b77457653261347bbf6d4d8d551f06ac5e0a..2a8ca1099387cb5f0afc2d96aa8ee032040dcac1 100644 --- a/pkg/vm/engine/tae/catalog/catalog_test.go +++ b/pkg/vm/engine/tae/catalog/catalog_test.go @@ -77,8 +77,8 @@ func TestCreateDB1(t *testing.T) { assert.Equal(t, 2, len(catalog.entries)) cnt := 0 - catalog.link.Loop(func(n *common.DLNode) bool { - t.Log(n.GetPayload().(*DBEntry).GetID()) + catalog.link.Loop(func(n *common.GenericDLNode[*DBEntry]) bool { + t.Log(n.GetPayload().GetID()) cnt++ return true }, true) @@ -116,7 +116,7 @@ func TestCreateDB1(t *testing.T) { assert.Nil(t, err) cnt = 0 - catalog.link.Loop(func(n *common.DLNode) bool { + catalog.link.Loop(func(n *common.GenericDLNode[*DBEntry]) bool { // t.Log(n.payload.(*DBEntry).String()) cnt++ return true diff --git a/pkg/vm/engine/tae/catalog/database.go b/pkg/vm/engine/tae/catalog/database.go index 5f73a303d88836470777e9c473d7db0a1709b5f8..6adfb47a08f33e09f9a94933da3893741d32d0c0 100644 --- a/pkg/vm/engine/tae/catalog/database.go +++ b/pkg/vm/engine/tae/catalog/database.go @@ -58,6 +58,12 @@ func (ai *accessInfo) ReadFrom(r io.Reader) (n int64, err error) { return 20, nil } +func databaseTxnCanGetFn[T *DBEntry](n *common.GenericDLNode[*DBEntry], ts types.TS) (can, dropped bool) { + db := n.GetPayload() + can, dropped = db.TxnCanGet(ts) + return +} + type DBEntry struct { *BaseEntry catalog *Catalog @@ -66,13 +72,17 @@ type DBEntry struct { fullName string isSys bool - entries map[uint64]*common.DLNode - nameNodes map[string]*nodeList - link *common.SortedDList + entries map[uint64]*common.GenericDLNode[*TableEntry] + nameNodes map[string]*nodeList[*TableEntry] + link *common.GenericSortedDList[*TableEntry] nodesMu sync.RWMutex } +func compareTableFn(a, b *TableEntry) int { + return a.BaseEntry.DoCompre(b.BaseEntry) +} + func NewDBEntry(catalog *Catalog, name string, txnCtx txnif.AsyncTxn) *DBEntry { id := catalog.NextDB() @@ -80,9 +90,9 @@ func NewDBEntry(catalog *Catalog, name string, txnCtx txnif.AsyncTxn) *DBEntry { BaseEntry: NewBaseEntry(id), catalog: catalog, name: name, - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*TableEntry]), + nameNodes: make(map[string]*nodeList[*TableEntry]), + link: common.NewGenericSortedDList(compareTableFn), } if txnCtx != nil { // Only in unit test, txnCtx can be nil @@ -101,9 +111,9 @@ func NewDBEntryByTS(catalog *Catalog, name string, ts types.TS) *DBEntry { BaseEntry: NewBaseEntry(id), catalog: catalog, name: name, - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*TableEntry]), + nameNodes: make(map[string]*nodeList[*TableEntry]), + link: common.NewGenericSortedDList(compareTableFn), } e.CreateWithTS(ts) e.acInfo.CreateAt = types.CurrentTimestamp() @@ -116,9 +126,9 @@ func NewSystemDBEntry(catalog *Catalog) *DBEntry { BaseEntry: NewBaseEntry(id), catalog: catalog, name: SystemDBName, - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*TableEntry]), + nameNodes: make(map[string]*nodeList[*TableEntry]), + link: common.NewGenericSortedDList(compareTableFn), isSys: true, } entry.CreateWithTS(types.SystemDBTS) @@ -128,9 +138,9 @@ func NewSystemDBEntry(catalog *Catalog) *DBEntry { func NewReplayDBEntry() *DBEntry { entry := &DBEntry{ BaseEntry: NewReplayBaseEntry(), - entries: make(map[uint64]*common.DLNode), - nameNodes: make(map[string]*nodeList), - link: new(common.SortedDList), + entries: make(map[uint64]*common.GenericDLNode[*TableEntry]), + nameNodes: make(map[string]*nodeList[*TableEntry]), + link: common.NewGenericSortedDList(compareTableFn), } return entry } @@ -142,11 +152,6 @@ func (e *DBEntry) CoarseTableCnt() int { return len(e.entries) } -func (e *DBEntry) Compare(o common.NodePayload) int { - oe := o.(*DBEntry).BaseEntry - return e.DoCompre(oe) -} - func (e *DBEntry) GetTenantID() uint32 { return e.acInfo.TenantID } func (e *DBEntry) GetUserID() uint32 { return e.acInfo.UserID } func (e *DBEntry) GetRoleID() uint32 { return e.acInfo.RoleID } @@ -169,10 +174,10 @@ func (e *DBEntry) StringLocked() string { return fmt.Sprintf("DB%s[name=%s]", e.BaseEntry.StringLocked(), e.GetFullName()) } -func (e *DBEntry) MakeTableIt(reverse bool) *common.SortedDListIt { +func (e *DBEntry) MakeTableIt(reverse bool) *common.GenericSortedDListIt[*TableEntry] { e.RLock() defer e.RUnlock() - return common.NewSortedDListIt(e.RWMutex, e.link, reverse) + return common.NewGenericSortedDListIt(e.RWMutex, e.link, reverse) } func (e *DBEntry) PPString(level common.PPLevel, depth int, prefix string) string { @@ -183,7 +188,7 @@ func (e *DBEntry) PPString(level common.PPLevel, depth int, prefix string) strin } it := e.MakeTableIt(true) for it.Valid() { - table := it.Get().GetPayload().(*TableEntry) + table := it.Get().GetPayload() _ = w.WriteByte('\n') _, _ = w.WriteString(table.PPString(level, depth+1, "")) it.Next() @@ -206,6 +211,10 @@ func (e *DBEntry) GetBlockEntryByID(id *common.ID) (blk *BlockEntry, err error) return } +func (e *DBEntry) GetItemNodeByIDLocked(id uint64) *common.GenericDLNode[*TableEntry] { + return e.entries[id] +} + func (e *DBEntry) GetTableEntryByID(id uint64) (table *TableEntry, err error) { e.RLock() defer e.RUnlock() @@ -213,11 +222,12 @@ func (e *DBEntry) GetTableEntryByID(id uint64) (table *TableEntry, err error) { if node == nil { return nil, ErrNotFound } - table = node.GetPayload().(*TableEntry) + table = node.GetPayload() return } -func (e *DBEntry) txnGetNodeByName(name string, txnCtx txnif.AsyncTxn) (*common.DLNode, error) { +func (e *DBEntry) txnGetNodeByName(name string, + txnCtx txnif.AsyncTxn) (*common.GenericDLNode[*TableEntry], error) { e.RLock() defer e.RUnlock() fullName := genTblFullName(txnCtx.GetTenantID(), name) @@ -225,7 +235,7 @@ func (e *DBEntry) txnGetNodeByName(name string, txnCtx txnif.AsyncTxn) (*common. if node == nil { return nil, ErrNotFound } - return node.TxnGetTableNodeLocked(txnCtx) + return node.TxnGetNodeLocked(txnCtx) } func (e *DBEntry) GetTableEntry(name string, txnCtx txnif.AsyncTxn) (entry *TableEntry, err error) { @@ -233,7 +243,7 @@ func (e *DBEntry) GetTableEntry(name string, txnCtx txnif.AsyncTxn) (entry *Tabl if err != nil { return } - entry = n.GetPayload().(*TableEntry) + entry = n.GetPayload() return } @@ -251,7 +261,7 @@ func (e *DBEntry) DropTableEntry(name string, txnCtx txnif.AsyncTxn) (deleted *T if err != nil { return } - entry := dn.GetPayload().(*TableEntry) + entry := dn.GetPayload() entry.Lock() defer entry.Unlock() err = entry.DropEntryLocked(txnCtx) @@ -319,13 +329,16 @@ func (e *DBEntry) AddEntryLocked(table *TableEntry, txn txnif.AsyncTxn) (err err n := e.link.Insert(table) e.entries[table.GetID()] = n - nn := newNodeList(e, &e.nodesMu, fullName) + nn := newNodeList(e.GetItemNodeByIDLocked, + tableTxnCanGetFn[*TableEntry], + &e.nodesMu, + fullName) e.nameNodes[fullName] = nn nn.CreateNode(table.GetID()) } else { - node := nn.GetTableNode() - record := node.GetPayload().(*TableEntry) + node := nn.GetNode() + record := node.GetPayload() err = record.PrepareAdd(txn) if err != nil { return @@ -349,7 +362,7 @@ func (e *DBEntry) GetCatalog() *Catalog { return e.catalog } func (e *DBEntry) RecurLoop(processor Processor) (err error) { tableIt := e.MakeTableIt(true) for tableIt.Valid() { - table := tableIt.Get().GetPayload().(*TableEntry) + table := tableIt.Get().GetPayload() if err = processor.OnTable(table); err != nil { if err == ErrStopCurrRecur { err = nil diff --git a/pkg/vm/engine/tae/catalog/node.go b/pkg/vm/engine/tae/catalog/node.go index b428e379b2ef82c1755bd2e71de772cfd2d860da..96652f5fa9ff182b2490debc0c457113a795fd44 100644 --- a/pkg/vm/engine/tae/catalog/node.go +++ b/pkg/vm/engine/tae/catalog/node.go @@ -18,35 +18,41 @@ import ( "fmt" "sync" + "github.com/matrixorigin/matrixone/pkg/container/types" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/txnif" ) -type nodeList struct { +type nodeList[T any] struct { common.SSLLNode - host any - rwlocker *sync.RWMutex - name string -} - -func newNodeList(host any, rwlocker *sync.RWMutex, name string) *nodeList { - return &nodeList{ - SSLLNode: *common.NewSSLLNode(), - host: host, - rwlocker: rwlocker, - name: name, + getter func(uint64) *common.GenericDLNode[T] + txnChecker func(*common.GenericDLNode[T], types.TS) (bool, bool) + rwlocker *sync.RWMutex + name string +} + +func newNodeList[T any](getter func(uint64) *common.GenericDLNode[T], + txnChecker func(*common.GenericDLNode[T], types.TS) (bool, bool), + rwlocker *sync.RWMutex, + name string) *nodeList[T] { + return &nodeList[T]{ + SSLLNode: *common.NewSSLLNode(), + getter: getter, + txnChecker: txnChecker, + rwlocker: rwlocker, + name: name, } } -func (n *nodeList) CreateNode(id uint64) *nameNode { - nn := newNameNode(n.host, id) +func (n *nodeList[T]) CreateNode(id uint64) *nameNode[T] { + nn := newNameNode[T](id, n.getter) n.rwlocker.Lock() defer n.rwlocker.Unlock() n.Insert(nn) return nn } -func (n *nodeList) DeleteNode(id uint64) (deleted *nameNode, empty bool) { +func (n *nodeList[T]) DeleteNode(id uint64) (deleted *nameNode[T], empty bool) { n.rwlocker.Lock() defer n.rwlocker.Unlock() var prev common.ISSLLNode @@ -54,10 +60,10 @@ func (n *nodeList) DeleteNode(id uint64) (deleted *nameNode, empty bool) { curr := n.GetNext() depth := 0 for curr != nil { - nid := curr.(*nameNode).Id + nid := curr.(*nameNode[T]).id if id == nid { prev.ReleaseNextNode() - deleted = curr.(*nameNode) + deleted = curr.(*nameNode[T]) next := curr.GetNext() if next == nil && depth == 0 { empty = true @@ -71,16 +77,16 @@ func (n *nodeList) DeleteNode(id uint64) (deleted *nameNode, empty bool) { return } -func (n *nodeList) ForEachNodes(fn func(*nameNode) bool) { +func (n *nodeList[T]) ForEachNodes(fn func(*nameNode[T]) bool) { n.rwlocker.RLock() defer n.rwlocker.RUnlock() n.ForEachNodesLocked(fn) } -func (n *nodeList) ForEachNodesLocked(fn func(*nameNode) bool) { +func (n *nodeList[T]) ForEachNodesLocked(fn func(*nameNode[T]) bool) { curr := n.GetNext() for curr != nil { - nn := curr.(*nameNode) + nn := curr.(*nameNode[T]) if ok := fn(nn); !ok { break } @@ -88,9 +94,9 @@ func (n *nodeList) ForEachNodesLocked(fn func(*nameNode) bool) { } } -func (n *nodeList) LengthLocked() int { +func (n *nodeList[T]) LengthLocked() int { length := 0 - fn := func(*nameNode) bool { + fn := func(*nameNode[T]) bool { length++ return true } @@ -98,31 +104,16 @@ func (n *nodeList) LengthLocked() int { return length } -func (n *nodeList) Length() int { +func (n *nodeList[T]) Length() int { n.rwlocker.RLock() defer n.rwlocker.RUnlock() return n.LengthLocked() } -func (n *nodeList) GetTableNode() *common.DLNode { - n.rwlocker.RLock() - defer n.rwlocker.RUnlock() - return n.GetNext().(*nameNode).GetTableNode() -} - -func (n *nodeList) GetDBNode() *common.DLNode { +func (n *nodeList[T]) GetNode() *common.GenericDLNode[T] { n.rwlocker.RLock() defer n.rwlocker.RUnlock() - return n.GetNext().(*nameNode).GetDBNode() -} - -func (n *nodeList) TxnGetTableNodeLocked(txn txnif.TxnReader) (dn *common.DLNode, err error) { - getter := func(nn *nameNode) (n *common.DLNode, entry *BaseEntry) { - n = nn.GetTableNode() - entry = n.GetPayload().(*TableEntry).BaseEntry - return - } - return n.TxnGetNodeLocked(txn, getter) + return n.GetNext().(*nameNode[T]).GetNode() } // Create Deleted @@ -143,30 +134,17 @@ func (n *nodeList) TxnGetTableNodeLocked(txn txnif.TxnReader) (dn *common.DLNode // 7. Txn3 commit // 8. Txn4 can still find "tb1" // 9. Txn5 start and cannot find "tb1" -func (n *nodeList) TxnGetNodeLocked( - txn txnif.TxnReader, - getter func(*nameNode) (*common.DLNode, *BaseEntry, - )) (dn *common.DLNode, err error) { - fn := func(nn *nameNode) (goNext bool) { - dlNode, entry := getter(nn) - entry.RLock() - goNext = true - needWait, txnToWait := entry.NeedWaitCommitting(txn.GetStartTS()) - if needWait { - entry.RUnlock() - txnToWait.GetTxnState(true) - entry.RLock() - } - un := entry.GetNodeToRead(txn.GetStartTS()) - if un == nil { - entry.RUnlock() +func (n *nodeList[T]) TxnGetNodeLocked( + txn txnif.TxnReader) (dn *common.GenericDLNode[T], err error) { + fn := func(nn *nameNode[T]) bool { + dlNode := nn.GetNode() + can, dropped := n.txnChecker(dlNode, txn.GetStartTS()) + if !can { return true } - if un.HasDropped() { - entry.RUnlock() + if dropped { return false } - entry.RUnlock() dn = dlNode return true } @@ -177,22 +155,13 @@ func (n *nodeList) TxnGetNodeLocked( return } -func (n *nodeList) TxnGetDBNodeLocked(txn txnif.TxnReader) (*common.DLNode, error) { - getter := func(nn *nameNode) (n *common.DLNode, entry *BaseEntry) { - n = nn.GetDBNode() - entry = n.GetPayload().(*DBEntry).BaseEntry - return - } - return n.TxnGetNodeLocked(txn, getter) -} - -func (n *nodeList) PString(level common.PPLevel) string { +func (n *nodeList[T]) PString(level common.PPLevel) string { curr := n.GetNext() if curr == nil { return fmt.Sprintf("TableNode[\"%s\"](Len=0)", n.name) } - node := curr.(*nameNode) - s := fmt.Sprintf("TableNode[\"%s\"](Len=%d)->[%d", n.name, n.Length(), node.Id) + node := curr.(*nameNode[T]) + s := fmt.Sprintf("TableNode[\"%s\"](Len=%d)->[%d", n.name, n.Length(), node.id) if level == common.PPL0 { s = fmt.Sprintf("%s]", s) return s @@ -200,38 +169,32 @@ func (n *nodeList) PString(level common.PPLevel) string { curr = curr.GetNext() for curr != nil { - node := curr.(*nameNode) - s = fmt.Sprintf("%s->%d", s, node.Id) + node := curr.(*nameNode[T]) + s = fmt.Sprintf("%s->%d", s, node.id) curr = curr.GetNext() } s = fmt.Sprintf("%s]", s) return s } -type nameNode struct { +type nameNode[T any] struct { common.SSLLNode - Id uint64 - host any + getter func(uint64) *common.GenericDLNode[T] + id uint64 } -func newNameNode(host any, id uint64) *nameNode { - return &nameNode{ - Id: id, +func newNameNode[T any](id uint64, + getter func(uint64) *common.GenericDLNode[T]) *nameNode[T] { + return &nameNode[T]{ SSLLNode: *common.NewSSLLNode(), - host: host, - } -} - -func (n *nameNode) GetDBNode() *common.DLNode { - if n == nil { - return nil + getter: getter, + id: id, } - return n.host.(*Catalog).entries[n.Id] } -func (n *nameNode) GetTableNode() *common.DLNode { +func (n *nameNode[T]) GetNode() *common.GenericDLNode[T] { if n == nil { return nil } - return n.host.(*DBEntry).entries[n.Id] + return n.getter(n.id) } diff --git a/pkg/vm/engine/tae/catalog/node_test.go b/pkg/vm/engine/tae/catalog/node_test.go index b4309a1bc28c99a43ade72bb1826a1c7b29152fb..0b7c9570f5874b8f97b2f16f7382dc3c867cd825 100644 --- a/pkg/vm/engine/tae/catalog/node_test.go +++ b/pkg/vm/engine/tae/catalog/node_test.go @@ -29,9 +29,7 @@ type testNode struct { func newTestNode(val int) *testNode { return &testNode{val: val} } - -func (n *testNode) Compare(o common.NodePayload) int { - on := o.(*testNode) +func compareTestNode(n, on *testNode) int { if n.val > on.val { return 1 } else if n.val < on.val { @@ -41,10 +39,9 @@ func (n *testNode) Compare(o common.NodePayload) int { } func TestDLNode(t *testing.T) { - link := new(common.SortedDList) + link := common.NewGenericSortedDList[*testNode](compareTestNode) now := time.Now() - var node *common.DLNode - // for i := 10; i >= 0; i-- { + var node *common.GenericDLNode[*testNode] nodeCnt := 10 for i := 0; i < nodeCnt; i++ { n := link.Insert(newTestNode(i)) @@ -54,17 +51,17 @@ func TestDLNode(t *testing.T) { } t.Log(time.Since(now)) cnt := 0 - link.Loop(func(node *common.DLNode) bool { + link.Loop(func(n *common.GenericDLNode[*testNode]) bool { cnt++ return true }, true) assert.Equal(t, nodeCnt, cnt) - assert.Equal(t, 5, node.GetPayload().(*testNode).val) + assert.Equal(t, 5, node.GetPayload().val) link.Delete(node) cnt = 0 - link.Loop(func(node *common.DLNode) bool { - t.Logf("%d", node.GetPayload().(*testNode).val) + link.Loop(func(n *common.GenericDLNode[*testNode]) bool { + t.Logf("%d", node.GetPayload().val) cnt++ return true }, true) diff --git a/pkg/vm/engine/tae/catalog/segment.go b/pkg/vm/engine/tae/catalog/segment.go index 9a99ad7c7b2b214d8c5a6ac783411c801b8112dd..77ac24cb897157c84b1a35d472d4ed5c196ac776 100644 --- a/pkg/vm/engine/tae/catalog/segment.go +++ b/pkg/vm/engine/tae/catalog/segment.go @@ -31,11 +31,15 @@ import ( type SegmentDataFactory = func(meta *SegmentEntry) data.Segment +func compareSegmentFn(a, b *SegmentEntry) int { + return a.BaseEntry.DoCompre(b.BaseEntry) +} + type SegmentEntry struct { *BaseEntry table *TableEntry - entries map[uint64]*common.DLNode - link *common.SortedDList + entries map[uint64]*common.GenericDLNode[*BlockEntry] + link *common.GenericSortedDList[*BlockEntry] state EntryState segData data.Segment } @@ -45,8 +49,8 @@ func NewSegmentEntry(table *TableEntry, txn txnif.AsyncTxn, state EntryState, da e := &SegmentEntry{ BaseEntry: NewBaseEntry(id), table: table, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareBlockFn), + entries: make(map[uint64]*common.GenericDLNode[*BlockEntry]), state: state, } e.CreateWithTxn(txn) @@ -59,8 +63,8 @@ func NewSegmentEntry(table *TableEntry, txn txnif.AsyncTxn, state EntryState, da func NewReplaySegmentEntry() *SegmentEntry { e := &SegmentEntry{ BaseEntry: NewReplayBaseEntry(), - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareBlockFn), + entries: make(map[uint64]*common.GenericDLNode[*BlockEntry]), } return e } @@ -69,8 +73,8 @@ func NewStandaloneSegment(table *TableEntry, id uint64, ts types.TS) *SegmentEnt e := &SegmentEntry{ BaseEntry: NewBaseEntry(id), table: table, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareBlockFn), + entries: make(map[uint64]*common.GenericDLNode[*BlockEntry]), state: ES_Appendable, } e.CreateWithTS(ts) @@ -81,8 +85,8 @@ func NewSysSegmentEntry(table *TableEntry, id uint64) *SegmentEntry { e := &SegmentEntry{ BaseEntry: NewBaseEntry(id), table: table, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareBlockFn), + entries: make(map[uint64]*common.GenericDLNode[*BlockEntry]), state: ES_Appendable, } e.CreateWithTS(types.SystemDBTS) @@ -113,7 +117,7 @@ func (entry *SegmentEntry) GetBlockEntryByIDLocked(id uint64) (blk *BlockEntry, err = ErrNotFound return } - blk = node.GetPayload().(*BlockEntry) + blk = node.GetPayload() return } @@ -132,7 +136,7 @@ func (entry *SegmentEntry) PPString(level common.PPLevel, depth int, prefix stri } it := entry.MakeBlockIt(true) for it.Valid() { - block := it.Get().GetPayload().(*BlockEntry) + block := it.Get().GetPayload() block.RLock() _ = w.WriteByte('\n') _, _ = w.WriteString(block.PPString(level, depth+1, prefix)) @@ -169,16 +173,11 @@ func (entry *SegmentEntry) GetTable() *TableEntry { return entry.table } -func (entry *SegmentEntry) Compare(o common.NodePayload) int { - oe := o.(*SegmentEntry).BaseEntry - return entry.DoCompre(oe) -} - func (entry *SegmentEntry) GetAppendableBlockCnt() int { cnt := 0 it := entry.MakeBlockIt(true) for it.Valid() { - if it.Get().GetPayload().(*BlockEntry).IsAppendable() { + if it.Get().GetPayload().IsAppendable() { cnt++ } it.Next() @@ -189,7 +188,7 @@ func (entry *SegmentEntry) GetAppendableBlockCnt() int { func (entry *SegmentEntry) LastAppendableBlock() (blk *BlockEntry) { it := entry.MakeBlockIt(false) for it.Valid() { - itBlk := it.Get().GetPayload().(*BlockEntry) + itBlk := it.Get().GetPayload() if itBlk.IsAppendable() { blk = itBlk break @@ -227,10 +226,10 @@ func (entry *SegmentEntry) DropBlockEntry(id uint64, txn txnif.AsyncTxn) (delete return } -func (entry *SegmentEntry) MakeBlockIt(reverse bool) *common.SortedDListIt { +func (entry *SegmentEntry) MakeBlockIt(reverse bool) *common.GenericSortedDListIt[*BlockEntry] { entry.RLock() defer entry.RUnlock() - return common.NewSortedDListIt(entry.RWMutex, entry.link, reverse) + return common.NewGenericSortedDListIt(entry.RWMutex, entry.link, reverse) } func (entry *SegmentEntry) AddEntryLocked(block *BlockEntry) { @@ -338,7 +337,7 @@ func (entry *SegmentEntry) CollectBlockEntries(commitFilter func(be *BaseEntry) blks := make([]*BlockEntry, 0) blkIt := entry.MakeBlockIt(true) for blkIt.Valid() { - blk := blkIt.Get().GetPayload().(*BlockEntry) + blk := blkIt.Get().GetPayload() blk.RLock() if commitFilter != nil && blockFilter != nil { if commitFilter(blk.BaseEntry) && blockFilter(blk) { diff --git a/pkg/vm/engine/tae/catalog/table.go b/pkg/vm/engine/tae/catalog/table.go index d8cd0bbeab434ee3b6f2afd4170cf47cff53de2a..8d7728ec2aa899b2fcca230c1604c02787369240 100644 --- a/pkg/vm/engine/tae/catalog/table.go +++ b/pkg/vm/engine/tae/catalog/table.go @@ -29,12 +29,18 @@ import ( type TableDataFactory = func(meta *TableEntry) data.Table +func tableTxnCanGetFn[T *TableEntry](n *common.GenericDLNode[*TableEntry], ts types.TS) (can, dropped bool) { + table := n.GetPayload() + can, dropped = table.TxnCanGet(ts) + return +} + type TableEntry struct { *BaseEntry db *DBEntry schema *Schema - entries map[uint64]*common.DLNode - link *common.SortedDList + entries map[uint64]*common.GenericDLNode[*SegmentEntry] + link *common.GenericSortedDList[*SegmentEntry] tableData data.Table rows uint64 // fullname is format as 'tenantID-tableName', the tenantID prefix is only used 'mo_catalog' database @@ -60,8 +66,8 @@ func NewTableEntry(db *DBEntry, schema *Schema, txnCtx txnif.AsyncTxn, dataFacto BaseEntry: NewBaseEntry(id), db: db, schema: schema, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareSegmentFn), + entries: make(map[uint64]*common.GenericDLNode[*SegmentEntry]), } if dataFactory != nil { e.tableData = dataFactory(e) @@ -75,8 +81,8 @@ func NewSystemTableEntry(db *DBEntry, id uint64, schema *Schema) *TableEntry { BaseEntry: NewBaseEntry(id), db: db, schema: schema, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareSegmentFn), + entries: make(map[uint64]*common.GenericDLNode[*SegmentEntry]), } e.CreateWithTS(types.SystemDBTS) var sid uint64 @@ -97,8 +103,8 @@ func NewSystemTableEntry(db *DBEntry, id uint64, schema *Schema) *TableEntry { func NewReplayTableEntry() *TableEntry { e := &TableEntry{ BaseEntry: NewReplayBaseEntry(), - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareSegmentFn), + entries: make(map[uint64]*common.GenericDLNode[*SegmentEntry]), } return e } @@ -107,8 +113,8 @@ func MockStaloneTableEntry(id uint64, schema *Schema) *TableEntry { return &TableEntry{ BaseEntry: NewBaseEntry(id), schema: schema, - link: new(common.SortedDList), - entries: make(map[uint64]*common.DLNode), + link: common.NewGenericSortedDList(compareSegmentFn), + entries: make(map[uint64]*common.GenericDLNode[*SegmentEntry]), } } @@ -140,13 +146,13 @@ func (entry *TableEntry) GetSegmentByID(id uint64) (seg *SegmentEntry, err error if node == nil { return nil, ErrNotFound } - return node.GetPayload().(*SegmentEntry), nil + return node.GetPayload(), nil } -func (entry *TableEntry) MakeSegmentIt(reverse bool) *common.SortedDListIt { +func (entry *TableEntry) MakeSegmentIt(reverse bool) *common.GenericSortedDListIt[*SegmentEntry] { entry.RLock() defer entry.RUnlock() - return common.NewSortedDListIt(entry.RWMutex, entry.link, reverse) + return common.NewGenericSortedDListIt(entry.RWMutex, entry.link, reverse) } func (entry *TableEntry) CreateSegment(txn txnif.AsyncTxn, state EntryState, dataFactory SegmentDataFactory) (created *SegmentEntry, err error) { @@ -190,11 +196,6 @@ func (entry *TableEntry) GetFullName() string { return entry.fullName } -func (entry *TableEntry) Compare(o common.NodePayload) int { - oe := o.(*TableEntry).BaseEntry - return entry.DoCompre(oe) -} - func (entry *TableEntry) GetDB() *DBEntry { return entry.db } @@ -207,7 +208,7 @@ func (entry *TableEntry) PPString(level common.PPLevel, depth int, prefix string } it := entry.MakeSegmentIt(true) for it.Valid() { - segment := it.Get().GetPayload().(*SegmentEntry) + segment := it.Get().GetPayload() _ = w.WriteByte('\n') _, _ = w.WriteString(segment.PPString(level, depth+1, prefix)) it.Next() @@ -232,7 +233,7 @@ func (entry *TableEntry) GetTableData() data.Table { return entry.tableData } func (entry *TableEntry) LastAppendableSegmemt() (seg *SegmentEntry) { it := entry.MakeSegmentIt(false) for it.Valid() { - itSeg := it.Get().GetPayload().(*SegmentEntry) + itSeg := it.Get().GetPayload() if itSeg.IsAppendable() { seg = itSeg break @@ -251,7 +252,7 @@ func (entry *TableEntry) AsCommonID() *common.ID { func (entry *TableEntry) RecurLoop(processor Processor) (err error) { segIt := entry.MakeSegmentIt(true) for segIt.Valid() { - segment := segIt.Get().GetPayload().(*SegmentEntry) + segment := segIt.Get().GetPayload() if err = processor.OnSegment(segment); err != nil { if err == ErrStopCurrRecur { err = nil @@ -262,7 +263,7 @@ func (entry *TableEntry) RecurLoop(processor Processor) (err error) { } blkIt := segment.MakeBlockIt(true) for blkIt.Valid() { - block := blkIt.Get().GetPayload().(*BlockEntry) + block := blkIt.Get().GetPayload() if err = processor.OnBlock(block); err != nil { if err == ErrStopCurrRecur { err = nil diff --git a/pkg/vm/engine/tae/catalog/updatenode.go b/pkg/vm/engine/tae/catalog/updatenode.go index dc02e6f4b00e79f51ebf2be2e7647ceab5386fac..d1ee709563d6804a2f957b9bbf11d79e3c58d7e4 100644 --- a/pkg/vm/engine/tae/catalog/updatenode.go +++ b/pkg/vm/engine/tae/catalog/updatenode.go @@ -22,7 +22,6 @@ import ( "io" "github.com/matrixorigin/matrixone/pkg/container/types" - "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/txnif" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/wal" ) @@ -178,7 +177,7 @@ func (e *UpdateNode) ApplyDelete() (err error) { return } -func (e *UpdateNode) DoCompre(o *UpdateNode) int { +func compareUpdateNode(e, o *UpdateNode) int { if e.Start.Less(o.Start) { return -1 } @@ -186,11 +185,7 @@ func (e *UpdateNode) DoCompre(o *UpdateNode) int { return 0 } return 1 -} -func (e *UpdateNode) Compare(o common.NodePayload) int { - oe := o.(*UpdateNode) - return e.DoCompre(oe) } func (e *UpdateNode) AddLogIndex(idx *wal.Index) { if e.LogIndex == nil { diff --git a/pkg/vm/engine/tae/common/dlnode.go b/pkg/vm/engine/tae/common/dlnode.go index 041b4e8c42670da2bbd14f852b7b0319cd82344c..506c950a565ad722a60b2b8db3dbc2d198da3e91 100644 --- a/pkg/vm/engine/tae/common/dlnode.go +++ b/pkg/vm/engine/tae/common/dlnode.go @@ -24,17 +24,16 @@ type Row struct { id int } -func (r *Row) Compare(o *Row) int { - or := o.(*Row) - if r.id > ir.id { +func compare(a, b *Row) int { + if a.id > b.id { return 1 - } else if r.id < ir.id { + } else if a.id < b.id { return -1 } return 0 } -dlist := new(SortedDList) +dlist := NewGenericSortedDList[*Row](compare) n1 := dlist.Insert(&Row{id: 10}) // [10] n2 := dlist.Insert(&Row{id: 5}) // [10]<->[5] n3 := dlist.Insert(&Row{id: 13}) // [13]<->[10]<->[5] @@ -42,7 +41,7 @@ n3.id = 8 dlist.Update(n3) // [10]<->[8]<->[5] dlist.Delete(n1) // [8]<->[5] -it := NewSortedDListIt(nil, dlist,true) +it := NewGenericSortedDListIt(nil, dlist,true) for it.Valid() { n := it.GetPayload() // n.xxx @@ -51,18 +50,25 @@ for it.Valid() { */ // Sorted doubly linked-list -type SortedDList struct { - head *DLNode - tail *DLNode +type GenericSortedDList[T any] struct { + head *GenericDLNode[T] + tail *GenericDLNode[T] + compare func(T, T) int +} + +func NewGenericSortedDList[T any](compare func(T, T) int) *GenericSortedDList[T] { + return &GenericSortedDList[T]{ + compare: compare, + } } // Get the head node -func (l *SortedDList) GetHead() *DLNode { +func (l *GenericSortedDList[T]) GetHead() *GenericDLNode[T] { return l.head } // Get the tail node -func (l *SortedDList) GetTail() *DLNode { +func (l *GenericSortedDList[T]) GetTail() *GenericDLNode[T] { return l.tail } @@ -77,8 +83,8 @@ func (l *SortedDList) GetTail() *DLNode { // --------- UPDATE [10,x10] TO [2, x10]-------------- // // [List] [1,x1] <-> [2,x10] <-> [3,x3] <-> [20,x20] -func (l *SortedDList) Update(n *DLNode) { - nhead, ntail := n.KeepSorted() +func (l *GenericSortedDList[T]) Update(n *GenericDLNode[T]) { + nhead, ntail := n.KeepSorted(l.compare) if nhead != nil { l.head = nhead } @@ -88,26 +94,26 @@ func (l *SortedDList) Update(n *DLNode) { } // Get the length of the list -func (l *SortedDList) Depth() int { +func (l *GenericSortedDList[T]) Depth() int { depth := 0 - l.Loop(func(_ *DLNode) bool { + l.Loop(func(_ *GenericDLNode[T]) bool { depth++ return true }, false) return depth } -// Insert a object and wrap it as DLNode into the list +// Insert a object and wrap it as GenericDLNode into the list // The inserted object must be instance of interface NodePayload // [List]: [1,x1] <-> [5,x5] <-> [10,x10] // Insert a node [7,x7] // [List]: [1,x1] <-> [5,x5] <-> [7,x7] <-> [10,x10] -func (l *SortedDList) Insert(payload NodePayload) *DLNode { +func (l *GenericSortedDList[T]) Insert(payload T) *GenericDLNode[T] { var ( - n *DLNode - tail *DLNode + n *GenericDLNode[T] + tail *GenericDLNode[T] ) - n, l.head, tail = InsertDLNode(payload, l.head) + n, l.head, tail = InsertGenericDLNode(payload, l.head, l.compare) if tail != nil { l.tail = tail } @@ -119,7 +125,7 @@ func (l *SortedDList) Insert(payload NodePayload) *DLNode { // Delete [node] // // [prev node] <-> [node] <-> [next node] =============> [prev node] <-> [next node] -func (l *SortedDList) Delete(n *DLNode) { +func (l *GenericSortedDList[T]) Delete(n *GenericDLNode[T]) { prev := n.prev next := n.next if prev != nil && next != nil { @@ -138,40 +144,31 @@ func (l *SortedDList) Delete(n *DLNode) { } // Loop the list and apply fn on each node -func (l *SortedDList) Loop(fn func(n *DLNode) bool, reverse bool) { +func (l *GenericSortedDList[T]) Loop(fn func(n *GenericDLNode[T]) bool, reverse bool) { if reverse { - LoopSortedDList(l.tail, fn, reverse) + LoopGenericSortedDList[T](l.tail, fn, reverse) } else { - LoopSortedDList(l.head, fn, reverse) + LoopGenericSortedDList[T](l.head, fn, reverse) } } -// wrapped object type by a DLNode -type NodePayload interface { - Compare(NodePayload) int -} - // Doubly sorted linked-list node -type DLNode struct { - prev, next *DLNode - payload NodePayload +type GenericDLNode[T any] struct { + prev, next *GenericDLNode[T] + payload T } -func (l *DLNode) Compare(o *DLNode) int { - return l.payload.Compare(o.payload) -} - -func (l *DLNode) GetPayload() NodePayload { return l.payload } -func (l *DLNode) GetPrev() *DLNode { return l.prev } -func (l *DLNode) GetNext() *DLNode { return l.next } +func (l *GenericDLNode[T]) GetPayload() T { return l.payload } +func (l *GenericDLNode[T]) GetPrev() *GenericDLNode[T] { return l.prev } +func (l *GenericDLNode[T]) GetNext() *GenericDLNode[T] { return l.next } // Keep node be sorted in the list -func (l *DLNode) KeepSorted() (head *DLNode, tail *DLNode) { +func (l *GenericDLNode[T]) KeepSorted(compare func(T, T) int) (head, tail *GenericDLNode[T]) { curr := l head = curr prev := l.prev next := l.next - for (curr != nil && next != nil) && (curr.Compare(next) < 0) { + for (curr != nil && next != nil) && (compare(curr.payload, next.payload) < 0) { if head == curr { head = next } @@ -205,8 +202,10 @@ func (l *DLNode) KeepSorted() (head *DLNode, tail *DLNode) { // nhead is the new head node // ntail is the new tail node. // If ntail is not nil, tail is updated. Else tail is not updated -func InsertDLNode(payload NodePayload, head *DLNode) (node, nhead, ntail *DLNode) { - node = &DLNode{ +func InsertGenericDLNode[T any](payload T, + head *GenericDLNode[T], + compare func(T, T) int) (node, nhead, ntail *GenericDLNode[T]) { + node = &GenericDLNode[T]{ payload: payload, } if head == nil { @@ -217,25 +216,18 @@ func InsertDLNode(payload NodePayload, head *DLNode) (node, nhead, ntail *DLNode node.next = head head.prev = node - nhead, ntail = node.KeepSorted() + nhead, ntail = node.KeepSorted(compare) return } -// Given a node of a dlist list, find the head node -func FindHead(n *DLNode) *DLNode { - head := n - for head.prev != nil { - head = head.prev - } - return head -} - // Loop the list and apply fn on each node // head is the head node of a list // fn is operation applied to each node during iterating. // if fn(node) returns false, stop iterating. // reverse is true to loop in reversed way. -func LoopSortedDList(head *DLNode, fn func(node *DLNode) bool, reverse bool) { +func LoopGenericSortedDList[T any](head *GenericDLNode[T], + fn func(node *GenericDLNode[T]) bool, + reverse bool) { curr := head for curr != nil { goNext := fn(curr) @@ -251,24 +243,26 @@ func LoopSortedDList(head *DLNode, fn func(node *DLNode) bool, reverse bool) { } // Sorted doubly linked-list iterator -type SortedDListIt struct { +type GenericSortedDListIt[T any] struct { linkLocker *sync.RWMutex - curr *DLNode - nextFunc func(*DLNode) *DLNode + curr *GenericDLNode[T] + nextFunc func(*GenericDLNode[T]) *GenericDLNode[T] } // linkLocker is the outer locker to guard dlist access -func NewSortedDListIt(linkLocker *sync.RWMutex, dlist *SortedDList, reverse bool) *SortedDListIt { - it := &SortedDListIt{ +func NewGenericSortedDListIt[T any](linkLocker *sync.RWMutex, + dlist *GenericSortedDList[T], + reverse bool) *GenericSortedDListIt[T] { + it := &GenericSortedDListIt[T]{ linkLocker: linkLocker, } if reverse { - it.nextFunc = func(n *DLNode) *DLNode { + it.nextFunc = func(n *GenericDLNode[T]) *GenericDLNode[T] { return n.prev } it.curr = dlist.tail } else { - it.nextFunc = func(n *DLNode) *DLNode { + it.nextFunc = func(n *GenericDLNode[T]) *GenericDLNode[T] { return n.next } it.curr = dlist.head @@ -276,11 +270,11 @@ func NewSortedDListIt(linkLocker *sync.RWMutex, dlist *SortedDList, reverse bool return it } -func (it *SortedDListIt) Valid() bool { +func (it *GenericSortedDListIt[T]) Valid() bool { return it.curr != nil } -func (it *SortedDListIt) Next() { +func (it *GenericSortedDListIt[T]) Next() { if it.linkLocker == nil { it.curr = it.nextFunc(it.curr) return @@ -290,6 +284,6 @@ func (it *SortedDListIt) Next() { it.linkLocker.RUnlock() } -func (it *SortedDListIt) Get() *DLNode { +func (it *GenericSortedDListIt[T]) Get() *GenericDLNode[T] { return it.curr } diff --git a/pkg/vm/engine/tae/db/gcop.go b/pkg/vm/engine/tae/db/gcop.go index 017352da1f4adfbc0946fb912385f8d9eb077f38..a25a7e09f2710060c4dbed4e146f7f9de9cb308b 100644 --- a/pkg/vm/engine/tae/db/gcop.go +++ b/pkg/vm/engine/tae/db/gcop.go @@ -70,7 +70,7 @@ func gcSegmentClosure(entry *catalog.SegmentEntry, gct GCType) tasks.FuncT { table := entry.GetTable() it := entry.MakeBlockIt(false) for it.Valid() { - blk := it.Get().GetPayload().(*catalog.BlockEntry) + blk := it.Get().GetPayload() scopes = append(scopes, *blk.AsCommonID()) err = gcBlockClosure(blk, gct)() if err != nil { @@ -97,7 +97,7 @@ func gcTableClosure(entry *catalog.TableEntry, gct GCType) tasks.FuncT { dbEntry := entry.GetDB() it := entry.MakeSegmentIt(false) for it.Valid() { - seg := it.Get().GetPayload().(*catalog.SegmentEntry) + seg := it.Get().GetPayload() scopes = append(scopes, *seg.AsCommonID()) if err = gcSegmentClosure(seg, gct)(); err != nil { return @@ -119,7 +119,7 @@ func gcDatabaseClosure(entry *catalog.DBEntry) tasks.FuncT { }() it := entry.MakeTableIt(false) for it.Valid() { - table := it.Get().GetPayload().(*catalog.TableEntry) + table := it.Get().GetPayload() scopes = append(scopes, *table.AsCommonID()) if err = gcTableClosure(table, GCType_DB)(); err != nil { return diff --git a/pkg/vm/engine/tae/iface/txnif/types.go b/pkg/vm/engine/tae/iface/txnif/types.go index 3fb50dd79caba340dc90676c3ac1e73f603e3d8a..c3303ac025f7c7e4586e15e50f12a073de949680 100644 --- a/pkg/vm/engine/tae/iface/txnif/types.go +++ b/pkg/vm/engine/tae/iface/txnif/types.go @@ -128,8 +128,8 @@ type UpdateChain interface { RUnlock() GetID() *common.ID - DeleteNode(*common.DLNode) - DeleteNodeLocked(*common.DLNode) + // DeleteNode(*common.DLNode) + // DeleteNodeLocked(*common.DLNode) AddNode(txn AsyncTxn) UpdateNode AddNodeLocked(txn AsyncTxn) UpdateNode @@ -178,7 +178,7 @@ type UpdateNode interface { GetID() *common.ID String() string GetChain() UpdateChain - GetDLNode() *common.DLNode + // GetDLNode() *common.DLNode GetMask() *roaring.Bitmap GetValues() map[uint32]interface{} diff --git a/pkg/vm/engine/tae/tables/block.go b/pkg/vm/engine/tae/tables/block.go index 65dff518a2e12e28aa9c4076ea8bbf551ca35127..4ecae45419c62ecd5c41f57b49cd914091ff0196 100644 --- a/pkg/vm/engine/tae/tables/block.go +++ b/pkg/vm/engine/tae/tables/block.go @@ -17,10 +17,11 @@ package tables import ( "bytes" "fmt" - "github.com/matrixorigin/matrixone/pkg/container/types" "sync" "sync/atomic" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/RoaringBitmap/roaring" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/compute" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/containers" @@ -543,7 +544,7 @@ func (blk *dataBlock) updateWithFineLock( chain.Lock() node = chain.AddNodeLocked(txn) if err = chain.TryUpdateNodeLocked(row, v, node); err != nil { - chain.DeleteNodeLocked(node.GetDLNode()) + chain.DeleteNodeLocked(node.(*updates.ColumnUpdateNode)) } chain.Unlock() } diff --git a/pkg/vm/engine/tae/tables/updates/chain_test.go b/pkg/vm/engine/tae/tables/updates/chain_test.go index 002fd2ddf3dbe6c2f1e2d283305066272a41eb04..c0a63c3a9fd4274843737fc59cda48372f082b35 100644 --- a/pkg/vm/engine/tae/tables/updates/chain_test.go +++ b/pkg/vm/engine/tae/tables/updates/chain_test.go @@ -78,7 +78,7 @@ func TestColumnChain1(t *testing.T) { } t.Log(chain.StringLocked()) assert.Equal(t, cnt1+cnt2+cnt3+cnt4, chain.DepthLocked()) - t.Log(chain.GetHead().GetPayload().(*ColumnUpdateNode).StringLocked()) + t.Log(chain.GetHead().GetPayload().StringLocked()) } func TestColumnChain2(t *testing.T) { @@ -234,7 +234,7 @@ func TestColumnChain3(t *testing.T) { // t.Log(chain.StringLocked()) assert.Equal(t, ncnt, chain.DepthLocked()) - node := chain.GetHead().GetPayload().(*ColumnUpdateNode) + node := chain.GetHead().GetPayload() cmd, err := node.MakeCommand(1) assert.Nil(t, err) defer cmd.Close() diff --git a/pkg/vm/engine/tae/tables/updates/colupdate.go b/pkg/vm/engine/tae/tables/updates/colupdate.go index d7325800bca669b476897bb1a056b7346854b1e3..3ebefd1c49ab61016cae31c7d365e0f10e07bc38 100644 --- a/pkg/vm/engine/tae/tables/updates/colupdate.go +++ b/pkg/vm/engine/tae/tables/updates/colupdate.go @@ -31,8 +31,29 @@ import ( "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/wal" ) +func compareUpdateNode(a, b *ColumnUpdateNode) int { + a.RLock() + defer a.RUnlock() + b.RLock() + defer b.RUnlock() + if a.commitTs == b.commitTs { + if a.startTs.Less(b.startTs) { + return -1 + } else if a.startTs.Greater(b.startTs) { + return 1 + } + return 0 + } + if a.commitTs == txnif.UncommitTS { + return 1 + } else if b.commitTs == txnif.UncommitTS { + return -1 + } + return 0 +} + type ColumnUpdateNode struct { - *common.DLNode + *common.GenericDLNode[*ColumnUpdateNode] *sync.RWMutex mask *roaring.Bitmap vals map[uint32]any @@ -92,7 +113,7 @@ func (node *ColumnUpdateNode) MakeCommand(id uint32) (cmd txnif.TxnCmd, err erro func (node *ColumnUpdateNode) AttachTo(chain *ColumnChain) { node.chain = chain - node.DLNode = chain.Insert(node) + node.GenericDLNode = chain.Insert(node) } func (node *ColumnUpdateNode) GetID() *common.ID { @@ -107,9 +128,9 @@ func (node *ColumnUpdateNode) GetChain() txnif.UpdateChain { return node.chain } -func (node *ColumnUpdateNode) GetDLNode() *common.DLNode { - return node.DLNode -} +// func (node *ColumnUpdateNode) GetDLNode() *common.DLNode { +// return node.GenericDLNode +// } func (node *ColumnUpdateNode) SetMask(mask *roaring.Bitmap) { node.mask = mask } @@ -120,27 +141,6 @@ func (node *ColumnUpdateNode) SetValues(vals map[uint32]any) { node.vals = vals func (node *ColumnUpdateNode) GetValues() map[uint32]any { return node.vals } -func (node *ColumnUpdateNode) Compare(o common.NodePayload) int { - op := o.(*ColumnUpdateNode) - node.RLock() - defer node.RUnlock() - op.RLock() - defer op.RUnlock() - if node.commitTs == op.commitTs { - if node.startTs.Less(op.startTs) { - return -1 - } else if node.startTs.Greater(op.startTs) { - return 1 - } - return 0 - } - if node.commitTs == txnif.UncommitTS { - return 1 - } else if op.commitTs == txnif.UncommitTS { - return -1 - } - return 0 -} func (node *ColumnUpdateNode) GetValueLocked(row uint32) (v any, err error) { v = node.vals[row] @@ -377,7 +377,7 @@ func (node *ColumnUpdateNode) ApplyCommit(index *wal.Index) (err error) { } func (node *ColumnUpdateNode) PrepareRollback() (err error) { - node.chain.DeleteNode(node.DLNode) + node.chain.DeleteNode(node) return } diff --git a/pkg/vm/engine/tae/tables/updates/delchain.go b/pkg/vm/engine/tae/tables/updates/delchain.go index d1a0c19e985c443a3a5cf7bdfcba321530bf8c4e..3fdf49a39e3c4d4247e29e2c71fd808bf0cd96e8 100644 --- a/pkg/vm/engine/tae/tables/updates/delchain.go +++ b/pkg/vm/engine/tae/tables/updates/delchain.go @@ -32,7 +32,7 @@ import ( type DeleteChain struct { *sync.RWMutex - *common.SortedDList + *common.GenericSortedDList[*DeleteNode] mvcc *MVCCHandle cnt uint32 } @@ -42,9 +42,9 @@ func NewDeleteChain(rwlocker *sync.RWMutex, mvcc *MVCCHandle) *DeleteChain { rwlocker = new(sync.RWMutex) } chain := &DeleteChain{ - RWMutex: rwlocker, - SortedDList: new(common.SortedDList), - mvcc: mvcc, + RWMutex: rwlocker, + GenericSortedDList: common.NewGenericSortedDList(compareDeleteNode), + mvcc: mvcc, } return chain } @@ -73,8 +73,8 @@ func (chain *DeleteChain) StringLocked() string { func (chain *DeleteChain) GetController() *MVCCHandle { return chain.mvcc } func (chain *DeleteChain) LoopChainLocked(fn func(node *DeleteNode) bool, reverse bool) { - wrapped := func(n *common.DLNode) bool { - dnode := n.GetPayload().(*DeleteNode) + wrapped := func(n *common.GenericDLNode[*DeleteNode]) bool { + dnode := n.GetPayload() return fn(dnode) } chain.Loop(wrapped, reverse) @@ -144,14 +144,14 @@ func (chain *DeleteChain) PrepareRangeDelete(start, end uint32, ts types.TS) (er } func (chain *DeleteChain) UpdateLocked(node *DeleteNode) { - chain.Update(node.DLNode) + chain.Update(node.GenericDLNode) } func (chain *DeleteChain) RemoveNodeLocked(node txnif.DeleteNode) { - chain.Delete(node.(*DeleteNode).DLNode) + chain.Delete(node.(*DeleteNode).GenericDLNode) } -func (chain *DeleteChain) DepthLocked() int { return chain.SortedDList.Depth() } +func (chain *DeleteChain) DepthLocked() int { return chain.GenericSortedDList.Depth() } func (chain *DeleteChain) AddNodeLocked(txn txnif.AsyncTxn, deleteType handle.DeleteType) txnif.DeleteNode { node := NewDeleteNode(txn, deleteType) diff --git a/pkg/vm/engine/tae/tables/updates/delete.go b/pkg/vm/engine/tae/tables/updates/delete.go index 3e5dc90cd0eda37730912c7924447bd9968e719c..d129bbad92e4402151c24cbf715950a18961754c 100644 --- a/pkg/vm/engine/tae/tables/updates/delete.go +++ b/pkg/vm/engine/tae/tables/updates/delete.go @@ -17,10 +17,11 @@ package updates import ( "encoding/binary" "fmt" - "github.com/matrixorigin/matrixone/pkg/container/types" "io" "sync" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/RoaringBitmap/roaring" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/handle" @@ -36,9 +37,30 @@ const ( NT_Merge ) +func compareDeleteNode(a, b *DeleteNode) int { + a.RLock() + defer a.RUnlock() + b.RLock() + defer b.RUnlock() + if a.commitTs == b.commitTs { + if a.startTs.Less(b.startTs) { + return -1 + } else if a.startTs.Greater(b.startTs) { + return 1 + } + return 0 + } + if a.commitTs.Equal(txnif.UncommitTS) { + return 1 + } else if b.commitTs.Equal(txnif.UncommitTS) { + return -1 + } + return 0 +} + type DeleteNode struct { *sync.RWMutex - *common.DLNode + *common.GenericDLNode[*DeleteNode] chain *DeleteChain txn txnif.AsyncTxn logIndex *wal.Index @@ -96,29 +118,7 @@ func (node *DeleteNode) AddLogIndexLocked(index *wal.Index) { func (node *DeleteNode) IsMerged() bool { return node.nt == NT_Merge } func (node *DeleteNode) AttachTo(chain *DeleteChain) { node.chain = chain - node.DLNode = chain.Insert(node) -} - -func (node *DeleteNode) Compare(o common.NodePayload) int { - op := o.(*DeleteNode) - node.RLock() - defer node.RUnlock() - op.RLock() - defer op.RUnlock() - if node.commitTs == op.commitTs { - if node.startTs.Less(op.startTs) { - return -1 - } else if node.startTs.Greater(op.startTs) { - return 1 - } - return 0 - } - if node.commitTs.Equal(txnif.UncommitTS) { - return 1 - } else if op.commitTs.Equal(txnif.UncommitTS) { - return -1 - } - return 0 + node.GenericDLNode = chain.Insert(node) } func (node *DeleteNode) GetChain() txnif.DeleteChain { return node.chain } diff --git a/pkg/vm/engine/tae/tables/updates/mvcc.go b/pkg/vm/engine/tae/tables/updates/mvcc.go index 9905dabd776c9f9700f16aae2bd12a971f67e936..5b031802cef7ba7e615b18709858b9e20800c61f 100644 --- a/pkg/vm/engine/tae/tables/updates/mvcc.go +++ b/pkg/vm/engine/tae/tables/updates/mvcc.go @@ -16,10 +16,11 @@ package updates import ( "fmt" - "github.com/matrixorigin/matrixone/pkg/container/types" "sync" "sync/atomic" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/RoaringBitmap/roaring" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/catalog" "github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common" @@ -181,7 +182,7 @@ func (n *MVCCHandle) CreateUpdateNode(colIdx uint16, txn txnif.AsyncTxn) txnif.U func (n *MVCCHandle) DropUpdateNode(colIdx uint16, node txnif.UpdateNode) { chain := n.columns[colIdx] - chain.DeleteNodeLocked(node.GetDLNode()) + chain.DeleteNodeLocked(node.(*ColumnUpdateNode)) } func (n *MVCCHandle) PrepareUpdate(row uint32, colIdx uint16, update txnif.UpdateNode) error { diff --git a/pkg/vm/engine/tae/tables/updates/updatechain.go b/pkg/vm/engine/tae/tables/updates/updatechain.go index 274e44952e213acdfb6efb19daf1b3185941ec08..18c00b973c62c8193026d31a05b09ba547882538 100644 --- a/pkg/vm/engine/tae/tables/updates/updatechain.go +++ b/pkg/vm/engine/tae/tables/updates/updatechain.go @@ -28,7 +28,7 @@ import ( ) type ColumnChain struct { - *common.SortedDList + *common.GenericSortedDList[*ColumnUpdateNode] *sync.RWMutex id *common.ID view *ColumnView @@ -38,9 +38,9 @@ type ColumnChain struct { func MockColumnUpdateChain() *ColumnChain { chain := &ColumnChain{ - SortedDList: new(common.SortedDList), - RWMutex: new(sync.RWMutex), - id: &common.ID{}, + GenericSortedDList: common.NewGenericSortedDList(compareUpdateNode), + RWMutex: new(sync.RWMutex), + id: &common.ID{}, } chain.view = NewColumnView() return chain @@ -53,10 +53,10 @@ func NewColumnChain(rwlocker *sync.RWMutex, colIdx uint16, mvcc *MVCCHandle) *Co id := *mvcc.GetID() id.Idx = colIdx chain := &ColumnChain{ - SortedDList: new(common.SortedDList), - RWMutex: rwlocker, - mvcc: mvcc, - id: &id, + GenericSortedDList: common.NewGenericSortedDList(compareUpdateNode), + RWMutex: rwlocker, + mvcc: mvcc, + id: &id, } chain.view = NewColumnView() return chain @@ -110,18 +110,17 @@ func (chain *ColumnChain) AddNodeLocked(txn txnif.AsyncTxn) txnif.UpdateNode { return node } -func (chain *ColumnChain) DeleteNode(node *common.DLNode) { +func (chain *ColumnChain) DeleteNode(n *ColumnUpdateNode) { chain.Lock() defer chain.Unlock() - chain.DeleteNodeLocked(node) + chain.DeleteNodeLocked(n) } -func (chain *ColumnChain) DeleteNodeLocked(node *common.DLNode) { - n := node.GetPayload().(*ColumnUpdateNode) +func (chain *ColumnChain) DeleteNodeLocked(n *ColumnUpdateNode) { for row := range n.vals { _ = chain.view.Delete(row, n) } - chain.Delete(node) + chain.Delete(n.GenericDLNode) chain.SetUpdateCnt(uint32(chain.view.mask.GetCardinality())) } @@ -134,8 +133,8 @@ func (chain *ColumnChain) AddNode(txn txnif.AsyncTxn) txnif.UpdateNode { } func (chain *ColumnChain) LoopChainLocked(fn func(col *ColumnUpdateNode) bool, reverse bool) { - wrapped := func(node *common.DLNode) bool { - col := node.GetPayload().(*ColumnUpdateNode) + wrapped := func(node *common.GenericDLNode[*ColumnUpdateNode]) bool { + col := node.GetPayload() return fn(col) } chain.Loop(wrapped, reverse) @@ -159,7 +158,7 @@ func (chain *ColumnChain) PrepareUpdate(row uint32, n txnif.UpdateNode) error { } func (chain *ColumnChain) UpdateLocked(node *ColumnUpdateNode) { - chain.Update(node.DLNode) + chain.Update(node.GenericDLNode) } func (chain *ColumnChain) StringLocked() string { diff --git a/pkg/vm/engine/tae/tables/updates/updateview.go b/pkg/vm/engine/tae/tables/updates/updateview.go index 595e0757efbee5c03741d0688006d4a3f2370911..28dad4bef74353b45cad9601c818f597cee8095b 100644 --- a/pkg/vm/engine/tae/tables/updates/updateview.go +++ b/pkg/vm/engine/tae/tables/updates/updateview.go @@ -27,14 +27,14 @@ import ( ) type ColumnView struct { - links map[uint32]*common.SortedDList + links map[uint32]*common.GenericSortedDList[*ColumnUpdateNode] mask *roaring.Bitmap } func NewColumnView() *ColumnView { // func NewColumnView(chain *ColumnChain) *ColumnView { return &ColumnView{ - links: make(map[uint32]*common.SortedDList), + links: make(map[uint32]*common.GenericSortedDList[*ColumnUpdateNode]), mask: roaring.New(), } } @@ -69,7 +69,7 @@ func (view *ColumnView) GetValue(key uint32, startTs types.TS) (v any, err error } head := link.GetHead() for head != nil { - node := head.GetPayload().(*ColumnUpdateNode) + node := head.GetPayload() if node.GetStartTS().Less(startTs) { node.RLock() // | @@ -130,12 +130,12 @@ func (view *ColumnView) GetValue(key uint32, startTs types.TS) (v any, err error func (view *ColumnView) PrepapreInsert(key uint32, ts types.TS) (err error) { // First update to key - var link *common.SortedDList + var link *common.GenericSortedDList[*ColumnUpdateNode] if link = view.links[key]; link == nil { return } - node := link.GetHead().GetPayload().(*ColumnUpdateNode) + node := link.GetHead().GetPayload() node.RLock() // 1. The specified row has committed update if node.txn == nil { @@ -164,16 +164,16 @@ func (view *ColumnView) PrepapreInsert(key uint32, ts types.TS) (err error) { func (view *ColumnView) Insert(key uint32, un txnif.UpdateNode) (err error) { n := un.(*ColumnUpdateNode) // First update to key - var link *common.SortedDList + var link *common.GenericSortedDList[*ColumnUpdateNode] if link = view.links[key]; link == nil { - link = new(common.SortedDList) + link = common.NewGenericSortedDList[*ColumnUpdateNode](compareUpdateNode) link.Insert(n) view.mask.Add(key) view.links[key] = link return } - node := link.GetHead().GetPayload().(*ColumnUpdateNode) + node := link.GetHead().GetPayload() node.RLock() // 1. The specified row has committed update if node.txn == nil { @@ -203,9 +203,9 @@ func (view *ColumnView) Insert(key uint32, un txnif.UpdateNode) (err error) { func (view *ColumnView) Delete(key uint32, n *ColumnUpdateNode) (err error) { link := view.links[key] - var target *common.DLNode - link.Loop(func(dlnode *common.DLNode) bool { - node := dlnode.GetPayload().(*ColumnUpdateNode) + var target *common.GenericDLNode[*ColumnUpdateNode] + link.Loop(func(dlnode *common.GenericDLNode[*ColumnUpdateNode]) bool { + node := dlnode.GetPayload() if node.GetStartTS() == n.GetStartTS() { target = dlnode return false @@ -222,10 +222,11 @@ func (view *ColumnView) Delete(key uint32, n *ColumnUpdateNode) (err error) { return } -func (view *ColumnView) RowStringLocked(row uint32, link *common.SortedDList) string { +func (view *ColumnView) RowStringLocked(row uint32, + link *common.GenericSortedDList[*ColumnUpdateNode]) string { s := fmt.Sprintf("[ROW=%d]:", row) - link.Loop(func(dlnode *common.DLNode) bool { - n := dlnode.GetPayload().(*ColumnUpdateNode) + link.Loop(func(dlnode *common.GenericDLNode[*ColumnUpdateNode]) bool { + n := dlnode.GetPayload() n.RLock() s = fmt.Sprintf("%s\n%s", s, n.StringLocked()) n.RUnlock() diff --git a/pkg/vm/engine/tae/txn/txnimpl/block.go b/pkg/vm/engine/tae/txn/txnimpl/block.go index ae1113f54524e4729af616d62bfc8d5594876a34..6be1409278661b3a572a15733d28ae60f10b60df 100644 --- a/pkg/vm/engine/tae/txn/txnimpl/block.go +++ b/pkg/vm/engine/tae/txn/txnimpl/block.go @@ -37,7 +37,7 @@ type txnBlock struct { type blockIt struct { sync.RWMutex - linkIt *common.SortedDListIt + linkIt *common.GenericSortedDListIt[*catalog.BlockEntry] curr *catalog.BlockEntry table *txnTable err error @@ -59,7 +59,7 @@ func newBlockIt(table *txnTable, meta *catalog.SegmentEntry) *blockIt { var ok bool var err error for it.linkIt.Valid() { - curr := it.linkIt.Get().GetPayload().(*catalog.BlockEntry) + curr := it.linkIt.Get().GetPayload() curr.RLock() ok, err = curr.TxnCanRead(it.table.store.txn, curr.RWMutex) if err != nil { @@ -97,7 +97,7 @@ func (it *blockIt) Next() { it.curr = nil break } - entry := node.GetPayload().(*catalog.BlockEntry) + entry := node.GetPayload() entry.RLock() valid, err = entry.TxnCanRead(it.table.store.txn, entry.RWMutex) entry.RUnlock() diff --git a/pkg/vm/engine/tae/txn/txnimpl/database.go b/pkg/vm/engine/tae/txn/txnimpl/database.go index ae63ae69eacab822d6f7d8e1e8862369c7740373..881746a136cbc0a9791a369eeefe583c6a5e444e 100644 --- a/pkg/vm/engine/tae/txn/txnimpl/database.go +++ b/pkg/vm/engine/tae/txn/txnimpl/database.go @@ -28,7 +28,7 @@ import ( type txnDBIt struct { *sync.RWMutex txn txnif.AsyncTxn - linkIt *common.SortedDListIt + linkIt *common.GenericSortedDListIt[*catalog.DBEntry] itered bool // linkIt has no dummy head, use this to avoid duplicate filter logic for the very first entry curr *catalog.DBEntry err error @@ -66,7 +66,7 @@ func (it *txnDBIt) Next() { it.curr = nil break } - curr := node.GetPayload().(*catalog.DBEntry) + curr := node.GetPayload() curr.RLock() if curr.GetTenantID() == it.txn.GetTenantID() || isSysSharedDB(curr.GetName()) { valid, err = curr.TxnCanRead(it.txn, curr.RWMutex) diff --git a/pkg/vm/engine/tae/txn/txnimpl/relation.go b/pkg/vm/engine/tae/txn/txnimpl/relation.go index 29311f79645d2f5a9211ef106bc1e76febc002d4..858916f383b102de918fbb07ee16373e0ecc4587 100644 --- a/pkg/vm/engine/tae/txn/txnimpl/relation.go +++ b/pkg/vm/engine/tae/txn/txnimpl/relation.go @@ -32,7 +32,7 @@ var _ handle.RelationIt = (*txnRelationIt)(nil) type txnRelationIt struct { *sync.RWMutex txnDB *txnDB - linkIt *common.SortedDListIt + linkIt *common.GenericSortedDListIt[*catalog.TableEntry] itered bool // linkIt has no dummy head, use this to avoid duplicate filter logic for the very first entry curr *catalog.TableEntry err error @@ -72,7 +72,7 @@ func (it *txnRelationIt) Next() { it.curr = nil break } - entry := node.GetPayload().(*catalog.TableEntry) + entry := node.GetPayload() entry.RLock() // SystemDB can hold table created by different tenant, filter needed. // while the 3 shared tables are not affected diff --git a/pkg/vm/engine/tae/txn/txnimpl/segment.go b/pkg/vm/engine/tae/txn/txnimpl/segment.go index 82dbde4036364caeffd3bd210faead6c496bd971..c1250b9242269ec1cbbb2b6af0ef9c8685ac553c 100644 --- a/pkg/vm/engine/tae/txn/txnimpl/segment.go +++ b/pkg/vm/engine/tae/txn/txnimpl/segment.go @@ -32,7 +32,7 @@ type txnSegment struct { type segmentIt struct { sync.RWMutex - linkIt *common.SortedDListIt + linkIt *common.GenericSortedDListIt[*catalog.SegmentEntry] curr *catalog.SegmentEntry table *txnTable err error @@ -51,7 +51,7 @@ func newSegmentIt(table *txnTable) handle.SegmentIt { var err error var ok bool for it.linkIt.Valid() { - curr := it.linkIt.Get().GetPayload().(*catalog.SegmentEntry) + curr := it.linkIt.Get().GetPayload() curr.RLock() ok, err = curr.TxnCanRead(it.table.store.txn, curr.RWMutex) if err != nil { @@ -97,7 +97,7 @@ func (it *segmentIt) Next() { it.curr = nil break } - entry := node.GetPayload().(*catalog.SegmentEntry) + entry := node.GetPayload() entry.RLock() valid, err = entry.TxnCanRead(it.table.store.txn, entry.RWMutex) entry.RUnlock() diff --git a/pkg/vm/engine/tae/txn/txnimpl/table.go b/pkg/vm/engine/tae/txn/txnimpl/table.go index b053d9ee13bee88ba1d73646eafef3157a50582e..81e649b75158e68c94096ae7844e6c0a1a20b53f 100644 --- a/pkg/vm/engine/tae/txn/txnimpl/table.go +++ b/pkg/vm/engine/tae/txn/txnimpl/table.go @@ -554,7 +554,7 @@ func (tbl *txnTable) PreCommitOr2PCPrepareDedup() (err error) { func (tbl *txnTable) DoDedup(pks containers.Vector, preCommit bool) (err error) { segIt := tbl.entry.MakeSegmentIt(false) for segIt.Valid() { - seg := segIt.Get().GetPayload().(*catalog.SegmentEntry) + seg := segIt.Get().GetPayload() if preCommit && seg.GetID() < tbl.maxSegId { return } @@ -585,7 +585,7 @@ func (tbl *txnTable) DoDedup(pks containers.Vector, preCommit bool) (err error) err = nil blkIt := seg.MakeBlockIt(false) for blkIt.Valid() { - blk := blkIt.Get().GetPayload().(*catalog.BlockEntry) + blk := blkIt.Get().GetPayload() if preCommit && blk.GetID() < tbl.maxBlkId { return }