diff --git a/pkg/vm/engine/tae/tables/updates/chain_test.go b/pkg/vm/engine/tae/tables/updates/chain_test.go
index efa2777b106141cdb4c519d079e8f2a56da3bb0d..002fd2ddf3dbe6c2f1e2d283305066272a41eb04 100644
--- a/pkg/vm/engine/tae/tables/updates/chain_test.go
+++ b/pkg/vm/engine/tae/tables/updates/chain_test.go
@@ -16,11 +16,12 @@ package updates
 
 import (
 	"bytes"
-	"github.com/matrixorigin/matrixone/pkg/container/types"
 	"sync"
 	"testing"
 	"time"
 
+	"github.com/matrixorigin/matrixone/pkg/container/types"
+
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/catalog"
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common"
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/handle"
@@ -37,7 +38,7 @@ const (
 
 func mockTxn() *txnbase.Txn {
 	txn := new(txnbase.Txn)
-	txn.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
+	txn.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
 	return txn
 }
 
@@ -66,7 +67,7 @@ func TestColumnChain1(t *testing.T) {
 	cnt4 := 5
 	for i := 0; i < cnt1+cnt2+cnt3+cnt4; i++ {
 		txn := new(txnbase.Txn)
-		txn.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(),
+		txn.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(),
 			types.NextGlobalTsForTest(), nil)
 		n := chain.AddNode(txn)
 		if (i >= cnt1 && i < cnt1+cnt2) || (i >= cnt1+cnt2+cnt3) {
@@ -95,7 +96,7 @@ func TestColumnChain2(t *testing.T) {
 	controller := NewMVCCHandle(blk)
 	chain := NewColumnChain(nil, 0, controller)
 	txn1 := new(txnbase.Txn)
-	txn1.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
+	txn1.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
 	n1 := chain.AddNode(txn1)
 
 	err := chain.TryUpdateNodeLocked(1, int32(11), n1)
@@ -109,7 +110,7 @@ func TestColumnChain2(t *testing.T) {
 	assert.Equal(t, 3, chain.view.RowCnt())
 
 	txn2 := new(txnbase.Txn)
-	txn2.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(),
+	txn2.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(),
 		types.NextGlobalTsForTest(), nil)
 	n2 := chain.AddNode(txn2)
 	err = chain.TryUpdateNodeLocked(2, int32(222), n2)
@@ -131,7 +132,7 @@ func TestColumnChain2(t *testing.T) {
 	assert.Equal(t, 1, chain.view.links[4].Depth())
 
 	txn3 := new(txnbase.Txn)
-	txn3.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(),
+	txn3.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(),
 		types.NextGlobalTsForTest(), nil)
 	n3 := chain.AddNode(txn3)
 	err = chain.TryUpdateNodeLocked(2, int32(2222), n3)
@@ -143,7 +144,7 @@ func TestColumnChain2(t *testing.T) {
 		return func() {
 			defer wg.Done()
 			txn := new(txnbase.Txn)
-			txn.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(),
+			txn.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(),
 				types.NextGlobalTsForTest(), nil)
 			n := chain.AddNode(txn)
 			for j := 0; j < 4; j++ {
@@ -376,7 +377,7 @@ func TestDeleteChain1(t *testing.T) {
 	controller := NewMVCCHandle(blk)
 	chain := NewDeleteChain(nil, controller)
 	txn1 := new(txnbase.Txn)
-	txn1.TxnCtx = txnbase.NewTxnCtx(nil, common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
+	txn1.TxnCtx = txnbase.NewTxnCtx(common.NextGlobalSeqNum(), types.NextGlobalTsForTest(), nil)
 	n1 := chain.AddNodeLocked(txn1, handle.DeleteType(handle.DT_Normal)).(*DeleteNode)
 	assert.Equal(t, 1, chain.Depth())
 
diff --git a/pkg/vm/engine/tae/txn/txnbase/txn.go b/pkg/vm/engine/tae/txn/txnbase/txn.go
index 9a54be744d1fd1324b33c23a0e41d77145755772..fff6a5e409e7620f95ef44601e41d76b48765f9d 100644
--- a/pkg/vm/engine/tae/txn/txnbase/txn.go
+++ b/pkg/vm/engine/tae/txn/txnbase/txn.go
@@ -16,9 +16,10 @@ package txnbase
 
 import (
 	"fmt"
-	"github.com/matrixorigin/matrixone/pkg/container/types"
 	"sync"
 
+	"github.com/matrixorigin/matrixone/pkg/container/types"
+
 	"github.com/matrixorigin/matrixone/pkg/logutil"
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/common"
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/handle"
@@ -50,14 +51,12 @@ var DefaultTxnFactory = func(mgr *TxnManager, store txnif.TxnStore, id uint64, s
 }
 
 type Txn struct {
-	sync.RWMutex
 	sync.WaitGroup
 	*TxnCtx
-	Mgr      *TxnManager
-	Store    txnif.TxnStore
-	Err      error
-	DoneCond sync.Cond
-	LSN      uint64
+	Mgr   *TxnManager
+	Store txnif.TxnStore
+	Err   error
+	LSN   uint64
 
 	PrepareCommitFn   func(txnif.AsyncTxn) error
 	PrepareRollbackFn func(txnif.AsyncTxn) error
@@ -70,8 +69,7 @@ func NewTxn(mgr *TxnManager, store txnif.TxnStore, txnId uint64, start types.TS,
 		Mgr:   mgr,
 		Store: store,
 	}
-	txn.TxnCtx = NewTxnCtx(&txn.RWMutex, txnId, start, info)
-	txn.DoneCond = *sync.NewCond(txn)
+	txn.TxnCtx = NewTxnCtx(txnId, start, info)
 	return txn
 }
 
@@ -164,30 +162,6 @@ func (txn *Txn) IsTerminated(waitIfcommitting bool) bool {
 	return state == txnif.TxnStateCommitted || state == txnif.TxnStateRollbacked
 }
 
-func (txn *Txn) GetTxnState(waitIfcommitting bool) txnif.TxnState {
-	txn.RLock()
-	state := txn.State
-	if !waitIfcommitting {
-		txn.RUnlock()
-		return state
-	}
-	if state != txnif.TxnStateCommitting {
-		txn.RUnlock()
-		return state
-	}
-	txn.RUnlock()
-	txn.DoneCond.L.Lock()
-	state = txn.State
-	if state != txnif.TxnStateCommitting {
-		txn.DoneCond.L.Unlock()
-		return state
-	}
-	txn.DoneCond.Wait()
-	state = txn.State
-	txn.DoneCond.L.Unlock()
-	return state
-}
-
 func (txn *Txn) PrepareCommit() (err error) {
 	logutil.Debugf("Prepare Committing %d", txn.ID)
 	if txn.PrepareCommitFn != nil {
diff --git a/pkg/vm/engine/tae/txn/txnbase/txnctx.go b/pkg/vm/engine/tae/txn/txnbase/txnctx.go
index 7f009aeae411b0b9893093ade2ed4b588467d801..aa22f3d6fff6aac5be81b017c901255f1b052d72 100644
--- a/pkg/vm/engine/tae/txn/txnbase/txnctx.go
+++ b/pkg/vm/engine/tae/txn/txnbase/txnctx.go
@@ -17,9 +17,10 @@ package txnbase
 import (
 	"encoding/binary"
 	"fmt"
-	"github.com/matrixorigin/matrixone/pkg/container/types"
 	"sync"
 
+	"github.com/matrixorigin/matrixone/pkg/container/types"
+
 	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/iface/txnif"
 )
 
@@ -34,7 +35,8 @@ func IDCtxToID(buf []byte) uint64 {
 }
 
 type TxnCtx struct {
-	*sync.RWMutex
+	sync.RWMutex
+	DoneCond          sync.Cond
 	ID                uint64
 	IDCtx             []byte
 	StartTS, CommitTS types.TS
@@ -42,18 +44,16 @@ type TxnCtx struct {
 	State             txnif.TxnState
 }
 
-func NewTxnCtx(rwlocker *sync.RWMutex, id uint64, start types.TS, info []byte) *TxnCtx {
-	if rwlocker == nil {
-		rwlocker = new(sync.RWMutex)
-	}
-	return &TxnCtx{
+func NewTxnCtx(id uint64, start types.TS, info []byte) *TxnCtx {
+	ctx := &TxnCtx{
 		ID:       id,
 		IDCtx:    IDToIDCtx(id),
-		RWMutex:  rwlocker,
 		StartTS:  start,
 		CommitTS: txnif.UncommitTS,
 		Info:     info,
 	}
+	ctx.DoneCond = *sync.NewCond(ctx)
+	return ctx
 }
 
 func (ctx *TxnCtx) GetCtx() []byte {
@@ -63,7 +63,7 @@ func (ctx *TxnCtx) GetCtx() []byte {
 func (ctx *TxnCtx) Repr() string {
 	ctx.RLock()
 	defer ctx.RUnlock()
-	repr := fmt.Sprintf("Txn[%d][%d->%d][%s]", ctx.ID, ctx.StartTS, ctx.CommitTS, txnif.TxnStrState(ctx.State))
+	repr := fmt.Sprintf("ctx[%d][%d->%d][%s]", ctx.ID, ctx.StartTS, ctx.CommitTS, txnif.TxnStrState(ctx.State))
 	return repr
 }
 
@@ -85,6 +85,44 @@ func (ctx *TxnCtx) GetCommitTS() types.TS {
 	return ctx.CommitTS
 }
 
+// Atomically returns the current txn state
+func (ctx *TxnCtx) getTxnState() txnif.TxnState {
+	ctx.RLock()
+	defer ctx.RUnlock()
+	return ctx.State
+}
+
+// Wait txn state to be rollbacked or committed
+func (ctx *TxnCtx) resolveTxnState() txnif.TxnState {
+	ctx.DoneCond.L.Lock()
+	defer ctx.DoneCond.L.Unlock()
+	state := ctx.State
+	if state != txnif.TxnStateCommitting {
+		return state
+	}
+	ctx.DoneCond.Wait()
+	return ctx.State
+}
+
+// False when atomically get the current txn state
+//
+// True when the txn state is committing, wait it to be committed or rollbacked. It
+// is used during snapshot reads. If TxnStateActive is currently returned, this value will
+// definitely not be used, because even if it becomes TxnStateCommitting later, the timestamp
+// would be larger than the current read timestamp.
+func (ctx *TxnCtx) GetTxnState(waitIfcommitting bool) (state txnif.TxnState) {
+	// Quick get the current txn state
+	// If waitIfcommitting is false, return the state
+	// If state is not txnif.TxnStateCommitting, return the state
+	if state = ctx.getTxnState(); !waitIfcommitting || state != txnif.TxnStateCommitting {
+		return state
+	}
+
+	// Wait the committing txn to be committed or rollbacked
+	state = ctx.resolveTxnState()
+	return
+}
+
 func (ctx *TxnCtx) IsVisible(o txnif.TxnReader) bool {
 	ostart := o.GetStartTS()
 	ctx.RLock()