From 9fcf0d95a7b9d5fae84ecf883b124f139be76d93 Mon Sep 17 00:00:00 2001 From: dongdongyang <47596332+dongdongyang33@users.noreply.github.com> Date: Tue, 16 Aug 2022 10:48:28 +0800 Subject: [PATCH] implement intersect (#4485) add intersect Approved by: @nnsgmsone, @aressu1985 --- pkg/sql/colexec/intersect/intersect.go | 213 ++++++++++++++++++++ pkg/sql/colexec/intersect/intersect_test.go | 129 ++++++++++++ pkg/sql/colexec/intersect/types.go | 51 +++++ pkg/sql/compile/compile.go | 10 +- pkg/sql/compile/operator.go | 18 ++ pkg/vm/overload.go | 17 +- pkg/vm/types.go | 1 + test/cases/set/set_operator.test | 25 +++ test/result/set/set_operator.result | 29 +++ 9 files changed, 486 insertions(+), 7 deletions(-) create mode 100644 pkg/sql/colexec/intersect/intersect.go create mode 100644 pkg/sql/colexec/intersect/intersect_test.go create mode 100644 pkg/sql/colexec/intersect/types.go create mode 100644 test/cases/set/set_operator.test create mode 100644 test/result/set/set_operator.result diff --git a/pkg/sql/colexec/intersect/intersect.go b/pkg/sql/colexec/intersect/intersect.go new file mode 100644 index 000000000..bd8727298 --- /dev/null +++ b/pkg/sql/colexec/intersect/intersect.go @@ -0,0 +1,213 @@ +// Copyright 2022 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intersect + +import ( + "bytes" + + "github.com/matrixorigin/matrixone/pkg/common/hashmap" + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/vm/process" +) + +func String(_ any, buf *bytes.Buffer) { + buf.WriteString(" intersect ") +} + +func Prepare(proc *process.Process, argument any) error { + var err error + arg := argument.(*Argument) + arg.ctr.btc = nil + arg.ctr.hashTable, err = hashmap.NewStrMap(true, arg.IBucket, arg.NBucket, proc.Mp) + if err != nil { + return err + } + arg.ctr.inBuckets = make([]uint8, hashmap.UnitLimit) + return nil +} + +func Call(idx int, proc *process.Process, argument any) (bool, error) { + arg := argument.(*Argument) + + analyze := proc.GetAnalyze(idx) + analyze.Start() + defer analyze.Stop() + + for { + switch arg.ctr.state { + case build: + if err := arg.ctr.buildHashTable(proc, analyze, 1); err != nil { + arg.ctr.state = end + return true, err + } + arg.ctr.state = probe + + case probe: + var err error + isLast := false + if isLast, err = arg.ctr.probeHashTable(proc, analyze, 0); err != nil { + arg.ctr.state = end + return true, err + } + if isLast { + arg.ctr.state = end + continue + } + + return false, nil + + case end: + arg.ctr.hashTable.Free() + arg.ctr.freeSels(proc) + proc.SetInputBatch(nil) + return true, nil + } + } +} + +// build hash table +func (c *container) buildHashTable(proc *process.Process, analyse process.Analyze, idx int) error { + for { + btc := <-proc.Reg.MergeReceivers[idx].Ch + + // last batch of block + if btc == nil { + break + } + + // empty batch + if btc.Length() == 0 { + continue + } + + analyse.Input(btc) + + cnt := btc.Length() + itr := c.hashTable.NewIterator() + for i := 0; i < cnt; i += hashmap.UnitLimit { + rowcnt := c.hashTable.GroupCount() + + n := cnt - i + if n > hashmap.UnitLimit { + n = hashmap.UnitLimit + } + + vs, zs, err := itr.Insert(i, n, btc.Vecs) + if err != nil { + btc.Clean(proc.Mp) + return err + } + + for j, v := range vs { + if zs[j] == 0 { + continue + } + + if v > rowcnt { + c.cnts = append(c.cnts, proc.GetMheap().GetSels()) + c.cnts[v-1] = append(c.cnts[v-1], 1) + rowcnt++ + } + } + } + btc.Clean(proc.Mp) + } + return nil +} + +func (c *container) probeHashTable(proc *process.Process, analyze process.Analyze, idx int) (bool, error) { + btc := <-proc.Reg.MergeReceivers[idx].Ch + + // last batch of block + if btc == nil { + return true, nil + } + + // empty batch + if btc.Length() == 0 { + return false, nil + } + + analyze.Input(btc) + defer btc.Clean(proc.Mp) + + c.btc = batch.NewWithSize(len(btc.Vecs)) + for i := range btc.Vecs { + c.btc.Vecs[i] = vector.New(btc.Vecs[i].Typ) + } + needInsert := make([]uint8, hashmap.UnitLimit) + resetsNeedInsert := make([]uint8, hashmap.UnitLimit) + cnt := btc.Length() + itr := c.hashTable.NewIterator() + for i := 0; i < cnt; i += hashmap.UnitLimit { + n := cnt - i + if n > hashmap.UnitLimit { + n = hashmap.UnitLimit + } + + copy(c.inBuckets, hashmap.OneUInt8s) + copy(needInsert, resetsNeedInsert) + insertcnt := 0 + + vs, zs := itr.Find(i, n, btc.Vecs, c.inBuckets) + + for j, v := range vs { + // not in the processed bucket + if c.inBuckets[j] == 0 { + continue + } + + // null value + if zs[j] == 0 { + continue + } + + // not found + if v == 0 { + continue + } + + // has been added into output batch + if c.cnts[v-1][0] == 0 { + continue + } + + needInsert[j] = 1 + c.cnts[v-1][0] = 0 + c.btc.Zs = append(c.btc.Zs, 1) + insertcnt++ + } + + if insertcnt > 0 { + for pos := range btc.Vecs { + if err := vector.UnionBatch(c.btc.Vecs[pos], btc.Vecs[pos], int64(i), insertcnt, needInsert, proc.Mp); err != nil { + return false, err + } + } + } + } + + analyze.Output(c.btc) + proc.SetInputBatch(c.btc) + return false, nil +} + +func (c *container) freeSels(proc *process.Process) { + for i := range c.cnts { + proc.GetMheap().PutSels(c.cnts[i]) + } + c.cnts = nil +} diff --git a/pkg/sql/colexec/intersect/intersect_test.go b/pkg/sql/colexec/intersect/intersect_test.go new file mode 100644 index 000000000..2850c86f2 --- /dev/null +++ b/pkg/sql/colexec/intersect/intersect_test.go @@ -0,0 +1,129 @@ +// Copyright 2022 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intersect + +import ( + "context" + "testing" + + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/matrixorigin/matrixone/pkg/vm/mheap" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "github.com/stretchr/testify/require" +) + +type intersectTestCase struct { + proc *process.Process + arg *Argument + cancel context.CancelFunc +} + +func TestIntersect(t *testing.T) { + proc := testutil.NewProcess() + // [2 rows + 2 row, 3 columns] intersect [1 row + 1 rows, 3 columns] + /* + {1, 2, 3} {1, 2, 3} + {1, 2, 3} intersect {4, 5, 6} ==> {1, 2, 3} + {3, 4, 5} + {3, 4, 5} + */ + c := newIntersectTestCase( + proc, + []*batch.Batch{ + testutil.NewBatchWithVectors( + []*vector.Vector{ + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{1, 1}), + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{2, 2}), + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{3, 3}), + }, nil), + testutil.NewBatchWithVectors( + []*vector.Vector{ + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{3, 3}), + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{4, 4}), + testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{5, 5}), + }, nil), + }, + []*batch.Batch{ + testutil.NewBatchWithVectors( + []*vector.Vector{ + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{1}), + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{2}), + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{3}), + }, nil), + testutil.NewBatchWithVectors( + []*vector.Vector{ + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{4}), + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{5}), + testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{6}), + }, nil), + }, + ) + + err := Prepare(c.proc, c.arg) + require.NoError(t, err) + cnt := 0 + end := false + for { + end, err = Call(0, c.proc, c.arg) + if end { + break + } + require.NoError(t, err) + result := c.proc.InputBatch() + if result != nil && len(result.Zs) != 0 { + cnt += result.Length() + require.Equal(t, 3, len(result.Vecs)) // 3 column + c.proc.InputBatch().Clean(c.proc.Mp) + } + } + require.Equal(t, 1, cnt) // 1 row + require.Equal(t, int64(0), mheap.Size(c.proc.Mp)) +} + +func newIntersectTestCase(proc *process.Process, leftBatches, rightBatches []*batch.Batch) intersectTestCase { + ctx, cancel := context.WithCancel(context.Background()) + proc.Reg.MergeReceivers = make([]*process.WaitRegister, 2) + { + c := make(chan *batch.Batch, len(leftBatches)+10) + for i := range leftBatches { + c <- leftBatches[i] + } + c <- nil + proc.Reg.MergeReceivers[0] = &process.WaitRegister{ + Ctx: ctx, + Ch: c, + } + } + { + c := make(chan *batch.Batch, len(rightBatches)+10) + for i := range rightBatches { + c <- rightBatches[i] + } + c <- nil + proc.Reg.MergeReceivers[1] = &process.WaitRegister{ + Ctx: ctx, + Ch: c, + } + } + arg := new(Argument) + return intersectTestCase{ + proc: proc, + arg: arg, + cancel: cancel, + } +} diff --git a/pkg/sql/colexec/intersect/types.go b/pkg/sql/colexec/intersect/types.go new file mode 100644 index 000000000..c24c33541 --- /dev/null +++ b/pkg/sql/colexec/intersect/types.go @@ -0,0 +1,51 @@ +// Copyright 2022 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intersect + +import ( + "github.com/matrixorigin/matrixone/pkg/common/hashmap" + "github.com/matrixorigin/matrixone/pkg/container/batch" +) + +const ( + build = iota + probe + end +) + +type Argument struct { + ctr container + + // hash table bucket related information. + IBucket uint64 + NBucket uint64 +} + +type container struct { + // operator state + state int + + // cnt record for intersect + cnts [][]int64 + + // Hash table for checking duplicate data + hashTable *hashmap.StrHashMap + + // Result batch of intersec column execute operator + btc *batch.Batch + + // process bucket mark + inBuckets []uint8 +} diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go index c0019fab0..26b02d771 100644 --- a/pkg/sql/compile/compile.go +++ b/pkg/sql/compile/compile.go @@ -420,7 +420,7 @@ func (c *Compile) compilePlanScope(n *plan.Node, ns []*plan.Node) ([]*Scope, err } c.anal.curr = curr return c.compileSort(n, c.compileUnion(n, ss, children, ns)), nil - case plan.Node_MINUS: + case plan.Node_MINUS, plan.Node_INTERSECT: curr := c.anal.curr c.anal.curr = int(n.Children[0]) ss, err := c.compilePlanScope(ns[n.Children[0]], ns) @@ -564,6 +564,14 @@ func (c *Compile) compileMinusAndIntersect(n *plan.Node, ss []*Scope, children [ Arg: constructMinus(n, c.proc, i, len(rs)), } } + case plan.Node_INTERSECT: + for i := range rs { + rs[i].Instructions[0] = vm.Instruction{ + Op: vm.Intersect, + Idx: c.anal.curr, + Arg: constructIntersect(n, c.proc, i, len(rs)), + } + } } return []*Scope{c.newMergeScope(append(append(rs, left), right))} } diff --git a/pkg/sql/compile/operator.go b/pkg/sql/compile/operator.go index a39b9ec17..25a382408 100644 --- a/pkg/sql/compile/operator.go +++ b/pkg/sql/compile/operator.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrixorigin/matrixone/pkg/sql/colexec/anti" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/intersect" "github.com/matrixorigin/matrixone/pkg/sql/colexec/loopanti" "github.com/matrixorigin/matrixone/pkg/sql/colexec/minus" @@ -174,6 +175,16 @@ func dupInstruction(in vm.Instruction) vm.Instruction { } case *dispatch.Argument: case *connector.Argument: + case *minus.Argument: + rin.Arg = &minus.Argument{ + IBucket: arg.IBucket, + NBucket: arg.NBucket, + } + case *intersect.Argument: + rin.Arg = &intersect.Argument{ + IBucket: arg.IBucket, + NBucket: arg.NBucket, + } default: panic(errors.New(errno.SyntaxErrororAccessRuleViolation, fmt.Sprintf("Unsupport instruction %T\n", in.Arg))) } @@ -439,6 +450,13 @@ func constructMinus(n *plan.Node, proc *process.Process, ibucket, nbucket int) * } } +func constructIntersect(n *plan.Node, proc *process.Process, ibucket, nbucket int) *intersect.Argument { + return &intersect.Argument{ + IBucket: uint64(ibucket), + NBucket: uint64(nbucket), + } +} + func constructDispatch(all bool, regs []*process.WaitRegister) *dispatch.Argument { arg := new(dispatch.Argument) arg.All = all diff --git a/pkg/vm/overload.go b/pkg/vm/overload.go index 264488564..8d63dcb4d 100644 --- a/pkg/vm/overload.go +++ b/pkg/vm/overload.go @@ -16,7 +16,9 @@ package vm import ( "bytes" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/anti" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/intersect" "github.com/matrixorigin/matrixone/pkg/sql/colexec/loopanti" "github.com/matrixorigin/matrixone/pkg/sql/colexec/minus" @@ -91,8 +93,9 @@ var stringFunc = [...]func(any, *bytes.Buffer){ Insert: insert.String, Update: update.String, - Union: union.String, - Minus: minus.String, + Union: union.String, + Minus: minus.String, + Intersect: intersect.String, } var prepareFunc = [...]func(*process.Process, any) error{ @@ -130,8 +133,9 @@ var prepareFunc = [...]func(*process.Process, any) error{ Insert: insert.Prepare, Update: update.Prepare, - Union: union.Prepare, - Minus: minus.Prepare, + Union: union.Prepare, + Minus: minus.Prepare, + Intersect: intersect.Prepare, } var execFunc = [...]func(int, *process.Process, any) (bool, error){ @@ -169,6 +173,7 @@ var execFunc = [...]func(int, *process.Process, any) (bool, error){ Insert: insert.Call, Update: update.Call, - Union: union.Call, - Minus: minus.Call, + Union: union.Call, + Minus: minus.Call, + Intersect: intersect.Call, } diff --git a/pkg/vm/types.go b/pkg/vm/types.go index 324e51a0c..0c240b293 100644 --- a/pkg/vm/types.go +++ b/pkg/vm/types.go @@ -51,6 +51,7 @@ const ( Union Minus + Intersect ) // Instruction contains relational algebra diff --git a/test/cases/set/set_operator.test b/test/cases/set/set_operator.test new file mode 100644 index 000000000..51f724bf7 --- /dev/null +++ b/test/cases/set/set_operator.test @@ -0,0 +1,25 @@ +-- @suite +-- @setup + +drop table if exists t1; +drop table if exists t2; +create table t1 (a smallint, b bigint, c int); +insert into t1 values (1,2,3); +insert into t1 values (1,2,3); +insert into t1 values (3,4,5); +insert into t1 values (4,5,6); +insert into t1 values (4,5,6); +insert into t1 values (1,1,2); +create table t2 (a smallint, b bigint, c int); +insert into t2 values (1,2,3); +insert into t2 values (3,4,5); +insert into t2 values (1,2,1); + +-- @case +-- @desc:test for set operators +-- @label:bvt + +select * from t1 minus select * from t2; +select * from t1 intersect select * from t2; +select a, b from t1 minus select b, c from t2; +select a, b from t1 intersect select b, c from t2; diff --git a/test/result/set/set_operator.result b/test/result/set/set_operator.result new file mode 100644 index 000000000..0de44dd9d --- /dev/null +++ b/test/result/set/set_operator.result @@ -0,0 +1,29 @@ +drop table if exists t1; +drop table if exists t2; +create table t1 (a smallint, b bigint, c int); +insert into t1 values (1,2,3); +insert into t1 values (1,2,3); +insert into t1 values (3,4,5); +insert into t1 values (4,5,6); +insert into t1 values (4,5,6); +insert into t1 values (1,1,2); +create table t2 (a smallint, b bigint, c int); +insert into t2 values (1,2,3); +insert into t2 values (3,4,5); +insert into t2 values (1,2,1); +select * from t1 minus select * from t2; +a b c +1 1 2 +4 5 6 +select * from t1 intersect select * from t2; +a b c +1 2 3 +3 4 5 +select a, b from t1 minus select b, c from t2; +a b +3 4 +1 1 +1 2 +select a, b from t1 intersect select b, c from t2; +a b +4 5 -- GitLab