diff --git a/pkg/compare/compare.go b/pkg/compare/compare.go index 1dfee3fac270eef8b68de65f9f6c174a1b6211fd..67894888d7eade25ce6adbe0ac131cb4a452b34c 100644 --- a/pkg/compare/compare.go +++ b/pkg/compare/compare.go @@ -1,39 +1,9 @@ package compare import ( - abools "matrixbase/pkg/compare/asc/bools" - abytes "matrixbase/pkg/compare/asc/bytes" - afloats "matrixbase/pkg/compare/asc/floats" - aints "matrixbase/pkg/compare/asc/ints" - dbools "matrixbase/pkg/compare/desc/bools" - dbytes "matrixbase/pkg/compare/desc/bytes" - dfloats "matrixbase/pkg/compare/desc/floats" - dints "matrixbase/pkg/compare/desc/ints" "matrixbase/pkg/container/types" ) func New(typ types.T, desc bool) Compare { - switch typ { - case types.T_int: - if desc { - return dints.New() - } - return aints.New() - case types.T_bool: - if desc { - return dbools.New() - } - return abools.New() - case types.T_float: - if desc { - return dfloats.New() - } - return afloats.New() - case types.T_bytes, types.T_json: - if desc { - return dbytes.New() - } - return abytes.New() - } return nil } diff --git a/pkg/sql/colexec/limit/limit.go b/pkg/sql/colexec/limit/limit.go index 904f2cb8eac06b02cae4678110773614efdc1384..411c8d2f838b875a2e8b165ddf996c83a00e91f1 100644 --- a/pkg/sql/colexec/limit/limit.go +++ b/pkg/sql/colexec/limit/limit.go @@ -4,6 +4,7 @@ import ( "matrixbase/pkg/container/batch" "matrixbase/pkg/encoding" "matrixbase/pkg/vm/process" + "matrixbase/pkg/vm/register" ) func Prepare(_ *process.Process, _ interface{}) error { @@ -22,6 +23,7 @@ func Call(proc *process.Process, arg interface{}) (bool, error) { } n.Seen = newSeen proc.Reg.Ax = bat + register.FreeRegisters(proc) return false, nil } length, err := bat.Length(proc) @@ -41,6 +43,7 @@ func Call(proc *process.Process, arg interface{}) (bool, error) { } n.Seen = newSeen proc.Reg.Ax = bat + register.FreeRegisters(proc) return false, nil } diff --git a/pkg/sql/colexec/top/top.go b/pkg/sql/colexec/top/top.go new file mode 100644 index 0000000000000000000000000000000000000000..0f26ed41eb7ef81c81f8705b988cb7a7145ec86c --- /dev/null +++ b/pkg/sql/colexec/top/top.go @@ -0,0 +1,106 @@ +package top + +import ( + "container/heap" + "matrixbase/pkg/compare" + "matrixbase/pkg/container/batch" + "matrixbase/pkg/container/vector" + "matrixbase/pkg/encoding" + "matrixbase/pkg/vm/process" + "matrixbase/pkg/vm/register" +) + +func Prepare(proc *process.Process, arg interface{}) error { + n := arg.(Argument) + data, err := proc.Alloc(n.Limit * 8) + if err != nil { + return err + } + sels := encoding.DecodeInt64Slice(data) + for i := int64(0); i < n.Limit; i++ { + sels[i] = i + } + n.Ctr.n = len(n.Fs) + n.Ctr.sels = sels + n.Ctr.selsData = data + n.Ctr.vecs = make([]*vector.Vector, len(n.Fs)) + n.Ctr.cmps = make([]compare.Compare, len(n.Fs)) + for i, f := range n.Fs { + n.Ctr.cmps[i] = compare.New(f.Oid, f.Type == Descending) + } + return nil +} + +func Call(proc *process.Process, arg interface{}) (bool, error) { + var err error + + n := arg.(Argument) + bat := proc.Reg.Ax.(*batch.Batch) + for i, f := range n.Fs { + n.Ctr.vecs[i], err = bat.GetVector(f.Attr, proc) + if err != nil { + for j := 0; j < i; j++ { + n.Ctr.vecs[i].Free(proc) + } + return false, err + } + } + processBatch(bat, n) + data, err := proc.Alloc(int64(len(n.Ctr.sels)) * 8) + if err != nil { + for _, vec := range n.Ctr.vecs { + vec.Free(proc) + } + proc.Free(n.Ctr.selsData) + return false, err + } + sels := encoding.DecodeInt64Slice(data) + for i, j := 0, len(n.Ctr.sels); i < j; i++ { + sels[len(sels)-1-i] = heap.Pop(&n.Ctr).(int64) + } + if len(bat.Sels) > 0 { + proc.Free(bat.SelsData) + } + bat.Sels = sels + bat.SelsData = data + proc.Reg.Ax = bat + register.FreeRegisters(proc) + return false, nil +} + +func processBatch(bat *batch.Batch, n Argument) { + if length := int64(len(bat.Sels)); length > 0 { + if length < n.Limit { + for i := int64(0); i < length; i++ { + n.Ctr.sels[i] = bat.Sels[i] + } + n.Ctr.sels = n.Ctr.sels[:length] + heap.Init(&n.Ctr) + return + } + for i := int64(0); i < n.Limit; i++ { + n.Ctr.sels[i] = bat.Sels[i] + } + heap.Init(&n.Ctr) + for i, j := n.Limit, length; i < j; i++ { + if n.Ctr.compare(bat.Sels[i], n.Ctr.sels[0]) < 0 { + n.Ctr.sels[0] = bat.Sels[i] + } + heap.Fix(&n.Ctr, 0) + } + return + } + length := int64(n.Ctr.vecs[0].Length()) + if length < n.Limit { + n.Ctr.sels = n.Ctr.sels[:length] + heap.Init(&n.Ctr) + return + } + heap.Init(&n.Ctr) + for i, j := n.Limit, length; i < j; i++ { + if n.Ctr.compare(i, n.Ctr.sels[0]) < 0 { + n.Ctr.sels[0] = i + } + heap.Fix(&n.Ctr, 0) + } +} diff --git a/pkg/sql/colexec/top/types.go b/pkg/sql/colexec/top/types.go new file mode 100644 index 0000000000000000000000000000000000000000..c0d37136798419653b18ba3b8e4466c237dd3015 --- /dev/null +++ b/pkg/sql/colexec/top/types.go @@ -0,0 +1,70 @@ +package top + +import ( + "matrixbase/pkg/compare" + "matrixbase/pkg/container/types" + "matrixbase/pkg/container/vector" +) + +// Direction for ordering results. +type Direction int8 + +// Direction values. +const ( + DefaultDirection Direction = iota + Ascending + Descending +) + +type Container struct { + n int // number of attributes involved in sorting + sels []int64 + selsData []byte + vecs []*vector.Vector + cmps []compare.Compare +} + +type Field struct { + Oid types.T + Attr string + Type Direction +} + +type Argument struct { + Limit int64 + Fs []Field + Ctr Container +} + +func (ctr *Container) compare(i, j int64) int { + for k := 0; k < ctr.n; k++ { + if r := ctr.cmps[k].Compare(0, 0, i, j); r != 0 { + return r + } + } + return 0 +} + +// maximum heap +func (ctr *Container) Len() int { + return len(ctr.sels) +} + +func (ctr *Container) Less(i, j int) bool { + return ctr.compare(ctr.sels[i], ctr.sels[j]) > 0 +} + +func (ctr *Container) Swap(i, j int) { + ctr.sels[i], ctr.sels[j] = ctr.sels[j], ctr.sels[i] +} + +func (ctr *Container) Push(x interface{}) { + ctr.sels = append(ctr.sels, x.(int64)) +} + +func (ctr *Container) Pop() interface{} { + n := len(ctr.sels) - 1 + x := ctr.sels[n] + ctr.sels = ctr.sels[:n] + return x +}