From 51a18beb35e6b3fdea546a627998d305f82a423c Mon Sep 17 00:00:00 2001
From: nnsgmsone <nnsmgsone@outlook.com>
Date: Thu, 18 Aug 2022 11:16:55 +0800
Subject: [PATCH] Fix the bug of union (#4562)

Approved by: @yingfeng
---
 pkg/sql/colexec/agg/agg.go     |  6 ++++++
 pkg/sql/colexec/agg/avg/avg.go |  6 ++++++
 pkg/sql/colexec/group/group.go |  5 +++++
 pkg/sql/compile/compile.go     | 34 ++++++++++++----------------------
 4 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/pkg/sql/colexec/agg/agg.go b/pkg/sql/colexec/agg/agg.go
index 922150c1f..73b6fccc7 100644
--- a/pkg/sql/colexec/agg/agg.go
+++ b/pkg/sql/colexec/agg/agg.go
@@ -150,11 +150,17 @@ func (a *UnaryAgg[T1, T2]) BatchFill(start int64, os []uint8, vps []uint64, zs [
 		if nsp.Any() {
 			for i := range os {
 				if !nsp.Contains(uint64(i) + uint64(start)) {
+					if vps[i] == 0 {
+						continue
+					}
 					a.es[vps[i]-1] = false
 				}
 			}
 		} else {
 			for i := range os {
+				if vps[i] == 0 {
+					continue
+				}
 				a.es[vps[i]-1] = false
 			}
 		}
diff --git a/pkg/sql/colexec/agg/avg/avg.go b/pkg/sql/colexec/agg/avg/avg.go
index f95265039..ddc68cd21 100644
--- a/pkg/sql/colexec/agg/avg/avg.go
+++ b/pkg/sql/colexec/agg/avg/avg.go
@@ -122,6 +122,9 @@ func (a *Decimal64Avg) BatchFill(rs, vs any, start, count int64, vps []uint64, z
 		if nsp.Contains(uint64(i + start)) {
 			continue
 		}
+		if vps[i] == 0 {
+			continue
+		}
 		j := vps[i] - 1
 		a.cnts[j] += zs[j]
 	}
@@ -173,6 +176,9 @@ func (a *Decimal128Avg) BatchFill(rs, vs any, start, count int64, vps []uint64,
 		if nsp.Contains(uint64(i + start)) {
 			continue
 		}
+		if vps[i] == 0 {
+			continue
+		}
 		j := vps[i] - 1
 		a.cnts[j] += zs[j]
 	}
diff --git a/pkg/sql/colexec/group/group.go b/pkg/sql/colexec/group/group.go
index 444b77237..0cb9688bd 100644
--- a/pkg/sql/colexec/group/group.go
+++ b/pkg/sql/colexec/group/group.go
@@ -321,6 +321,7 @@ func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error
 
 func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, hashRows uint64, proc *process.Process) error {
 	cnt := 0
+	valCnt := 0
 	copy(ctr.inserted[:n], ctr.zInserted[:n])
 	for k, v := range vals[:n] {
 		if v == 0 {
@@ -332,6 +333,7 @@ func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, h
 			cnt++
 			ctr.bat.Zs = append(ctr.bat.Zs, 0)
 		}
+		valCnt++
 		ai := int64(v) - 1
 		ctr.bat.Zs[ai] += bat.Zs[i+k]
 	}
@@ -347,6 +349,9 @@ func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, h
 			}
 		}
 	}
+	if valCnt == 0 {
+		return nil
+	}
 	for j, ag := range ctr.bat.Aggs {
 		err := ag.BatchFill(int64(i), ctr.inserted[:n], vals, bat.Zs, []*vector.Vector{ctr.aggVecs[j].vec})
 		if err != nil {
diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go
index 8e7f4e03c..55fbffb49 100644
--- a/pkg/sql/compile/compile.go
+++ b/pkg/sql/compile/compile.go
@@ -554,13 +554,14 @@ func (c *Compile) compileProjection(n *plan.Node, ss []*Scope) []*Scope {
 func (c *Compile) compileUnion(n *plan.Node, ss []*Scope, children []*Scope, ns []*plan.Node) []*Scope {
 	ss = append(ss, children...)
 	rs := c.newScopeList(validScopeCount(ss))
-	regs := extraGroupRegisters(rs)
+	j := 0
 	for i := range ss {
 		if !ss[i].IsEnd {
 			ss[i].appendInstruction(vm.Instruction{
 				Op:  vm.Dispatch,
-				Arg: constructDispatch(true, regs),
+				Arg: constructDispatch(true, extraRegisters(rs, j)),
 			})
+			j++
 			ss[i].IsEnd = true
 		}
 	}
@@ -812,13 +813,14 @@ func (c *Compile) compileAgg(n *plan.Node, ss []*Scope, ns []*plan.Node) []*Scop
 
 func (c *Compile) compileGroup(n *plan.Node, ss []*Scope, ns []*plan.Node) []*Scope {
 	rs := c.newScopeList(validScopeCount(ss))
-	regs := extraGroupRegisters(rs)
+	j := 0
 	for i := range ss {
 		if !ss[i].IsEnd {
 			ss[i].appendInstruction(vm.Instruction{
 				Op:  vm.Dispatch,
-				Arg: constructDispatch(true, regs),
+				Arg: constructDispatch(true, extraRegisters(rs, j)),
 			})
+			j++
 			ss[i].IsEnd = true
 		}
 	}
@@ -890,12 +892,12 @@ func (c *Compile) newScopeListWithNode(mcpu, childrenCount int) []*Scope {
 func (c *Compile) newJoinScopeListWithBucket(rs, ss, children []*Scope) ([]*Scope, *Scope, *Scope) {
 	left := c.newMergeScope(ss)
 	right := c.newMergeScope(children)
-	leftRegs := extraJoinRegisters(rs, 0)
+	leftRegs := extraRegisters(rs, 0)
 	left.appendInstruction(vm.Instruction{
 		Op:  vm.Dispatch,
 		Arg: constructDispatch(true, leftRegs),
 	})
-	rightRegs := extraJoinRegisters(rs, 1)
+	rightRegs := extraRegisters(rs, 1)
 	right.appendInstruction(vm.Instruction{
 		Op:  vm.Dispatch,
 		Arg: constructDispatch(true, rightRegs),
@@ -929,7 +931,7 @@ func (c *Compile) newJoinScopeList(ss []*Scope, children []*Scope) ([]*Scope, *S
 	}
 	chp.Instructions = append(chp.Instructions, vm.Instruction{
 		Op:  vm.Dispatch,
-		Arg: constructDispatch(true, extraJoinRegisters(rs, 1)),
+		Arg: constructDispatch(true, extraRegisters(rs, 1)),
 	})
 	return rs, chp
 }
@@ -944,7 +946,7 @@ func (c *Compile) newLeftScope(s *Scope, ss []*Scope) *Scope {
 	})
 	rs.appendInstruction(vm.Instruction{
 		Op:  vm.Dispatch,
-		Arg: constructDispatch(false, extraJoinRegisters(ss, 0)),
+		Arg: constructDispatch(false, extraRegisters(ss, 0)),
 	})
 	rs.IsEnd = true
 	rs.Proc = process.NewWithAnalyze(s.Proc, c.ctx, 1, c.anal.Nodes())
@@ -963,7 +965,7 @@ func (c *Compile) newRightScope(s *Scope, ss []*Scope) *Scope {
 	})
 	rs.appendInstruction(vm.Instruction{
 		Op:  vm.Dispatch,
-		Arg: constructDispatch(true, extraJoinRegisters(ss, 1)),
+		Arg: constructDispatch(true, extraRegisters(ss, 1)),
 	})
 	rs.IsEnd = true
 	rs.Proc = process.NewWithAnalyze(s.Proc, c.ctx, 1, c.anal.Nodes())
@@ -1018,19 +1020,7 @@ func validScopeCount(ss []*Scope) int {
 	return cnt
 }
 
-func extraGroupRegisters(ss []*Scope) []*process.WaitRegister {
-	var regs []*process.WaitRegister
-
-	for _, s := range ss {
-		if s.IsEnd {
-			continue
-		}
-		regs = append(regs, s.Proc.Reg.MergeReceivers...)
-	}
-	return regs
-}
-
-func extraJoinRegisters(ss []*Scope, i int) []*process.WaitRegister {
+func extraRegisters(ss []*Scope, i int) []*process.WaitRegister {
 	regs := make([]*process.WaitRegister, 0, len(ss))
 	for _, s := range ss {
 		if s.IsEnd {
-- 
GitLab