From 9fcf0d95a7b9d5fae84ecf883b124f139be76d93 Mon Sep 17 00:00:00 2001
From: dongdongyang <47596332+dongdongyang33@users.noreply.github.com>
Date: Tue, 16 Aug 2022 10:48:28 +0800
Subject: [PATCH] implement intersect (#4485)

add intersect

Approved by: @nnsgmsone, @aressu1985
---
 pkg/sql/colexec/intersect/intersect.go      | 213 ++++++++++++++++++++
 pkg/sql/colexec/intersect/intersect_test.go | 129 ++++++++++++
 pkg/sql/colexec/intersect/types.go          |  51 +++++
 pkg/sql/compile/compile.go                  |  10 +-
 pkg/sql/compile/operator.go                 |  18 ++
 pkg/vm/overload.go                          |  17 +-
 pkg/vm/types.go                             |   1 +
 test/cases/set/set_operator.test            |  25 +++
 test/result/set/set_operator.result         |  29 +++
 9 files changed, 486 insertions(+), 7 deletions(-)
 create mode 100644 pkg/sql/colexec/intersect/intersect.go
 create mode 100644 pkg/sql/colexec/intersect/intersect_test.go
 create mode 100644 pkg/sql/colexec/intersect/types.go
 create mode 100644 test/cases/set/set_operator.test
 create mode 100644 test/result/set/set_operator.result

diff --git a/pkg/sql/colexec/intersect/intersect.go b/pkg/sql/colexec/intersect/intersect.go
new file mode 100644
index 000000000..bd8727298
--- /dev/null
+++ b/pkg/sql/colexec/intersect/intersect.go
@@ -0,0 +1,213 @@
+// Copyright 2022 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package intersect
+
+import (
+	"bytes"
+
+	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
+	"github.com/matrixorigin/matrixone/pkg/container/batch"
+	"github.com/matrixorigin/matrixone/pkg/container/vector"
+	"github.com/matrixorigin/matrixone/pkg/vm/process"
+)
+
+func String(_ any, buf *bytes.Buffer) {
+	buf.WriteString(" intersect ")
+}
+
+func Prepare(proc *process.Process, argument any) error {
+	var err error
+	arg := argument.(*Argument)
+	arg.ctr.btc = nil
+	arg.ctr.hashTable, err = hashmap.NewStrMap(true, arg.IBucket, arg.NBucket, proc.Mp)
+	if err != nil {
+		return err
+	}
+	arg.ctr.inBuckets = make([]uint8, hashmap.UnitLimit)
+	return nil
+}
+
+func Call(idx int, proc *process.Process, argument any) (bool, error) {
+	arg := argument.(*Argument)
+
+	analyze := proc.GetAnalyze(idx)
+	analyze.Start()
+	defer analyze.Stop()
+
+	for {
+		switch arg.ctr.state {
+		case build:
+			if err := arg.ctr.buildHashTable(proc, analyze, 1); err != nil {
+				arg.ctr.state = end
+				return true, err
+			}
+			arg.ctr.state = probe
+
+		case probe:
+			var err error
+			isLast := false
+			if isLast, err = arg.ctr.probeHashTable(proc, analyze, 0); err != nil {
+				arg.ctr.state = end
+				return true, err
+			}
+			if isLast {
+				arg.ctr.state = end
+				continue
+			}
+
+			return false, nil
+
+		case end:
+			arg.ctr.hashTable.Free()
+			arg.ctr.freeSels(proc)
+			proc.SetInputBatch(nil)
+			return true, nil
+		}
+	}
+}
+
+// build hash table
+func (c *container) buildHashTable(proc *process.Process, analyse process.Analyze, idx int) error {
+	for {
+		btc := <-proc.Reg.MergeReceivers[idx].Ch
+
+		// last batch of block
+		if btc == nil {
+			break
+		}
+
+		// empty batch
+		if btc.Length() == 0 {
+			continue
+		}
+
+		analyse.Input(btc)
+
+		cnt := btc.Length()
+		itr := c.hashTable.NewIterator()
+		for i := 0; i < cnt; i += hashmap.UnitLimit {
+			rowcnt := c.hashTable.GroupCount()
+
+			n := cnt - i
+			if n > hashmap.UnitLimit {
+				n = hashmap.UnitLimit
+			}
+
+			vs, zs, err := itr.Insert(i, n, btc.Vecs)
+			if err != nil {
+				btc.Clean(proc.Mp)
+				return err
+			}
+
+			for j, v := range vs {
+				if zs[j] == 0 {
+					continue
+				}
+
+				if v > rowcnt {
+					c.cnts = append(c.cnts, proc.GetMheap().GetSels())
+					c.cnts[v-1] = append(c.cnts[v-1], 1)
+					rowcnt++
+				}
+			}
+		}
+		btc.Clean(proc.Mp)
+	}
+	return nil
+}
+
+func (c *container) probeHashTable(proc *process.Process, analyze process.Analyze, idx int) (bool, error) {
+	btc := <-proc.Reg.MergeReceivers[idx].Ch
+
+	// last batch of block
+	if btc == nil {
+		return true, nil
+	}
+
+	// empty batch
+	if btc.Length() == 0 {
+		return false, nil
+	}
+
+	analyze.Input(btc)
+	defer btc.Clean(proc.Mp)
+
+	c.btc = batch.NewWithSize(len(btc.Vecs))
+	for i := range btc.Vecs {
+		c.btc.Vecs[i] = vector.New(btc.Vecs[i].Typ)
+	}
+	needInsert := make([]uint8, hashmap.UnitLimit)
+	resetsNeedInsert := make([]uint8, hashmap.UnitLimit)
+	cnt := btc.Length()
+	itr := c.hashTable.NewIterator()
+	for i := 0; i < cnt; i += hashmap.UnitLimit {
+		n := cnt - i
+		if n > hashmap.UnitLimit {
+			n = hashmap.UnitLimit
+		}
+
+		copy(c.inBuckets, hashmap.OneUInt8s)
+		copy(needInsert, resetsNeedInsert)
+		insertcnt := 0
+
+		vs, zs := itr.Find(i, n, btc.Vecs, c.inBuckets)
+
+		for j, v := range vs {
+			// not in the processed bucket
+			if c.inBuckets[j] == 0 {
+				continue
+			}
+
+			// null value
+			if zs[j] == 0 {
+				continue
+			}
+
+			// not found
+			if v == 0 {
+				continue
+			}
+
+			// has been added into output batch
+			if c.cnts[v-1][0] == 0 {
+				continue
+			}
+
+			needInsert[j] = 1
+			c.cnts[v-1][0] = 0
+			c.btc.Zs = append(c.btc.Zs, 1)
+			insertcnt++
+		}
+
+		if insertcnt > 0 {
+			for pos := range btc.Vecs {
+				if err := vector.UnionBatch(c.btc.Vecs[pos], btc.Vecs[pos], int64(i), insertcnt, needInsert, proc.Mp); err != nil {
+					return false, err
+				}
+			}
+		}
+	}
+
+	analyze.Output(c.btc)
+	proc.SetInputBatch(c.btc)
+	return false, nil
+}
+
+func (c *container) freeSels(proc *process.Process) {
+	for i := range c.cnts {
+		proc.GetMheap().PutSels(c.cnts[i])
+	}
+	c.cnts = nil
+}
diff --git a/pkg/sql/colexec/intersect/intersect_test.go b/pkg/sql/colexec/intersect/intersect_test.go
new file mode 100644
index 000000000..2850c86f2
--- /dev/null
+++ b/pkg/sql/colexec/intersect/intersect_test.go
@@ -0,0 +1,129 @@
+// Copyright 2022 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package intersect
+
+import (
+	"context"
+	"testing"
+
+	"github.com/matrixorigin/matrixone/pkg/container/batch"
+	"github.com/matrixorigin/matrixone/pkg/container/types"
+	"github.com/matrixorigin/matrixone/pkg/container/vector"
+	"github.com/matrixorigin/matrixone/pkg/testutil"
+	"github.com/matrixorigin/matrixone/pkg/vm/mheap"
+	"github.com/matrixorigin/matrixone/pkg/vm/process"
+	"github.com/stretchr/testify/require"
+)
+
+type intersectTestCase struct {
+	proc   *process.Process
+	arg    *Argument
+	cancel context.CancelFunc
+}
+
+func TestIntersect(t *testing.T) {
+	proc := testutil.NewProcess()
+	// [2 rows + 2 row, 3 columns] intersect [1 row + 1 rows, 3 columns]
+	/*
+		{1, 2, 3}	    {1, 2, 3}
+		{1, 2, 3} intersect {4, 5, 6} ==> {1, 2, 3}
+		{3, 4, 5}
+		{3, 4, 5}
+	*/
+	c := newIntersectTestCase(
+		proc,
+		[]*batch.Batch{
+			testutil.NewBatchWithVectors(
+				[]*vector.Vector{
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{1, 1}),
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{2, 2}),
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{3, 3}),
+				}, nil),
+			testutil.NewBatchWithVectors(
+				[]*vector.Vector{
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{3, 3}),
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{4, 4}),
+					testutil.NewVector(2, types.T_int64.ToType(), proc.Mp, false, []int64{5, 5}),
+				}, nil),
+		},
+		[]*batch.Batch{
+			testutil.NewBatchWithVectors(
+				[]*vector.Vector{
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{1}),
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{2}),
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{3}),
+				}, nil),
+			testutil.NewBatchWithVectors(
+				[]*vector.Vector{
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{4}),
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{5}),
+					testutil.NewVector(1, types.T_int64.ToType(), proc.Mp, false, []int64{6}),
+				}, nil),
+		},
+	)
+
+	err := Prepare(c.proc, c.arg)
+	require.NoError(t, err)
+	cnt := 0
+	end := false
+	for {
+		end, err = Call(0, c.proc, c.arg)
+		if end {
+			break
+		}
+		require.NoError(t, err)
+		result := c.proc.InputBatch()
+		if result != nil && len(result.Zs) != 0 {
+			cnt += result.Length()
+			require.Equal(t, 3, len(result.Vecs)) // 3 column
+			c.proc.InputBatch().Clean(c.proc.Mp)
+		}
+	}
+	require.Equal(t, 1, cnt) // 1 row
+	require.Equal(t, int64(0), mheap.Size(c.proc.Mp))
+}
+
+func newIntersectTestCase(proc *process.Process, leftBatches, rightBatches []*batch.Batch) intersectTestCase {
+	ctx, cancel := context.WithCancel(context.Background())
+	proc.Reg.MergeReceivers = make([]*process.WaitRegister, 2)
+	{
+		c := make(chan *batch.Batch, len(leftBatches)+10)
+		for i := range leftBatches {
+			c <- leftBatches[i]
+		}
+		c <- nil
+		proc.Reg.MergeReceivers[0] = &process.WaitRegister{
+			Ctx: ctx,
+			Ch:  c,
+		}
+	}
+	{
+		c := make(chan *batch.Batch, len(rightBatches)+10)
+		for i := range rightBatches {
+			c <- rightBatches[i]
+		}
+		c <- nil
+		proc.Reg.MergeReceivers[1] = &process.WaitRegister{
+			Ctx: ctx,
+			Ch:  c,
+		}
+	}
+	arg := new(Argument)
+	return intersectTestCase{
+		proc:   proc,
+		arg:    arg,
+		cancel: cancel,
+	}
+}
diff --git a/pkg/sql/colexec/intersect/types.go b/pkg/sql/colexec/intersect/types.go
new file mode 100644
index 000000000..c24c33541
--- /dev/null
+++ b/pkg/sql/colexec/intersect/types.go
@@ -0,0 +1,51 @@
+// Copyright 2022 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package intersect
+
+import (
+	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
+	"github.com/matrixorigin/matrixone/pkg/container/batch"
+)
+
+const (
+	build = iota
+	probe
+	end
+)
+
+type Argument struct {
+	ctr container
+
+	// hash table bucket related information.
+	IBucket uint64
+	NBucket uint64
+}
+
+type container struct {
+	// operator state
+	state int
+
+	// cnt record for intersect
+	cnts [][]int64
+
+	// Hash table for checking duplicate data
+	hashTable *hashmap.StrHashMap
+
+	// Result batch of intersec column execute operator
+	btc *batch.Batch
+
+	// process bucket mark
+	inBuckets []uint8
+}
diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go
index c0019fab0..26b02d771 100644
--- a/pkg/sql/compile/compile.go
+++ b/pkg/sql/compile/compile.go
@@ -420,7 +420,7 @@ func (c *Compile) compilePlanScope(n *plan.Node, ns []*plan.Node) ([]*Scope, err
 		}
 		c.anal.curr = curr
 		return c.compileSort(n, c.compileUnion(n, ss, children, ns)), nil
-	case plan.Node_MINUS:
+	case plan.Node_MINUS, plan.Node_INTERSECT:
 		curr := c.anal.curr
 		c.anal.curr = int(n.Children[0])
 		ss, err := c.compilePlanScope(ns[n.Children[0]], ns)
@@ -564,6 +564,14 @@ func (c *Compile) compileMinusAndIntersect(n *plan.Node, ss []*Scope, children [
 				Arg: constructMinus(n, c.proc, i, len(rs)),
 			}
 		}
+	case plan.Node_INTERSECT:
+		for i := range rs {
+			rs[i].Instructions[0] = vm.Instruction{
+				Op:  vm.Intersect,
+				Idx: c.anal.curr,
+				Arg: constructIntersect(n, c.proc, i, len(rs)),
+			}
+		}
 	}
 	return []*Scope{c.newMergeScope(append(append(rs, left), right))}
 }
diff --git a/pkg/sql/compile/operator.go b/pkg/sql/compile/operator.go
index a39b9ec17..25a382408 100644
--- a/pkg/sql/compile/operator.go
+++ b/pkg/sql/compile/operator.go
@@ -19,6 +19,7 @@ import (
 	"fmt"
 
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/anti"
+	"github.com/matrixorigin/matrixone/pkg/sql/colexec/intersect"
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/loopanti"
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/minus"
 
@@ -174,6 +175,16 @@ func dupInstruction(in vm.Instruction) vm.Instruction {
 		}
 	case *dispatch.Argument:
 	case *connector.Argument:
+	case *minus.Argument:
+		rin.Arg = &minus.Argument{
+			IBucket: arg.IBucket,
+			NBucket: arg.NBucket,
+		}
+	case *intersect.Argument:
+		rin.Arg = &intersect.Argument{
+			IBucket: arg.IBucket,
+			NBucket: arg.NBucket,
+		}
 	default:
 		panic(errors.New(errno.SyntaxErrororAccessRuleViolation, fmt.Sprintf("Unsupport instruction %T\n", in.Arg)))
 	}
@@ -439,6 +450,13 @@ func constructMinus(n *plan.Node, proc *process.Process, ibucket, nbucket int) *
 	}
 }
 
+func constructIntersect(n *plan.Node, proc *process.Process, ibucket, nbucket int) *intersect.Argument {
+	return &intersect.Argument{
+		IBucket: uint64(ibucket),
+		NBucket: uint64(nbucket),
+	}
+}
+
 func constructDispatch(all bool, regs []*process.WaitRegister) *dispatch.Argument {
 	arg := new(dispatch.Argument)
 	arg.All = all
diff --git a/pkg/vm/overload.go b/pkg/vm/overload.go
index 264488564..8d63dcb4d 100644
--- a/pkg/vm/overload.go
+++ b/pkg/vm/overload.go
@@ -16,7 +16,9 @@ package vm
 
 import (
 	"bytes"
+
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/anti"
+	"github.com/matrixorigin/matrixone/pkg/sql/colexec/intersect"
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/loopanti"
 	"github.com/matrixorigin/matrixone/pkg/sql/colexec/minus"
 
@@ -91,8 +93,9 @@ var stringFunc = [...]func(any, *bytes.Buffer){
 	Insert:   insert.String,
 	Update:   update.String,
 
-	Union: union.String,
-	Minus: minus.String,
+	Union:     union.String,
+	Minus:     minus.String,
+	Intersect: intersect.String,
 }
 
 var prepareFunc = [...]func(*process.Process, any) error{
@@ -130,8 +133,9 @@ var prepareFunc = [...]func(*process.Process, any) error{
 	Insert:   insert.Prepare,
 	Update:   update.Prepare,
 
-	Union: union.Prepare,
-	Minus: minus.Prepare,
+	Union:     union.Prepare,
+	Minus:     minus.Prepare,
+	Intersect: intersect.Prepare,
 }
 
 var execFunc = [...]func(int, *process.Process, any) (bool, error){
@@ -169,6 +173,7 @@ var execFunc = [...]func(int, *process.Process, any) (bool, error){
 	Insert:   insert.Call,
 	Update:   update.Call,
 
-	Union: union.Call,
-	Minus: minus.Call,
+	Union:     union.Call,
+	Minus:     minus.Call,
+	Intersect: intersect.Call,
 }
diff --git a/pkg/vm/types.go b/pkg/vm/types.go
index 324e51a0c..0c240b293 100644
--- a/pkg/vm/types.go
+++ b/pkg/vm/types.go
@@ -51,6 +51,7 @@ const (
 
 	Union
 	Minus
+	Intersect
 )
 
 // Instruction contains relational algebra
diff --git a/test/cases/set/set_operator.test b/test/cases/set/set_operator.test
new file mode 100644
index 000000000..51f724bf7
--- /dev/null
+++ b/test/cases/set/set_operator.test
@@ -0,0 +1,25 @@
+-- @suite
+-- @setup
+
+drop table if exists t1;
+drop table if exists t2;
+create table t1 (a smallint, b bigint, c int);
+insert into t1 values (1,2,3);
+insert into t1 values (1,2,3);
+insert into t1 values (3,4,5);
+insert into t1 values (4,5,6);
+insert into t1 values (4,5,6);
+insert into t1 values (1,1,2);
+create table t2 (a smallint, b bigint, c int);
+insert into t2 values (1,2,3);
+insert into t2 values (3,4,5);
+insert into t2 values (1,2,1);
+
+-- @case
+-- @desc:test for set operators
+-- @label:bvt
+
+select * from t1 minus select * from t2;
+select * from t1 intersect select * from t2;
+select a, b from t1 minus select b, c from t2;
+select a, b from t1 intersect select b, c from t2;
diff --git a/test/result/set/set_operator.result b/test/result/set/set_operator.result
new file mode 100644
index 000000000..0de44dd9d
--- /dev/null
+++ b/test/result/set/set_operator.result
@@ -0,0 +1,29 @@
+drop table if exists t1;
+drop table if exists t2;
+create table t1 (a smallint, b bigint, c int);
+insert into t1 values (1,2,3);
+insert into t1 values (1,2,3);
+insert into t1 values (3,4,5);
+insert into t1 values (4,5,6);
+insert into t1 values (4,5,6);
+insert into t1 values (1,1,2);
+create table t2 (a smallint, b bigint, c int);
+insert into t2 values (1,2,3);
+insert into t2 values (3,4,5);
+insert into t2 values (1,2,1);
+select * from t1 minus select * from t2;
+a	b	c
+1	1	2
+4	5	6
+select * from t1 intersect select * from t2;
+a	b	c
+1	2	3
+3	4	5
+select a, b from t1 minus select b, c from t2;
+a	b
+3	4
+1	1
+1	2
+select a, b from t1 intersect select b, c from t2;
+a	b
+4	5
-- 
GitLab