diff --git a/.devcontainer.json b/.devcontainer.json index fff529de85e22edb4389a9a156e547576f704c11..565eb008c71d9b8ab7b335f0e44061023ae5c2e2 100644 --- a/.devcontainer.json +++ b/.devcontainer.json @@ -2,7 +2,7 @@ "name": "Milvus Distributed Dev Container Definition", "dockerComposeFile": ["./docker-compose-vscode.yml"], "service": "ubuntu", - "initializeCommand": "scripts/init_devcontainer.sh && docker-compose -f docker-compose-vscode.yml down || true && docker-compose -f docker-compose-vscode.yml pull --ignore-pull-failures ubuntu", + "initializeCommand": "scripts/init_devcontainer.sh && docker-compose -f docker-compose-vscode.yml down all -v || true && docker-compose -f docker-compose-vscode.yml pull --ignore-pull-failures ubuntu", "workspaceFolder": "/go/src/github.com/zilliztech/milvus-distributed", "shutdownAction": "stopCompose", "extensions": [ diff --git a/Makefile b/Makefile index 48f90b3fbfa858053f5357abc6723f3e0746fcf3..1ae87e6930a5061ce356517b24aa7e61d7032e58 100644 --- a/Makefile +++ b/Makefile @@ -127,7 +127,6 @@ install: all @mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/master $(GOPATH)/bin/master @mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/proxy $(GOPATH)/bin/proxy @mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/writenode $(GOPATH)/bin/writenode - @mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/indexbuilder $(GOPATH)/bin/indexbuilder @mkdir -p $(LIBRARY_PATH) && cp -f $(PWD)/internal/core/output/lib/* $(LIBRARY_PATH) @echo "Installation successful." @@ -135,10 +134,7 @@ clean: @echo "Cleaning up all the generated files" @find . -name '*.test' | xargs rm -fv @find . -name '*~' | xargs rm -fv - @rm -rf bin/ - @rm -rf lib/ - @rm -rf $(GOPATH)/bin/master - @rm -rf $(GOPATH)/bin/proxy - @rm -rf $(GOPATH)/bin/querynode - @rm -rf $(GOPATH)/bin/writenode - @rm -rf $(GOPATH)/bin/indexbuilder + @rm -rvf querynode + @rm -rvf master + @rm -rvf proxy + @rm -rvf writenode diff --git a/deployments/docker/docker-compose.yml b/deployments/docker/docker-compose.yml index 0ae708a19ecb9ababafe5fcdb6bd5f9d5eac529e..60bf5d9fff1f2a0aa47274becc742f778aaeb77a 100644 --- a/deployments/docker/docker-compose.yml +++ b/deployments/docker/docker-compose.yml @@ -36,14 +36,6 @@ services: networks: - milvus - jaeger: - image: jaegertracing/all-in-one:latest - ports: - - "6831:6831/udp" - - "16686:16686" - networks: - - milvus - networks: milvus: diff --git a/docker-compose.yml b/docker-compose.yml index 4f7f32bcd7030a3d7826f101f67095e26c3e55a7..d78e40515124ab50ed6fac5d1958d2a1bd8cfbb7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -86,10 +86,5 @@ services: networks: - milvus - jaeger: - image: jaegertracing/all-in-one:latest - networks: - - milvus - networks: milvus: diff --git a/go.mod b/go.mod index bb426c8ba01c5a6bcfdd5778e54e5f1aba3f7a79..47afde2bff278667721c394a64396793788e92f2 100644 --- a/go.mod +++ b/go.mod @@ -4,17 +4,14 @@ go 1.15 require ( code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 // indirect - github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect github.com/apache/pulsar-client-go v0.1.1 - github.com/apache/thrift v0.13.0 - github.com/aws/aws-sdk-go v1.30.8 // indirect + github.com/aws/aws-sdk-go v1.30.8 github.com/coreos/etcd v3.3.25+incompatible // indirect - github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect + github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 github.com/frankban/quicktest v1.10.2 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/git-hooks/git-hooks v1.3.1 // indirect github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect - github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.2 github.com/google/btree v1.0.0 github.com/klauspost/compress v1.10.11 // indirect @@ -23,12 +20,12 @@ require ( github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/onsi/ginkgo v1.12.1 // indirect github.com/onsi/gomega v1.10.0 // indirect - github.com/opentracing/opentracing-go v1.2.0 + github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 // indirect github.com/pingcap/errors v0.11.4 // indirect github.com/pingcap/log v0.0.0-20200828042413-fce0951f1463 // indirect - github.com/pivotal-golang/bytefmt v0.0.0-20200131002437-cf55d5288a48 // indirect + github.com/pivotal-golang/bytefmt v0.0.0-20200131002437-cf55d5288a48 github.com/prometheus/client_golang v1.5.1 // indirect github.com/prometheus/common v0.10.0 // indirect github.com/prometheus/procfs v0.1.3 // indirect @@ -38,9 +35,7 @@ require ( github.com/spf13/cast v1.3.0 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.6.1 - github.com/tikv/client-go v0.0.0-20200824032810-95774393107b // indirect - github.com/uber/jaeger-client-go v2.25.0+incompatible - github.com/uber/jaeger-lib v2.4.0+incompatible // indirect + github.com/tikv/client-go v0.0.0-20200824032810-95774393107b github.com/urfave/cli v1.22.5 // indirect github.com/yahoo/athenz v1.9.16 // indirect go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 @@ -55,7 +50,7 @@ require ( google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150 // indirect google.golang.org/grpc v1.31.0 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect - gopkg.in/yaml.v2 v2.3.0 // indirect + gopkg.in/yaml.v2 v2.3.0 honnef.co/go/tools v0.0.1-2020.1.4 // indirect sigs.k8s.io/yaml v1.2.0 // indirect ) diff --git a/go.sum b/go.sum index 21b39b35f6c8665605565fadd7974b4434170bcc..eb4ef6b6a4059c712a42b92603ceb5d687117066 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/HdrHistogram/hdrhistogram-go v1.0.1 h1:GX8GAYDuhlFQnI2fRDHQhTlkHMz8bEn0jTI6LJU0mpw= -github.com/HdrHistogram/hdrhistogram-go v1.0.1/go.mod h1:BWJ+nMSHY3L41Zj7CA3uXnloDp7xxV0YvstAE7nKTaM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= @@ -26,8 +24,6 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1C github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/apache/pulsar-client-go v0.1.1 h1:v/kU+2ZCC6yFIcbZrFtWa9/nvVzVr18L+xYJUvZSxEQ= github.com/apache/pulsar-client-go v0.1.1/go.mod h1:mlxC65KL1BLhGO2bnT9zWMttVzR2czVPb27D477YpyU= -github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= -github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/ardielle/ardielle-go v1.5.2 h1:TilHTpHIQJ27R1Tl/iITBzMwiUGSlVfiVhwDNGM3Zj4= github.com/ardielle/ardielle-go v1.5.2/go.mod h1:I4hy1n795cUhaVt/ojz83SNVCYIGsAFAONtv2Dr7HUI= github.com/ardielle/ardielle-tools v1.5.4/go.mod h1:oZN+JRMnqGiIhrzkRN9l26Cej9dEx4jeNG6A+AdkShk= @@ -121,7 +117,6 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/protobuf v0.0.0-20180814211427-aa810b61a9c7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -173,7 +168,6 @@ github.com/grpc-ecosystem/grpc-gateway v1.8.1/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.9.5 h1:UImYN5qQ8tuGpGE16ZmjvcTtTw24zw1QAp/SlnNrZhI= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -349,7 +343,6 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/protocolbuffers/protobuf v3.14.0+incompatible h1:8r0H76h/Q/lEnFFY60AuM23NOnaDMi6bd7zuboSYM+o= github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA= github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= @@ -410,12 +403,6 @@ github.com/tikv/client-go v0.0.0-20200824032810-95774393107b/go.mod h1:K0NcdVNrX github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 h1:LnC5Kc/wtumK+WB441p7ynQJzVuNRJiqddSIE3IlSEQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/uber/jaeger-client-go v1.6.0 h1:3+zLlq+4npI5fg8IsgAje3YsP7TcEdNzJScyqFIzxEQ= -github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U= -github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= -github.com/uber/jaeger-lib v1.5.0 h1:OHbgr8l656Ub3Fw5k9SWnBfIEwvoHQ+W2y+Aa9D1Uyo= -github.com/uber/jaeger-lib v2.4.0+incompatible h1:fY7QsGQWiCt8pajv4r7JEvmATdCVaWxXbjwyYwsNaLQ= -github.com/uber/jaeger-lib v2.4.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/ugorji/go v1.1.2/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/ugorji/go/codec v0.0.0-20190204201341-e444a5086c43/go.mod h1:iT03XoTwV7xq/+UGwKO3UbC1nNNlopQiY61beSdrtOA= github.com/unrolled/render v1.0.0 h1:XYtvhA3UkpB7PqkvhUFYmpKD55OudoIeygcfus4vcd4= diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index 5d95eabf3d97ba40fcb9ee646b480a9c5a280254..4484d1e119bc4853e5cd49df6396047e48df8e8f 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -54,8 +54,12 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con conf[key] = value; } - auto stoi_closure = [](const std::string& s) -> int { return std::stoi(s); }; - auto stof_closure = [](const std::string& s) -> int { return std::stof(s); }; + auto stoi_closure = [](const std::string& s) -> auto { + return std::stoi(s); + }; + auto stof_closure = [](const std::string& s) -> auto { + return std::stof(s); + }; /***************************** meta *******************************/ check_parameter<int>(conf, milvus::knowhere::meta::DIM, stoi_closure, std::nullopt); diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index bd335951f8053029f720e368d7907cb0d65d451d..d50f11d18a78fc2540cb15fcdbe1322b4c9238ef 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -267,10 +267,11 @@ L2(const float* point_a, const float* point_b, int dim) { return dis; } -int hamming_weight(uint8_t n) { - int count=0; - while(n != 0){ - count += n&1; +int +hamming_weight(uint8_t n) { + int count = 0; + while (n != 0) { + count += n & 1; n >>= 1; } return count; diff --git a/internal/indexbuilder/task.go b/internal/indexbuilder/task.go index 0c92b153e3cb2956fdda7000cc2e6fa7e2d3e66b..a7c62df7255c03780b69f716fbb4a10049eeb1a2 100644 --- a/internal/indexbuilder/task.go +++ b/internal/indexbuilder/task.go @@ -223,14 +223,24 @@ func (it *IndexBuildTask) Execute() error { for _, value := range insertData.Data { // TODO: BinaryVectorFieldData - floatVectorFieldData, ok := value.(*storage.FloatVectorFieldData) - if !ok { - return errors.New("we expect FloatVectorFieldData or BinaryVectorFieldData") + floatVectorFieldData, fOk := value.(*storage.FloatVectorFieldData) + if fOk { + err = it.index.BuildFloatVecIndexWithoutIds(floatVectorFieldData.Data) + if err != nil { + return err + } } - err = it.index.BuildFloatVecIndexWithoutIds(floatVectorFieldData.Data) - if err != nil { - return err + binaryVectorFieldData, bOk := value.(*storage.BinaryVectorFieldData) + if bOk { + err = it.index.BuildBinaryVecIndexWithoutIds(binaryVectorFieldData.Data) + if err != nil { + return err + } + } + + if !fOk || !bOk { + return errors.New("we expect FloatVectorFieldData or BinaryVectorFieldData") } indexBlobs, err := it.index.Serialize() diff --git a/internal/master/master.go b/internal/master/master.go index 6ade92eaad31c50a8923bdf37d79ed429df7750b..adb002adf804461a267412e6bd13bacfe16de96f 100644 --- a/internal/master/master.go +++ b/internal/master/master.go @@ -218,7 +218,6 @@ func CreateServer(ctx context.Context) (*Master, error) { m.grpcServer = grpc.NewServer() masterpb.RegisterMasterServer(m.grpcServer, m) - return m, nil } diff --git a/internal/master/master_test.go b/internal/master/master_test.go index a605e73aa76127c24918cdd3826fae0d0d186ad8..0a44ed90e886b55d6b7bd5bec9e2d1842041fd2c 100644 --- a/internal/master/master_test.go +++ b/internal/master/master_test.go @@ -110,7 +110,6 @@ func TestMaster(t *testing.T) { conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) require.Nil(t, err) - cli := masterpb.NewMasterClient(conn) t.Run("TestConfigTask", func(t *testing.T) { @@ -887,6 +886,12 @@ func TestMaster(t *testing.T) { var k2sMsgstream ms.MsgStream = k2sMs assert.True(t, receiveTimeTickMsg(&k2sMsgstream)) + conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + assert.Nil(t, err) + defer conn.Close() + + cli := masterpb.NewMasterClient(conn) + sch := schemapb.CollectionSchema{ Name: "name" + strconv.FormatUint(rand.Uint64(), 10), Description: "test collection", diff --git a/internal/master/segment_manager.go b/internal/master/segment_manager.go index e7a4b64dce64454033e4cb6d98c8765e8ce25bf6..63ef0a11b6f4ee7aeec92e8e3dbb9c3de6286416 100644 --- a/internal/master/segment_manager.go +++ b/internal/master/segment_manager.go @@ -297,7 +297,8 @@ func (manager *SegmentManagerImpl) syncWriteNodeTimestamp(timeTick Timestamp) er manager.mu.Lock() defer manager.mu.Unlock() for _, status := range manager.collStatus { - for i, segStatus := range status.segments { + for i := 0; i < len(status.segments); i++ { + segStatus := status.segments[i] if !segStatus.closable { closable, err := manager.judgeSegmentClosable(segStatus) if err != nil { @@ -318,6 +319,7 @@ func (manager *SegmentManagerImpl) syncWriteNodeTimestamp(timeTick Timestamp) er continue } status.segments = append(status.segments[:i], status.segments[i+1:]...) + i-- ts, err := manager.globalTSOAllocator() if err != nil { log.Printf("allocate tso error: %s", err.Error()) diff --git a/internal/msgstream/msg.go b/internal/msgstream/msg.go index 3b8f4dacd7ccd5b313e29701852df800fb716ef5..518bcfa7afe56a34cf86ff1464eb309a42560a9c 100644 --- a/internal/msgstream/msg.go +++ b/internal/msgstream/msg.go @@ -1,8 +1,6 @@ package msgstream import ( - "context" - "github.com/golang/protobuf/proto" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) @@ -10,8 +8,6 @@ import ( type MsgType = internalPb.MsgType type TsMsg interface { - GetMsgContext() context.Context - SetMsgContext(context.Context) BeginTs() Timestamp EndTs() Timestamp Type() MsgType @@ -21,7 +17,6 @@ type TsMsg interface { } type BaseMsg struct { - MsgCtx context.Context BeginTimestamp Timestamp EndTimestamp Timestamp HashValues []uint32 @@ -49,14 +44,6 @@ func (it *InsertMsg) Type() MsgType { return it.MsgType } -func (it *InsertMsg) GetMsgContext() context.Context { - return it.MsgCtx -} - -func (it *InsertMsg) SetMsgContext(ctx context.Context) { - it.MsgCtx = ctx -} - func (it *InsertMsg) Marshal(input TsMsg) ([]byte, error) { insertMsg := input.(*InsertMsg) insertRequest := &insertMsg.InsertRequest @@ -101,13 +88,6 @@ func (fl *FlushMsg) Type() MsgType { return fl.GetMsgType() } -func (fl *FlushMsg) GetMsgContext() context.Context { - return fl.MsgCtx -} -func (fl *FlushMsg) SetMsgContext(ctx context.Context) { - fl.MsgCtx = ctx -} - func (fl *FlushMsg) Marshal(input TsMsg) ([]byte, error) { flushMsgTask := input.(*FlushMsg) flushMsg := &flushMsgTask.FlushMsg @@ -141,14 +121,6 @@ func (dt *DeleteMsg) Type() MsgType { return dt.MsgType } -func (dt *DeleteMsg) GetMsgContext() context.Context { - return dt.MsgCtx -} - -func (dt *DeleteMsg) SetMsgContext(ctx context.Context) { - dt.MsgCtx = ctx -} - func (dt *DeleteMsg) Marshal(input TsMsg) ([]byte, error) { deleteTask := input.(*DeleteMsg) deleteRequest := &deleteTask.DeleteRequest @@ -193,14 +165,6 @@ func (st *SearchMsg) Type() MsgType { return st.MsgType } -func (st *SearchMsg) GetMsgContext() context.Context { - return st.MsgCtx -} - -func (st *SearchMsg) SetMsgContext(ctx context.Context) { - st.MsgCtx = ctx -} - func (st *SearchMsg) Marshal(input TsMsg) ([]byte, error) { searchTask := input.(*SearchMsg) searchRequest := &searchTask.SearchRequest @@ -234,14 +198,6 @@ func (srt *SearchResultMsg) Type() MsgType { return srt.MsgType } -func (srt *SearchResultMsg) GetMsgContext() context.Context { - return srt.MsgCtx -} - -func (srt *SearchResultMsg) SetMsgContext(ctx context.Context) { - srt.MsgCtx = ctx -} - func (srt *SearchResultMsg) Marshal(input TsMsg) ([]byte, error) { searchResultTask := input.(*SearchResultMsg) searchResultRequest := &searchResultTask.SearchResult @@ -275,14 +231,6 @@ func (tst *TimeTickMsg) Type() MsgType { return tst.MsgType } -func (tst *TimeTickMsg) GetMsgContext() context.Context { - return tst.MsgCtx -} - -func (tst *TimeTickMsg) SetMsgContext(ctx context.Context) { - tst.MsgCtx = ctx -} - func (tst *TimeTickMsg) Marshal(input TsMsg) ([]byte, error) { timeTickTask := input.(*TimeTickMsg) timeTick := &timeTickTask.TimeTickMsg @@ -316,14 +264,6 @@ func (qs *QueryNodeStatsMsg) Type() MsgType { return qs.MsgType } -func (qs *QueryNodeStatsMsg) GetMsgContext() context.Context { - return qs.MsgCtx -} - -func (qs *QueryNodeStatsMsg) SetMsgContext(ctx context.Context) { - qs.MsgCtx = ctx -} - func (qs *QueryNodeStatsMsg) Marshal(input TsMsg) ([]byte, error) { queryNodeSegStatsTask := input.(*QueryNodeStatsMsg) queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeStats @@ -365,14 +305,6 @@ func (cc *CreateCollectionMsg) Type() MsgType { return cc.MsgType } -func (cc *CreateCollectionMsg) GetMsgContext() context.Context { - return cc.MsgCtx -} - -func (cc *CreateCollectionMsg) SetMsgContext(ctx context.Context) { - cc.MsgCtx = ctx -} - func (cc *CreateCollectionMsg) Marshal(input TsMsg) ([]byte, error) { createCollectionMsg := input.(*CreateCollectionMsg) createCollectionRequest := &createCollectionMsg.CreateCollectionRequest @@ -405,13 +337,6 @@ type DropCollectionMsg struct { func (dc *DropCollectionMsg) Type() MsgType { return dc.MsgType } -func (dc *DropCollectionMsg) GetMsgContext() context.Context { - return dc.MsgCtx -} - -func (dc *DropCollectionMsg) SetMsgContext(ctx context.Context) { - dc.MsgCtx = ctx -} func (dc *DropCollectionMsg) Marshal(input TsMsg) ([]byte, error) { dropCollectionMsg := input.(*DropCollectionMsg) @@ -436,18 +361,109 @@ func (dc *DropCollectionMsg) Unmarshal(input []byte) (TsMsg, error) { return dropCollectionMsg, nil } -/////////////////////////////////////////CreatePartition////////////////////////////////////////// -type CreatePartitionMsg struct { +/////////////////////////////////////////HasCollection////////////////////////////////////////// +type HasCollectionMsg struct { BaseMsg - internalPb.CreatePartitionRequest + internalPb.HasCollectionRequest } -func (cc *CreatePartitionMsg) GetMsgContext() context.Context { - return cc.MsgCtx +func (hc *HasCollectionMsg) Type() MsgType { + return hc.MsgType } -func (cc *CreatePartitionMsg) SetMsgContext(ctx context.Context) { - cc.MsgCtx = ctx +func (hc *HasCollectionMsg) Marshal(input TsMsg) ([]byte, error) { + hasCollectionMsg := input.(*HasCollectionMsg) + hasCollectionRequest := &hasCollectionMsg.HasCollectionRequest + mb, err := proto.Marshal(hasCollectionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (hc *HasCollectionMsg) Unmarshal(input []byte) (TsMsg, error) { + hasCollectionRequest := internalPb.HasCollectionRequest{} + err := proto.Unmarshal(input, &hasCollectionRequest) + if err != nil { + return nil, err + } + hasCollectionMsg := &HasCollectionMsg{HasCollectionRequest: hasCollectionRequest} + hasCollectionMsg.BeginTimestamp = hasCollectionMsg.Timestamp + hasCollectionMsg.EndTimestamp = hasCollectionMsg.Timestamp + + return hasCollectionMsg, nil +} + +/////////////////////////////////////////DescribeCollection////////////////////////////////////////// +type DescribeCollectionMsg struct { + BaseMsg + internalPb.DescribeCollectionRequest +} + +func (dc *DescribeCollectionMsg) Type() MsgType { + return dc.MsgType +} + +func (dc *DescribeCollectionMsg) Marshal(input TsMsg) ([]byte, error) { + describeCollectionMsg := input.(*DescribeCollectionMsg) + describeCollectionRequest := &describeCollectionMsg.DescribeCollectionRequest + mb, err := proto.Marshal(describeCollectionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (dc *DescribeCollectionMsg) Unmarshal(input []byte) (TsMsg, error) { + describeCollectionRequest := internalPb.DescribeCollectionRequest{} + err := proto.Unmarshal(input, &describeCollectionRequest) + if err != nil { + return nil, err + } + describeCollectionMsg := &DescribeCollectionMsg{DescribeCollectionRequest: describeCollectionRequest} + describeCollectionMsg.BeginTimestamp = describeCollectionMsg.Timestamp + describeCollectionMsg.EndTimestamp = describeCollectionMsg.Timestamp + + return describeCollectionMsg, nil +} + +/////////////////////////////////////////ShowCollection////////////////////////////////////////// +type ShowCollectionMsg struct { + BaseMsg + internalPb.ShowCollectionRequest +} + +func (sc *ShowCollectionMsg) Type() MsgType { + return sc.MsgType +} + +func (sc *ShowCollectionMsg) Marshal(input TsMsg) ([]byte, error) { + showCollectionMsg := input.(*ShowCollectionMsg) + showCollectionRequest := &showCollectionMsg.ShowCollectionRequest + mb, err := proto.Marshal(showCollectionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (sc *ShowCollectionMsg) Unmarshal(input []byte) (TsMsg, error) { + showCollectionRequest := internalPb.ShowCollectionRequest{} + err := proto.Unmarshal(input, &showCollectionRequest) + if err != nil { + return nil, err + } + showCollectionMsg := &ShowCollectionMsg{ShowCollectionRequest: showCollectionRequest} + showCollectionMsg.BeginTimestamp = showCollectionMsg.Timestamp + showCollectionMsg.EndTimestamp = showCollectionMsg.Timestamp + + return showCollectionMsg, nil +} + +/////////////////////////////////////////CreatePartition////////////////////////////////////////// +type CreatePartitionMsg struct { + BaseMsg + internalPb.CreatePartitionRequest } func (cc *CreatePartitionMsg) Type() MsgType { @@ -483,14 +499,6 @@ type DropPartitionMsg struct { internalPb.DropPartitionRequest } -func (dc *DropPartitionMsg) GetMsgContext() context.Context { - return dc.MsgCtx -} - -func (dc *DropPartitionMsg) SetMsgContext(ctx context.Context) { - dc.MsgCtx = ctx -} - func (dc *DropPartitionMsg) Type() MsgType { return dc.MsgType } @@ -518,6 +526,105 @@ func (dc *DropPartitionMsg) Unmarshal(input []byte) (TsMsg, error) { return dropPartitionMsg, nil } +/////////////////////////////////////////HasPartition////////////////////////////////////////// +type HasPartitionMsg struct { + BaseMsg + internalPb.HasPartitionRequest +} + +func (hc *HasPartitionMsg) Type() MsgType { + return hc.MsgType +} + +func (hc *HasPartitionMsg) Marshal(input TsMsg) ([]byte, error) { + hasPartitionMsg := input.(*HasPartitionMsg) + hasPartitionRequest := &hasPartitionMsg.HasPartitionRequest + mb, err := proto.Marshal(hasPartitionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (hc *HasPartitionMsg) Unmarshal(input []byte) (TsMsg, error) { + hasPartitionRequest := internalPb.HasPartitionRequest{} + err := proto.Unmarshal(input, &hasPartitionRequest) + if err != nil { + return nil, err + } + hasPartitionMsg := &HasPartitionMsg{HasPartitionRequest: hasPartitionRequest} + hasPartitionMsg.BeginTimestamp = hasPartitionMsg.Timestamp + hasPartitionMsg.EndTimestamp = hasPartitionMsg.Timestamp + + return hasPartitionMsg, nil +} + +/////////////////////////////////////////DescribePartition////////////////////////////////////////// +type DescribePartitionMsg struct { + BaseMsg + internalPb.DescribePartitionRequest +} + +func (dc *DescribePartitionMsg) Type() MsgType { + return dc.MsgType +} + +func (dc *DescribePartitionMsg) Marshal(input TsMsg) ([]byte, error) { + describePartitionMsg := input.(*DescribePartitionMsg) + describePartitionRequest := &describePartitionMsg.DescribePartitionRequest + mb, err := proto.Marshal(describePartitionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (dc *DescribePartitionMsg) Unmarshal(input []byte) (TsMsg, error) { + describePartitionRequest := internalPb.DescribePartitionRequest{} + err := proto.Unmarshal(input, &describePartitionRequest) + if err != nil { + return nil, err + } + describePartitionMsg := &DescribePartitionMsg{DescribePartitionRequest: describePartitionRequest} + describePartitionMsg.BeginTimestamp = describePartitionMsg.Timestamp + describePartitionMsg.EndTimestamp = describePartitionMsg.Timestamp + + return describePartitionMsg, nil +} + +/////////////////////////////////////////ShowPartition////////////////////////////////////////// +type ShowPartitionMsg struct { + BaseMsg + internalPb.ShowPartitionRequest +} + +func (sc *ShowPartitionMsg) Type() MsgType { + return sc.MsgType +} + +func (sc *ShowPartitionMsg) Marshal(input TsMsg) ([]byte, error) { + showPartitionMsg := input.(*ShowPartitionMsg) + showPartitionRequest := &showPartitionMsg.ShowPartitionRequest + mb, err := proto.Marshal(showPartitionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (sc *ShowPartitionMsg) Unmarshal(input []byte) (TsMsg, error) { + showPartitionRequest := internalPb.ShowPartitionRequest{} + err := proto.Unmarshal(input, &showPartitionRequest) + if err != nil { + return nil, err + } + showPartitionMsg := &ShowPartitionMsg{ShowPartitionRequest: showPartitionRequest} + showPartitionMsg.BeginTimestamp = showPartitionMsg.Timestamp + showPartitionMsg.EndTimestamp = showPartitionMsg.Timestamp + + return showPartitionMsg, nil +} + /////////////////////////////////////////LoadIndex////////////////////////////////////////// type LoadIndexMsg struct { BaseMsg @@ -528,14 +635,6 @@ func (lim *LoadIndexMsg) Type() MsgType { return lim.MsgType } -func (lim *LoadIndexMsg) GetMsgContext() context.Context { - return lim.MsgCtx -} - -func (lim *LoadIndexMsg) SetMsgContext(ctx context.Context) { - lim.MsgCtx = ctx -} - func (lim *LoadIndexMsg) Marshal(input TsMsg) ([]byte, error) { loadIndexMsg := input.(*LoadIndexMsg) loadIndexRequest := &loadIndexMsg.LoadIndex diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index d59c5be4e70dc553e71b18c9fceac1feeb5fe88d..37dd71c053441673334cbda1c60adac9fdc5fc5f 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -4,15 +4,12 @@ import ( "context" "log" "reflect" - "strings" "sync" "time" "github.com/apache/pulsar-client-go/pulsar" "github.com/golang/protobuf/proto" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" - oplog "github.com/opentracing/opentracing-go/log" + "github.com/zilliztech/milvus-distributed/internal/errors" commonPb "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" @@ -154,29 +151,6 @@ func (ms *PulsarMsgStream) Close() { } } -type propertiesReaderWriter struct { - ppMap map[string]string -} - -func (ppRW *propertiesReaderWriter) Set(key, val string) { - // The GRPC HPACK implementation rejects any uppercase keys here. - // - // As such, since the HTTP_HEADERS format is case-insensitive anyway, we - // blindly lowercase the key (which is guaranteed to work in the - // Inject/Extract sense per the OpenTracing spec). - key = strings.ToLower(key) - ppRW.ppMap[key] = val -} - -func (ppRW *propertiesReaderWriter) ForeachKey(handler func(key, val string) error) error { - for k, val := range ppRW.ppMap { - if err := handler(k, val); err != nil { - return err - } - } - return nil -} - func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { tsMsgs := msgPack.Msgs if len(tsMsgs) <= 0 { @@ -226,50 +200,12 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { if err != nil { return err } - - msg := &pulsar.ProducerMessage{Payload: mb} - var child opentracing.Span - if v.Msgs[i].Type() == internalPb.MsgType_kSearch || - v.Msgs[i].Type() == internalPb.MsgType_kSearchResult { - tracer := opentracing.GlobalTracer() - ctx := v.Msgs[i].GetMsgContext() - if ctx == nil { - ctx = context.Background() - } - - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan("start send pulsar msg", - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan("start send pulsar msg") - } - child.SetTag("hash keys", v.Msgs[i].HashKeys()) - child.SetTag("start time", v.Msgs[i].BeginTs()) - child.SetTag("end time", v.Msgs[i].EndTs()) - child.SetTag("msg type", v.Msgs[i].Type()) - msg.Properties = make(map[string]string) - err = tracer.Inject(child.Context(), opentracing.TextMap, &propertiesReaderWriter{msg.Properties}) - if err != nil { - child.LogFields(oplog.Error(err)) - child.Finish() - return err - } - child.LogFields(oplog.String("inject success", "inject success")) - } - if _, err := (*ms.producers[k]).Send( context.Background(), - msg, + &pulsar.ProducerMessage{Payload: mb}, ); err != nil { - if child != nil { - child.LogFields(oplog.Error(err)) - child.Finish() - } return err } - if child != nil { - child.Finish() - } } } return nil @@ -282,49 +218,14 @@ func (ms *PulsarMsgStream) Broadcast(msgPack *MsgPack) error { if err != nil { return err } - msg := &pulsar.ProducerMessage{Payload: mb} - var child opentracing.Span - if v.Type() == internalPb.MsgType_kSearch || - v.Type() == internalPb.MsgType_kSearchResult { - tracer := opentracing.GlobalTracer() - ctx := v.GetMsgContext() - if ctx == nil { - ctx = context.Background() - } - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan("start send pulsar msg", - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan("start send pulsar msg, start time: %d") - } - child.SetTag("hash keys", v.HashKeys()) - child.SetTag("start time", v.BeginTs()) - child.SetTag("end time", v.EndTs()) - child.SetTag("msg type", v.Type()) - msg.Properties = make(map[string]string) - err = tracer.Inject(child.Context(), opentracing.TextMap, &propertiesReaderWriter{msg.Properties}) - if err != nil { - child.LogFields(oplog.Error(err)) - child.Finish() - return err - } - child.LogFields(oplog.String("inject success", "inject success")) - } for i := 0; i < producerLen; i++ { if _, err := (*ms.producers[i]).Send( context.Background(), - msg, + &pulsar.ProducerMessage{Payload: mb}, ); err != nil { - if child != nil { - child.LogFields(oplog.Error(err)) - child.Finish() - } return err } } - if child != nil { - child.Finish() - } } return nil } @@ -357,7 +258,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { for { select { case <-ms.ctx.Done(): - log.Println("done") return default: tsMsgList := make([]TsMsg, 0) @@ -370,7 +270,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { } pulsarMsg, ok := value.Interface().(pulsar.ConsumerMessage) - if !ok { log.Printf("type assertion failed, not consumer message type") continue @@ -384,23 +283,6 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { continue } tsMsg, err := ms.unmarshal.Unmarshal(pulsarMsg.Payload(), headerMsg.MsgType) - if tsMsg.Type() == internalPb.MsgType_kSearch || - tsMsg.Type() == internalPb.MsgType_kSearchResult { - tracer := opentracing.GlobalTracer() - spanContext, err := tracer.Extract(opentracing.HTTPHeaders, &propertiesReaderWriter{pulsarMsg.Properties()}) - if err != nil { - log.Println("extract message err") - log.Println(err.Error()) - } - span := opentracing.StartSpan("pulsar msg received", - ext.RPCServerOption(spanContext)) - span.SetTag("msg type", tsMsg.Type()) - span.SetTag("hash keys", tsMsg.HashKeys()) - span.SetTag("start time", tsMsg.BeginTs()) - span.SetTag("end time", tsMsg.EndTs()) - tsMsg.SetMsgContext(opentracing.ContextWithSpan(context.Background(), span)) - span.Finish() - } if err != nil { log.Printf("Failed to unmarshal tsMsg, error = %v", err) continue @@ -464,8 +346,6 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { ms.inputBuf = make([]TsMsg, 0) isChannelReady := make([]bool, len(ms.consumers)) eofMsgTimeStamp := make(map[int]Timestamp) - spans := make(map[Timestamp]opentracing.Span) - ctxs := make(map[Timestamp]context.Context) for { select { case <-ms.ctx.Done(): @@ -491,22 +371,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { ms.inputBuf = append(ms.inputBuf, ms.unsolvedBuf...) ms.unsolvedBuf = ms.unsolvedBuf[:0] for _, v := range ms.inputBuf { - var ctx context.Context - var span opentracing.Span - if v.Type() == internalPb.MsgType_kInsert { - if _, ok := spans[v.BeginTs()]; !ok { - span, ctx = opentracing.StartSpanFromContext(v.GetMsgContext(), "after find time tick") - ctxs[v.BeginTs()] = ctx - spans[v.BeginTs()] = span - } - } if v.EndTs() <= timeStamp { timeTickBuf = append(timeTickBuf, v) - if v.Type() == internalPb.MsgType_kInsert { - v.SetMsgContext(ctxs[v.BeginTs()]) - spans[v.BeginTs()].Finish() - delete(spans, v.BeginTs()) - } } else { ms.unsolvedBuf = append(ms.unsolvedBuf, v) } @@ -554,24 +420,6 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, if err != nil { log.Printf("Failed to unmarshal, error = %v", err) } - - if tsMsg.Type() == internalPb.MsgType_kInsert { - tracer := opentracing.GlobalTracer() - spanContext, err := tracer.Extract(opentracing.HTTPHeaders, &propertiesReaderWriter{pulsarMsg.Properties()}) - if err != nil { - log.Println("extract message err") - log.Println(err.Error()) - } - span := opentracing.StartSpan("pulsar msg received", - ext.RPCServerOption(spanContext)) - span.SetTag("hash keys", tsMsg.HashKeys()) - span.SetTag("start time", tsMsg.BeginTs()) - span.SetTag("end time", tsMsg.EndTs()) - span.SetTag("msg type", tsMsg.Type()) - tsMsg.SetMsgContext(opentracing.ContextWithSpan(context.Background(), span)) - span.Finish() - } - if headerMsg.MsgType == internalPb.MsgType_kTimeTick { eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp return @@ -652,7 +500,7 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e result := make(map[int32]*MsgPack) for i, request := range tsMsgs { if request.Type() != internalPb.MsgType_kInsert { - return nil, errors.New("msg's must be Insert") + return nil, errors.New(string("msg's must be Insert")) } insertRequest := request.(*InsertMsg) keys := hashKeys[i] @@ -663,7 +511,7 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e keysLen := len(keys) if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { - return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal") + return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) } for index, key := range keys { _, ok := result[key] @@ -686,9 +534,6 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e } insertMsg := &InsertMsg{ - BaseMsg: BaseMsg{ - MsgCtx: request.GetMsgContext(), - }, InsertRequest: sliceRequest, } result[key].Msgs = append(result[key].Msgs, insertMsg) @@ -701,7 +546,7 @@ func deleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e result := make(map[int32]*MsgPack) for i, request := range tsMsgs { if request.Type() != internalPb.MsgType_kDelete { - return nil, errors.New("msg's must be Delete") + return nil, errors.New(string("msg's must be Delete")) } deleteRequest := request.(*DeleteMsg) keys := hashKeys[i] @@ -711,7 +556,7 @@ func deleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e keysLen := len(keys) if keysLen != timestampLen || keysLen != primaryKeysLen { - return nil, errors.New("the length of hashValue, timestamps, primaryKeys are not equal") + return nil, errors.New(string("the length of hashValue, timestamps, primaryKeys are not equal")) } for index, key := range keys { @@ -745,7 +590,7 @@ func defaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, for i, request := range tsMsgs { keys := hashKeys[i] if len(keys) != 1 { - return nil, errors.New("len(msg.hashValue) must equal 1") + return nil, errors.New(string("len(msg.hashValue) must equal 1")) } key := keys[0] _, ok := result[key] diff --git a/internal/proxy/grpc_service.go b/internal/proxy/grpc_service.go index 29bb0eeecd5ff4354a6eaab3226715ba007321d3..df298ba1ecd3b22d36ec9dcadc111e3d8d777169 100644 --- a/internal/proxy/grpc_service.go +++ b/internal/proxy/grpc_service.go @@ -6,7 +6,6 @@ import ( "log" "time" - "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" @@ -19,13 +18,8 @@ const ( ) func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.IntegerRangeResponse, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "insert grpc received") - defer span.Finish() - span.SetTag("collection name", in.CollectionName) - span.SetTag("partition tag", in.PartitionTag) log.Println("insert into: ", in.CollectionName) it := &InsertTask{ - ctx: ctx, Condition: NewTaskCondition(ctx), BaseInsertTask: BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ @@ -125,14 +119,8 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc } func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.QueryResult, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "search grpc received") - defer span.Finish() - span.SetTag("collection name", req.CollectionName) - span.SetTag("partition tag", req.PartitionTags) - span.SetTag("dsl", req.Dsl) log.Println("search: ", req.CollectionName, req.Dsl) qt := &QueryTask{ - ctx: ctx, Condition: NewTaskCondition(ctx), SearchRequest: internalpb.SearchRequest{ ProxyID: Params.ProxyID(), diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index cc2ba4832cfac0d8140efafce98bcd1cf080f5cf..f4232bc82234f07d54d713772e89d70f77b5586f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,8 +2,6 @@ package proxy import ( "context" - "fmt" - "io" "log" "math/rand" "net" @@ -11,10 +9,6 @@ import ( "sync" "time" - "github.com/opentracing/opentracing-go" - "github.com/uber/jaeger-client-go" - "github.com/uber/jaeger-client-go/config" - "google.golang.org/grpc" "github.com/zilliztech/milvus-distributed/internal/allocator" @@ -45,9 +39,6 @@ type Proxy struct { manipulationMsgStream *msgstream.PulsarMsgStream queryMsgStream *msgstream.PulsarMsgStream - tracer opentracing.Tracer - closer io.Closer - // Add callback functions at different stages startCallbacks []func() closeCallbacks []func() @@ -60,28 +51,11 @@ func Init() { func CreateProxy(ctx context.Context) (*Proxy, error) { rand.Seed(time.Now().UnixNano()) ctx1, cancel := context.WithCancel(ctx) - var err error p := &Proxy{ proxyLoopCtx: ctx1, proxyLoopCancel: cancel, } - cfg := &config.Configuration{ - ServiceName: "tracing", - Sampler: &config.SamplerConfig{ - Type: "const", - Param: 1, - }, - Reporter: &config.ReporterConfig{ - LogSpans: true, - }, - } - p.tracer, p.closer, err = cfg.NewTracer(config.Logger(jaeger.StdLogger)) - if err != nil { - panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err)) - } - opentracing.SetGlobalTracer(p.tracer) - pulsarAddress := Params.PulsarAddress() p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize()) @@ -224,17 +198,12 @@ func (p *Proxy) stopProxyLoop() { p.tick.Close() p.proxyLoopWg.Wait() - } // Close closes the server. func (p *Proxy) Close() { p.stopProxyLoop() - if p.closer != nil { - p.closer.Close() - } - for _, cb := range p.closeCallbacks { cb() } diff --git a/internal/proxy/repack_func.go b/internal/proxy/repack_func.go index 45c4f4d0abc6b5af643e32a7e7229d9442ae298c..44139999e0403719ca9eaf141f110980b808b6e1 100644 --- a/internal/proxy/repack_func.go +++ b/internal/proxy/repack_func.go @@ -182,7 +182,6 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, insertMsg := &msgstream.InsertMsg{ InsertRequest: sliceRequest, } - insertMsg.SetMsgContext(request.GetMsgContext()) if together { // all rows with same hash value are accumulated to only one message if len(result[key].Msgs) <= 0 { result[key].Msgs = append(result[key].Msgs, insertMsg) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d28e5654a4c76499f91052aa475bd00c94d12ccc..425cae75cfb3e4de24b55cf4a24cf3cc5aa55dbe 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -7,9 +7,6 @@ import ( "math" "strconv" - "github.com/opentracing/opentracing-go" - oplog "github.com/opentracing/opentracing-go/log" - "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -77,21 +74,12 @@ func (it *InsertTask) Type() internalpb.MsgType { } func (it *InsertTask) PreExecute() error { - span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask preExecute") - defer span.Finish() - it.ctx = ctx - span.SetTag("hash keys", it.ReqID) - span.SetTag("start time", it.BeginTs()) collectionName := it.BaseInsertTask.CollectionName if err := ValidateCollectionName(collectionName); err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } partitionTag := it.BaseInsertTask.PartitionTag if err := ValidatePartitionTag(partitionTag, true); err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } @@ -99,36 +87,22 @@ func (it *InsertTask) PreExecute() error { } func (it *InsertTask) Execute() error { - span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask Execute") - defer span.Finish() - it.ctx = ctx - span.SetTag("hash keys", it.ReqID) - span.SetTag("start time", it.BeginTs()) collectionName := it.BaseInsertTask.CollectionName - span.LogFields(oplog.String("collection_name", collectionName)) if !globalMetaCache.Hit(collectionName) { err := globalMetaCache.Sync(collectionName) if err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } } description, err := globalMetaCache.Get(collectionName) if err != nil || description == nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } autoID := description.Schema.AutoID - span.LogFields(oplog.Bool("auto_id", autoID)) var rowIDBegin UniqueID var rowIDEnd UniqueID rowNums := len(it.BaseInsertTask.RowData) rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums)) - span.LogFields(oplog.Int("rowNums", rowNums), - oplog.Int("rowIDBegin", int(rowIDBegin)), - oplog.Int("rowIDEnd", int(rowIDEnd))) it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums) for i := rowIDBegin; i < rowIDEnd; i++ { offset := i - rowIDBegin @@ -151,8 +125,6 @@ func (it *InsertTask) Execute() error { EndTs: it.EndTs(), Msgs: make([]msgstream.TsMsg, 1), } - tsMsg.SetMsgContext(ctx) - span.LogFields(oplog.String("send msg", "send msg")) msgPack.Msgs[0] = tsMsg err = it.manipulationMsgStream.Produce(msgPack) @@ -166,14 +138,11 @@ func (it *InsertTask) Execute() error { if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR it.result.Status.Reason = err.Error() - span.LogFields(oplog.Error(err)) } return nil } func (it *InsertTask) PostExecute() error { - span, _ := opentracing.StartSpanFromContext(it.ctx, "InsertTask postExecute") - defer span.Finish() return nil } @@ -383,38 +352,24 @@ func (qt *QueryTask) SetTs(ts Timestamp) { } func (qt *QueryTask) PreExecute() error { - span, ctx := opentracing.StartSpanFromContext(qt.ctx, "QueryTask preExecute") - defer span.Finish() - qt.ctx = ctx - span.SetTag("hash keys", qt.ReqID) - span.SetTag("start time", qt.BeginTs()) - collectionName := qt.query.CollectionName if !globalMetaCache.Hit(collectionName) { err := globalMetaCache.Sync(collectionName) if err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } } _, err := globalMetaCache.Get(collectionName) if err != nil { // err is not nil if collection not exists - span.LogFields(oplog.Error(err)) - span.Finish() return err } if err := ValidateCollectionName(qt.query.CollectionName); err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } for _, tag := range qt.query.PartitionTags { if err := ValidatePartitionTag(tag, false); err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } } @@ -424,8 +379,6 @@ func (qt *QueryTask) PreExecute() error { } queryBytes, err := proto.Marshal(qt.query) if err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() return err } qt.Query = &commonpb.Blob{ @@ -435,11 +388,6 @@ func (qt *QueryTask) PreExecute() error { } func (qt *QueryTask) Execute() error { - span, ctx := opentracing.StartSpanFromContext(qt.ctx, "QueryTask Execute") - defer span.Finish() - qt.ctx = ctx - span.SetTag("hash keys", qt.ReqID) - span.SetTag("start time", qt.BeginTs()) var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{ SearchRequest: qt.SearchRequest, BaseMsg: msgstream.BaseMsg{ @@ -453,31 +401,22 @@ func (qt *QueryTask) Execute() error { EndTs: qt.Timestamp, Msgs: make([]msgstream.TsMsg, 1), } - tsMsg.SetMsgContext(ctx) msgPack.Msgs[0] = tsMsg err := qt.queryMsgStream.Produce(msgPack) log.Printf("[Proxy] length of searchMsg: %v", len(msgPack.Msgs)) if err != nil { - span.LogFields(oplog.Error(err)) - span.Finish() log.Printf("[Proxy] send search request failed: %v", err) } return err } func (qt *QueryTask) PostExecute() error { - span, _ := opentracing.StartSpanFromContext(qt.ctx, "QueryTask postExecute") - defer span.Finish() - span.SetTag("hash keys", qt.ReqID) - span.SetTag("start time", qt.BeginTs()) for { select { case <-qt.ctx.Done(): log.Print("wait to finish failed, timeout!") - span.LogFields(oplog.String("wait to finish failed, timeout", "wait to finish failed, timeout")) return errors.New("wait to finish failed, timeout") case searchResults := <-qt.resultBuf: - span.LogFields(oplog.String("receive result", "receive result")) filterSearchResult := make([]*internalpb.SearchResult, 0) var filterReason string for _, partialSearchResult := range searchResults { @@ -496,7 +435,6 @@ func (qt *QueryTask) PostExecute() error { Reason: filterReason, }, } - span.LogFields(oplog.Error(errors.New(filterReason))) return errors.New(filterReason) } @@ -588,7 +526,6 @@ func (qt *QueryTask) PostExecute() error { reducedHitsBs, err := proto.Marshal(reducedHits) if err != nil { log.Println("marshal error") - span.LogFields(oplog.Error(err)) return err } qt.result.Hits = append(qt.result.Hits, reducedHitsBs) @@ -700,10 +637,7 @@ func (dct *DescribeCollectionTask) PreExecute() error { func (dct *DescribeCollectionTask) Execute() error { var err error dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest) - if err != nil { - return err - } - err = globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result) + globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result) return err } diff --git a/internal/querynode/flow_graph_filter_dm_node.go b/internal/querynode/flow_graph_filter_dm_node.go index 3368e9e31f5f5b0b5ceeb3da1c382c960e947961..fbc8eedb5c82b00e868561b6e5971a9af3f78468 100644 --- a/internal/querynode/flow_graph_filter_dm_node.go +++ b/internal/querynode/flow_graph_filter_dm_node.go @@ -1,11 +1,9 @@ package querynode import ( - "context" "log" "math" - "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" @@ -34,28 +32,6 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { // TODO: add error handling } - var childs []opentracing.Span - tracer := opentracing.GlobalTracer() - if tracer != nil && msgStreamMsg != nil { - for _, msg := range msgStreamMsg.TsMessages() { - if msg.Type() == internalPb.MsgType_kInsert || msg.Type() == internalPb.MsgType_kSearch { - var child opentracing.Span - ctx := msg.GetMsgContext() - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan("pass filter node", - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan("pass filter node") - } - child.SetTag("hash keys", msg.HashKeys()) - child.SetTag("start time", msg.BeginTs()) - child.SetTag("end time", msg.EndTs()) - msg.SetMsgContext(opentracing.ContextWithSpan(ctx, child)) - childs = append(childs, child) - } - } - } - ddMsg, ok := (*in[1]).(*ddMsg) if !ok { log.Println("type assertion failed for ddMsg") @@ -70,20 +46,11 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { timestampMax: msgStreamMsg.TimestampMax(), }, } - for key, msg := range msgStreamMsg.TsMessages() { + for _, msg := range msgStreamMsg.TsMessages() { switch msg.Type() { case internalPb.MsgType_kInsert: - var ctx2 context.Context - if childs != nil { - if childs[key] != nil { - ctx2 = opentracing.ContextWithSpan(msg.GetMsgContext(), childs[key]) - } else { - ctx2 = context.Background() - } - } resMsg := fdmNode.filterInvalidInsertMessage(msg.(*msgstream.InsertMsg)) if resMsg != nil { - resMsg.SetMsgContext(ctx2) iMsg.insertMessages = append(iMsg.insertMessages, resMsg) } // case internalPb.MsgType_kDelete: @@ -95,10 +62,6 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { iMsg.gcRecord = ddMsg.gcRecord var res Msg = &iMsg - - for _, child := range childs { - child.Finish() - } return []*Msg{&res} } diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index 1ba6cefd70b0fbcd0601377f8afbb88df1b487fd..9a2c8ca1f11e34738dfbfae93eaeb8e715b70ef3 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -1,15 +1,11 @@ package querynode import ( - "context" "fmt" "log" "sync" - "github.com/opentracing/opentracing-go" - oplog "github.com/opentracing/opentracing-go/log" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) type insertNode struct { @@ -18,7 +14,6 @@ type insertNode struct { } type InsertData struct { - insertContext map[int64]context.Context insertIDs map[UniqueID][]UniqueID insertTimestamps map[UniqueID][]Timestamp insertRecords map[UniqueID][]*commonpb.Blob @@ -43,30 +38,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { // TODO: add error handling } - var childs []opentracing.Span - tracer := opentracing.GlobalTracer() - if tracer != nil && iMsg != nil { - for _, msg := range iMsg.insertMessages { - if msg.Type() == internalPb.MsgType_kInsert || msg.Type() == internalPb.MsgType_kSearch { - var child opentracing.Span - ctx := msg.GetMsgContext() - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan("pass filter node", - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan("pass filter node") - } - child.SetTag("hash keys", msg.HashKeys()) - child.SetTag("start time", msg.BeginTs()) - child.SetTag("end time", msg.EndTs()) - msg.SetMsgContext(opentracing.ContextWithSpan(ctx, child)) - childs = append(childs, child) - } - } - } - insertData := InsertData{ - insertContext: make(map[int64]context.Context), insertIDs: make(map[int64][]int64), insertTimestamps: make(map[int64][]uint64), insertRecords: make(map[int64][]*commonpb.Blob), @@ -75,7 +47,6 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { // 1. hash insertMessages to insertData for _, task := range iMsg.insertMessages { - insertData.insertContext[task.SegmentID] = task.GetMsgContext() insertData.insertIDs[task.SegmentID] = append(insertData.insertIDs[task.SegmentID], task.RowIDs...) insertData.insertTimestamps[task.SegmentID] = append(insertData.insertTimestamps[task.SegmentID], task.Timestamps...) insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...) @@ -114,7 +85,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { wg := sync.WaitGroup{} for segmentID := range insertData.insertRecords { wg.Add(1) - go iNode.insert(insertData.insertContext[segmentID], &insertData, segmentID, &wg) + go iNode.insert(&insertData, segmentID, &wg) } wg.Wait() @@ -122,21 +93,15 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { gcRecord: iMsg.gcRecord, timeRange: iMsg.timeRange, } - for _, child := range childs { - child.Finish() - } return []*Msg{&res} } -func (iNode *insertNode) insert(ctx context.Context, insertData *InsertData, segmentID int64, wg *sync.WaitGroup) { - span, _ := opentracing.StartSpanFromContext(ctx, "insert node insert") - defer span.Finish() +func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) { var targetSegment, err = iNode.replica.getSegmentByID(segmentID) if err != nil { log.Println("cannot find segment:", segmentID) // TODO: add error handling wg.Done() - span.LogFields(oplog.Error(err)) return } @@ -150,7 +115,6 @@ func (iNode *insertNode) insert(ctx context.Context, insertData *InsertData, seg log.Println(err) // TODO: add error handling wg.Done() - span.LogFields(oplog.Error(err)) return } diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index e2eb8c8cb962963f2f3fb14ae85132c49e7021ff..819d2b85546af76905ea3d694411a7ee2f17f54d 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -14,12 +14,6 @@ import "C" import ( "context" - "fmt" - "io" - - "github.com/opentracing/opentracing-go" - "github.com/uber/jaeger-client-go" - "github.com/uber/jaeger-client-go/config" ) type QueryNode struct { @@ -36,10 +30,6 @@ type QueryNode struct { searchService *searchService loadIndexService *loadIndexService statsService *statsService - - //opentracing - tracer opentracing.Tracer - closer io.Closer } func Init() { @@ -49,47 +39,31 @@ func Init() { func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode { ctx1, cancel := context.WithCancel(ctx) - q := &QueryNode{ - queryNodeLoopCtx: ctx1, - queryNodeLoopCancel: cancel, - QueryNodeID: queryNodeID, - - dataSyncService: nil, - metaService: nil, - searchService: nil, - statsService: nil, - } - - var err error - cfg := &config.Configuration{ - ServiceName: "tracing", - Sampler: &config.SamplerConfig{ - Type: "const", - Param: 1, - }, - Reporter: &config.ReporterConfig{ - LogSpans: true, - }, - } - q.tracer, q.closer, err = cfg.NewTracer(config.Logger(jaeger.StdLogger)) - if err != nil { - panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err)) - } - opentracing.SetGlobalTracer(q.tracer) segmentsMap := make(map[int64]*Segment) collections := make([]*Collection, 0) tSafe := newTSafe() - q.replica = &collectionReplicaImpl{ + var replica collectionReplica = &collectionReplicaImpl{ collections: collections, segments: segmentsMap, tSafe: tSafe, } - return q + return &QueryNode{ + queryNodeLoopCtx: ctx1, + queryNodeLoopCancel: cancel, + QueryNodeID: queryNodeID, + + replica: replica, + + dataSyncService: nil, + metaService: nil, + searchService: nil, + statsService: nil, + } } func (node *QueryNode) Start() error { @@ -126,8 +100,4 @@ func (node *QueryNode) Close() { if node.statsService != nil { node.statsService.close() } - if node.closer != nil { - node.closer.Close() - } - } diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index 43512b90192c2a49371ce43c1b292eb858c9e697..c2e0fa5d93b37f8d11be301dcc165fe523fe37ff 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -5,8 +5,6 @@ import ( "context" "errors" "fmt" - "github.com/opentracing/opentracing-go" - oplog "github.com/opentracing/opentracing-go/log" "log" "sync" @@ -135,27 +133,22 @@ func (ss *searchService) receiveSearchMsg() { } searchMsg := make([]msgstream.TsMsg, 0) serverTime := ss.getServiceableTime() - for i, msg := range msgPack.Msgs { - if msg.BeginTs() > serverTime { - ss.msgBuffer <- msg + for i := range msgPack.Msgs { + if msgPack.Msgs[i].BeginTs() > serverTime { + ss.msgBuffer <- msgPack.Msgs[i] continue } searchMsg = append(searchMsg, msgPack.Msgs[i]) } for _, msg := range searchMsg { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg") - msg.SetMsgContext(ctx) err := ss.search(msg) if err != nil { log.Println(err) - span.LogFields(oplog.Error(err)) err2 := ss.publishFailedSearchResult(msg, err.Error()) if err2 != nil { - span.LogFields(oplog.Error(err2)) log.Println("publish FailedSearchResult failed, error message: ", err2) } } - span.Finish() } log.Println("ReceiveSearchMsg, do search done, num of searchMsg = ", len(searchMsg)) } @@ -217,12 +210,8 @@ func (ss *searchService) doUnsolvedMsgSearch() { // TODO:: cache map[dsl]plan // TODO: reBatched search requests func (ss *searchService) search(msg msgstream.TsMsg) error { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "do search") - defer span.Finish() - msg.SetMsgContext(ctx) searchMsg, ok := msg.(*msgstream.SearchMsg) if !ok { - span.LogFields(oplog.Error(errors.New("invalid request type = " + string(msg.Type())))) return errors.New("invalid request type = " + string(msg.Type())) } @@ -231,27 +220,23 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { query := servicepb.Query{} err := proto.Unmarshal(queryBlob, &query) if err != nil { - span.LogFields(oplog.Error(err)) return errors.New("unmarshal query failed") } collectionName := query.CollectionName partitionTags := query.PartitionTags collection, err := ss.replica.getCollectionByName(collectionName) if err != nil { - span.LogFields(oplog.Error(err)) return err } collectionID := collection.ID() dsl := query.Dsl plan, err := createPlan(*collection, dsl) if err != nil { - span.LogFields(oplog.Error(err)) return err } placeHolderGroupBlob := query.PlaceholderGroup placeholderGroup, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) if err != nil { - span.LogFields(oplog.Error(err)) return err } placeholderGroups := make([]*PlaceholderGroup, 0) @@ -263,7 +248,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { for _, partitionTag := range partitionTags { hasPartition := ss.replica.hasPartition(collectionID, partitionTag) if !hasPartition { - span.LogFields(oplog.Error(errors.New("search Failed, invalid partitionTag"))) return errors.New("search Failed, invalid partitionTag") } } @@ -276,7 +260,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { searchResult, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) if err != nil { - span.LogFields(oplog.Error(err)) return err } searchResults = append(searchResults, searchResult) @@ -296,18 +279,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { Hits: nil, } searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{ - MsgCtx: searchMsg.MsgCtx, - HashValues: []uint32{uint32(searchMsg.ResultChannelID)}, - }, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, SearchResult: results, } err = ss.publishSearchResult(searchResultMsg) if err != nil { - span.LogFields(oplog.Error(err)) return err } - span.LogFields(oplog.String("publish search research success", "publish search research success")) return nil } @@ -315,22 +293,18 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { numSegment := int64(len(searchResults)) err2 := reduceSearchResults(searchResults, numSegment, inReduced) if err2 != nil { - span.LogFields(oplog.Error(err2)) return err2 } err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) if err != nil { - span.LogFields(oplog.Error(err)) return err } marshaledHits, err := reorganizeQueryResults(plan, placeholderGroups, searchResults, numSegment, inReduced) if err != nil { - span.LogFields(oplog.Error(err)) return err } hitsBlob, err := marshaledHits.getHitsBlob() if err != nil { - span.LogFields(oplog.Error(err)) return err } @@ -365,14 +339,11 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { MetricType: plan.getMetricType(), } searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{ - MsgCtx: searchMsg.MsgCtx, - HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, SearchResult: results, } err = ss.publishSearchResult(searchResultMsg) if err != nil { - span.LogFields(oplog.Error(err)) return err } } @@ -385,9 +356,6 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result") - defer span.Finish() - msg.SetMsgContext(ctx) fmt.Println("Public SearchResult", msg.HashKeys()) msgPack := msgstream.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, msg) @@ -396,9 +364,6 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { } func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg") - defer span.Finish() - msg.SetMsgContext(ctx) msgPack := msgstream.MsgPack{} searchMsg, ok := msg.(*msgstream.SearchMsg) if !ok { diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go index 26907eddb5c6e5d799093f31b5e10d94e5c5faa5..7c4271b23be5e31373966c3b64acfc395285916f 100644 --- a/internal/util/flowgraph/input_node.go +++ b/internal/util/flowgraph/input_node.go @@ -1,12 +1,8 @@ package flowgraph import ( - "fmt" "log" - "github.com/opentracing/opentracing-go" - "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" - "github.com/zilliztech/milvus-distributed/internal/msgstream" ) @@ -29,33 +25,11 @@ func (inNode *InputNode) InStream() *msgstream.MsgStream { } // empty input and return one *Msg -func (inNode *InputNode) Operate([]*Msg) []*Msg { +func (inNode *InputNode) Operate(in []*Msg) []*Msg { //fmt.Println("Do InputNode operation") msgPack := (*inNode.inStream).Consume() - var childs []opentracing.Span - tracer := opentracing.GlobalTracer() - if tracer != nil && msgPack != nil { - for _, msg := range msgPack.Msgs { - if msg.Type() == internalpb.MsgType_kInsert { - var child opentracing.Span - ctx := msg.GetMsgContext() - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan(fmt.Sprintf("through msg input node, start time = %d", msg.BeginTs()), - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan(fmt.Sprintf("through msg input node, start time = %d", msg.BeginTs())) - } - child.SetTag("hash keys", msg.HashKeys()) - child.SetTag("start time", msg.BeginTs()) - child.SetTag("end time", msg.EndTs()) - msg.SetMsgContext(opentracing.ContextWithSpan(ctx, child)) - childs = append(childs, child) - } - } - } - // TODO: add status if msgPack == nil { log.Println("null msg pack") @@ -68,10 +42,6 @@ func (inNode *InputNode) Operate([]*Msg) []*Msg { timestampMax: msgPack.EndTs, } - for _, child := range childs { - child.Finish() - } - return []*Msg{&msgStreamMsg} } diff --git a/internal/writenode/data_sync_service_test.go b/internal/writenode/data_sync_service_test.go index df82cec4d93fa212db9432a624a2d2e2cefd134b..48257803606e7284a6884ce3e67eaf18c145cd76 100644 --- a/internal/writenode/data_sync_service_test.go +++ b/internal/writenode/data_sync_service_test.go @@ -236,7 +236,7 @@ func newMeta() { }, { FieldID: 0, - Name: "RawID", + Name: "RowID", Description: "test collection filed 1", DataType: schemapb.DataType_INT64, TypeParams: []*commonpb.KeyValuePair{ diff --git a/internal/writenode/flow_graph_filter_dm_node.go b/internal/writenode/flow_graph_filter_dm_node.go index ea60f6c987027c466d614c233af68706960a7f4b..48ac781ddc4be96502fa79bbe5509e4a41ff9d46 100644 --- a/internal/writenode/flow_graph_filter_dm_node.go +++ b/internal/writenode/flow_graph_filter_dm_node.go @@ -1,12 +1,9 @@ package writenode import ( - "context" "log" "math" - "github.com/opentracing/opentracing-go" - "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" @@ -35,34 +32,11 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { // TODO: add error handling } - var childs []opentracing.Span - tracer := opentracing.GlobalTracer() - if tracer != nil { - for _, msg := range msgStreamMsg.TsMessages() { - if msg.Type() == internalPb.MsgType_kInsert { - var child opentracing.Span - ctx := msg.GetMsgContext() - if parent := opentracing.SpanFromContext(ctx); parent != nil { - child = tracer.StartSpan("pass filter node", - opentracing.FollowsFrom(parent.Context())) - } else { - child = tracer.StartSpan("pass filter node") - } - child.SetTag("hash keys", msg.HashKeys()) - child.SetTag("start time", msg.BeginTs()) - child.SetTag("end time", msg.EndTs()) - msg.SetMsgContext(opentracing.ContextWithSpan(ctx, child)) - childs = append(childs, child) - } - } - } - ddMsg, ok := (*in[1]).(*ddMsg) if !ok { log.Println("type assertion failed for ddMsg") // TODO: add error handling } - fdmNode.ddMsg = ddMsg var iMsg = insertMsg{ @@ -83,20 +57,11 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { } } - for key, msg := range msgStreamMsg.TsMessages() { + for _, msg := range msgStreamMsg.TsMessages() { switch msg.Type() { case internalPb.MsgType_kInsert: - var ctx2 context.Context - if childs != nil { - if childs[key] != nil { - ctx2 = opentracing.ContextWithSpan(msg.GetMsgContext(), childs[key]) - } else { - ctx2 = context.Background() - } - } resMsg := fdmNode.filterInvalidInsertMessage(msg.(*msgstream.InsertMsg)) if resMsg != nil { - resMsg.SetMsgContext(ctx2) iMsg.insertMessages = append(iMsg.insertMessages, resMsg) } // case internalPb.MsgType_kDelete: @@ -108,9 +73,6 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg { iMsg.gcRecord = ddMsg.gcRecord var res Msg = &iMsg - for _, child := range childs { - child.Finish() - } return []*Msg{&res} } diff --git a/internal/writenode/flow_graph_insert_buffer_node.go b/internal/writenode/flow_graph_insert_buffer_node.go index 05db320a4fbf8826948c172f48dba15750bc86b3..6cc7b7b36ffdaa342bcb3172abe49b03c65535ef 100644 --- a/internal/writenode/flow_graph_insert_buffer_node.go +++ b/internal/writenode/flow_graph_insert_buffer_node.go @@ -4,15 +4,11 @@ import ( "bytes" "context" "encoding/binary" - "fmt" "log" "path" "strconv" "unsafe" - "github.com/opentracing/opentracing-go" - oplog "github.com/opentracing/opentracing-go/log" - "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/kv" miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" @@ -100,23 +96,12 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg { // iMsg is insertMsg // 1. iMsg -> buffer for _, msg := range iMsg.insertMessages { - ctx := msg.GetMsgContext() - var span opentracing.Span - if ctx != nil { - span, _ = opentracing.StartSpanFromContext(ctx, fmt.Sprintf("insert buffer node, start time = %d", msg.BeginTs())) - } else { - span = opentracing.StartSpan(fmt.Sprintf("insert buffer node, start time = %d", msg.BeginTs())) - } - span.SetTag("hash keys", msg.HashKeys()) - span.SetTag("start time", msg.BeginTs()) - span.SetTag("end time", msg.EndTs()) if len(msg.RowIDs) != len(msg.Timestamps) || len(msg.RowIDs) != len(msg.RowData) { log.Println("Error: misaligned messages detected") continue } currentSegID := msg.GetSegmentID() collectionName := msg.GetCollectionName() - span.LogFields(oplog.Int("segment id", int(currentSegID))) idata, ok := ibNode.insertBuffer.insertData[currentSegID] if !ok { @@ -125,21 +110,6 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg { } } - // Timestamps - _, ok = idata.Data[1].(*storage.Int64FieldData) - if !ok { - idata.Data[1] = &storage.Int64FieldData{ - Data: []int64{}, - NumRows: 0, - } - } - tsData := idata.Data[1].(*storage.Int64FieldData) - for _, ts := range msg.Timestamps { - tsData.Data = append(tsData.Data, int64(ts)) - } - tsData.NumRows += len(msg.Timestamps) - span.LogFields(oplog.Int("tsData numRows", tsData.NumRows)) - // 1.1 Get CollectionMeta from etcd collection, err := ibNode.replica.getCollectionByName(collectionName) //collSchema, err := ibNode.getCollectionSchemaByName(collectionName) @@ -388,11 +358,9 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg { // 1.3 store in buffer ibNode.insertBuffer.insertData[currentSegID] = idata - span.LogFields(oplog.String("store in buffer", "store in buffer")) // 1.4 if full // 1.4.1 generate binlogs - span.LogFields(oplog.String("generate binlogs", "generate binlogs")) if ibNode.insertBuffer.full(currentSegID) { log.Printf(". Insert Buffer full, auto flushing (%v) rows of data...", ibNode.insertBuffer.size(currentSegID)) // partitionTag -> partitionID @@ -461,7 +429,6 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg { ibNode.outCh <- inBinlogMsg } } - span.Finish() } if len(iMsg.insertMessages) > 0 { diff --git a/internal/writenode/flush_sync_service_test.go b/internal/writenode/flush_sync_service_test.go index 4c62cfcacfd15bb7975fd8b442c4944bfcafadbd..7da80503d6191fcbe38aa299829c2c8e69181e16 100644 --- a/internal/writenode/flush_sync_service_test.go +++ b/internal/writenode/flush_sync_service_test.go @@ -89,6 +89,12 @@ func TestFlushSyncService_Start(t *testing.T) { time.Sleep(time.Millisecond * 50) } + for { + if len(ddChan) == 0 && len(insertChan) == 0 { + break + } + } + ret, err := fService.metaTable.getSegBinlogPaths(SegID) assert.NoError(t, err) assert.Equal(t, map[int64][]string{ diff --git a/internal/writenode/write_node.go b/internal/writenode/write_node.go index b08b94bf91c2533a7b201e1ea3b3574b39a102b6..d3ce6f84c70ac40d1fd52fc3a9f13df7d295d8ed 100644 --- a/internal/writenode/write_node.go +++ b/internal/writenode/write_node.go @@ -2,12 +2,6 @@ package writenode import ( "context" - "fmt" - "io" - - "github.com/opentracing/opentracing-go" - "github.com/uber/jaeger-client-go" - "github.com/uber/jaeger-client-go/config" ) type WriteNode struct { @@ -17,8 +11,6 @@ type WriteNode struct { flushSyncService *flushSyncService metaService *metaService replica collectionReplica - tracer opentracing.Tracer - closer io.Closer } func NewWriteNode(ctx context.Context, writeNodeID uint64) *WriteNode { @@ -46,22 +38,6 @@ func Init() { } func (node *WriteNode) Start() error { - cfg := &config.Configuration{ - ServiceName: "tracing", - Sampler: &config.SamplerConfig{ - Type: "const", - Param: 1, - }, - Reporter: &config.ReporterConfig{ - LogSpans: true, - }, - } - var err error - node.tracer, node.closer, err = cfg.NewTracer(config.Logger(jaeger.StdLogger)) - if err != nil { - panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err)) - } - opentracing.SetGlobalTracer(node.tracer) // TODO GOOSE Init Size?? chanSize := 100 @@ -85,9 +61,4 @@ func (node *WriteNode) Close() { if node.dataSyncService != nil { (*node.dataSyncService).close() } - - if node.closer != nil { - node.closer.Close() - } - } diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index efb68f38fcafe252235836d47faafa57faba765c..9bee462aee885e91e3b5735d9a8c2ef4f58a45df 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -4,5 +4,5 @@ numpy==1.18.1 pytest==5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4 -pymilvus-distributed==0.0.9 +pymilvus-distributed==0.0.10 sklearn==0.0 diff --git a/tests/python/test_index.py b/tests/python/test_index.py new file mode 100644 index 0000000000000000000000000000000000000000..687e6573a68070eaee449f7d0cf425cf971fb82b --- /dev/null +++ b/tests/python/test_index.py @@ -0,0 +1,824 @@ +import logging +import time +import pdb +import threading +from multiprocessing import Pool, Process +import numpy +import pytest +import sklearn.preprocessing +from .utils import * +from .constants import * + +uid = "test_index" +BUILD_TIMEOUT = 300 +field_name = default_float_vec_field_name +binary_field_name = default_binary_vec_field_name +query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_k, 1) +default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + + +@pytest.mark.skip("wait for debugging...") +class TestIndexBase: + @pytest.fixture( + scope="function", + params=gen_simple_index() + ) + def get_simple_index(self, request, connect): + logging.getLogger().info(request.param) + # TODO: Determine the service mode + # if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + @pytest.fixture( + scope="function", + params=[ + 1, + 10, + 1111 + ], + ) + def get_nq(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + connect.create_index(collection, field_name, get_simple_index) + + def test_create_index_on_field_not_existed(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index on field not existed + expected: error raised + ''' + tmp_field_name = gen_unique_str() + ids = connect.bulk_insert(collection, default_entities) + with pytest.raises(Exception) as e: + connect.create_index(collection, tmp_field_name, get_simple_index) + + @pytest.mark.level(2) + def test_create_index_on_field(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index on other field + expected: error raised + ''' + tmp_field_name = "int64" + ids = connect.bulk_insert(collection, default_entities) + with pytest.raises(Exception) as e: + connect.create_index(collection, tmp_field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_partition(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection, create partition, and add entities in it, create index + expected: return search success + ''' + connect.create_partition(collection, default_tag) + ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag) + connect.flush([collection]) + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_partition_flush(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection, create partition, and add entities in it, create index + expected: return search success + ''' + connect.create_partition(collection, default_tag) + ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag) + connect.flush() + connect.create_index(collection, field_name, get_simple_index) + + def test_create_index_without_connect(self, dis_connect, collection): + ''' + target: test create index without connection + method: create collection and add entities in it, check if added successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + dis_connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, collection, get_simple_index, get_nq): + ''' + target: test create index interface, search with more query vectors + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + connect.create_index(collection, field_name, get_simple_index) + logging.getLogger().info(connect.get_collection_stats(collection)) + nq = get_nq + index_type = get_simple_index["index_type"] + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, search_params=search_param) + res = connect.search(collection, query) + assert len(res) == nq + + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def test_create_index_multithread(self, connect, collection, args): + ''' + target: test create index interface with multiprocess + method: create collection and add entities in it, create index + expected: return search success + ''' + connect.bulk_insert(collection, default_entities) + + def build(connect): + connect.create_index(collection, field_name, default_index) + + threads_num = 8 + threads = [] + for i in range(threads_num): + m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"]) + t = MilvusTestThread(target=build, args=(m,)) + threads.append(t) + t.start() + time.sleep(0.2) + for t in threads: + t.join() + + def test_create_index_collection_not_existed(self, connect): + ''' + target: test create index interface when collection name not existed + method: create collection and add entities in it, create index + , make sure the collection name not in index + expected: create index failed + ''' + collection_name = gen_unique_str(uid) + with pytest.raises(Exception) as e: + connect.create_index(collection_name, field_name, default_index) + + @pytest.mark.level(2) + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_insert_flush(self, connect, collection, get_simple_index): + ''' + target: test create index + method: create collection and create index, add entities in it + expected: create index ok, and count correct + ''' + connect.create_index(collection, field_name, get_simple_index) + ids = connect.bulk_insert(collection, default_entities) + connect.flush([collection]) + count = connect.count_entities(collection) + assert count == default_nb + + @pytest.mark.level(2) + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly(self, connect, collection, get_simple_index): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + connect.create_index(collection, field_name, get_simple_index) + connect.create_index(collection, field_name, get_simple_index) + + # TODO: + @pytest.mark.level(2) + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly(self, connect, collection): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + ids = connect.bulk_insert(collection, default_entities) + indexs = [default_index, {"metric_type":"L2", "index_type": "FLAT", "params":{"nlist": 1024}}] + for index in indexs: + connect.create_index(collection, field_name, index) + stats = connect.get_collection_stats(collection) + # assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"] + assert stats["row_count"] == default_nb + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_ip(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_ip(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_partition_ip(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection, create partition, and add entities in it, create index + expected: return search success + ''' + connect.create_partition(collection, default_tag) + ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag) + connect.flush([collection]) + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_partition_flush_ip(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection, create partition, and add entities in it, create index + expected: return search success + ''' + connect.create_partition(collection, default_tag) + ids = connect.bulk_insert(collection, default_entities, partition_tag=default_tag) + connect.flush() + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors_ip(self, connect, collection, get_simple_index, get_nq): + ''' + target: test create index interface, search with more query vectors + method: create collection and add entities in it, create index + expected: return search success + ''' + metric_type = "IP" + ids = connect.bulk_insert(collection, default_entities) + get_simple_index["metric_type"] = metric_type + connect.create_index(collection, field_name, get_simple_index) + logging.getLogger().info(connect.get_collection_stats(collection)) + nq = get_nq + index_type = get_simple_index["index_type"] + search_param = get_search_param(index_type) + query, vecs = gen_query_vectors(field_name, default_entities, default_top_k, nq, metric_type=metric_type, search_params=search_param) + res = connect.search(collection, query) + assert len(res) == nq + + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def test_create_index_multithread_ip(self, connect, collection, args): + ''' + target: test create index interface with multiprocess + method: create collection and add entities in it, create index + expected: return search success + ''' + connect.bulk_insert(collection, default_entities) + + def build(connect): + default_index["metric_type"] = "IP" + connect.create_index(collection, field_name, default_index) + + threads_num = 8 + threads = [] + for i in range(threads_num): + m = get_milvus(host=args["ip"], port=args["port"], handler=args["handler"]) + t = MilvusTestThread(target=build, args=(m,)) + threads.append(t) + t.start() + time.sleep(0.2) + for t in threads: + t.join() + + def test_create_index_collection_not_existed_ip(self, connect, collection): + ''' + target: test create index interface when collection name not existed + method: create collection and add entities in it, create index + , make sure the collection name not in index + expected: return code not equals to 0, create index failed + ''' + collection_name = gen_unique_str(uid) + default_index["metric_type"] = "IP" + with pytest.raises(Exception) as e: + connect.create_index(collection_name, field_name, default_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_insert_ip(self, connect, collection, get_simple_index): + ''' + target: test create index interface when there is no vectors in collection, and does not affect the subsequent process + method: create collection and add no vectors in it, and then create index, add entities in it + expected: return code equals to 0 + ''' + default_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + ids = connect.bulk_insert(collection, default_entities) + connect.flush([collection]) + count = connect.count_entities(collection) + assert count == default_nb + + @pytest.mark.level(2) + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly_ip(self, connect, collection, get_simple_index): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + default_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + connect.create_index(collection, field_name, get_simple_index) + + # TODO: + @pytest.mark.level(2) + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly_ip(self, connect, collection): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + ids = connect.bulk_insert(collection, default_entities) + indexs = [default_index, {"index_type": "FLAT", "params": {"nlist": 1024}, "metric_type": "IP"}] + for index in indexs: + connect.create_index(collection, field_name, index) + stats = connect.get_collection_stats(collection) + # assert stats["partitions"][0]["segments"][0]["index_name"] == index["index_type"] + assert stats["row_count"] == default_nb + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, collection, get_simple_index): + ''' + target: test drop index interface + method: create collection and add entities in it, create index, call drop index + expected: return code 0, and default index param + ''' + # ids = connect.bulk_insert(collection, entities) + connect.create_index(collection, field_name, get_simple_index) + connect.drop_index(collection, field_name) + stats = connect.get_collection_stats(collection) + # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type + assert not stats["partitions"][0]["segments"] + + @pytest.mark.level(2) + def test_drop_index_repeatly(self, connect, collection, get_simple_index): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + connect.create_index(collection, field_name, get_simple_index) + stats = connect.get_collection_stats(collection) + connect.drop_index(collection, field_name) + connect.drop_index(collection, field_name) + stats = connect.get_collection_stats(collection) + logging.getLogger().info(stats) + # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type + assert not stats["partitions"][0]["segments"] + + @pytest.mark.level(2) + def test_drop_index_without_connect(self, dis_connect, collection): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + dis_connect.drop_index(collection, field_name) + + def test_drop_index_collection_not_existed(self, connect): + ''' + target: test drop index interface when collection name not existed + method: create collection and add entities in it, create index + , make sure the collection name not in index, and then drop it + expected: return code not equals to 0, drop index failed + ''' + collection_name = gen_unique_str(uid) + with pytest.raises(Exception) as e: + connect.drop_index(collection_name, field_name) + + def test_drop_index_collection_not_create(self, connect, collection): + ''' + target: test drop index interface when index not created + method: create collection and add entities in it, create index + expected: return code not equals to 0, drop index failed + ''' + # ids = connect.bulk_insert(collection, entities) + # no create index + connect.drop_index(collection, field_name) + + @pytest.mark.level(2) + def test_create_drop_index_repeatly(self, connect, collection, get_simple_index): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + for i in range(4): + connect.create_index(collection, field_name, get_simple_index) + connect.drop_index(collection, field_name) + + def test_drop_index_ip(self, connect, collection, get_simple_index): + ''' + target: test drop index interface + method: create collection and add entities in it, create index, call drop index + expected: return code 0, and default index param + ''' + # ids = connect.bulk_insert(collection, entities) + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + connect.drop_index(collection, field_name) + stats = connect.get_collection_stats(collection) + # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type + assert not stats["partitions"][0]["segments"] + + @pytest.mark.level(2) + def test_drop_index_repeatly_ip(self, connect, collection, get_simple_index): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + get_simple_index["metric_type"] = "IP" + connect.create_index(collection, field_name, get_simple_index) + stats = connect.get_collection_stats(collection) + connect.drop_index(collection, field_name) + connect.drop_index(collection, field_name) + stats = connect.get_collection_stats(collection) + logging.getLogger().info(stats) + # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type + assert not stats["partitions"][0]["segments"] + + @pytest.mark.level(2) + def test_drop_index_without_connect_ip(self, dis_connect, collection): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + dis_connect.drop_index(collection, field_name) + + def test_drop_index_collection_not_create_ip(self, connect, collection): + ''' + target: test drop index interface when index not created + method: create collection and add entities in it, create index + expected: return code not equals to 0, drop index failed + ''' + # ids = connect.bulk_insert(collection, entities) + # no create index + connect.drop_index(collection, field_name) + + @pytest.mark.level(2) + def test_create_drop_index_repeatly_ip(self, connect, collection, get_simple_index): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + get_simple_index["metric_type"] = "IP" + for i in range(4): + connect.create_index(collection, field_name, get_simple_index) + connect.drop_index(collection, field_name) + + +@pytest.mark.skip("binary") +class TestIndexBinary: + @pytest.fixture( + scope="function", + params=gen_simple_index() + ) + def get_simple_index(self, request, connect): + # TODO: Determine the service mode + # if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_jaccard_index(self, request, connect): + if request.param["index_type"] in binary_support(): + request.param["metric_type"] = "JACCARD" + return request.param + else: + pytest.skip("Skip index") + + @pytest.fixture( + scope="function", + params=gen_binary_index() + ) + def get_l2_index(self, request, connect): + request.param["metric_type"] = "L2" + return request.param + + @pytest.fixture( + scope="function", + params=[ + 1, + 10, + 1111 + ], + ) + def get_nq(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, binary_collection, get_jaccard_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(binary_collection, default_binary_entities) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_partition(self, connect, binary_collection, get_jaccard_index): + ''' + target: test create index interface + method: create collection, create partition, and add entities in it, create index + expected: return search success + ''' + connect.create_partition(binary_collection, default_tag) + ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq): + ''' + target: test create index interface, search with more query vectors + method: create collection and add entities in it, create index + expected: return search success + ''' + nq = get_nq + ids = connect.bulk_insert(binary_collection, default_binary_entities) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, nq, metric_type="JACCARD") + search_param = get_search_param(get_jaccard_index["index_type"], metric_type="JACCARD") + logging.getLogger().info(search_param) + res = connect.search(binary_collection, query, search_params=search_param) + assert len(res) == nq + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_invalid_metric_type_binary(self, connect, binary_collection, get_l2_index): + ''' + target: test create index interface with invalid metric type + method: add entitys into binary connection, flash, create index with L2 metric type. + expected: return create_index failure + ''' + # insert 6000 vectors + ids = connect.bulk_insert(binary_collection, default_binary_entities) + connect.flush([binary_collection]) + + if get_l2_index["index_type"] == "BIN_FLAT": + res = connect.create_index(binary_collection, binary_field_name, get_l2_index) + else: + with pytest.raises(Exception) as e: + res = connect.create_index(binary_collection, binary_field_name, get_l2_index) + + """ + ****************************************************************** + The following cases are used to test `get_index_info` function + ****************************************************************** + """ + + def test_get_index_info(self, connect, binary_collection, get_jaccard_index): + ''' + target: test describe index interface + method: create collection and add entities in it, create index, call describe index + expected: return code 0, and index instructure + ''' + ids = connect.bulk_insert(binary_collection, default_binary_entities) + connect.flush([binary_collection]) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + stats = connect.get_collection_stats(binary_collection) + assert stats["row_count"] == default_nb + for partition in stats["partitions"]: + segments = partition["segments"] + if segments: + for segment in segments: + for file in segment["files"]: + if "index_type" in file: + assert file["index_type"] == get_jaccard_index["index_type"] + + def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index): + ''' + target: test describe index interface + method: create collection, create partition and add entities in it, create index, call describe index + expected: return code 0, and index instructure + ''' + connect.create_partition(binary_collection, default_tag) + ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag) + connect.flush([binary_collection]) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + stats = connect.get_collection_stats(binary_collection) + logging.getLogger().info(stats) + assert stats["row_count"] == default_nb + assert len(stats["partitions"]) == 2 + for partition in stats["partitions"]: + segments = partition["segments"] + if segments: + for segment in segments: + for file in segment["files"]: + if "index_type" in file: + assert file["index_type"] == get_jaccard_index["index_type"] + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, binary_collection, get_jaccard_index): + ''' + target: test drop index interface + method: create collection and add entities in it, create index, call drop index + expected: return code 0, and default index param + ''' + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + stats = connect.get_collection_stats(binary_collection) + logging.getLogger().info(stats) + connect.drop_index(binary_collection, binary_field_name) + stats = connect.get_collection_stats(binary_collection) + # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type + assert not stats["partitions"][0]["segments"] + + def test_drop_index_partition(self, connect, binary_collection, get_jaccard_index): + ''' + target: test drop index interface + method: create collection, create partition and add entities in it, create index on collection, call drop collection index + expected: return code 0, and default index param + ''' + connect.create_partition(binary_collection, default_tag) + ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag) + connect.flush([binary_collection]) + connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + stats = connect.get_collection_stats(binary_collection) + connect.drop_index(binary_collection, binary_field_name) + stats = connect.get_collection_stats(binary_collection) + assert stats["row_count"] == default_nb + for partition in stats["partitions"]: + segments = partition["segments"] + if segments: + for segment in segments: + for file in segment["files"]: + if "index_type" not in file: + continue + if file["index_type"] == get_jaccard_index["index_type"]: + assert False + + +@pytest.mark.skip("wait for debugging...") +class TestIndexInvalid(object): + """ + Test create / describe / drop index interfaces with invalid collection names + """ + + @pytest.fixture( + scope="function", + params=gen_invalid_strs() + ) + def get_collection_name(self, request): + yield request.param + + @pytest.mark.level(1) + def test_create_index_with_invalid_collectionname(self, connect, get_collection_name): + collection_name = get_collection_name + with pytest.raises(Exception) as e: + connect.create_index(collection_name, field_name, default_index) + + @pytest.mark.level(1) + def test_drop_index_with_invalid_collectionname(self, connect, get_collection_name): + collection_name = get_collection_name + with pytest.raises(Exception) as e: + connect.drop_index(collection_name) + + @pytest.fixture( + scope="function", + params=gen_invalid_index() + ) + def get_index(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_index_with_invalid_index_params(self, connect, collection, get_index): + logging.getLogger().info(get_index) + with pytest.raises(Exception) as e: + connect.create_index(collection, field_name, get_simple_index) + + +@pytest.mark.skip("wait for debugging...") +class TestIndexAsync: + @pytest.fixture(scope="function", autouse=True) + def skip_http_check(self, args): + if args["handler"] == "HTTP": + pytest.skip("skip in http mode") + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.fixture( + scope="function", + params=gen_simple_index() + ) + def get_simple_index(self, request, connect): + # TODO: Determine the service mode + # if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return request.param + + def check_result(self, res): + logging.getLogger().info("In callback check search result") + logging.getLogger().info(res) + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + logging.getLogger().info("start index") + future = connect.create_index(collection, field_name, get_simple_index, _async=True) + logging.getLogger().info("before result") + res = future.result() + # TODO: + logging.getLogger().info(res) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_drop(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + logging.getLogger().info("start index") + future = connect.create_index(collection, field_name, get_simple_index, _async=True) + logging.getLogger().info("DROP") + connect.drop_collection(collection) + + @pytest.mark.level(2) + def test_create_index_with_invalid_collectionname(self, connect): + collection_name = " " + future = connect.create_index(collection_name, field_name, default_index, _async=True) + with pytest.raises(Exception) as e: + res = future.result() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_callback(self, connect, collection, get_simple_index): + ''' + target: test create index interface + method: create collection and add entities in it, create index + expected: return search success + ''' + ids = connect.bulk_insert(collection, default_entities) + logging.getLogger().info("start index") + future = connect.create_index(collection, field_name, get_simple_index, _async=True, + _callback=self.check_result) + logging.getLogger().info("before result") + res = future.result() + # TODO: + logging.getLogger().info(res)