diff --git a/pkg/container/ring/count/count.go b/pkg/container/ring/count/count.go index 06f257d6ab75568c1ff1731cb45f904490acd615..31e65c1c87ff6c1e5d49792ee432032ae99fc7a8 100644 --- a/pkg/container/ring/count/count.go +++ b/pkg/container/ring/count/count.go @@ -207,3 +207,255 @@ func (r *CountRing) Eval(_ []int64) *vector.Vector { Typ: types.Type{Oid: types.T_int64, Size: 8}, } } + +func NewDistinctCount(typ types.Type) *DistCountRing { + return &DistCountRing{Typ: typ} +} + +func (r *DistCountRing) String() string { + return fmt.Sprintf("%v-%v", r.Vs, r.Ns) +} + +func (r *DistCountRing) Free(m *mheap.Mheap) { + if r.Da != nil { + mheap.Free(m, r.Da) + r.Da = nil + r.Vs = nil + r.Ns = nil + } +} + +func (r *DistCountRing) Count() int { + return len(r.Vs) +} + +func (r *DistCountRing) Size() int { + return cap(r.Da) +} + +func (r *DistCountRing) Dup() ring.Ring { + return &DistCountRing{ + Typ: r.Typ, + } +} + +func (r *DistCountRing) Type() types.Type { + return r.Typ +} + +func (r *DistCountRing) SetLength(n int) { + r.Vs = r.Vs[:n] + r.Ns = r.Ns[:n] + r.Ms = r.Ms[:n] +} + +func (r *DistCountRing) Shrink(sels []int64) { + for i, sel := range sels { + r.Vs[i] = r.Vs[sel] + r.Ns[i] = r.Ns[sel] + r.Ms[i] = r.Ms[sel] + } + r.Vs = r.Vs[:len(sels)] + r.Ns = r.Ns[:len(sels)] + r.Ms = r.Ms[:len(sels)] +} + +func (r *DistCountRing) Shuffle(_ []int64, _ *mheap.Mheap) error { + return nil +} + +func (r *DistCountRing) Grow(m *mheap.Mheap) error { + n := len(r.Vs) + if n == 0 { + data, err := mheap.Alloc(m, 64) + if err != nil { + return err + } + r.Da = data + r.Ns = make([]int64, 0, 8) + r.Vs = encoding.DecodeInt64Slice(data) + } else if n+1 >= cap(r.Vs) { + r.Da = r.Da[:n*8] + data, err := mheap.Grow(m, r.Da, int64(n+1)*8) + if err != nil { + return err + } + mheap.Free(m, r.Da) + r.Da = data + r.Vs = encoding.DecodeInt64Slice(data) + } + r.Vs = r.Vs[:n+1] + r.Vs[n] = 0 + r.Ns = append(r.Ns, 0) + r.Ms = append(r.Ms, make(map[any]uint8)) + return nil +} + +func (r *DistCountRing) Grows(size int, m *mheap.Mheap) error { + n := len(r.Vs) + if n == 0 { + data, err := mheap.Alloc(m, int64(size*8)) + if err != nil { + return err + } + r.Da = data + r.Ns = make([]int64, 0, size) + r.Vs = encoding.DecodeInt64Slice(data) + } else if n+size >= cap(r.Vs) { + r.Da = r.Da[:n*8] + data, err := mheap.Grow(m, r.Da, int64(n+size)*8) + if err != nil { + return err + } + mheap.Free(m, r.Da) + r.Da = data + r.Vs = encoding.DecodeInt64Slice(data) + } + r.Vs = r.Vs[:n+size] + for i := 0; i < size; i++ { + r.Ns = append(r.Ns, 0) + r.Ms = append(r.Ms, make(map[any]uint8)) + } + return nil +} + +func (r *DistCountRing) Fill(i int64, sel, z int64, vec *vector.Vector) { + if nulls.Contains(vec.Nsp, uint64(sel)) { + r.Ns[i] += z + } else { + if insertIntoMap(r.Ms[i], getValue(vec, sel)) { + r.Vs[i] += z + } + } +} + +func (r *DistCountRing) BatchFill(start int64, os []uint8, vps []uint64, zs []int64, vec *vector.Vector) { + if nulls.Any(vec.Nsp) { + for i := range os { + if nulls.Contains(vec.Nsp, uint64(start)+uint64(i)) { + r.Ns[vps[i]-1] += zs[int64(i)+start] + } else { + if insertIntoMap(r.Ms[vps[i]-1], getValue(vec, int64(i)+start)) { + r.Vs[vps[i]-1] += zs[int64(i)+start] + } + } + } + } else { + for i := range os { + if insertIntoMap(r.Ms[vps[i]-1], getValue(vec, int64(i)+start)) { + r.Vs[vps[i]-1] += zs[int64(i)+start] + } + } + } +} + +func (r *DistCountRing) BulkFill(i int64, zs []int64, vec *vector.Vector) { + if nulls.Any(vec.Nsp) { + for j, z := range zs { + if nulls.Contains(vec.Nsp, uint64(j)) { + r.Ns[i] += z + } else { + if insertIntoMap(r.Ms[i], getValue(vec, int64(j))) { + r.Vs[i] += z + } + } + } + } else { + for j, z := range zs { + if insertIntoMap(r.Ms[i], getValue(vec, int64(j))) { + r.Vs[i] += z + } + } + } +} + +func (r *DistCountRing) Add(a interface{}, x, y int64) { + ar := a.(*DistCountRing) + if insertIntoMap(r.Ms[x], ar.Vs[y]) { + r.Vs[x] += ar.Vs[y] + } + r.Ns[x] += ar.Ns[y] +} + +func (r *DistCountRing) BatchAdd(a interface{}, start int64, os []uint8, vps []uint64) { + ar := a.(*DistCountRing) + for i := range os { + if insertIntoMap(r.Ms[vps[i]-1], ar.Vs[int64(i)+start]) { + r.Vs[vps[i]-1] += ar.Vs[int64(i)+start] + } + r.Ns[vps[i]-1] += ar.Ns[int64(i)+start] + } +} + +// r[x] += a[y] * z +func (r *DistCountRing) Mul(a interface{}, x, y, z int64) { + ar := a.(*DistCountRing) + r.Vs[x] += ar.Vs[y] * z + r.Ns[x] += ar.Ns[y] * z +} + +func (r *DistCountRing) Eval(_ []int64) *vector.Vector { + defer func() { + r.Da = nil + r.Vs = nil + r.Ns = nil + }() + nsp := new(nulls.Nulls) + return &vector.Vector{ + Nsp: nsp, + Data: r.Da, + Col: r.Vs, + Or: false, + Typ: types.Type{Oid: types.T_int64, Size: 8}, + } +} + +func insertIntoMap(mp map[any]uint8, v any) bool { + if _, ok := mp[v]; ok { + return false + } + mp[v] = 0 + return true +} + +func getValue(vec *vector.Vector, sel int64) any { + switch vec.Typ.Oid { + case types.T_bool: + return vec.Col.([]bool)[sel] + case types.T_int8: + return vec.Col.([]int8)[sel] + case types.T_int16: + return vec.Col.([]int16)[sel] + case types.T_int32: + return vec.Col.([]int32)[sel] + case types.T_int64: + return vec.Col.([]int64)[sel] + case types.T_uint8: + return vec.Col.([]uint8)[sel] + case types.T_uint16: + return vec.Col.([]uint16)[sel] + case types.T_uint32: + return vec.Col.([]uint32)[sel] + case types.T_uint64: + return vec.Col.([]uint64)[sel] + case types.T_float32: + return vec.Col.([]float32)[sel] + case types.T_float64: + return vec.Col.([]float64)[sel] + case types.T_date: + return vec.Col.([]types.Date)[sel] + case types.T_datetime: + return vec.Col.([]types.Datetime)[sel] + case types.T_timestamp: + return vec.Col.([]types.Timestamp)[sel] + case types.T_decimal64: + return vec.Col.([]types.Decimal64)[sel] + case types.T_decimal128: + return vec.Col.([]types.Decimal128)[sel] + case types.T_char, types.T_varchar, types.T_json: + vs := vec.Col.(*types.Bytes) + return string(vs.Get(sel)) + default: + panic(fmt.Sprintf("unexpect type %s for function vector.SetLength", vec.Typ)) + } +} diff --git a/pkg/container/ring/count/types.go b/pkg/container/ring/count/types.go index 5387f97e80fd1312c10db8a6a622fda216601974..f247be00bbcd8a384f424a982e40f901da43b780 100644 --- a/pkg/container/ring/count/types.go +++ b/pkg/container/ring/count/types.go @@ -22,3 +22,11 @@ type CountRing struct { Vs []int64 Typ types.Type } + +type DistCountRing struct { + Da []byte + Ns []int64 + Vs []int64 + Typ types.Type + Ms []map[any]uint8 +} diff --git a/pkg/sql/colexec2/aggregate/aggregate.go b/pkg/sql/colexec2/aggregate/aggregate.go index 56049525b3d8a52406d597bcdf5f2af29d718f09..93857fed5b1cccb08fd97b735107670538f5e0ff 100755 --- a/pkg/sql/colexec2/aggregate/aggregate.go +++ b/pkg/sql/colexec2/aggregate/aggregate.go @@ -75,7 +75,7 @@ func ReturnType(op int, typ types.T) types.T { return 0 } -func New(op int, typ types.Type) (ring.Ring, error) { +func New(op int, dist bool, typ types.Type) (ring.Ring, error) { switch op { case Sum: return NewSum(typ) @@ -89,6 +89,9 @@ func New(op int, typ types.Type) (ring.Ring, error) { case Min: return NewMin(typ) case Count: + if dist { + return count.NewDistinctCount(typ), nil + } return count.NewCount(typ), nil case StarCount: return starcount.NewCount(typ), nil diff --git a/pkg/sql/colexec2/aggregate/types.go b/pkg/sql/colexec2/aggregate/types.go index 4ec5ac8eb112f7b32933415d98524299a7709165..54c51006a687f7615b17bee7f50decc389a535a4 100755 --- a/pkg/sql/colexec2/aggregate/types.go +++ b/pkg/sql/colexec2/aggregate/types.go @@ -47,6 +47,7 @@ var Names = [...]string{ } type Aggregate struct { - Op int - E *plan.Expr + Op int + Dist bool + E *plan.Expr } diff --git a/pkg/sql/colexec2/group/group.go b/pkg/sql/colexec2/group/group.go index 9cd2936fbbf1d810411787026bed1d38147b9343..21781939880fecabc99d1c939173bb992b53dda7 100644 --- a/pkg/sql/colexec2/group/group.go +++ b/pkg/sql/colexec2/group/group.go @@ -114,7 +114,7 @@ func (ctr *Container) process(ap *Argument, proc *process.Process) (bool, error) ctr.bat.Zs = []int64{0} ctr.bat.Rs = make([]ring.Ring, len(ap.Aggs)) for i, agg := range ap.Aggs { - if ctr.bat.Rs[i], err = aggregate.New(agg.Op, ctr.aggVecs[i].vec.Typ); err != nil { + if ctr.bat.Rs[i], err = aggregate.New(agg.Op, agg.Dist, ctr.aggVecs[i].vec.Typ); err != nil { return false, err } } @@ -248,7 +248,7 @@ func (ctr *Container) processWithGroup(ap *Argument, proc *process.Process) (boo } ctr.bat.Rs = make([]ring.Ring, len(ap.Aggs)) for i, agg := range ap.Aggs { - if ctr.bat.Rs[i], err = aggregate.New(agg.Op, ctr.aggVecs[i].vec.Typ); err != nil { + if ctr.bat.Rs[i], err = aggregate.New(agg.Op, agg.Dist, ctr.aggVecs[i].vec.Typ); err != nil { return false, err } } diff --git a/pkg/sql/colexec2/mergegroup/group_test.go b/pkg/sql/colexec2/mergegroup/group_test.go index 5c833561915e3a80c53a8dbe884af5152009e505..463f2725dcab584230acdd2392417adfabd05a46 100644 --- a/pkg/sql/colexec2/mergegroup/group_test.go +++ b/pkg/sql/colexec2/mergegroup/group_test.go @@ -306,7 +306,7 @@ func newBatch(t *testing.T, flgs []bool, ts []types.Type, proc *process.Process, bat.Vecs[i] = vec } { - r, _ := aggregate.New(aggregate.Max, ts[0]) + r, _ := aggregate.New(aggregate.Max, false, ts[0]) r.Grows(int(rows), proc.Mp) for i := int64(0); i < rows; i++ { r.Fill(i, i, 1, bat.Vecs[0]) diff --git a/pkg/sql/compile2/operator.go b/pkg/sql/compile2/operator.go index 1c3ee85617d84532f738e4106698c6e75c6fafad..00ba99556d856e070a2d5b55ae09e76b46775918 100644 --- a/pkg/sql/compile2/operator.go +++ b/pkg/sql/compile2/operator.go @@ -16,6 +16,7 @@ package compile2 import ( "fmt" + "github.com/matrixorigin/matrixone/pkg/sql/colexec2/deletion" "github.com/matrixorigin/matrixone/pkg/vm/engine" @@ -313,14 +314,18 @@ func constructGroup(n *plan.Node) *group.Argument { aggs := make([]aggregate.Aggregate, len(n.AggList)) for i, expr := range n.AggList { if f, ok := expr.Expr.(*plan.Expr_F); ok { + dist := (uint64(f.F.Func.Obj) & function.Distinct) != 0 + f.F.Func.Obj = int64(uint64(f.F.Func.Obj) & function.DistinctMask) fun, err := function.GetFunctionByID(f.F.Func.GetObj()) if err != nil { panic(err) } aggs[i] = aggregate.Aggregate{ - Op: fun.AggregateInfo, - E: f.F.Args[0], + E: f.F.Args[0], + Dist: dist, + Op: fun.AggregateInfo, } + } } diff --git a/pkg/sql/plan2/build_expr.go b/pkg/sql/plan2/build_expr.go index c4d9d48b8d1390cf5297b8903f668190346f43fa..f842878d8fb68d6b862f3550caf67e0cf1c65b66 100644 --- a/pkg/sql/plan2/build_expr.go +++ b/pkg/sql/plan2/build_expr.go @@ -116,13 +116,13 @@ func buildExpr(stmt tree.Expr, ctx CompilerContext, query *Query, node *Node, bi case *tree.ParenExpr: resultExpr, isAgg, err = buildExpr(astExpr.Expr, ctx, query, node, binderCtx, needAgg) case *tree.OrExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("or", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("or", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case *tree.XorExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("xor", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("xor", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case *tree.NotExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("not", []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("not", false, []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) case *tree.AndExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("and", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("and", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case *tree.UnaryExpr: resultExpr, isAgg, err = buildUnaryExpr(astExpr, ctx, query, node, binderCtx, needAgg) case *tree.BinaryExpr: @@ -138,13 +138,13 @@ func buildExpr(stmt tree.Expr, ctx CompilerContext, query *Query, node *Node, bi case *tree.CastExpr: resultExpr, isAgg, err = buildCastExpr(astExpr, ctx, query, node, binderCtx, needAgg) case *tree.IsNullExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("ifnull", []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("ifnull", false, []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) case *tree.IsNotNullExpr: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("ifnull", []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("ifnull", false, []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) if err != nil { return } - resultExpr, _, err = getFunctionExprByNameAndPlanExprs("not", []*Expr{resultExpr}) + resultExpr, _, err = getFunctionExprByNameAndPlanExprs("not", false, []*Expr{resultExpr}) case *tree.Tuple: exprs := make([]*Expr, 0, len(astExpr.Exprs)) for _, ast := range astExpr.Exprs { @@ -247,7 +247,7 @@ func buildCaseExpr(astExpr *tree.CaseExpr, ctx CompilerContext, query *Query, no isAgg = isAgg && paramIsAgg if caseExpr != nil { // rewrite "case col when 1 then '1' else '2'" to "case when col=1 then '1' else '2'" - condExpr, _, err = getFunctionExprByNameAndPlanExprs("=", []*Expr{caseExpr, condExpr}) + condExpr, _, err = getFunctionExprByNameAndPlanExprs("=", false, []*Expr{caseExpr, condExpr}) if err != nil { return nil, false, err } @@ -281,7 +281,7 @@ func buildCaseExpr(astExpr *tree.CaseExpr, ctx CompilerContext, query *Query, no }, }) } - return getFunctionExprByNameAndPlanExprs("case", args) + return getFunctionExprByNameAndPlanExprs("case", false, args) } func buildColRefExpr(astExpr *tree.UnresolvedName, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (expr *Expr, err error) { @@ -322,30 +322,30 @@ func buildColRefExpr(astExpr *tree.UnresolvedName, ctx CompilerContext, query *Q func buildRangeCond(astExpr *tree.RangeCond, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (resultExpr *Expr, isAgg bool, err error) { if astExpr.Not { - left, paramIsAgg, err := getFunctionExprByNameAndAstExprs("<", []tree.Expr{astExpr.Left, astExpr.From}, ctx, query, node, binderCtx, needAgg) + left, paramIsAgg, err := getFunctionExprByNameAndAstExprs("<", false, []tree.Expr{astExpr.Left, astExpr.From}, ctx, query, node, binderCtx, needAgg) if err != nil { return nil, false, err } isAgg = paramIsAgg - right, paramIsAgg, err := getFunctionExprByNameAndAstExprs(">", []tree.Expr{astExpr.Left, astExpr.To}, ctx, query, node, binderCtx, needAgg) + right, paramIsAgg, err := getFunctionExprByNameAndAstExprs(">", false, []tree.Expr{astExpr.Left, astExpr.To}, ctx, query, node, binderCtx, needAgg) if err != nil { return nil, false, err } isAgg = isAgg && paramIsAgg - resultExpr, _, err = getFunctionExprByNameAndPlanExprs("or", []*Expr{left, right}) + resultExpr, _, err = getFunctionExprByNameAndPlanExprs("or", false, []*Expr{left, right}) return resultExpr, isAgg, err } else { - left, paramIsAgg, err := getFunctionExprByNameAndAstExprs(">=", []tree.Expr{astExpr.Left, astExpr.From}, ctx, query, node, binderCtx, needAgg) + left, paramIsAgg, err := getFunctionExprByNameAndAstExprs(">=", false, []tree.Expr{astExpr.Left, astExpr.From}, ctx, query, node, binderCtx, needAgg) if err != nil { return nil, false, err } isAgg = paramIsAgg - right, paramIsAgg, err := getFunctionExprByNameAndAstExprs("<=", []tree.Expr{astExpr.Left, astExpr.To}, ctx, query, node, binderCtx, needAgg) + right, paramIsAgg, err := getFunctionExprByNameAndAstExprs("<=", false, []tree.Expr{astExpr.Left, astExpr.To}, ctx, query, node, binderCtx, needAgg) if err != nil { return nil, false, err } isAgg = isAgg && paramIsAgg - resultExpr, _, err = getFunctionExprByNameAndPlanExprs("and", []*Expr{left, right}) + resultExpr, _, err = getFunctionExprByNameAndPlanExprs("and", false, []*Expr{left, right}) return resultExpr, isAgg, err } } @@ -356,31 +356,31 @@ func buildFunctionExpr(astExpr *tree.FuncExpr, ctx CompilerContext, query *Query return nil, false, errors.New(errno.SyntaxErrororAccessRuleViolation, fmt.Sprintf("function expr '%v' is not support now", astExpr)) } funcName := funcReference.Parts[0] - return getFunctionExprByNameAndAstExprs(funcName, astExpr.Exprs, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs(funcName, astExpr.Type == tree.FUNC_TYPE_DISTINCT, astExpr.Exprs, ctx, query, node, binderCtx, needAgg) } func buildComparisonExpr(astExpr *tree.ComparisonExpr, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (resultExpr *Expr, isAgg bool, err error) { switch astExpr.Op { case tree.EQUAL: - return getFunctionExprByNameAndAstExprs("=", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("=", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.LESS_THAN: - return getFunctionExprByNameAndAstExprs("<", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("<", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.LESS_THAN_EQUAL: - return getFunctionExprByNameAndAstExprs("<=", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("<=", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.GREAT_THAN: - return getFunctionExprByNameAndAstExprs(">", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs(">", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.GREAT_THAN_EQUAL: - return getFunctionExprByNameAndAstExprs(">=", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs(">=", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.NOT_EQUAL: - return getFunctionExprByNameAndAstExprs("<>", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("<>", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.LIKE: - return getFunctionExprByNameAndAstExprs("like", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("like", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.NOT_LIKE: - resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("like", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + resultExpr, isAgg, err = getFunctionExprByNameAndAstExprs("like", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) if err != nil { return } - resultExpr, _, err = getFunctionExprByNameAndPlanExprs("not", []*Expr{resultExpr}) + resultExpr, _, err = getFunctionExprByNameAndPlanExprs("not", false, []*Expr{resultExpr}) return case tree.IN: return buildInExpr(astExpr, ctx, query, node, binderCtx, needAgg) @@ -404,7 +404,7 @@ func buildInExpr(astExpr *tree.ComparisonExpr, ctx CompilerContext, query *Query } return buildExpr(new_expr, ctx, query, node, binderCtx, needAgg) default: - return getFunctionExprByNameAndAstExprs("in", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("in", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) } } @@ -422,20 +422,20 @@ func buildNotInExpr(astExpr *tree.ComparisonExpr, ctx CompilerContext, query *Qu } return buildExpr(new_expr, ctx, query, node, binderCtx, needAgg) default: - resultExpr, _, err := getFunctionExprByNameAndAstExprs("in", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + resultExpr, _, err := getFunctionExprByNameAndAstExprs("in", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) if err != nil { return nil, false, err } - return getFunctionExprByNameAndPlanExprs("not", []*Expr{resultExpr}) + return getFunctionExprByNameAndPlanExprs("not", false, []*Expr{resultExpr}) } } func buildUnaryExpr(astExpr *tree.UnaryExpr, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (expr *Expr, isAgg bool, err error) { switch astExpr.Op { case tree.UNARY_MINUS: - return getFunctionExprByNameAndAstExprs("unary_minus", []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("unary_minus", false, []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) case tree.UNARY_PLUS: - return getFunctionExprByNameAndAstExprs("unary_plus", []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("unary_plus", false, []tree.Expr{astExpr.Expr}, ctx, query, node, binderCtx, needAgg) case tree.UNARY_TILDE: return nil, false, errors.New(errno.SyntaxErrororAccessRuleViolation, fmt.Sprintf("'%v' is not support now", astExpr)) case tree.UNARY_MARK: @@ -447,17 +447,17 @@ func buildUnaryExpr(astExpr *tree.UnaryExpr, ctx CompilerContext, query *Query, func buildBinaryExpr(astExpr *tree.BinaryExpr, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (expr *Expr, isAgg bool, err error) { switch astExpr.Op { case tree.PLUS: - return getFunctionExprByNameAndAstExprs("+", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("+", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.MINUS: - return getFunctionExprByNameAndAstExprs("-", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("-", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.MULTI: - return getFunctionExprByNameAndAstExprs("*", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("*", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.MOD: - return getFunctionExprByNameAndAstExprs("%", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("%", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.DIV: - return getFunctionExprByNameAndAstExprs("/", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("/", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) case tree.INTEGER_DIV: - return getFunctionExprByNameAndAstExprs("div", []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) + return getFunctionExprByNameAndAstExprs("div", false, []tree.Expr{astExpr.Left, astExpr.Right}, ctx, query, node, binderCtx, needAgg) } return nil, false, errors.New(errno.SyntaxErrororAccessRuleViolation, fmt.Sprintf("'%v' is not support now", astExpr)) diff --git a/pkg/sql/plan2/build_from.go b/pkg/sql/plan2/build_from.go index 74ee25e2986f4ea5f8a2cc742212cee7b6b81b4a..03ec7c3551c2b83ccf8daae7938f126c66ca0835 100644 --- a/pkg/sql/plan2/build_from.go +++ b/pkg/sql/plan2/build_from.go @@ -220,7 +220,7 @@ func buildJoinTable(tbl *tree.JoinTableExpr, ctx CompilerContext, query *Query, // append equal function expr to onlist var equalFunctionExpr *Expr - equalFunctionExpr, _, err = getFunctionExprByNameAndPlanExprs("=", []*Expr{leftColExpr, rigthColExpr}) + equalFunctionExpr, _, err = getFunctionExprByNameAndPlanExprs("=", false, []*Expr{leftColExpr, rigthColExpr}) if err != nil { return } diff --git a/pkg/sql/plan2/build_function.go b/pkg/sql/plan2/build_function.go index c3beb7a19a2afe24e88de397e859b9ffbafb231f..d626f1228e99b6a5d3ca1ff77f4a2ba33a8d7a25 100644 --- a/pkg/sql/plan2/build_function.go +++ b/pkg/sql/plan2/build_function.go @@ -26,7 +26,7 @@ import ( "github.com/matrixorigin/matrixone/pkg/sql/plan2/function" ) -func getFunctionExprByNameAndPlanExprs(name string, exprs []*Expr) (resultExpr *Expr, isAgg bool, err error) { +func getFunctionExprByNameAndPlanExprs(name string, distinct bool, exprs []*Expr) (resultExpr *Expr, isAgg bool, err error) { // deal with special function switch name { case "+", "-": @@ -94,7 +94,10 @@ func getFunctionExprByNameAndPlanExprs(name string, exprs []*Expr) (resultExpr * Typ: returnType, } isAgg = funcDef.IsAggregate() - + if isAgg && distinct { + fe := resultExpr.Expr.(*plan.Expr_F) + fe.F.Func.Obj = int64(uint64(fe.F.Func.Obj) | function.Distinct) + } return } @@ -111,7 +114,7 @@ func rewriteStarToCol(query *Query, node *Node) (string, error) { return "", errors.New(errno.InvalidColumnReference, "can not find any column when rewrite count(*) to starcount(col)") } -func getFunctionExprByNameAndAstExprs(name string, astExprs []tree.Expr, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (resultExpr *Expr, isAgg bool, err error) { +func getFunctionExprByNameAndAstExprs(name string, distinct bool, astExprs []tree.Expr, ctx CompilerContext, query *Query, node *Node, binderCtx *BinderContext, needAgg bool) (resultExpr *Expr, isAgg bool, err error) { // name = strings.ToLower(name) args := make([]*Expr, len(astExprs)) // deal with special function [rewrite some ast function expr] @@ -185,7 +188,7 @@ func getFunctionExprByNameAndAstExprs(name string, astExprs []tree.Expr, ctx Com } } - resultExpr, paramIsAgg, err = getFunctionExprByNameAndPlanExprs(name, args) + resultExpr, paramIsAgg, err = getFunctionExprByNameAndPlanExprs(name, distinct, args) if paramIsAgg { node.AggList = append(node.AggList, resultExpr) resultExpr = &Expr{ @@ -305,7 +308,7 @@ func getIntervalFunction(name string, dateExpr *Expr, intervalExpr *Expr) (*Expr "+": "date_add", "-": "date_sub", } - resultExpr, _, err := getFunctionExprByNameAndPlanExprs(namesMap[name], exprs) + resultExpr, _, err := getFunctionExprByNameAndPlanExprs(namesMap[name], false, exprs) return resultExpr, err } diff --git a/pkg/sql/plan2/build_subquery.go b/pkg/sql/plan2/build_subquery.go index 8020c7e171b674499a6355713948bc9901f05fa3..f93b086c5269e921746784457e65f974f3a8df3d 100644 --- a/pkg/sql/plan2/build_subquery.go +++ b/pkg/sql/plan2/build_subquery.go @@ -61,7 +61,7 @@ func buildSubQuery(subquery *tree.Subquery, ctx CompilerContext, query *Query, n }, } if subquery.Exists { - returnExpr, _, err = getFunctionExprByNameAndPlanExprs("exists", []*Expr{returnExpr}) + returnExpr, _, err = getFunctionExprByNameAndPlanExprs("exists", false, []*Expr{returnExpr}) if err != nil { return nil, err } diff --git a/pkg/sql/plan2/build_util.go b/pkg/sql/plan2/build_util.go index 950f3efb72f9f3ee4191cf0c8e996133cb98516b..e1c6d14b451ab844b8ab89306397eb348e8ca9c7 100644 --- a/pkg/sql/plan2/build_util.go +++ b/pkg/sql/plan2/build_util.go @@ -88,7 +88,7 @@ func getColumnsWithSameName(leftProjList []*Expr, rightProjList []*Expr) ([]*Exp Typ: leftProjList[leftIdx].Typ, } - equalFunctionExpr, _, err := getFunctionExprByNameAndPlanExprs("=", []*Expr{leftColExpr, rightColExpr}) + equalFunctionExpr, _, err := getFunctionExprByNameAndPlanExprs("=", false, []*Expr{leftColExpr, rightColExpr}) if err != nil { return nil, nil, err } diff --git a/pkg/sql/plan2/explain/explain_expr.go b/pkg/sql/plan2/explain/explain_expr.go index e2cfe548662f100da558dc85cff515bc101e2a5f..0f4af022fb823b07b510c44d17c8a0d0a6b2f817 100644 --- a/pkg/sql/plan2/explain/explain_expr.go +++ b/pkg/sql/plan2/explain/explain_expr.go @@ -85,7 +85,7 @@ func funcExprExplain(funcExpr *plan.Expr_F, Typ *plan.Type, options *ExplainOpti funcName := funcExpr.F.GetFunc().GetObjName() funcDef := funcExpr.F.GetFunc() - funcProtoType, err := function.GetFunctionByID(funcDef.Obj) + funcProtoType, err := function.GetFunctionByID(funcDef.Obj & function.DistinctMask) if err != nil { return result, errors.New(errno.InvalidName, "invalid function or opreator name '"+funcName+"'") } diff --git a/pkg/sql/plan2/function/function_id.go b/pkg/sql/plan2/function/function_id.go index 6b7499c2b0dfce4131b7419d0344664ed6ba3502..d563d7a4d703ae1311177e406fd0126c7a091d92 100644 --- a/pkg/sql/plan2/function/function_id.go +++ b/pkg/sql/plan2/function/function_id.go @@ -14,6 +14,11 @@ package function +const ( + Distinct = 0x8000000000000000 + DistinctMask = 0x7FFFFFFFFFFFFFFF +) + // All function IDs const ( EQUAL = iota // =