Skip to content
Snippets Groups Projects
Commit 4015d724 authored by neza2017's avatar neza2017 Committed by yefu.chen
Browse files

Merge operation


Signed-off-by: default avatarneza2017 <yefu.chen@zilliz.com>
parent 810be533
No related branches found
No related tags found
No related merge requests found
Showing
with 355 additions and 433 deletions
package main
import (
"context"
"crypto/md5"
"flag"
"fmt"
"log"
"math/rand"
"os"
"sync"
"sync/atomic"
"time"
"github.com/pivotal-golang/bytefmt"
"github.com/zilliztech/milvus-distributed/internal/storage"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
// Global variables
var durationSecs, threads, loops, numVersion, batchOpSize int
var valueSize uint64
var valueData []byte
var batchValueData [][]byte
var counter, totalKeyCount, keyNum int32
var endTime, setFinish, getFinish, deleteFinish time.Time
var totalKeys [][]byte
var logFileName = "benchmark.log"
var logFile *os.File
var store storagetype.Store
var wg sync.WaitGroup
func runSet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, 1)
key := []byte(fmt.Sprint("key", num))
for ver := 1; ver <= numVersion; ver++ {
atomic.AddInt32(&counter, 1)
err := store.PutRow(context.Background(), key, valueData, "empty", uint64(ver))
if err != nil {
log.Fatalf("Error setting key %s, %s", key, err.Error())
//atomic.AddInt32(&setCount, -1)
}
}
}
// Remember last done time
setFinish = time.Now()
wg.Done()
}
func runBatchSet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
keys := make([][]byte, batchOpSize)
versions := make([]uint64, batchOpSize)
batchSuffix := make([]string, batchOpSize)
for n := batchOpSize; n > 0; n-- {
keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
}
for ver := 1; ver <= numVersion; ver++ {
atomic.AddInt32(&counter, 1)
err := store.PutRows(context.Background(), keys, batchValueData, batchSuffix, versions)
if err != nil {
log.Fatalf("Error setting batch keys %s %s", keys, err.Error())
//atomic.AddInt32(&batchSetCount, -1)
}
}
}
setFinish = time.Now()
wg.Done()
}
func runGet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&counter, 1)
//num := atomic.AddInt32(&keyNum, 1)
//key := []byte(fmt.Sprint("key", num))
num = num % totalKeyCount
key := totalKeys[num]
_, err := store.GetRow(context.Background(), key, uint64(numVersion))
if err != nil {
log.Fatalf("Error getting key %s, %s", key, err.Error())
//atomic.AddInt32(&getCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func runBatchGet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
//keys := make([][]byte, batchOpSize)
//for n := batchOpSize; n > 0; n-- {
// keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
//}
end := num % totalKeyCount
if end < int32(batchOpSize) {
end = int32(batchOpSize)
}
start := end - int32(batchOpSize)
keys := totalKeys[start:end]
versions := make([]uint64, batchOpSize)
for i := range versions {
versions[i] = uint64(numVersion)
}
atomic.AddInt32(&counter, 1)
_, err := store.GetRows(context.Background(), keys, versions)
if err != nil {
log.Fatalf("Error getting key %s, %s", keys, err.Error())
//atomic.AddInt32(&batchGetCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func runDelete() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&counter, 1)
//num := atomic.AddInt32(&keyNum, 1)
//key := []byte(fmt.Sprint("key", num))
num = num % totalKeyCount
key := totalKeys[num]
err := store.DeleteRow(context.Background(), key, uint64(numVersion))
if err != nil {
log.Fatalf("Error getting key %s, %s", key, err.Error())
//atomic.AddInt32(&deleteCount, -1)
}
}
// Remember last done time
deleteFinish = time.Now()
wg.Done()
}
func runBatchDelete() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
//keys := make([][]byte, batchOpSize)
//for n := batchOpSize; n > 0; n-- {
// keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
//}
end := num % totalKeyCount
if end < int32(batchOpSize) {
end = int32(batchOpSize)
}
start := end - int32(batchOpSize)
keys := totalKeys[start:end]
atomic.AddInt32(&counter, 1)
versions := make([]uint64, batchOpSize)
for i := range versions {
versions[i] = uint64(numVersion)
}
err := store.DeleteRows(context.Background(), keys, versions)
if err != nil {
log.Fatalf("Error getting key %s, %s", keys, err.Error())
//atomic.AddInt32(&batchDeleteCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func main() {
// Parse command line
myflag := flag.NewFlagSet("myflag", flag.ExitOnError)
myflag.IntVar(&durationSecs, "d", 5, "Duration of each test in seconds")
myflag.IntVar(&threads, "t", 1, "Number of threads to run")
myflag.IntVar(&loops, "l", 1, "Number of times to repeat test")
var sizeArg string
var storeType string
myflag.StringVar(&sizeArg, "z", "1k", "Size of objects in bytes with postfix K, M, and G")
myflag.StringVar(&storeType, "s", "s3", "Storage type, tikv or minio or s3")
myflag.IntVar(&numVersion, "v", 1, "Max versions for each key")
myflag.IntVar(&batchOpSize, "b", 100, "Batch operation kv pair number")
if err := myflag.Parse(os.Args[1:]); err != nil {
os.Exit(1)
}
// Check the arguments
var err error
if valueSize, err = bytefmt.ToBytes(sizeArg); err != nil {
log.Fatalf("Invalid -z argument for object size: %v", err)
}
var option = storagetype.Option{TikvAddress: "localhost:2379", Type: storeType, BucketName: "zilliz-hz"}
store, err = storage.NewStore(context.Background(), option)
if err != nil {
log.Fatalf("Error when creating storage " + err.Error())
}
logFile, err = os.OpenFile(logFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0777)
if err != nil {
log.Fatalf("Prepare log file error, " + err.Error())
}
// Echo the parameters
log.Printf("Benchmark log will write to file %s\n", logFile.Name())
fmt.Fprintf(logFile, "Parameters: duration=%d, threads=%d, loops=%d, valueSize=%s, batchSize=%d, versions=%d\n", durationSecs, threads, loops, sizeArg, batchOpSize, numVersion)
// Init test data
valueData = make([]byte, valueSize)
rand.Read(valueData)
hasher := md5.New()
hasher.Write(valueData)
batchValueData = make([][]byte, batchOpSize)
for i := range batchValueData {
batchValueData[i] = make([]byte, valueSize)
rand.Read(batchValueData[i])
hasher := md5.New()
hasher.Write(batchValueData[i])
}
// Loop running the tests
for loop := 1; loop <= loops; loop++ {
// reset counters
counter = 0
keyNum = 0
totalKeyCount = 0
totalKeys = nil
// Run the batchSet case
// key seq start from setCount
counter = 0
startTime := time.Now()
endTime = startTime.Add(time.Second * time.Duration(durationSecs))
for n := 1; n <= threads; n++ {
wg.Add(1)
go runBatchSet()
}
wg.Wait()
setTime := setFinish.Sub(startTime).Seconds()
bps := float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
fmt.Fprintf(logFile, "Loop %d: BATCH PUT time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
loop, setTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/setTime, float64(counter*int32(batchOpSize))/setTime)
// Record all test keys
//totalKeyCount = keyNum
//totalKeys = make([][]byte, totalKeyCount)
//for i := int32(0); i < totalKeyCount; i++ {
// totalKeys[i] = []byte(fmt.Sprint("key", i))
//}
//
//// Run the get case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runGet()
//}
//wg.Wait()
//
//getTime := getFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize) / getTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: GET time %.1f secs, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
// loop, getTime, counter, bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter)/getTime))
// Run the batchGet case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runBatchGet()
//}
//wg.Wait()
//
//getTime = getFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / getTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH GET time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
// loop, getTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter * int32(batchOpSize))/getTime))
//
//// Run the delete case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runDelete()
//}
//wg.Wait()
//
//deleteTime := deleteFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize) / deleteTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: Delete time %.1f secs, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
// loop, deleteTime, counter, float64(counter)/deleteTime, float64(counter)/deleteTime))
//
//// Run the batchDelete case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runBatchDelete()
//}
//wg.Wait()
//
//deleteTime = setFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH DELETE time %.1f secs, batchs = %d, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
// loop, setTime, counter, counter*int32(batchOpSize), float64(counter)/setTime, float64(counter * int32(batchOpSize))/setTime))
// Print line mark
lineMark := "\n"
fmt.Fprint(logFile, lineMark)
}
log.Print("Benchmark test done.")
}
......@@ -36,6 +36,14 @@ services:
networks:
- milvus
jaeger:
image: jaegertracing/all-in-one:latest
ports:
- "6831:6831/udp"
- "16686:16686"
networks:
- milvus
networks:
milvus:
......
......@@ -83,5 +83,10 @@ services:
networks:
- milvus
jaeger:
image: jaegertracing/all-in-one:latest
networks:
- milvus
networks:
milvus:
......@@ -4,14 +4,17 @@ go 1.15
require (
code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 // indirect
github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect
github.com/apache/pulsar-client-go v0.1.1
github.com/aws/aws-sdk-go v1.30.8
github.com/apache/thrift v0.13.0
github.com/aws/aws-sdk-go v1.30.8 // indirect
github.com/coreos/etcd v3.3.25+incompatible // indirect
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect
github.com/frankban/quicktest v1.10.2 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/git-hooks/git-hooks v1.3.1 // indirect
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2
github.com/google/btree v1.0.0
github.com/klauspost/compress v1.10.11 // indirect
......@@ -20,12 +23,12 @@ require (
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo v1.12.1 // indirect
github.com/onsi/gomega v1.10.0 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/opentracing/opentracing-go v1.2.0
github.com/pierrec/lz4 v2.5.2+incompatible // indirect
github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 // indirect
github.com/pingcap/errors v0.11.4 // indirect
github.com/pingcap/log v0.0.0-20200828042413-fce0951f1463 // indirect
github.com/pivotal-golang/bytefmt v0.0.0-20200131002437-cf55d5288a48
github.com/pivotal-golang/bytefmt v0.0.0-20200131002437-cf55d5288a48 // indirect
github.com/prometheus/client_golang v1.5.1 // indirect
github.com/prometheus/common v0.10.0 // indirect
github.com/prometheus/procfs v0.1.3 // indirect
......@@ -35,7 +38,9 @@ require (
github.com/spf13/cast v1.3.0
github.com/spf13/viper v1.7.1
github.com/stretchr/testify v1.6.1
github.com/tikv/client-go v0.0.0-20200824032810-95774393107b
github.com/tikv/client-go v0.0.0-20200824032810-95774393107b // indirect
github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.4.0+incompatible // indirect
github.com/urfave/cli v1.22.5 // indirect
github.com/yahoo/athenz v1.9.16 // indirect
go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738
......@@ -50,7 +55,7 @@ require (
google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150 // indirect
google.golang.org/grpc v1.31.0
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v2 v2.3.0
gopkg.in/yaml.v2 v2.3.0 // indirect
honnef.co/go/tools v0.0.1-2020.1.4 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
)
......
......@@ -15,6 +15,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/HdrHistogram/hdrhistogram-go v1.0.1 h1:GX8GAYDuhlFQnI2fRDHQhTlkHMz8bEn0jTI6LJU0mpw=
github.com/HdrHistogram/hdrhistogram-go v1.0.1/go.mod h1:BWJ+nMSHY3L41Zj7CA3uXnloDp7xxV0YvstAE7nKTaM=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
......@@ -24,6 +26,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1C
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/apache/pulsar-client-go v0.1.1 h1:v/kU+2ZCC6yFIcbZrFtWa9/nvVzVr18L+xYJUvZSxEQ=
github.com/apache/pulsar-client-go v0.1.1/go.mod h1:mlxC65KL1BLhGO2bnT9zWMttVzR2czVPb27D477YpyU=
github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI=
github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/ardielle/ardielle-go v1.5.2 h1:TilHTpHIQJ27R1Tl/iITBzMwiUGSlVfiVhwDNGM3Zj4=
github.com/ardielle/ardielle-go v1.5.2/go.mod h1:I4hy1n795cUhaVt/ojz83SNVCYIGsAFAONtv2Dr7HUI=
github.com/ardielle/ardielle-tools v1.5.4/go.mod h1:oZN+JRMnqGiIhrzkRN9l26Cej9dEx4jeNG6A+AdkShk=
......@@ -117,6 +121,7 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
github.com/golang/protobuf v0.0.0-20180814211427-aa810b61a9c7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
......@@ -343,6 +348,7 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx
github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/protocolbuffers/protobuf v3.14.0+incompatible h1:8r0H76h/Q/lEnFFY60AuM23NOnaDMi6bd7zuboSYM+o=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
......@@ -403,6 +409,12 @@ github.com/tikv/client-go v0.0.0-20200824032810-95774393107b/go.mod h1:K0NcdVNrX
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/uber/jaeger-client-go v1.6.0 h1:3+zLlq+4npI5fg8IsgAje3YsP7TcEdNzJScyqFIzxEQ=
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/uber/jaeger-lib v1.5.0 h1:OHbgr8l656Ub3Fw5k9SWnBfIEwvoHQ+W2y+Aa9D1Uyo=
github.com/uber/jaeger-lib v2.4.0+incompatible h1:fY7QsGQWiCt8pajv4r7JEvmATdCVaWxXbjwyYwsNaLQ=
github.com/uber/jaeger-lib v2.4.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
github.com/ugorji/go v1.1.2/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ=
github.com/ugorji/go/codec v0.0.0-20190204201341-e444a5086c43/go.mod h1:iT03XoTwV7xq/+UGwKO3UbC1nNNlopQiY61beSdrtOA=
github.com/unrolled/render v1.0.0 h1:XYtvhA3UkpB7PqkvhUFYmpKD55OudoIeygcfus4vcd4=
......
......@@ -18,7 +18,7 @@
namespace milvus {
inline int
field_sizeof(DataType data_type, int dim = 1) {
datatype_sizeof(DataType data_type, int dim = 1) {
switch (data_type) {
case DataType::BOOL:
return sizeof(bool);
......@@ -78,7 +78,7 @@ datatype_name(DataType data_type) {
}
inline bool
field_is_vector(DataType datatype) {
datatype_is_vector(DataType datatype) {
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;
}
......@@ -119,9 +119,9 @@ struct FieldMeta {
int
get_sizeof() const {
if (is_vector()) {
return field_sizeof(type_, get_dim());
return datatype_sizeof(type_, get_dim());
} else {
return field_sizeof(type_, 1);
return datatype_sizeof(type_, 1);
}
}
......
......@@ -50,7 +50,7 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
schema->primary_key_offset_opt_ = schema->size();
}
if (field_is_vector(data_type)) {
if (datatype_is_vector(data_type)) {
auto type_map = RepeatedKeyValToMap(child.type_params());
auto index_map = RepeatedKeyValToMap(child.index_params());
if (!index_map.count("metric_type")) {
......
......@@ -14,13 +14,15 @@
#include <faiss/MetricType.h>
#include <string>
#include <boost/align/aligned_allocator.hpp>
#include <memory>
#include <vector>
namespace milvus {
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
using engine::DataType;
using engine::FieldElementType;
using engine::QueryResult;
using engine::idx_t;
using MetricType = faiss::MetricType;
MetricType
......@@ -39,4 +41,33 @@ constexpr std::false_type always_false{};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
///////////////////////////////////////////////////////////////////////////////////////////////////
struct QueryResult {
QueryResult() = default;
QueryResult(uint64_t num_queries, uint64_t topK) : topK_(topK), num_queries_(num_queries) {
auto count = get_row_count();
result_distances_.resize(count);
internal_seg_offsets_.resize(count);
}
[[nodiscard]] uint64_t
get_row_count() const {
return topK_ * num_queries_;
}
public:
uint64_t num_queries_;
uint64_t topK_;
uint64_t seg_id_;
std::vector<float> result_distances_;
public:
// TODO(gexi): utilize these field
std::vector<int64_t> internal_seg_offsets_;
std::vector<int64_t> result_offsets_;
std::vector<std::vector<char>> row_data_;
};
using QueryResultPtr = std::shared_ptr<QueryResult>;
} // namespace milvus
......@@ -55,6 +55,7 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con
}
auto stoi_closure = [](const std::string& s) -> int { return std::stoi(s); };
auto stof_closure = [](const std::string& s) -> int { return std::stof(s); };
/***************************** meta *******************************/
check_parameter<int>(conf, milvus::knowhere::meta::DIM, stoi_closure, std::nullopt);
......@@ -88,7 +89,7 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con
check_parameter<int>(conf, milvus::knowhere::IndexParams::edge_size, stoi_closure, std::nullopt);
/************************** NGT Search Params *****************************/
check_parameter<int>(conf, milvus::knowhere::IndexParams::epsilon, stoi_closure, std::nullopt);
check_parameter<float>(conf, milvus::knowhere::IndexParams::epsilon, stof_closure, std::nullopt);
check_parameter<int>(conf, milvus::knowhere::IndexParams::max_search_edges, stoi_closure, std::nullopt);
/************************** NGT_PANNG Params *****************************/
......@@ -274,6 +275,12 @@ IndexWrapper::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* se
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) {
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::call_once(raw_data_loaded_, load_raw_data_closure);
}
auto res = index_->Query(dataset, conf, nullptr);
auto ids = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto distances = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
......@@ -291,5 +298,19 @@ IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Con
return std::move(query_res);
}
void
IndexWrapper::LoadRawData() {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto bs = index_->Serialize(config_);
auto bptr = std::make_shared<milvus::knowhere::Binary>();
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
bptr->size = raw_data_.size();
bs.Append(RAW_DATA, bptr);
index_->Load(bs);
}
}
} // namespace indexbuilder
} // namespace milvus
......@@ -66,6 +66,9 @@ class IndexWrapper {
void
StoreRawData(const knowhere::DatasetPtr& dataset);
void
LoadRawData();
template <typename T>
void
check_parameter(knowhere::Config& conf,
......@@ -92,6 +95,7 @@ class IndexWrapper {
milvus::json index_config_;
knowhere::Config config_;
std::vector<uint8_t> raw_data_;
std::once_flag raw_data_loaded_;
};
} // namespace indexbuilder
......
......@@ -4,13 +4,16 @@ set(MILVUS_QUERY_SRCS
generated/PlanNode.cpp
generated/Expr.cpp
visitors/ShowPlanNodeVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
visitors/ShowExprVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
visitors/ExecExprVisitor.cpp
visitors/VerifyPlanNodeVisitor.cpp
visitors/VerifyExprVisitor.cpp
Plan.cpp
Search.cpp
SearchOnSealed.cpp
BruteForceSearch.cpp
SearchBruteForce.cpp
SubQueryResult.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query milvus_proto milvus_utils)
target_link_libraries(milvus_query milvus_proto milvus_utils knowhere)
......@@ -21,6 +21,7 @@
#include <boost/align/aligned_allocator.hpp>
#include <boost/algorithm/string.hpp>
#include <algorithm>
#include "query/generated/VerifyPlanNodeVisitor.h"
namespace milvus::query {
......@@ -106,7 +107,7 @@ Parser::ParseRangeNode(const Json& out_body) {
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
Assert(!datatype_is_vector(data_type));
switch (data_type) {
case DataType::BOOL:
......@@ -138,6 +139,8 @@ Parser::CreatePlanImpl(const std::string& dsl_str) {
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
VerifyPlanNodeVisitor verifier;
vec_node->accept(verifier);
auto plan = std::make_unique<Plan>(schema);
plan->tag2field_ = std::move(tag2field_);
......@@ -152,7 +155,7 @@ Parser::ParseTermNode(const Json& out_body) {
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
Assert(!datatype_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
return ParseTermNodeImpl<bool>(field_name, body);
......
......@@ -16,7 +16,7 @@
#include <faiss/utils/distances.h>
#include "utils/tools.h"
#include "query/BruteForceSearch.h"
#include "query/SearchBruteForce.h"
namespace milvus::query {
......@@ -34,13 +34,13 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
}
Status
QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
FloatSearch(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
......@@ -75,6 +75,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const auto& indexing_entry = indexing_record.get_vec_entry(vecfield_offset);
auto search_conf = indexing_entry.get_search_conf(topK);
// TODO: use sub_qr
for (int chunk_id = 0; chunk_id < max_indexed_id; ++chunk_id) {
auto indexing = indexing_entry.get_vec_indexing(chunk_id);
auto dataset = knowhere::GenDataset(num_queries, dim, query_data);
......@@ -99,10 +100,12 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
Assert(vec_chunk_size == indexing_entry.get_chunk_size());
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
// TODO: use sub_qr
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
// should be not visitable
faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto& chunk = vec_ptr->get_chunk(chunk_id);
......@@ -112,6 +115,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto nsize = element_end - element_begin;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
// TODO: make it wrapped
faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view);
Assert(buf_uids.size() == total_count);
......@@ -134,13 +138,13 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
}
Status
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
BinarySearch(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
......@@ -169,8 +173,8 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto total_count = topK * num_queries;
// step 3: small indexing search
std::vector<int64_t> final_uids(total_count, -1);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
// TODO: this is too intrusive
// TODO: use QuerySubResult instead
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data};
using segcore::BinaryVector;
......@@ -181,30 +185,27 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
auto vec_chunk_size = vec_ptr->get_chunk_size();
auto max_chunk = upper_div(ins_barrier, vec_chunk_size);
SubQueryResult final_result(num_queries, topK, metric_type);
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
auto& chunk = vec_ptr->get_chunk(chunk_id);
auto element_begin = chunk_id * vec_chunk_size;
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size);
auto nsize = element_end - element_begin;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
BinarySearchBruteForce(query_dataset, chunk.data(), nsize, buf_dis.data(), buf_uids.data(), bitmap_view);
auto sub_result = BinarySearchBruteForce(query_dataset, chunk.data(), nsize, bitmap_view);
// convert chunk uid to segment uid
for (auto& x : buf_uids) {
for (auto& x : sub_result.mutable_labels()) {
if (x != -1) {
x += chunk_id * vec_chunk_size;
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
final_result.merge(sub_result);
}
results.result_distances_ = std::move(final_dis);
results.internal_seg_offsets_ = std::move(final_uids);
results.result_distances_ = std::move(final_result.mutable_values());
results.internal_seg_offsets_ = std::move(final_result.mutable_labels());
results.topK_ = topK;
results.num_queries_ = num_queries;
......
......@@ -14,27 +14,29 @@
#include "segcore/SegmentSmallIndex.h"
#include <deque>
#include <boost/dynamic_bitset.hpp>
#include "query/SubQueryResult.h"
namespace milvus::query {
using BitmapChunk = boost::dynamic_bitset<>;
using BitmapSimple = std::deque<BitmapChunk>;
// TODO: merge these two search into one
// note: c++17 don't support optional ref
Status
QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmap_opt,
QueryResult& results);
FloatSearch(const segcore::SegmentSmallIndex& segment,
const QueryInfo& info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmap_opt,
QueryResult& results);
Status
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results);
BinarySearch(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results);
} // namespace milvus::query
......@@ -9,58 +9,16 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "BruteForceSearch.h"
#include "SearchBruteForce.h"
#include <vector>
#include <common/Types.h>
#include <boost/dynamic_bitset.hpp>
#include <queue>
#include "SubQueryResult.h"
namespace milvus::query {
void
BinarySearchBruteForceNaive(MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
// THIS IS A NAIVE IMPLEMENTATION, ready for optimize
Assert(metric_type == faiss::METRIC_Jaccard);
Assert(code_size % 4 == 0);
using T = std::tuple<float, int>;
for (int64_t q = 0; q < num_queries; ++q) {
auto query_ptr = query_data + code_size * q;
auto query = boost::dynamic_bitset(query_ptr, query_ptr + code_size);
std::vector<T> max_heap(topk + 1, std::make_tuple(std::numeric_limits<float>::max(), -1));
for (int64_t i = 0; i < chunk_size; ++i) {
auto element_ptr = binary_chunk + code_size * i;
auto element = boost::dynamic_bitset(element_ptr, element_ptr + code_size);
auto the_and = (query & element).count();
auto the_or = (query | element).count();
auto distance = the_or ? (float)(the_or - the_and) / the_or : 0;
if (distance < std::get<0>(max_heap[0])) {
max_heap[topk] = std::make_tuple(distance, i);
std::push_heap(max_heap.begin(), max_heap.end());
std::pop_heap(max_heap.begin(), max_heap.end());
}
}
std::sort(max_heap.begin(), max_heap.end());
for (int k = 0; k < topk; ++k) {
auto info = max_heap[k];
result_distances[k + q * topk] = std::get<0>(info);
result_labels[k + q * topk] = std::get<1>(info);
}
}
}
void
SubQueryResult
BinarySearchBruteForceFast(MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
......@@ -68,9 +26,11 @@ BinarySearchBruteForceFast(MetricType metric_type,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
const faiss::BitsetView& bitset) {
SubQueryResult sub_result(num_queries, topk, metric_type);
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
const idx_t block_size = chunk_size;
bool use_heap = true;
......@@ -132,18 +92,26 @@ BinarySearchBruteForceFast(MetricType metric_type,
} else {
PanicInfo("Unsupported metric type");
}
return sub_result;
}
void
FloatSearchBruteForceFast(MetricType metric_type,
const float* chunk_data,
int64_t chunk_size,
float* result_distances,
idx_t* result_labels,
const faiss::BitsetView& bitset) {
// TODO
}
SubQueryResult
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
const uint8_t* binary_chunk,
int64_t chunk_size,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
const faiss::BitsetView& bitset) {
// TODO: refactor the internal function
BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data,
result_distances, result_labels, bitset);
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data, bitset);
}
} // namespace milvus::query
......@@ -13,6 +13,7 @@
#include <faiss/utils/BinaryDistance.h>
#include "segcore/ConcurrentVector.h"
#include "common/Schema.h"
#include "query/SubQueryResult.h"
namespace milvus::query {
using MetricType = faiss::MetricType;
......@@ -28,12 +29,10 @@ struct BinaryQueryDataset {
} // namespace dataset
void
SubQueryResult
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
const uint8_t* binary_chunk,
int64_t chunk_size,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset = nullptr);
const faiss::BitsetView& bitset = nullptr);
} // namespace milvus::query
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "utils/EasyAssert.h"
#include "query/SubQueryResult.h"
#include "segcore/Reduce.h"
namespace milvus::query {
template <bool is_desc>
void
SubQueryResult::merge_impl(const SubQueryResult& right) {
Assert(num_queries_ == right.num_queries_);
Assert(topk_ == right.topk_);
Assert(metric_type_ == right.metric_type_);
Assert(is_desc == is_descending(metric_type_));
for (int64_t qn = 0; qn < num_queries_; ++qn) {
auto offset = qn * topk_;
int64_t* __restrict__ left_labels = this->get_labels() + offset;
float* __restrict__ left_values = this->get_values() + offset;
auto right_labels = right.get_labels() + offset;
auto right_values = right.get_values() + offset;
std::vector<float> buf_values(topk_);
std::vector<int64_t> buf_labels(topk_);
auto lit = 0; // left iter
auto rit = 0; // right iter
for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) {
auto left_v = left_values[lit];
auto right_v = right_values[rit];
// optimize out at compiling
if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) {
buf_values[buf_iter] = left_values[lit];
buf_labels[buf_iter] = left_labels[lit];
++lit;
} else {
buf_values[buf_iter] = right_values[rit];
buf_labels[buf_iter] = right_labels[rit];
++rit;
}
}
std::copy_n(buf_values.data(), topk_, left_values);
std::copy_n(buf_labels.data(), topk_, left_labels);
}
}
void
SubQueryResult::merge(const SubQueryResult& sub_result) {
Assert(metric_type_ == sub_result.metric_type_);
if (is_descending(metric_type_)) {
this->merge_impl<true>(sub_result);
} else {
this->merge_impl<false>(sub_result);
}
}
SubQueryResult
SubQueryResult::merge(const SubQueryResult& left, const SubQueryResult& right) {
auto left_copy = left;
left_copy.merge(right);
return left_copy;
}
} // namespace milvus::query
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/Types.h"
#include <limits>
#include <vector>
namespace milvus::query {
class SubQueryResult {
public:
SubQueryResult(int64_t num_queries, int64_t topk, MetricType metric_type)
: metric_type_(metric_type),
num_queries_(num_queries),
topk_(topk),
labels_(num_queries * topk, -1),
values_(num_queries * topk, init_value(metric_type)) {
}
public:
static constexpr float
init_value(MetricType metric_type) {
return (is_descending(metric_type) ? -1 : 1) * std::numeric_limits<float>::max();
}
static constexpr bool
is_descending(MetricType metric_type) {
// TODO
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
return true;
} else {
return false;
}
}
public:
int64_t
get_num_queries() const {
return num_queries_;
}
int64_t
get_topk() const {
return topk_;
}
const int64_t*
get_labels() const {
return labels_.data();
}
int64_t*
get_labels() {
return labels_.data();
}
const float*
get_values() const {
return values_.data();
}
float*
get_values() {
return values_.data();
}
auto&
mutable_labels() {
return labels_;
}
auto&
mutable_values() {
return values_;
}
static SubQueryResult
merge(const SubQueryResult& left, const SubQueryResult& right);
void
merge(const SubQueryResult& sub_result);
private:
template <bool is_desc>
void
merge_impl(const SubQueryResult& sub_result);
private:
int64_t num_queries_;
int64_t topk_;
MetricType metric_type_;
std::vector<int64_t> labels_;
std::vector<float> values_;
};
} // namespace milvus::query
......@@ -21,7 +21,7 @@
#include "ExprVisitor.h"
namespace milvus::query {
class ExecExprVisitor : ExprVisitor {
class ExecExprVisitor : public ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;
......
......@@ -19,7 +19,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ExecPlanNodeVisitor : PlanNodeVisitor {
class ExecPlanNodeVisitor : public PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment