From 51a18beb35e6b3fdea546a627998d305f82a423c Mon Sep 17 00:00:00 2001 From: nnsgmsone <nnsmgsone@outlook.com> Date: Thu, 18 Aug 2022 11:16:55 +0800 Subject: [PATCH] Fix the bug of union (#4562) Approved by: @yingfeng --- pkg/sql/colexec/agg/agg.go | 6 ++++++ pkg/sql/colexec/agg/avg/avg.go | 6 ++++++ pkg/sql/colexec/group/group.go | 5 +++++ pkg/sql/compile/compile.go | 34 ++++++++++++---------------------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/pkg/sql/colexec/agg/agg.go b/pkg/sql/colexec/agg/agg.go index 922150c1f..73b6fccc7 100644 --- a/pkg/sql/colexec/agg/agg.go +++ b/pkg/sql/colexec/agg/agg.go @@ -150,11 +150,17 @@ func (a *UnaryAgg[T1, T2]) BatchFill(start int64, os []uint8, vps []uint64, zs [ if nsp.Any() { for i := range os { if !nsp.Contains(uint64(i) + uint64(start)) { + if vps[i] == 0 { + continue + } a.es[vps[i]-1] = false } } } else { for i := range os { + if vps[i] == 0 { + continue + } a.es[vps[i]-1] = false } } diff --git a/pkg/sql/colexec/agg/avg/avg.go b/pkg/sql/colexec/agg/avg/avg.go index f95265039..ddc68cd21 100644 --- a/pkg/sql/colexec/agg/avg/avg.go +++ b/pkg/sql/colexec/agg/avg/avg.go @@ -122,6 +122,9 @@ func (a *Decimal64Avg) BatchFill(rs, vs any, start, count int64, vps []uint64, z if nsp.Contains(uint64(i + start)) { continue } + if vps[i] == 0 { + continue + } j := vps[i] - 1 a.cnts[j] += zs[j] } @@ -173,6 +176,9 @@ func (a *Decimal128Avg) BatchFill(rs, vs any, start, count int64, vps []uint64, if nsp.Contains(uint64(i + start)) { continue } + if vps[i] == 0 { + continue + } j := vps[i] - 1 a.cnts[j] += zs[j] } diff --git a/pkg/sql/colexec/group/group.go b/pkg/sql/colexec/group/group.go index 444b77237..0cb9688bd 100644 --- a/pkg/sql/colexec/group/group.go +++ b/pkg/sql/colexec/group/group.go @@ -321,6 +321,7 @@ func (ctr *container) processHStr(bat *batch.Batch, proc *process.Process) error func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, hashRows uint64, proc *process.Process) error { cnt := 0 + valCnt := 0 copy(ctr.inserted[:n], ctr.zInserted[:n]) for k, v := range vals[:n] { if v == 0 { @@ -332,6 +333,7 @@ func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, h cnt++ ctr.bat.Zs = append(ctr.bat.Zs, 0) } + valCnt++ ai := int64(v) - 1 ctr.bat.Zs[ai] += bat.Zs[i+k] } @@ -347,6 +349,9 @@ func (ctr *container) batchFill(i int, n int, bat *batch.Batch, vals []uint64, h } } } + if valCnt == 0 { + return nil + } for j, ag := range ctr.bat.Aggs { err := ag.BatchFill(int64(i), ctr.inserted[:n], vals, bat.Zs, []*vector.Vector{ctr.aggVecs[j].vec}) if err != nil { diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go index 8e7f4e03c..55fbffb49 100644 --- a/pkg/sql/compile/compile.go +++ b/pkg/sql/compile/compile.go @@ -554,13 +554,14 @@ func (c *Compile) compileProjection(n *plan.Node, ss []*Scope) []*Scope { func (c *Compile) compileUnion(n *plan.Node, ss []*Scope, children []*Scope, ns []*plan.Node) []*Scope { ss = append(ss, children...) rs := c.newScopeList(validScopeCount(ss)) - regs := extraGroupRegisters(rs) + j := 0 for i := range ss { if !ss[i].IsEnd { ss[i].appendInstruction(vm.Instruction{ Op: vm.Dispatch, - Arg: constructDispatch(true, regs), + Arg: constructDispatch(true, extraRegisters(rs, j)), }) + j++ ss[i].IsEnd = true } } @@ -812,13 +813,14 @@ func (c *Compile) compileAgg(n *plan.Node, ss []*Scope, ns []*plan.Node) []*Scop func (c *Compile) compileGroup(n *plan.Node, ss []*Scope, ns []*plan.Node) []*Scope { rs := c.newScopeList(validScopeCount(ss)) - regs := extraGroupRegisters(rs) + j := 0 for i := range ss { if !ss[i].IsEnd { ss[i].appendInstruction(vm.Instruction{ Op: vm.Dispatch, - Arg: constructDispatch(true, regs), + Arg: constructDispatch(true, extraRegisters(rs, j)), }) + j++ ss[i].IsEnd = true } } @@ -890,12 +892,12 @@ func (c *Compile) newScopeListWithNode(mcpu, childrenCount int) []*Scope { func (c *Compile) newJoinScopeListWithBucket(rs, ss, children []*Scope) ([]*Scope, *Scope, *Scope) { left := c.newMergeScope(ss) right := c.newMergeScope(children) - leftRegs := extraJoinRegisters(rs, 0) + leftRegs := extraRegisters(rs, 0) left.appendInstruction(vm.Instruction{ Op: vm.Dispatch, Arg: constructDispatch(true, leftRegs), }) - rightRegs := extraJoinRegisters(rs, 1) + rightRegs := extraRegisters(rs, 1) right.appendInstruction(vm.Instruction{ Op: vm.Dispatch, Arg: constructDispatch(true, rightRegs), @@ -929,7 +931,7 @@ func (c *Compile) newJoinScopeList(ss []*Scope, children []*Scope) ([]*Scope, *S } chp.Instructions = append(chp.Instructions, vm.Instruction{ Op: vm.Dispatch, - Arg: constructDispatch(true, extraJoinRegisters(rs, 1)), + Arg: constructDispatch(true, extraRegisters(rs, 1)), }) return rs, chp } @@ -944,7 +946,7 @@ func (c *Compile) newLeftScope(s *Scope, ss []*Scope) *Scope { }) rs.appendInstruction(vm.Instruction{ Op: vm.Dispatch, - Arg: constructDispatch(false, extraJoinRegisters(ss, 0)), + Arg: constructDispatch(false, extraRegisters(ss, 0)), }) rs.IsEnd = true rs.Proc = process.NewWithAnalyze(s.Proc, c.ctx, 1, c.anal.Nodes()) @@ -963,7 +965,7 @@ func (c *Compile) newRightScope(s *Scope, ss []*Scope) *Scope { }) rs.appendInstruction(vm.Instruction{ Op: vm.Dispatch, - Arg: constructDispatch(true, extraJoinRegisters(ss, 1)), + Arg: constructDispatch(true, extraRegisters(ss, 1)), }) rs.IsEnd = true rs.Proc = process.NewWithAnalyze(s.Proc, c.ctx, 1, c.anal.Nodes()) @@ -1018,19 +1020,7 @@ func validScopeCount(ss []*Scope) int { return cnt } -func extraGroupRegisters(ss []*Scope) []*process.WaitRegister { - var regs []*process.WaitRegister - - for _, s := range ss { - if s.IsEnd { - continue - } - regs = append(regs, s.Proc.Reg.MergeReceivers...) - } - return regs -} - -func extraJoinRegisters(ss []*Scope, i int) []*process.WaitRegister { +func extraRegisters(ss []*Scope, i int) []*process.WaitRegister { regs := make([]*process.WaitRegister, 0, len(ss)) for _, s := range ss { if s.IsEnd { -- GitLab