diff --git a/.devcontainer.json b/.devcontainer.json index fff529de85e22edb4389a9a156e547576f704c11..a7368ab60cbd5825cf39f67cb0acdc358fa3af48 100644 --- a/.devcontainer.json +++ b/.devcontainer.json @@ -2,7 +2,7 @@ "name": "Milvus Distributed Dev Container Definition", "dockerComposeFile": ["./docker-compose-vscode.yml"], "service": "ubuntu", - "initializeCommand": "scripts/init_devcontainer.sh && docker-compose -f docker-compose-vscode.yml down || true && docker-compose -f docker-compose-vscode.yml pull --ignore-pull-failures ubuntu", + "initializeCommand": "scripts/init_devcontainer.sh && docker-compose -f docker-compose-vscode.yml down || true", "workspaceFolder": "/go/src/github.com/zilliztech/milvus-distributed", "shutdownAction": "stopCompose", "extensions": [ diff --git a/.gitignore b/.gitignore index 9c71802f8f82d8a89de42b91fa00f6de8f4e8c0c..6a495ae331957a89592553341ac6319ad71e6fa8 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,7 @@ pulsar/client-cpp/build/* # vscode generated files .vscode docker-compose-vscode.yml -docker-compose-vscode.yml.bak +docker-compose-vscode.yml.tmp cmake-build-debug cmake-build-release diff --git a/.jenkins/modules/Build/Build.groovy b/.jenkins/modules/Build/Build.groovy index 14fd0b9cdd619f108cc65f966543df9df6272716..1d35fd71ffd92f1cd8317688aa957eeb61092e26 100644 --- a/.jenkins/modules/Build/Build.groovy +++ b/.jenkins/modules/Build/Build.groovy @@ -1,22 +1,20 @@ timeout(time: 20, unit: 'MINUTES') { - dir ("scripts") { - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"Ccache artfactory files not found!\"' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $GO_CACHE_ARTFACTORY_URL --cache_dir=\$(go env GOCACHE) -f go-cache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"Go cache artfactory files not found!\"' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $THIRDPARTY_ARTFACTORY_URL --cache_dir=$CUSTOM_THIRDPARTY_PATH -f thirdparty-download.tar.gz || echo \"Thirdparty artfactory files not found!\"' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $GO_MOD_ARTFACTORY_URL --cache_dir=\$GOPATH/pkg/mod -f milvus-distributed-go-mod-cache.tar.gz || echo \"Go mod artfactory files not found!\"' - } + + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"Ccache artfactory files not found!\"' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/check_cache.sh -l $GO_CACHE_ARTFACTORY_URL --cache_dir=\$(go env GOCACHE) -f go-cache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"Go cache artfactory files not found!\"' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/check_cache.sh -l $THIRDPARTY_ARTFACTORY_URL --cache_dir=$CUSTOM_THIRDPARTY_PATH -f thirdparty-download.tar.gz || echo \"Thirdparty artfactory files not found!\"' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/check_cache.sh -l $GO_MOD_ARTFACTORY_URL --cache_dir=\$GOPATH/pkg/mod -f milvus-distributed-go-mod-\$(md5sum go.mod).tar.gz || echo \"Go mod artfactory files not found!\"' // Zero the cache statistics (but not the configuration options) sh 'ccache -z' sh '. ./scripts/before-install.sh && make install' sh 'echo -e "===\n=== ccache statistics after build\n===" && ccache --show-stats' - dir ("scripts") { - withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./update_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz -u ${USERNAME} -p ${PASSWORD}' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./update_cache.sh -l $GO_CACHE_ARTFACTORY_URL --cache_dir=\$(go env GOCACHE) -f go-cache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz -u ${USERNAME} -p ${PASSWORD}' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./update_cache.sh -l $THIRDPARTY_ARTFACTORY_URL --cache_dir=$CUSTOM_THIRDPARTY_PATH -f thirdparty-download.tar.gz -u ${USERNAME} -p ${PASSWORD}' - sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./update_cache.sh -l $GO_MOD_ARTFACTORY_URL --cache_dir=\$GOPATH/pkg/mod -f milvus-distributed-go-mod-cache.tar.gz -u ${USERNAME} -p ${PASSWORD}' - } + + withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/update_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz -u ${USERNAME} -p ${PASSWORD}' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/update_cache.sh -l $GO_CACHE_ARTFACTORY_URL --cache_dir=\$(go env GOCACHE) -f go-cache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz -u ${USERNAME} -p ${PASSWORD}' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/update_cache.sh -l $THIRDPARTY_ARTFACTORY_URL --cache_dir=$CUSTOM_THIRDPARTY_PATH -f thirdparty-download.tar.gz -u ${USERNAME} -p ${PASSWORD}' + sh '. ./scripts/before-install.sh && unset http_proxy && unset https_proxy && ./scripts/update_cache.sh -l $GO_MOD_ARTFACTORY_URL --cache_dir=\$GOPATH/pkg/mod -f milvus-distributed-go-mod-\$(md5sum go.mod).tar.gz -u ${USERNAME} -p ${PASSWORD}' } } diff --git a/.jenkins/modules/Regression/PythonRegression.groovy b/.jenkins/modules/Regression/PythonRegression.groovy index f726d92da14945da8b90bbf5ceb8e546860b4ad2..89a16f22eaad6dfdeb17c106af2d367aaac908d2 100644 --- a/.jenkins/modules/Regression/PythonRegression.groovy +++ b/.jenkins/modules/Regression/PythonRegression.groovy @@ -31,12 +31,14 @@ try { } catch(exc) { throw exc } finally { - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v pulsar' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v etcd' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v minio' dir ('build/docker/deploy') { + sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} ps | tail -n +3 | awk \'{ print $1 }\' | ( while read arg; do docker logs -t $arg > $arg.log 2>&1; done )' + archiveArtifacts artifacts: "**.log", allowEmptyArchive: true sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true' } + sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v pulsar' + sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v etcd' + sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} rm -f -s -v minio' dir ('build/docker/test') { sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run --rm regression /bin/bash -c "rm -rf __pycache__ && rm -rf .pytest_cache"' sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} down --rmi all -v || true' diff --git a/Makefile b/Makefile index 4d8a81973f8b43fbe9f65380cae1bedbc54f2073..2f758ea022b162d48ec386a5aa9f9583f1658d71 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ endif verifiers: getdeps cppcheck fmt static-check # Builds various components locally. -build-go: build-cpp +build-go: build-cpp get-rocksdb @echo "Building each component's binary to './bin'" @echo "Building master ..." @mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null diff --git a/build/docker/deploy/.env b/build/docker/deploy/.env index fe63196a782ecb9306dadc8667bba14cb49a91c0..5ce139901f79681aadadd359fc94b2f811b86de8 100644 --- a/build/docker/deploy/.env +++ b/build/docker/deploy/.env @@ -6,4 +6,4 @@ PULSAR_ADDRESS=pulsar://pulsar:6650 ETCD_ADDRESS=etcd:2379 MASTER_ADDRESS=master:53100 MINIO_ADDRESS=minio:9000 -INDEX_BUILDER_ADDRESS=indexbuider:31000 +INDEX_BUILDER_ADDRESS=indexbuilder:31000 diff --git a/build/docker/deploy/indexbuilder/DockerFile b/build/docker/deploy/indexbuilder/DockerFile index 4d40f27a981265ebcdc156941f1d82991657e6c0..26804b1e4bea655e60af58f98247de599e496a00 100644 --- a/build/docker/deploy/indexbuilder/DockerFile +++ b/build/docker/deploy/indexbuilder/DockerFile @@ -9,12 +9,31 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under the License. -FROM alpine:3.12.1 +FROM milvusdb/milvus-distributed-dev:amd64-ubuntu18.04-latest AS openblas + +#FROM alpine +FROM ubuntu:bionic-20200921 + +RUN apt-get update && apt-get install -y --no-install-recommends libtbb-dev gfortran + +#RUN echo "http://dl-cdn.alpinelinux.org/alpine/edge/testing" >> /etc/apk/repositories + +#RUN sed -i "s/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g" /etc/apk/repositories \ +# && apk add --no-cache libtbb gfortran + +COPY --from=openblas /usr/lib/libopenblas-r0.3.9.so /usr/lib/ + +RUN ln -s /usr/lib/libopenblas-r0.3.9.so /usr/lib/libopenblas.so.0 && \ + ln -s /usr/lib/libopenblas.so.0 /usr/lib/libopenblas.so COPY ./bin/indexbuilder /milvus-distributed/bin/indexbuilder COPY ./configs/ /milvus-distributed/configs/ +COPY ./lib/ /milvus-distributed/lib/ + +ENV LD_LIBRARY_PATH=/milvus-distributed/lib:$LD_LIBRARY_PATH:/usr/lib + WORKDIR /milvus-distributed/ CMD ["./bin/indexbuilder"] diff --git a/docs/developer_guides/chap03_index_service.md b/docs/developer_guides/chap03_index_service.md index 3edf51088d332e22257df6c2a18756556494560c..22f3f6671581768e125a097732c135b6284ae21f 100644 --- a/docs/developer_guides/chap03_index_service.md +++ b/docs/developer_guides/chap03_index_service.md @@ -13,8 +13,8 @@ ```go type Client interface { BuildIndex(req BuildIndexRequest) (BuildIndexResponse, error) - DescribeIndex(indexID UniqueID) (IndexDescription, error) - GetIndexFilePaths(indexID UniqueID) (IndexFilePaths, error) + GetIndexStates(req IndexStatesRequest) (IndexStatesResponse, error) + GetIndexFilePaths(req IndexFilePathRequest) (IndexFilePathsResponse, error) } ``` @@ -36,19 +36,23 @@ type BuildIndexResponse struct { -* *DescribeIndex* +* *GetIndexStates* ```go -enum IndexStatus { +type IndexStatesRequest struct { + IndexID UniqueID +} + +enum IndexState { NONE = 0; UNISSUED = 1; INPROGRESS = 2; FINISHED = 3; } -type IndexDescription struct { +type IndexStatesResponse struct { ID UniqueID - Status IndexStatus + State IndexState EnqueueTime time.Time ScheduleTime time.Time BuildCompleteTime time.Time @@ -60,7 +64,11 @@ type IndexDescription struct { * *GetIndexFilePaths* ```go -type IndexFilePaths struct { +type IndexFilePathRequest struct { + IndexID UniqueID +} + +type IndexFilePathsResponse struct { FilePaths []string } ``` @@ -74,7 +82,7 @@ type IndexNode interface { Start() error Close() error - SetTimeTickChannel(channelID string) error +// SetTimeTickChannel(channelID string) error SetStatsChannel(channelID string) error BuildIndex(req BuildIndexRequest) (BuildIndexResponse, error) diff --git a/docs/developer_guides/chap04_message_stream.md b/docs/developer_guides/chap04_message_stream.md index b97321a77c5444f1226c25eb0c00d87731083078..31fa4aa10d6c8e495677ce55111668e8b3a8f8d8 100644 --- a/docs/developer_guides/chap04_message_stream.md +++ b/docs/developer_guides/chap04_message_stream.md @@ -6,13 +6,13 @@ -#### 8.2 API +#### 8.2 Message Stream Service API ```go type Client interface { - CreateChannels(req CreateChannelRequest) (ChannelID []string, error) - DestoryChannels(channelID []string) error - DescribeChannels(channelID []string) (ChannelDescriptions, error) + CreateChannels(req CreateChannelRequest) (CreateChannelResponse, error) + DestoryChannels(req DestoryChannelRequest) error + DescribeChannels(req DescribeChannelRequest) (DescribeChannelResponse, error) } ``` @@ -28,7 +28,19 @@ type OwnerDescription struct { type CreateChannelRequest struct { OwnerDescription OwnerDescription - numChannels int + NumChannels int +} + +type CreateChannelResponse struct { + ChannelIDs []string +} +``` + +* *DestoryChannels* + +```go +type DestoryChannelRequest struct { + ChannelIDs []string } ``` @@ -37,11 +49,16 @@ type CreateChannelRequest struct { * *DescribeChannels* ```go +type DescribeChannelRequest struct { + ChannelIDs []string +} + type ChannelDescription struct { + ChannelID string Owner OwnerDescription } -type ChannelDescriptions struct { +type DescribeChannelResponse struct { Descriptions []ChannelDescription } ``` @@ -56,7 +73,7 @@ const { kInsert MsgType = 400 kDelete MsgType = 401 kSearch MsgType = 500 - KSearchResult MsgType = 1000 + kSearchResult MsgType = 1000 kSegStatistics MsgType = 1100 diff --git a/docs/developer_guides/chap05_proxy.md b/docs/developer_guides/chap05_proxy.md index 4ea764d7fe1d44afd7f95018ec6e58676bec5281..d9f15af375da0c891a0aa314871c0771bb229965 100644 --- a/docs/developer_guides/chap05_proxy.md +++ b/docs/developer_guides/chap05_proxy.md @@ -51,13 +51,16 @@ type ProxyNode interface { CreateCollection(req CreateCollectionRequest) error DropCollection(req DropCollectionRequest) error HasCollection(req HasCollectionRequest) (bool, error) - DescribeCollection(req DescribeCollectionRequest) (CollectionDescription, error) + DescribeCollection(req DescribeCollectionRequest) (DescribeCollectionResponse, error) + GetCollectionStatistics(req CollectionStatsRequest) (CollectionStatsResponse, error) ShowCollections(req ShowCollectionRequest) ([]string, error) + CreatePartition(req CreatePartitionRequest) error DropPartition(req DropPartitionRequest) error HasPartition(req HasPartitionRequest) (bool, error) - DescribePartition(req DescribePartitionRequest) (PartitionDescription, error) + GetPartitionStatistics(req PartitionStatsRequest) (PartitionStatsResponse, error) ShowPartitions(req ShowPartitionRequest) ([]string, error) + CreateIndex(req CreateIndexRequest) error DescribeIndex(DescribeIndexRequest) (DescribeIndexResponse, error) diff --git a/docs/developer_guides/chap06_master.md b/docs/developer_guides/chap06_master.md index 28c3b599de0fa43baf893702b9341dfba2a75cc1..1e55ebb925882799b6eee0c3e80ddd8da0d1e77c 100644 --- a/docs/developer_guides/chap06_master.md +++ b/docs/developer_guides/chap06_master.md @@ -12,15 +12,22 @@ type Client interface { CreateCollection(req CreateCollectionRequest) error DropCollection(req DropCollectionRequest) error HasCollection(req HasCollectionRequest) (bool, error) - DescribeCollection(req DescribeCollectionRequest) (CollectionDescription, error) - ShowCollections(req ShowCollectionRequest) ([]string, error) + DescribeCollection(req DescribeCollectionRequest) (CollectionDescriptionResponse, error) + GetCollectionStatistics(req CollectionStatsRequest) (CollectionStatsResponse, error) + ShowCollections(req ShowCollectionRequest) (ShowCollectionResponse, error) + CreatePartition(req CreatePartitionRequest) error DropPartition(req DropPartitionRequest) error HasPartition(req HasPartitionRequest) (bool, error) - DescribePartition(req DescribePartitionRequest) (PartitionDescription, error) - ShowPartitions(req ShowPartitionRequest) ([]string, error) + GetPartitionStatistics(req PartitionStatsRequest) (PartitionStatsResponse, error) + ShowPartitions(req ShowPartitionRequest) (ShowPartitionResponse, error) + + CreateIndex(req CreateIndexRequest) error + DescribeIndex(DescribeIndexRequest) (DescribeIndexResponse, error) + AllocTimestamp(req TsoRequest) (TsoResponse, error) AllocID(req IDRequest) (IDResponse, error) + GetDdChannel() (string, error) GetTimeTickChannel() (string, error) GetStatsChannel() (string, error) @@ -29,6 +36,81 @@ type Client interface { +* *DescribeCollection* + +```go +type DescribeCollectionRequest struct { + CollectionName string +} + +type CollectionDescriptionResponse struct { + Schema CollectionSchema +} +``` + +* *GetCollectionStatistics* + +```go +type CollectionStatsRequest struct { + CollectionName string +} + +type CollectionStatsResponse struct { + Stats []KeyValuePair +} +``` + +* *ShowCollections* + +```go +type ShowCollectionResponse struct { + CollectionNames []string +} +``` + +* *GetPartitionStatistics* + +```go +type PartitionStatsRequest struct { + CollectionName string + PartitionTag string +} + +type PartitionStatsResponse struct { + Stats []KeyValuePair +} +``` + +* *ShowPartitions* + +```go +type ShowPartitionResponse struct { + PartitionTags []string +} +``` + +* *DescribeIndex* + +```go +type DescribeIndexRequest struct { + CollectionName string + FieldName string +} + +type IndexDescription struct { + IndexName string + params []KeyValuePair +} + +type DescribeIndexResponse struct { + IndexDescriptions []IndexDescription +} +``` + + + + + #### 10.1 Interfaces (RPC) | RPC | description | diff --git a/docs/developer_guides/chap07_query_service.md b/docs/developer_guides/chap07_query_service.md index 39a9bb7a5d270772da086d34bddd59ad19293534..be89e767978cb83a659241d364db61c660f0b096 100644 --- a/docs/developer_guides/chap07_query_service.md +++ b/docs/developer_guides/chap07_query_service.md @@ -8,13 +8,15 @@ -#### 8.2 API +#### 8.2 Query Service API ```go type Client interface { RegisterNode(req NodeInfo) (InitParams, error) - DescribeService() (ServiceDescription, error) - DescribeParition(req DescribeParitionRequest) (PartitionDescriptions, error) + GetServiceStates() (ServiceStatesResponse, error) + ShowCollections(req ShowCollectionRequest) (ShowCollectionResponse, error) + ShowPartitions(req ShowPartitionRequest) (ShowPartitionResponse, error) + GetPartitionStates(req PartitionStatesRequest) (PartitionStatesResponse, error) LoadPartitions(req LoadPartitonRequest) error ReleasePartitions(req ReleasePartitionRequest) error CreateQueryChannel() (QueryChannels, error) @@ -33,7 +35,7 @@ type NodeInfo struct {} type InitParams struct {} ``` -* *DescribeService* +* *GetServiceStates* ```go type NodeState = int @@ -44,36 +46,51 @@ const ( ABNORMAL NodeState = 2 ) -type QueryNodeDescription struct { +//type ResourceCost struct { +// MemUsage int64 +// CpuUsage float32 +//} + +type QueryNodeStates struct { NodeState NodeState - ResourceCost ResourceCost + //ResourceCost ResourceCost } -type CollectionDescription struct { - ParitionIDs []UniqueID +type ServiceStatesResponse struct { + ServiceState NodeState } +``` -type DbDescription struct { - CollectionDescriptions []CollectionDescription +* *ShowCollections* + +```go +type ShowCollectionRequest struct { + DbID UniqueID } -type ServiceDescription struct { - DbDescriptions map[UniqueID]DbDescription - NodeDescriptions map[UniqueID]QueryNodeDescription +type ShowCollectionResponse struct { + CollectionIDs []UniqueID } ``` - - -* *DescribeParition* +* *ShowPartitions* ```go -type DescribeParitionRequest struct { +type ShowPartitionRequest struct { DbID UniqueID CollectionID UniqueID - partitionIDs []UniqueID } +type ShowPartitionResponse struct { + PartitionIDs []UniqueID +} +``` + + + +* *GetPartitionStates* + +```go type PartitionState = int const ( @@ -86,19 +103,19 @@ const ( IN_GPU PartitionState = 6 ) -type ResourceCost struct { - MemUsage int64 - CpuUsage float32 +type PartitionStatesRequest struct { + DbID UniqueID + CollectionID UniqueID + partitionIDs []UniqueID } -type PartitionDescription struct { - ID UniqueID +type PartitionStates struct { + PartitionID UniqueID State PartitionState - ResourceCost ResourceCost } -type PartitionDescriptions struct { - PartitionDescriptions []PartitionDescription +type PartitionStatesResponse struct { + States []PartitionStates } ``` diff --git a/docs/developer_guides/chap09_data_service.md b/docs/developer_guides/chap09_data_service.md index df8c789767d5d68fe9c958e428f963ceb80c2f7b..dc58c2a8194aa44c40c9709cd011550e215f408d 100644 --- a/docs/developer_guides/chap09_data_service.md +++ b/docs/developer_guides/chap09_data_service.md @@ -15,7 +15,10 @@ type Client interface { RegisterNode(req NodeInfo) (InitParams, error) AssignSegmentID(req AssignSegIDRequest) (AssignSegIDResponse, error) Flush(req FlushRequest) error + ShowSegments(req ShowSegmentRequest) (ShowSegmentResponse, error) + GetSegmentStates(req SegmentStatesRequest) (SegmentStatesResponse, error) GetInsertBinlogPaths(req InsertBinlogPathRequest) (InsertBinlogPathsResponse, error) + GetInsertChannels(req InsertChannelRequest) ([]string, error) GetTimeTickChannel() (string, error) GetStatsChannel() (string, error) @@ -73,15 +76,53 @@ type FlushRequest struct { +* *ShowSegments* + +```go +type ShowSegmentRequest struct { + CollectionID UniqueID + PartitionID UniqueID +} + +type ShowSegmentResponse struct { + SegmentIDs []UniqueID +} +``` + + + +* *GetSegmentStates* + +```go +enum SegmentState { + NONE = 0; + NOT_EXIST = 1; + GROWING = 2; + SEALED = 3; +} + +type SegmentStatesRequest struct { + SegmentID UniqueID +} + +type SegmentStatesResponse struct { + State SegmentState + CreateTime Timestamp + SealedTime Timestamp +} +``` + + + * *GetInsertBinlogPaths* ```go type InsertBinlogPathRequest struct { - segmentID UniqueID + SegmentID UniqueID } type InsertBinlogPathsResponse struct { - FieldIdxToPaths map[int32][]string + FieldIDToPaths map[int64][]string } ``` diff --git a/internal/allocator/allocator.go b/internal/allocator/allocator.go index 9cd7e935d718e19b84a7423c66cc8775c34c298b..a1aec5d8a234cc8db4388ec014b25986c17620c0 100644 --- a/internal/allocator/allocator.go +++ b/internal/allocator/allocator.go @@ -137,7 +137,10 @@ type Allocator struct { } func (ta *Allocator) Start() error { - err := ta.connectMaster() + connectMasterFn := func() error { + return ta.connectMaster() + } + err := Retry(10, time.Millisecond*200, connectMasterFn) if err != nil { panic("connect to master failed") } diff --git a/internal/allocator/retry.go b/internal/allocator/retry.go new file mode 100644 index 0000000000000000000000000000000000000000..89ab43cd00d3cffc036fb4a84f237ab8c7df5f89 --- /dev/null +++ b/internal/allocator/retry.go @@ -0,0 +1,40 @@ +package allocator + +import ( + "log" + "time" +) + +// Reference: https://blog.cyeam.com/golang/2018/08/27/retry + +func RetryImpl(attempts int, sleep time.Duration, fn func() error, maxSleepTime time.Duration) error { + if err := fn(); err != nil { + if s, ok := err.(InterruptError); ok { + return s.error + } + + if attempts--; attempts > 0 { + log.Printf("retry func error: %s. attempts #%d after %s.", err.Error(), attempts, sleep) + time.Sleep(sleep) + if sleep < maxSleepTime { + return RetryImpl(attempts, 2*sleep, fn, maxSleepTime) + } + return RetryImpl(attempts, maxSleepTime, fn, maxSleepTime) + } + return err + } + return nil +} + +func Retry(attempts int, sleep time.Duration, fn func() error) error { + maxSleepTime := time.Millisecond * 1000 + return RetryImpl(attempts, sleep, fn, maxSleepTime) +} + +type InterruptError struct { + error +} + +func NoRetryError(err error) InterruptError { + return InterruptError{err} +} diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index 4484d1e119bc4853e5cd49df6396047e48df8e8f..6a5db6b1c67fdc4d7685e65b1024d52e04dcb1a2 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -19,6 +19,7 @@ #include "utils/EasyAssert.h" #include "IndexWrapper.h" #include "indexbuilder/utils.h" +#include "index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h" namespace milvus { namespace indexbuilder { @@ -29,10 +30,10 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria parse(); - std::map<std::string, knowhere::IndexMode> mode_map = {{"CPU", knowhere::IndexMode::MODE_CPU}, - {"GPU", knowhere::IndexMode::MODE_GPU}}; - auto mode = get_config_by_name<std::string>("index_mode"); - auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU; + auto index_mode = get_index_mode(); + auto index_type = get_index_type(); + auto metric_type = get_metric_type(); + AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type); index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode); Assert(index_ != nullptr); @@ -154,6 +155,11 @@ IndexWrapper::dim() { void IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { auto index_type = get_index_type(); + auto index_mode = get_index_mode(); + config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS); + auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type); + AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!"); + if (is_in_need_id_list(index_type)) { PanicInfo(std::string(index_type) + " doesn't support build without ids yet!"); } @@ -173,6 +179,11 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { void IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) { Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end()); + auto index_type = get_index_type(); + auto index_mode = get_index_mode(); + config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS); + auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type); + AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!"); // index_->Train(dataset, config_); // index_->Add(dataset, config_); index_->BuildAll(dataset, config_); @@ -263,6 +274,31 @@ IndexWrapper::get_index_type() { return type.has_value() ? type.value() : knowhere::IndexEnum::INDEX_FAISS_IVFPQ; } +std::string +IndexWrapper::get_metric_type() { + auto type = get_config_by_name<std::string>(knowhere::Metric::TYPE); + if (type.has_value()) { + return type.value(); + } else { + auto index_type = get_index_type(); + if (is_in_bin_list(index_type)) { + return knowhere::Metric::JACCARD; + } else { + return knowhere::Metric::L2; + } + } +} + +knowhere::IndexMode +IndexWrapper::get_index_mode() { + static std::map<std::string, knowhere::IndexMode> mode_map = { + {"CPU", knowhere::IndexMode::MODE_CPU}, + {"GPU", knowhere::IndexMode::MODE_GPU}, + }; + auto mode = get_config_by_name<std::string>("index_mode"); + return mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU; +} + std::unique_ptr<IndexWrapper::QueryResult> IndexWrapper::Query(const knowhere::DatasetPtr& dataset) { return std::move(QueryImpl(dataset, config_)); diff --git a/internal/core/src/indexbuilder/IndexWrapper.h b/internal/core/src/indexbuilder/IndexWrapper.h index 16f2721712c655bff7b2e7d53a235e32ed1d6458..8bf2ed881c8a7e2ee4f86cdb88da7200e47ebfbe 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.h +++ b/internal/core/src/indexbuilder/IndexWrapper.h @@ -59,6 +59,12 @@ class IndexWrapper { std::string get_index_type(); + std::string + get_metric_type(); + + knowhere::IndexMode + get_index_mode(); + template <typename T> std::optional<T> get_config_by_name(std::string name); diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 217372700bb56faa63e24816d3b17ac791137fa0..e01d98989768acde93cab89f1ee8cb3fe0cc854f 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -35,7 +35,7 @@ CreateIndex(const char* serialized_type_params, const char* serialized_index_par *res_index = index.release(); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -59,7 +59,7 @@ BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* cIndex->BuildWithoutIds(ds); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -77,7 +77,7 @@ BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* ve cIndex->BuildWithoutIds(ds); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -94,7 +94,7 @@ SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size, char** res_buffer) { *res_buffer = binary.data; status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -109,7 +109,7 @@ LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, in cIndex->Load(serialized_sliced_blob_buffer, size); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -129,7 +129,7 @@ QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -153,7 +153,7 @@ QueryOnFloatVecIndexWithParam(CIndex index, status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -173,7 +173,7 @@ QueryOnBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors, C status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -197,7 +197,7 @@ QueryOnBinaryVecIndexWithParam(CIndex index, status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -213,7 +213,7 @@ CreateQueryResult(CIndexQueryResult* res) { status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -259,7 +259,7 @@ DeleteQueryResult(CIndexQueryResult res) { status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } diff --git a/internal/core/src/indexbuilder/utils.h b/internal/core/src/indexbuilder/utils.h index e1ed0804965cd9698a0a42cb182931d6c2a94cc3..6e1d89a2d9d82b585b46a358ff8cc47ed9dd510e 100644 --- a/internal/core/src/indexbuilder/utils.h +++ b/internal/core/src/indexbuilder/utils.h @@ -14,6 +14,7 @@ #include <vector> #include <string> #include <algorithm> +#include <tuple> #include "index/knowhere/knowhere/index/IndexType.h" @@ -57,6 +58,14 @@ Need_BuildAll_list() { return ret; } +std::vector<std::tuple<std::string, std::string>> +unsupported_index_combinations() { + static std::vector<std::tuple<std::string, std::string>> ret{ + std::make_tuple(std::string(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT), std::string(knowhere::Metric::L2)), + }; + return ret; +} + template <typename T> bool is_in_list(const T& t, std::function<std::vector<T>()> list_func) { @@ -84,5 +93,11 @@ is_in_need_id_list(const milvus::knowhere::IndexType& index_type) { return is_in_list<std::string>(index_type, Need_ID_List); } +bool +is_unsupported(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) { + return is_in_list<std::tuple<std::string, std::string>>(std::make_tuple(index_type, metric_type), + unsupported_index_combinations); +} + } // namespace indexbuilder } // namespace milvus diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 981251839821a29cdef880c38a12cf76a89b2b1f..94d111f3eb227d53cd85e3e32e00cd95c57c8f32 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -63,6 +63,8 @@ SearchOnSealed(const Schema& schema, Assert(record.test_readiness(field_offset)); auto indexing_entry = record.get_entry(field_offset); + std::cout << " SearchOnSealed, indexing_entry->metric:" << indexing_entry->metric_type_ << std::endl; + std::cout << " SearchOnSealed, query_info.metric_type_:" << query_info.metric_type_ << std::endl; Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_)); auto final = [&] { diff --git a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp index 263390a39a52831c0eb7f1bd71f8dc0cc1cecdb5..28b0730e6a3538510701abe11c5db5ad329478b0 100644 --- a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp @@ -61,6 +61,17 @@ InferIndexType(const Json& search_params) { PanicInfo("failed to infer index type"); } +static knowhere::IndexType +InferBinaryIndexType(const Json& search_params) { + namespace ip = knowhere::IndexParams; + namespace ie = knowhere::IndexEnum; + if (search_params.contains(ip::nprobe)) { + return ie::INDEX_FAISS_BIN_IVFFLAT; + } else { + return ie::INDEX_FAISS_BIN_IDMAP; + } +} + void VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) { auto& search_params = node.query_info_.search_params_; @@ -79,7 +90,18 @@ VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) { void VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) { - // TODO + auto& search_params = node.query_info_.search_params_; + auto inferred_type = InferBinaryIndexType(search_params); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(inferred_type); + auto index_mode = knowhere::IndexMode::MODE_CPU; + + // mock the api, topk will be passed from placeholder + auto params_copy = search_params; + params_copy[knowhere::meta::TOPK] = 10; + + // NOTE: the second parameter is not checked in knowhere, may be redundant + auto passed = adapter->CheckSearch(params_copy, inferred_type, index_mode); + AssertInfo(passed, "invalid search params"); } } // namespace milvus::query diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 5c8da297958b9665eccca37228e2b2c961983f92..9afd0b6262e35e4aae490cbb5f1ca17754a2d5d8 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -133,7 +133,7 @@ AppendBinaryIndex(CBinarySet c_binary_set, void* index_binary, int64_t index_siz auto binary_set = (milvus::knowhere::BinarySet*)c_binary_set; std::string index_key(c_index_key); uint8_t* index = (uint8_t*)index_binary; - std::shared_ptr<uint8_t[]> data(index); + std::shared_ptr<uint8_t[]> data(index, [](void*) {}); binary_set->Append(index_key, data, index_size); auto status = CStatus(); diff --git a/internal/indexbuilder/client/client.go b/internal/indexbuilder/client/client.go index 622fde9b1507251a67c170a015c106f56fe50e35..c18e02017b4f9f76a10f2232a6ab581dc4d513aa 100644 --- a/internal/indexbuilder/client/client.go +++ b/internal/indexbuilder/client/client.go @@ -2,6 +2,10 @@ package indexbuilderclient import ( "context" + "encoding/json" + "fmt" + "github.com/zilliztech/milvus-distributed/internal/errors" + "log" "time" "google.golang.org/grpc" @@ -54,20 +58,59 @@ func (c *Client) BuildIndexWithoutID(columnDataPaths []string, typeParams map[st if c.tryConnect() != nil { panic("BuildIndexWithoutID: failed to connect index builder") } + parseMap := func(mStr string) (map[string]string, error) { + buffer := make(map[string]interface{}) + err := json.Unmarshal([]byte(mStr), &buffer) + if err != nil { + return nil, errors.New("Unmarshal params failed") + } + ret := make(map[string]string) + for key, value := range buffer { + valueStr := fmt.Sprintf("%v", value) + ret[key] = valueStr + } + return ret, nil + } var typeParamsKV []*commonpb.KeyValuePair - for typeParam := range typeParams { - typeParamsKV = append(typeParamsKV, &commonpb.KeyValuePair{ - Key: typeParam, - Value: typeParams[typeParam], - }) + for key := range typeParams { + if key == "params" { + mapParams, err := parseMap(typeParams[key]) + if err != nil { + log.Println("parse params error: ", err) + } + for pk, pv := range mapParams { + typeParamsKV = append(typeParamsKV, &commonpb.KeyValuePair{ + Key: pk, + Value: pv, + }) + } + } else { + typeParamsKV = append(typeParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: typeParams[key], + }) + } } var indexParamsKV []*commonpb.KeyValuePair - for indexParam := range indexParams { - indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ - Key: indexParam, - Value: indexParams[indexParam], - }) + for key := range indexParams { + if key == "params" { + mapParams, err := parseMap(indexParams[key]) + if err != nil { + log.Println("parse params error: ", err) + } + for pk, pv := range mapParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: pk, + Value: pv, + }) + } + } else { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: indexParams[key], + }) + } } ctx := context.TODO() diff --git a/internal/indexbuilder/index.go b/internal/indexbuilder/index.go index fceec4bcf06083a4bff5a29a13a94e0e0864f29d..16439c4b299feafc829e581ee508b5076d8cbab7 100644 --- a/internal/indexbuilder/index.go +++ b/internal/indexbuilder/index.go @@ -14,6 +14,7 @@ package indexbuilder import "C" import ( "errors" + "fmt" "strconv" "unsafe" @@ -105,10 +106,13 @@ func (index *CIndex) BuildFloatVecIndexWithoutIds(vectors []float32) error { CStatus BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors); */ + fmt.Println("before BuildFloatVecIndexWithoutIds") status := C.BuildFloatVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0])) errorCode := status.error_code + fmt.Println("BuildFloatVecIndexWithoutIds error code: ", errorCode) if errorCode != 0 { errorMsg := C.GoString(status.error_msg) + fmt.Println("BuildFloatVecIndexWithoutIds error msg: ", errorMsg) defer C.free(unsafe.Pointer(status.error_msg)) return errors.New("BuildFloatVecIndexWithoutIds failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) } @@ -142,6 +146,8 @@ func (index *CIndex) Delete() error { } func NewCIndex(typeParams, indexParams map[string]string) (Index, error) { + fmt.Println("NNNNNNNNNNNNNNNNNNNNNNNNNNN typeParams: ", typeParams) + fmt.Println("NNNNNNNNNNNNNNNNNNNNNNNNNNN indexParams: ", indexParams) protoTypeParams := &indexcgopb.TypeParams{ Params: make([]*commonpb.KeyValuePair, 0), } @@ -168,10 +174,14 @@ func NewCIndex(typeParams, indexParams map[string]string) (Index, error) { CIndex* res_index); */ var indexPtr C.CIndex + fmt.Println("before create index ........................................") status := C.CreateIndex(typeParamsPointer, indexParamsPointer, &indexPtr) + fmt.Println("after create index ........................................") errorCode := status.error_code + fmt.Println("EEEEEEEEEEEEEEEEEEEEEEEEEE error code: ", errorCode) if errorCode != 0 { errorMsg := C.GoString(status.error_msg) + fmt.Println("EEEEEEEEEEEEEEEEEEEEEEEEEE error msg: ", errorMsg) defer C.free(unsafe.Pointer(status.error_msg)) return nil, errors.New(" failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) } diff --git a/internal/indexbuilder/indexbuilder.go b/internal/indexbuilder/indexbuilder.go index 4acfffc3d284665158d9f3c237682a5792257d50..712d17ef80103b48590e1bec01e8889ee0855d06 100644 --- a/internal/indexbuilder/indexbuilder.go +++ b/internal/indexbuilder/indexbuilder.go @@ -54,34 +54,48 @@ func CreateBuilder(ctx context.Context) (*Builder, error) { loopCancel: cancel, } - etcdAddress := Params.EtcdAddress - etcdClient, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) - if err != nil { - return nil, err + connectEtcdFn := func() error { + etcdAddress := Params.EtcdAddress + etcdClient, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) + if err != nil { + return err + } + etcdKV := etcdkv.NewEtcdKV(etcdClient, Params.MetaRootPath) + metakv, err := NewMetaTable(etcdKV) + if err != nil { + return err + } + b.metaTable = metakv + return nil } - etcdKV := etcdkv.NewEtcdKV(etcdClient, Params.MetaRootPath) - metakv, err := NewMetaTable(etcdKV) + err := Retry(10, time.Millisecond*200, connectEtcdFn) if err != nil { return nil, err } - b.metaTable = metakv idAllocator, err := allocator.NewIDAllocator(b.loopCtx, Params.MasterAddress) + b.idAllocator = idAllocator - option := &miniokv.Option{ - Address: Params.MinIOAddress, - AccessKeyID: Params.MinIOAccessKeyID, - SecretAccessKeyID: Params.MinIOSecretAccessKey, - UseSSL: Params.MinIOUseSSL, - BucketName: Params.MinioBucketName, - CreateBucket: true, + connectMinIOFn := func() error { + option := &miniokv.Option{ + Address: Params.MinIOAddress, + AccessKeyID: Params.MinIOAccessKeyID, + SecretAccessKeyID: Params.MinIOSecretAccessKey, + UseSSL: Params.MinIOUseSSL, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } + + b.kv, err = miniokv.NewMinIOKV(b.loopCtx, option) + if err != nil { + return err + } + return nil } - - b.kv, err = miniokv.NewMinIOKV(b.loopCtx, option) + err = Retry(10, time.Millisecond*200, connectMinIOFn) if err != nil { return nil, err } - b.idAllocator = idAllocator b.sched, err = NewTaskScheduler(b.loopCtx, b.idAllocator, b.kv, b.metaTable) if err != nil { diff --git a/internal/indexbuilder/retry.go b/internal/indexbuilder/retry.go new file mode 100644 index 0000000000000000000000000000000000000000..2cf4c6ecf576ddf095b85290b30bdb8f5b108047 --- /dev/null +++ b/internal/indexbuilder/retry.go @@ -0,0 +1,40 @@ +package indexbuilder + +import ( + "log" + "time" +) + +// Reference: https://blog.cyeam.com/golang/2018/08/27/retry + +func RetryImpl(attempts int, sleep time.Duration, fn func() error, maxSleepTime time.Duration) error { + if err := fn(); err != nil { + if s, ok := err.(InterruptError); ok { + return s.error + } + + if attempts--; attempts > 0 { + log.Printf("retry func error: %s. attempts #%d after %s.", err.Error(), attempts, sleep) + time.Sleep(sleep) + if sleep < maxSleepTime { + return RetryImpl(attempts, 2*sleep, fn, maxSleepTime) + } + return RetryImpl(attempts, maxSleepTime, fn, maxSleepTime) + } + return err + } + return nil +} + +func Retry(attempts int, sleep time.Duration, fn func() error) error { + maxSleepTime := time.Millisecond * 1000 + return RetryImpl(attempts, sleep, fn, maxSleepTime) +} + +type InterruptError struct { + error +} + +func NoRetryError(err error) InterruptError { + return InterruptError{err} +} diff --git a/internal/indexbuilder/task.go b/internal/indexbuilder/task.go index 15647e1563e53f3abab9976ca438ef334212fb34..73a5b77508b1e36b69a6f5cef440d6451f97b9c9 100644 --- a/internal/indexbuilder/task.go +++ b/internal/indexbuilder/task.go @@ -2,6 +2,7 @@ package indexbuilder import ( "context" + "fmt" "log" "strconv" "time" @@ -171,10 +172,12 @@ func (it *IndexBuildTask) Execute() error { indexParams[key] = value } + fmt.Println("before NewCIndex ..........................") it.index, err = NewCIndex(typeParams, indexParams) if err != nil { return err } + fmt.Println("after NewCIndex ..........................") getKeyByPathNaive := func(path string) string { // splitElements := strings.Split(path, "/") @@ -223,6 +226,7 @@ func (it *IndexBuildTask) Execute() error { for _, value := range insertData.Data { // TODO: BinaryVectorFieldData + fmt.Println("before build index ..................................") floatVectorFieldData, fOk := value.(*storage.FloatVectorFieldData) if fOk { err = it.index.BuildFloatVecIndexWithoutIds(floatVectorFieldData.Data) @@ -238,12 +242,15 @@ func (it *IndexBuildTask) Execute() error { return err } } + fmt.Println("after build index ..................................") if !fOk && !bOk { return errors.New("we expect FloatVectorFieldData or BinaryVectorFieldData") } + fmt.Println("before serialize .............................................") indexBlobs, err := it.index.Serialize() + fmt.Println("after serialize .............................................") if err != nil { return err } diff --git a/internal/kv/rocksdb/rocksdb_kv.go b/internal/kv/rocksdb/rocksdb_kv.go index 318f3078e84d6a3e3c6f21799e7cca26aaceb922..e986810684e2ee7d336e6da4e79bd18222d5b678 100644 --- a/internal/kv/rocksdb/rocksdb_kv.go +++ b/internal/kv/rocksdb/rocksdb_kv.go @@ -51,7 +51,17 @@ func (kv *RocksdbKV) Load(key string) (string, error) { } func (kv *RocksdbKV) LoadWithPrefix(key string) ([]string, []string, error) { + kv.readOptions.SetPrefixSameAsStart(true) + kv.db.Close() + kv.opts.SetPrefixExtractor(gorocksdb.NewFixedPrefixTransform(len(key))) + var err error + kv.db, err = gorocksdb.OpenDb(kv.opts, kv.GetName()) + if err != nil { + return nil, nil, err + } + iter := kv.db.NewIterator(kv.readOptions) + defer iter.Close() keys := make([]string, 0) values := make([]string, 0) iter.Seek([]byte(key)) @@ -97,7 +107,17 @@ func (kv *RocksdbKV) MultiSave(kvs map[string]string) error { } func (kv *RocksdbKV) RemoveWithPrefix(prefix string) error { + kv.readOptions.SetPrefixSameAsStart(true) + kv.db.Close() + kv.opts.SetPrefixExtractor(gorocksdb.NewFixedPrefixTransform(len(prefix))) + var err error + kv.db, err = gorocksdb.OpenDb(kv.opts, kv.GetName()) + if err != nil { + return err + } + iter := kv.db.NewIterator(kv.readOptions) + defer iter.Close() iter.Seek([]byte(prefix)) for ; iter.Valid(); iter.Next() { key := iter.Key() diff --git a/internal/kv/rocksdb/rocksdb_kv_test.go b/internal/kv/rocksdb/rocksdb_kv_test.go index 8898415a0594ac7a996bf85efd19b552923ae510..6513ee60e772271d5cb4a7e3b510d42649db6d92 100644 --- a/internal/kv/rocksdb/rocksdb_kv_test.go +++ b/internal/kv/rocksdb/rocksdb_kv_test.go @@ -52,3 +52,43 @@ func TestRocksdbKV(t *testing.T) { assert.Equal(t, vals[0], "123") assert.Equal(t, vals[1], "456") } + +func TestRocksdbKV_Prefix(t *testing.T) { + name := "/tmp/rocksdb" + rocksdbKV, err := rocksdbkv.NewRocksdbKV(name) + if err != nil { + panic(err) + } + + defer rocksdbKV.Close() + // Need to call RemoveWithPrefix + defer rocksdbKV.RemoveWithPrefix("") + + err = rocksdbKV.Save("abcd", "123") + assert.Nil(t, err) + + err = rocksdbKV.Save("abdd", "1234") + assert.Nil(t, err) + + err = rocksdbKV.Save("abddqqq", "1234555") + assert.Nil(t, err) + + keys, vals, err := rocksdbKV.LoadWithPrefix("abc") + assert.Nil(t, err) + assert.Equal(t, len(keys), 1) + assert.Equal(t, len(vals), 1) + //fmt.Println(keys) + //fmt.Println(vals) + + err = rocksdbKV.RemoveWithPrefix("abc") + assert.Nil(t, err) + val, err := rocksdbKV.Load("abc") + assert.Nil(t, err) + assert.Equal(t, len(val), 0) + val, err = rocksdbKV.Load("abdd") + assert.Nil(t, err) + assert.Equal(t, val, "1234") + val, err = rocksdbKV.Load("abddqqq") + assert.Nil(t, err) + assert.Equal(t, val, "1234555") +} diff --git a/internal/master/client.go b/internal/master/client.go index 88e44d8f70a02aba66217094d8f2327c6b06c814..a35151767620339a0d20a6b11e3f3ff88f822328 100644 --- a/internal/master/client.go +++ b/internal/master/client.go @@ -1,6 +1,7 @@ package master import ( + "sync" "time" buildindexclient "github.com/zilliztech/milvus-distributed/internal/indexbuilder/client" @@ -20,9 +21,12 @@ type MockWriteNodeClient struct { partitionTag string timestamp Timestamp collectionID UniqueID + lock sync.RWMutex } func (m *MockWriteNodeClient) FlushSegment(segmentID UniqueID, collectionID UniqueID, partitionTag string, timestamp Timestamp) error { + m.lock.Lock() + defer m.lock.Unlock() m.flushTime = time.Now() m.segmentID = segmentID m.collectionID = collectionID @@ -33,6 +37,8 @@ func (m *MockWriteNodeClient) FlushSegment(segmentID UniqueID, collectionID Uniq func (m *MockWriteNodeClient) DescribeSegment(segmentID UniqueID) (*writerclient.SegmentDescription, error) { now := time.Now() + m.lock.RLock() + defer m.lock.RUnlock() if now.Sub(m.flushTime).Seconds() > 2 { return &writerclient.SegmentDescription{ SegmentID: segmentID, diff --git a/internal/master/index_task.go b/internal/master/index_task.go index 444b777c9f36691a83b773526cfcec87296a3150..c5f320d0db7d01f1abe5de3bd3e67391288b4d38 100644 --- a/internal/master/index_task.go +++ b/internal/master/index_task.go @@ -24,11 +24,6 @@ func (task *createIndexTask) Ts() (Timestamp, error) { } func (task *createIndexTask) Execute() error { - // modify schema - if err := task.mt.UpdateFieldIndexParams(task.req.CollectionName, task.req.FieldName, task.req.ExtraParams); err != nil { - return err - } - // check if closed segment has the same index build history collMeta, err := task.mt.GetCollectionByName(task.req.CollectionName) if err != nil { return err @@ -44,6 +39,20 @@ func (task *createIndexTask) Execute() error { return fmt.Errorf("can not find field name %s", task.req.FieldName) } + // pre checks + isIndexable, err := task.mt.IsIndexable(collMeta.ID, fieldID) + if err != nil { + return err + } + if !isIndexable { + return fmt.Errorf("field %s is not vector", task.req.FieldName) + } + + // modify schema + if err := task.mt.UpdateFieldIndexParams(task.req.CollectionName, task.req.FieldName, task.req.ExtraParams); err != nil { + return err + } + // check if closed segment has the same index build history for _, segID := range collMeta.SegmentIDs { segMeta, err := task.mt.GetSegmentByID(segID) if err != nil { diff --git a/internal/master/master_test.go b/internal/master/master_test.go index a605e73aa76127c24918cdd3826fae0d0d186ad8..c8130d03a09990484699df7ea5cface56af91658 100644 --- a/internal/master/master_test.go +++ b/internal/master/master_test.go @@ -65,12 +65,8 @@ func refreshChannelNames() { } func receiveTimeTickMsg(stream *ms.MsgStream) bool { - for { - result := (*stream).Consume() - if len(result.Msgs) > 0 { - return true - } - } + result := (*stream).Consume() + return result != nil } func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { @@ -81,6 +77,14 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { return &msgPack } +func mockTimeTickBroadCast(msgStream ms.MsgStream, time Timestamp) error { + timeTick := [][2]uint64{ + {0, time}, + } + ttMsgPackForDD := getTimeTickMsgPack(timeTick) + return msgStream.Broadcast(ttMsgPackForDD) +} + func TestMaster(t *testing.T) { Init() refreshMasterAddress() @@ -534,10 +538,15 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + //consume msg - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() var consumeMsg ms.MsgStream = ddMs @@ -823,11 +832,16 @@ func TestMaster(t *testing.T) { assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) //consume msg - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = ddMs for { result := consumeMsg.Consume() @@ -850,19 +864,19 @@ func TestMaster(t *testing.T) { writeNodeStream.CreatePulsarProducers(Params.WriteNodeTimeTickChannelNames) writeNodeStream.Start() - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() - dMMs := ms.NewPulsarMsgStream(ctx, 1024) + dMMs := ms.NewPulsarTtMsgStream(ctx, 1024) dMMs.SetPulsarClient(pulsarAddr) - dMMs.CreatePulsarConsumers(Params.InsertChannelNames, "DMStream", ms.NewUnmarshalDispatcher(), 1024) + dMMs.CreatePulsarConsumers(Params.InsertChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) dMMs.Start() k2sMs := ms.NewPulsarMsgStream(ctx, 1024) k2sMs.SetPulsarClient(pulsarAddr) - k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, "K2SStream", ms.NewUnmarshalDispatcher(), 1024) + k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) k2sMs.Start() ttsoftmsgs := [][2]uint64{ @@ -897,10 +911,11 @@ func TestMaster(t *testing.T) { schemaBytes, err := proto.Marshal(&sch) assert.Nil(t, err) + ////////////////////////////CreateCollection//////////////////////// createCollectionReq := internalpb.CreateCollectionRequest{ MsgType: internalpb.MsgType_kCreateCollection, ReqID: 1, - Timestamp: uint64(time.Now().Unix()), + Timestamp: Timestamp(time.Now().Unix()), ProxyID: 1, Schema: &commonpb.Blob{Value: schemaBytes}, } @@ -908,6 +923,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = ddMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -942,6 +962,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var createPartitionMsg *ms.CreatePartitionMsg for { result := consumeMsg.Consume() @@ -976,6 +1001,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var dropPartitionMsg *ms.DropPartitionMsg for { result := consumeMsg.Consume() @@ -1006,6 +1036,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var dropCollectionMsg *ms.DropCollectionMsg for { result := consumeMsg.Consume() diff --git a/internal/master/scheduler_test.go b/internal/master/scheduler_test.go index a40f7584fa5cea0616de4a60494aef3e869519e5..f735a891c18544dfbecd4fc6ca5909457491688d 100644 --- a/internal/master/scheduler_test.go +++ b/internal/master/scheduler_test.go @@ -46,7 +46,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) { pulsarDDStream.Start() defer pulsarDDStream.Close() - consumeMs := ms.NewPulsarMsgStream(ctx, 1024) + consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024) consumeMs.SetPulsarClient(pulsarAddr) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.Start() @@ -96,6 +96,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) { err = createCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12)) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = consumeMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -118,7 +121,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) { dropCollectionReq := internalpb.DropCollectionRequest{ MsgType: internalpb.MsgType_kDropCollection, ReqID: 1, - Timestamp: 11, + Timestamp: 13, ProxyID: 1, CollectionName: &servicepb.CollectionName{CollectionName: sch.Name}, } @@ -138,6 +141,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) { err = dropCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14)) + assert.NoError(t, err) + var dropCollectionMsg *ms.DropCollectionMsg for { result := consumeMsg.Consume() @@ -184,7 +190,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { pulsarDDStream.Start() defer pulsarDDStream.Close() - consumeMs := ms.NewPulsarMsgStream(ctx, 1024) + consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024) consumeMs.SetPulsarClient(pulsarAddr) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.Start() @@ -234,6 +240,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = createCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12)) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = consumeMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -257,7 +266,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { createPartitionReq := internalpb.CreatePartitionRequest{ MsgType: internalpb.MsgType_kCreatePartition, ReqID: 1, - Timestamp: 11, + Timestamp: 13, ProxyID: 1, PartitionName: &servicepb.PartitionName{ CollectionName: sch.Name, @@ -279,6 +288,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = createPartitionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14)) + assert.NoError(t, err) + var createPartitionMsg *ms.CreatePartitionMsg for { result := consumeMsg.Consume() @@ -301,7 +313,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { dropPartitionReq := internalpb.DropPartitionRequest{ MsgType: internalpb.MsgType_kDropPartition, ReqID: 1, - Timestamp: 11, + Timestamp: 15, ProxyID: 1, PartitionName: &servicepb.PartitionName{ CollectionName: sch.Name, @@ -323,6 +335,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = dropPartitionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(16)) + assert.NoError(t, err) + var dropPartitionMsg *ms.DropPartitionMsg for { result := consumeMsg.Consume() diff --git a/internal/master/segment_manager_test.go b/internal/master/segment_manager_test.go index 28e16ec7eca16dacc512e598b769dfbf0dad4a34..69078054d541145a32e09aaf74c04dc42aa162bf 100644 --- a/internal/master/segment_manager_test.go +++ b/internal/master/segment_manager_test.go @@ -126,7 +126,7 @@ func TestSegmentManager_AssignSegment(t *testing.T) { } } - time.Sleep(time.Duration(Params.SegIDAssignExpiration)) + time.Sleep(time.Duration(Params.SegIDAssignExpiration) * time.Millisecond) timestamp, err := globalTsoAllocator() assert.Nil(t, err) err = mt.UpdateSegment(&pb.SegmentMeta{ @@ -156,3 +156,124 @@ func TestSegmentManager_AssignSegment(t *testing.T) { assert.Nil(t, err) assert.NotEqualValues(t, 0, segMeta.CloseTime) } + +func TestSegmentManager_SycnWritenode(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.TODO()) + defer cancelFunc() + + Init() + Params.TopicNum = 5 + Params.QueryNodeNum = 3 + Params.SegmentSize = 536870912 / 1024 / 1024 + Params.SegmentSizeFactor = 0.75 + Params.DefaultRecordSize = 1024 + Params.MinSegIDAssignCnt = 1048576 / 1024 + Params.SegIDAssignExpiration = 2000 + etcdAddress := Params.EtcdAddress + cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) + assert.Nil(t, err) + rootPath := "/test/root" + _, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix()) + assert.Nil(t, err) + + kvBase := etcdkv.NewEtcdKV(cli, rootPath) + defer kvBase.Close() + mt, err := NewMetaTable(kvBase) + assert.Nil(t, err) + + collName := "segmgr_test_coll" + var collID int64 = 1001 + partitionTag := "test_part" + schema := &schemapb.CollectionSchema{ + Name: collName, + Fields: []*schemapb.FieldSchema{ + {FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_INT32}, + {FieldID: 2, Name: "f2", IsPrimaryKey: false, DataType: schemapb.DataType_VECTOR_FLOAT, TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "128"}, + }}, + }, + } + err = mt.AddCollection(&pb.CollectionMeta{ + ID: collID, + Schema: schema, + CreateTime: 0, + SegmentIDs: []UniqueID{}, + PartitionTags: []string{}, + }) + assert.Nil(t, err) + err = mt.AddPartition(collID, partitionTag) + assert.Nil(t, err) + + var cnt int64 + globalIDAllocator := func() (UniqueID, error) { + val := atomic.AddInt64(&cnt, 1) + return val, nil + } + globalTsoAllocator := func() (Timestamp, error) { + val := atomic.AddInt64(&cnt, 1) + phy := time.Now().UnixNano() / int64(time.Millisecond) + ts := tsoutil.ComposeTS(phy, val) + return ts, nil + } + syncWriteChan := make(chan *msgstream.TimeTickMsg) + syncProxyChan := make(chan *msgstream.TimeTickMsg) + + segAssigner := NewSegmentAssigner(ctx, mt, globalTsoAllocator, syncProxyChan) + mockScheduler := &MockFlushScheduler{} + segManager, err := NewSegmentManager(ctx, mt, globalIDAllocator, globalTsoAllocator, syncWriteChan, mockScheduler, segAssigner) + assert.Nil(t, err) + + segManager.Start() + defer segManager.Close() + sizePerRecord, err := typeutil.EstimateSizePerRecord(schema) + assert.Nil(t, err) + maxCount := uint32(Params.SegmentSize * 1024 * 1024 / float64(sizePerRecord)) + + req := []*internalpb.SegIDRequest{ + {Count: maxCount, ChannelID: 1, CollName: collName, PartitionTag: partitionTag}, + {Count: maxCount, ChannelID: 2, CollName: collName, PartitionTag: partitionTag}, + {Count: maxCount, ChannelID: 3, CollName: collName, PartitionTag: partitionTag}, + } + assignSegment, err := segManager.AssignSegment(req) + assert.Nil(t, err) + timestamp, err := globalTsoAllocator() + assert.Nil(t, err) + for i := 0; i < len(assignSegment); i++ { + assert.EqualValues(t, maxCount, assignSegment[i].Count) + assert.EqualValues(t, i+1, assignSegment[i].ChannelID) + + err = mt.UpdateSegment(&pb.SegmentMeta{ + SegmentID: assignSegment[i].SegID, + CollectionID: collID, + PartitionTag: partitionTag, + ChannelStart: 0, + ChannelEnd: 1, + CloseTime: timestamp, + NumRows: int64(maxCount), + MemSize: 500000, + }) + assert.Nil(t, err) + } + + time.Sleep(time.Duration(Params.SegIDAssignExpiration) * time.Millisecond) + + timestamp, err = globalTsoAllocator() + assert.Nil(t, err) + tsMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: timestamp, EndTimestamp: timestamp, HashValues: []uint32{}, + }, + TimeTickMsg: internalpb.TimeTickMsg{ + MsgType: internalpb.MsgType_kTimeTick, + PeerID: 1, + Timestamp: timestamp, + }, + } + syncWriteChan <- tsMsg + time.Sleep(300 * time.Millisecond) + + segManager.mu.RLock() + defer segManager.mu.RUnlock() + status := segManager.collStatus[collID] + assert.Empty(t, status.segments) +} diff --git a/internal/master/time_snyc_producer_test.go b/internal/master/time_snyc_producer_test.go index e55b1ec42756e9014d738d4850a16b6eda75d945..3c0cc2e9aa02bd501df86aa296309d846121c021 100644 --- a/internal/master/time_snyc_producer_test.go +++ b/internal/master/time_snyc_producer_test.go @@ -58,6 +58,7 @@ func initTestPulsarStream(ctx context.Context, pulsarAddress string, return &input, &output } + func receiveMsg(stream *ms.MsgStream) []uint64 { receiveCount := 0 var results []uint64 diff --git a/internal/master/timesync.go b/internal/master/timesync.go index 49c388a8dd9249349777a99a3f54fd9b5a2ea472..79863f7ac8db2f6f585b16b81f355e54d3ec85d9 100644 --- a/internal/master/timesync.go +++ b/internal/master/timesync.go @@ -81,7 +81,7 @@ func (ttBarrier *softTimeTickBarrier) Start() error { // get a legal Timestamp ts := ttBarrier.minTimestamp() lastTt := atomic.LoadInt64(&(ttBarrier.lastTt)) - if ttBarrier.lastTt != 0 && ttBarrier.minTtInterval > ts-Timestamp(lastTt) { + if lastTt != 0 && ttBarrier.minTtInterval > ts-Timestamp(lastTt) { continue } ttBarrier.outTt <- ts diff --git a/internal/master/timesync_test.go b/internal/master/timesync_test.go index 59fb7b27621c8a523e023d3a10f747066eed3876..cab1c740273ab75cfd31f103746e35a75a5e4d4a 100644 --- a/internal/master/timesync_test.go +++ b/internal/master/timesync_test.go @@ -192,15 +192,15 @@ func TestTt_SoftTtBarrierStart(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { channels := []string{"SoftTtBarrierGetTimeTickClose"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) + //ttmsgs := [][2]int{ + // {1, 10}, + // {2, 20}, + // {3, 30}, + // {4, 40}, + // {1, 30}, + // {2, 30}, + //} + inStream, ttStream := producer(channels, nil) defer func() { (*inStream).Close() (*ttStream).Close() @@ -259,15 +259,15 @@ func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) { channels := []string{"SoftTtBarrierGetTimeTickCancel"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) + //ttmsgs := [][2]int{ + // {1, 10}, + // {2, 20}, + // {3, 30}, + // {4, 40}, + // {1, 30}, + // {2, 30}, + //} + inStream, ttStream := producer(channels, nil) defer func() { (*inStream).Close() (*ttStream).Close() diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 7fe0b58f1c0907d7ff12acaf7b365bd535449f5b..2cb4dd0c0c03c47cb3d042406a79c3db1d64c44b 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -183,6 +183,9 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { log.Printf("Warning: Receive empty msgPack") return nil } + if len(ms.producers) <= 0 { + return errors.New("nil producer in msg stream") + } reBucketValues := make([][]int32, len(tsMsgs)) for channelID, tsMsg := range tsMsgs { hashValues := tsMsg.HashKeys() @@ -475,15 +478,16 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { default: wg := sync.WaitGroup{} mu := sync.Mutex{} + findMapMutex := sync.RWMutex{} for i := 0; i < len(ms.consumers); i++ { if isChannelReady[i] { continue } wg.Add(1) - go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu) + go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu, &findMapMutex) } wg.Wait() - timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady) + timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex) if !ok || timeStamp <= ms.lastTimeStamp { log.Printf("All timeTick's timestamps are inconsistent") continue @@ -530,7 +534,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, eofMsgMap map[int]Timestamp, wg *sync.WaitGroup, - mu *sync.Mutex) { + mu *sync.Mutex, + findMapMutex *sync.RWMutex) { defer wg.Done() for { select { @@ -575,7 +580,9 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, } if headerMsg.MsgType == internalPb.MsgType_kTimeTick { + findMapMutex.Lock() eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp + findMapMutex.Unlock() return } mu.Lock() @@ -624,7 +631,7 @@ func (ms *InMemMsgStream) Chan() <- chan *MsgPack { } */ -func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, bool) { +func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool, mu *sync.RWMutex) (Timestamp, bool) { checkMap := make(map[Timestamp]int) var maxTime Timestamp = 0 for _, v := range msg { @@ -639,7 +646,10 @@ func checkTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool) (Timestamp, } return maxTime, true } - for i, v := range msg { + for i := range msg { + mu.RLock() + v := msg[i] + mu.Unlock() if v != maxTime { isChannelReady[i] = false } else { diff --git a/internal/msgstream/msgstream_test.go b/internal/msgstream/msgstream_test.go index c3b694c3c80e86e9ba8de305de2060a0420c81c8..55c95cebaf2012ddccc7b2e13ede6fc7028725c1 100644 --- a/internal/msgstream/msgstream_test.go +++ b/internal/msgstream/msgstream_test.go @@ -526,8 +526,6 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { log.Fatalf("broadcast error = %v", err) } receiveMsg(outputStream, len(msgPack1.Msgs)) - outputTtStream := (*outputStream).(*PulsarTtMsgStream) - fmt.Printf("timestamp = %v", outputTtStream.lastTimeStamp) (*inputStream).Close() (*outputStream).Close() } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 3993e6aae5033f53c3c001f474034a09efa33353..486e5a23fb161de8ec5b44410c330de4faa0f57a 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -210,6 +210,7 @@ func TestProxy_CreateCollection(t *testing.T) { wg.Add(1) go func(group *sync.WaitGroup) { defer group.Done() + println("collectionName:", collectionName) createCollection(t, collectionName) dropCollection(t, collectionName) }(&wg) @@ -488,7 +489,9 @@ func TestProxy_CreateIndex(t *testing.T) { go func(group *sync.WaitGroup) { defer group.Done() createCollection(t, collName) - createIndex(t, collName, fieldName) + if i%2 == 0 { + createIndex(t, collName, fieldName) + } dropCollection(t, collName) // dropIndex(t, collectionName, fieldName, indexName) }(&wg) @@ -510,7 +513,9 @@ func TestProxy_DescribeIndex(t *testing.T) { go func(group *sync.WaitGroup) { defer group.Done() createCollection(t, collName) - createIndex(t, collName, fieldName) + if i%2 == 0 { + createIndex(t, collName, fieldName) + } req := &servicepb.DescribeIndexRequest{ CollectionName: collName, FieldName: fieldName, @@ -539,7 +544,9 @@ func TestProxy_DescribeIndexProgress(t *testing.T) { go func(group *sync.WaitGroup) { defer group.Done() createCollection(t, collName) - createIndex(t, collName, fieldName) + if i%2 == 0 { + createIndex(t, collName, fieldName) + } req := &servicepb.DescribeIndexProgressRequest{ CollectionName: collName, FieldName: fieldName, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d28e5654a4c76499f91052aa475bd00c94d12ccc..bd2251a2d851b729548c1f92376fd4779730298c 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -419,9 +419,6 @@ func (qt *QueryTask) PreExecute() error { } } qt.MsgType = internalpb.MsgType_kSearch - if qt.query.PartitionTags == nil || len(qt.query.PartitionTags) <= 0 { - qt.query.PartitionTags = []string{Params.defaultPartitionTag()} - } queryBytes, err := proto.Marshal(qt.query) if err != nil { span.LogFields(oplog.Error(err)) @@ -502,7 +499,7 @@ func (qt *QueryTask) PostExecute() error { hits := make([][]*servicepb.Hits, 0) for _, partialSearchResult := range filterSearchResult { - if len(partialSearchResult.Hits) <= 0 { + if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { filterReason += "nq is zero\n" continue } @@ -541,7 +538,16 @@ func (qt *QueryTask) PostExecute() error { return nil } - topk := len(hits[0][0].IDs) + topk := 0 + getMax := func(a, b int) int { + if a > b { + return a + } + return b + } + for _, hit := range hits { + topk = getMax(topk, len(hit[0].IDs)) + } qt.result = &servicepb.QueryResult{ Status: &commonpb.Status{ ErrorCode: 0, @@ -559,14 +565,22 @@ func (qt *QueryTask) PostExecute() error { } for j := 0; j < topk; j++ { + valid := false choice, maxDistance := 0, minFloat32 for q, loc := range locs { // query num, the number of ways to merge + if loc >= len(hits[q][i].IDs) { + continue + } distance := hits[q][i].Scores[loc] - if distance > maxDistance { + if distance > maxDistance || (distance == maxDistance && choice != q) { choice = q maxDistance = distance + valid = true } } + if !valid { + break + } choiceOffset := locs[choice] // check if distance is valid, `invalid` here means very very big, // in this process, distance here is the smallest, so the rest of distance are all invalid diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index b0c7debb2936264dad9db1f5b69ddbb23f802845..c529e2299148157d0fe97d78882948a923671fcb 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -14,7 +14,7 @@ import ( type TaskQueue interface { utChan() <-chan int - utEmpty() bool + UTEmpty() bool utFull() bool addUnissuedTask(t task) error FrontUnissuedTask() task @@ -44,7 +44,9 @@ func (queue *BaseTaskQueue) utChan() <-chan int { return queue.utBufChan } -func (queue *BaseTaskQueue) utEmpty() bool { +func (queue *BaseTaskQueue) UTEmpty() bool { + queue.utLock.Lock() + defer queue.utLock.Unlock() return queue.unissuedTasks.Len() == 0 } @@ -316,7 +318,7 @@ func (sched *TaskScheduler) definitionLoop() { case <-sched.ctx.Done(): return case <-sched.DdQueue.utChan(): - if !sched.DdQueue.utEmpty() { + if !sched.DdQueue.UTEmpty() { t := sched.scheduleDdTask() sched.processTask(t, sched.DdQueue) } @@ -331,7 +333,7 @@ func (sched *TaskScheduler) manipulationLoop() { case <-sched.ctx.Done(): return case <-sched.DmQueue.utChan(): - if !sched.DmQueue.utEmpty() { + if !sched.DmQueue.UTEmpty() { t := sched.scheduleDmTask() go sched.processTask(t, sched.DmQueue) } @@ -348,7 +350,7 @@ func (sched *TaskScheduler) queryLoop() { return case <-sched.DqQueue.utChan(): log.Print("scheduler receive query request ...") - if !sched.DqQueue.utEmpty() { + if !sched.DqQueue.UTEmpty() { t := sched.scheduleDqTask() go sched.processTask(t, sched.DqQueue) } else { diff --git a/internal/proxy/timetick.go b/internal/proxy/timetick.go index 34b79ec26c42ea85ca7d5cb7213042b61179b2d9..f47960e3af8a936aec5e6151331f47de8f10e34a 100644 --- a/internal/proxy/timetick.go +++ b/internal/proxy/timetick.go @@ -24,12 +24,12 @@ type timeTick struct { tsoAllocator *allocator.TimestampAllocator tickMsgStream *msgstream.PulsarMsgStream - peerID UniqueID - wg sync.WaitGroup - ctx context.Context - cancel func() - timer *time.Ticker - + peerID UniqueID + wg sync.WaitGroup + ctx context.Context + cancel func() + timer *time.Ticker + tickLock sync.RWMutex checkFunc tickCheckFunc } @@ -85,6 +85,8 @@ func (tt *timeTick) tick() error { } else { //log.Printf("proxy send time tick message") } + tt.tickLock.Lock() + defer tt.tickLock.Unlock() tt.lastTick = tt.currentTick return nil } @@ -105,6 +107,8 @@ func (tt *timeTick) tickLoop() { } func (tt *timeTick) LastTick() Timestamp { + tt.tickLock.RLock() + defer tt.tickLock.RUnlock() return tt.lastTick } diff --git a/internal/querynode/collection_replica.go b/internal/querynode/collection_replica.go index 430fd55dd79acb93cc64c83305746f99868f32e1..c035069146627edf0c4af0282eb5e6d8b68ad10d 100644 --- a/internal/querynode/collection_replica.go +++ b/internal/querynode/collection_replica.go @@ -64,11 +64,11 @@ type collectionReplica interface { } type collectionReplicaImpl struct { - mu sync.RWMutex + tSafe tSafe + + mu sync.RWMutex // guards collections and segments collections []*Collection segments map[UniqueID]*Segment - - tSafe tSafe } //----------------------------------------------------------------------------------------------------- tSafe @@ -95,11 +95,10 @@ func (colReplica *collectionReplicaImpl) addCollection(collectionID UniqueID, sc } func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID) error { - collection, err := colReplica.getCollectionByID(collectionID) - colReplica.mu.Lock() defer colReplica.mu.Unlock() + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return err } @@ -128,6 +127,10 @@ func (colReplica *collectionReplicaImpl) getCollectionByID(collectionID UniqueID colReplica.mu.RLock() defer colReplica.mu.RUnlock() + return colReplica.getCollectionByIDPrivate(collectionID) +} + +func (colReplica *collectionReplicaImpl) getCollectionByIDPrivate(collectionID UniqueID) (*Collection, error) { for _, collection := range colReplica.collections { if collection.ID() == collectionID { return collection, nil @@ -164,26 +167,26 @@ func (colReplica *collectionReplicaImpl) hasCollection(collectionID UniqueID) bo //----------------------------------------------------------------------------------------------------- partition func (colReplica *collectionReplicaImpl) getPartitionNum(collectionID UniqueID) (int, error) { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.RLock() + defer colReplica.mu.RUnlock() + + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return -1, err } - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - return len(collection.partitions), nil } func (colReplica *collectionReplicaImpl) addPartition(collectionID UniqueID, partitionTag string) error { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.Lock() + defer colReplica.mu.Unlock() + + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return err } - colReplica.mu.Lock() - defer colReplica.mu.Unlock() - var newPartition = newPartition(partitionTag) *collection.Partitions() = append(*collection.Partitions(), newPartition) @@ -191,14 +194,18 @@ func (colReplica *collectionReplicaImpl) addPartition(collectionID UniqueID, par } func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID, partitionTag string) error { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.Lock() + defer colReplica.mu.Unlock() + + return colReplica.removePartitionPrivate(collectionID, partitionTag) +} + +func (colReplica *collectionReplicaImpl) removePartitionPrivate(collectionID UniqueID, partitionTag string) error { + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return err } - colReplica.mu.Lock() - defer colReplica.mu.Unlock() - var tmpPartitions = make([]*Partition, 0) for _, p := range *collection.Partitions() { if p.Tag() == partitionTag { @@ -215,6 +222,7 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID, return nil } +// deprecated func (colReplica *collectionReplicaImpl) addPartitionsByCollectionMeta(colMeta *etcdpb.CollectionMeta) error { if !colReplica.hasCollection(colMeta.ID) { err := errors.New("Cannot find collection, id = " + strconv.FormatInt(colMeta.ID, 10)) @@ -239,13 +247,14 @@ func (colReplica *collectionReplicaImpl) addPartitionsByCollectionMeta(colMeta * } func (colReplica *collectionReplicaImpl) removePartitionsByCollectionMeta(colMeta *etcdpb.CollectionMeta) error { - col, err := colReplica.getCollectionByID(colMeta.ID) + colReplica.mu.Lock() + defer colReplica.mu.Unlock() + + col, err := colReplica.getCollectionByIDPrivate(colMeta.ID) if err != nil { return err } - colReplica.mu.Lock() - pToDel := make([]string, 0) for _, partition := range col.partitions { hasPartition := false @@ -259,10 +268,8 @@ func (colReplica *collectionReplicaImpl) removePartitionsByCollectionMeta(colMet } } - colReplica.mu.Unlock() - for _, tag := range pToDel { - err := colReplica.removePartition(col.ID(), tag) + err := colReplica.removePartitionPrivate(col.ID(), tag) if err != nil { log.Println(err) } @@ -273,14 +280,18 @@ func (colReplica *collectionReplicaImpl) removePartitionsByCollectionMeta(colMet } func (colReplica *collectionReplicaImpl) getPartitionByTag(collectionID UniqueID, partitionTag string) (*Partition, error) { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.RLock() + defer colReplica.mu.RUnlock() + + return colReplica.getPartitionByTagPrivate(collectionID, partitionTag) +} + +func (colReplica *collectionReplicaImpl) getPartitionByTagPrivate(collectionID UniqueID, partitionTag string) (*Partition, error) { + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return nil, err } - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - for _, p := range *collection.Partitions() { if p.Tag() == partitionTag { return p, nil @@ -291,15 +302,15 @@ func (colReplica *collectionReplicaImpl) getPartitionByTag(collectionID UniqueID } func (colReplica *collectionReplicaImpl) hasPartition(collectionID UniqueID, partitionTag string) bool { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.RLock() + defer colReplica.mu.RUnlock() + + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { log.Println(err) return false } - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - for _, p := range *collection.Partitions() { if p.Tag() == partitionTag { return true @@ -332,30 +343,30 @@ func (colReplica *collectionReplicaImpl) getSegmentStatistics() []*internalpb.Se SegmentID: segmentID, MemorySize: currentMemSize, NumRows: segmentNumOfRows, - RecentlyModified: segment.recentlyModified, + RecentlyModified: segment.GetRecentlyModified(), } statisticData = append(statisticData, &stat) - segment.recentlyModified = false + segment.SetRecentlyModified(false) } return statisticData } func (colReplica *collectionReplicaImpl) addSegment(segmentID UniqueID, partitionTag string, collectionID UniqueID) error { - collection, err := colReplica.getCollectionByID(collectionID) + colReplica.mu.Lock() + defer colReplica.mu.Unlock() + + collection, err := colReplica.getCollectionByIDPrivate(collectionID) if err != nil { return err } - partition, err2 := colReplica.getPartitionByTag(collectionID, partitionTag) + partition, err2 := colReplica.getPartitionByTagPrivate(collectionID, partitionTag) if err2 != nil { return err2 } - colReplica.mu.Lock() - defer colReplica.mu.Unlock() - var newSegment = newSegment(collection, segmentID, partitionTag, collectionID) colReplica.segments[segmentID] = newSegment diff --git a/internal/querynode/load_index_service.go b/internal/querynode/load_index_service.go index 32b276bcf597ad2619d165ac3fb7a6cac4f87204..10857b4f3a4d313c7073e4841c97750e50055cfb 100644 --- a/internal/querynode/load_index_service.go +++ b/internal/querynode/load_index_service.go @@ -100,8 +100,18 @@ func (lis *loadIndexService) start() { continue } // 1. use msg's index paths to get index bytes - var indexBuffer [][]byte + fmt.Println("start load index") var err error + ok, err = lis.checkIndexReady(indexMsg) + if err != nil { + log.Println(err) + continue + } + if ok { + continue + } + + var indexBuffer [][]byte fn := func() error { indexBuffer, err = lis.loadIndex(indexMsg.IndexPaths) if err != nil { @@ -138,6 +148,13 @@ func (lis *loadIndexService) start() { } } +func (lis *loadIndexService) close() { + if lis.loadIndexMsgStream != nil { + lis.loadIndexMsgStream.Close() + } + lis.cancel() +} + func (lis *loadIndexService) printIndexParams(index []*commonpb.KeyValuePair) { fmt.Println("=================================================") for i := 0; i < len(index); i++ { @@ -190,6 +207,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd fieldStatsKey := lis.fieldsStatsIDs2Key(targetSegment.collectionID, indexMsg.FieldID) _, ok := lis.fieldIndexes[fieldStatsKey] newIndexParams := indexMsg.IndexParams + // sort index params by key sort.Slice(newIndexParams, func(i, j int) bool { return newIndexParams[i].Key < newIndexParams[j].Key }) if !ok { @@ -215,6 +233,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd }) } } + targetSegment.setIndexParam(indexMsg.FieldID, indexMsg.IndexParams) return nil } @@ -286,3 +305,15 @@ func (lis *loadIndexService) sendQueryNodeStats() error { fmt.Println("sent field stats") return nil } + +func (lis *loadIndexService) checkIndexReady(loadIndexMsg *msgstream.LoadIndexMsg) (bool, error) { + segment, err := lis.replica.getSegmentByID(loadIndexMsg.SegmentID) + if err != nil { + return false, err + } + if !segment.matchIndexParam(loadIndexMsg.FieldID, loadIndexMsg.IndexParams) { + return false, nil + } + return true, nil + +} diff --git a/internal/querynode/load_index_service_test.go b/internal/querynode/load_index_service_test.go index b214b408242fcee05c61afd8381e7d6d71b1cd21..852d976366c00285f2953f7201e7635cb47f8ad2 100644 --- a/internal/querynode/load_index_service_test.go +++ b/internal/querynode/load_index_service_test.go @@ -22,26 +22,29 @@ import ( "github.com/zilliztech/milvus-distributed/internal/querynode/client" ) -func TestLoadIndexService(t *testing.T) { +func TestLoadIndexService_FloatVector(t *testing.T) { node := newQueryNode() collectionID := rand.Int63n(1000000) segmentID := rand.Int63n(1000000) initTestMeta(t, node, "collection0", collectionID, segmentID) // loadIndexService and statsService + suffix := "-test-search" + strconv.FormatInt(rand.Int63n(1000000), 10) oldSearchChannelNames := Params.SearchChannelNames - var newSearchChannelNames []string - for _, channel := range oldSearchChannelNames { - newSearchChannelNames = append(newSearchChannelNames, channel+"new") - } + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) Params.SearchChannelNames = newSearchChannelNames oldSearchResultChannelNames := Params.SearchChannelNames - var newSearchResultChannelNames []string - for _, channel := range oldSearchResultChannelNames { - newSearchResultChannelNames = append(newSearchResultChannelNames, channel+"new") - } + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] go node.Start() //generate insert data @@ -328,9 +331,319 @@ func TestLoadIndexService(t *testing.T) { } Params.SearchChannelNames = oldSearchChannelNames Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName fmt.Println("loadIndex floatVector test Done!") defer assert.Equal(t, findFiledStats, true) <-node.queryNodeLoopCtx.Done() node.Close() } + +func TestLoadIndexService_BinaryVector(t *testing.T) { + node := newQueryNode() + collectionID := rand.Int63n(1000000) + segmentID := rand.Int63n(1000000) + initTestMeta(t, node, "collection0", collectionID, segmentID, true) + + // loadIndexService and statsService + suffix := "-test-search-binary" + strconv.FormatInt(rand.Int63n(1000000), 10) + oldSearchChannelNames := Params.SearchChannelNames + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) + Params.SearchChannelNames = newSearchChannelNames + + oldSearchResultChannelNames := Params.SearchChannelNames + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) + Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] + go node.Start() + + const msgLength = 1000 + const receiveBufSize = 1024 + const DIM = 128 + + // generator index data + var indexRowData []byte + for n := 0; n < msgLength; n++ { + for i := 0; i < DIM/8; i++ { + indexRowData = append(indexRowData, byte(rand.Intn(8))) + } + } + + //generator insert data + var insertRowBlob []*commonpb.Blob + var timestamps []uint64 + var rowIDs []int64 + var hashValues []uint32 + offset := 0 + for n := 0; n < msgLength; n++ { + rowData := make([]byte, 0) + rowData = append(rowData, indexRowData[offset:offset+(DIM/8)]...) + offset += DIM / 8 + age := make([]byte, 4) + binary.LittleEndian.PutUint32(age, 1) + rowData = append(rowData, age...) + blob := &commonpb.Blob{ + Value: rowData, + } + insertRowBlob = append(insertRowBlob, blob) + timestamps = append(timestamps, uint64(n)) + rowIDs = append(rowIDs, int64(n)) + hashValues = append(hashValues, uint32(n)) + } + + var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: hashValues, + }, + InsertRequest: internalpb.InsertRequest{ + MsgType: internalpb.MsgType_kInsert, + ReqID: 0, + CollectionName: "collection0", + PartitionTag: "default", + SegmentID: segmentID, + ChannelID: int64(0), + ProxyID: int64(0), + Timestamps: timestamps, + RowIDs: rowIDs, + RowData: insertRowBlob, + }, + } + insertMsgPack := msgstream.MsgPack{ + BeginTs: 0, + EndTs: math.MaxUint64, + Msgs: []msgstream.TsMsg{insertMsg}, + } + + // generate timeTick + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + }, + TimeTickMsg: internalpb.TimeTickMsg{ + MsgType: internalpb.MsgType_kTimeTick, + PeerID: UniqueID(0), + Timestamp: math.MaxUint64, + }, + } + timeTickMsgPack := &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{timeTickMsg}, + } + + // pulsar produce + insertChannels := Params.InsertChannelNames + ddChannels := Params.DDChannelNames + + insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream.SetPulsarClient(Params.PulsarAddress) + insertStream.CreatePulsarProducers(insertChannels) + ddStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + ddStream.SetPulsarClient(Params.PulsarAddress) + ddStream.CreatePulsarProducers(ddChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + var ddMsgStream msgstream.MsgStream = ddStream + ddMsgStream.Start() + + err := insertMsgStream.Produce(&insertMsgPack) + assert.NoError(t, err) + err = insertMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + err = ddMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + + //generate search data and send search msg + searchRowData := indexRowData[42*(DIM/8) : 43*(DIM/8)] + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"JACCARD\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + placeholderValue := servicepb.PlaceholderValue{ + Tag: "$0", + Type: servicepb.PlaceholderType_VECTOR_BINARY, + Values: [][]byte{searchRowData}, + } + placeholderGroup := servicepb.PlaceholderGroup{ + Placeholders: []*servicepb.PlaceholderValue{&placeholderValue}, + } + placeGroupByte, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") + } + query := servicepb.Query{ + CollectionName: "collection0", + PartitionTags: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + } + queryByte, err := proto.Marshal(&query) + if err != nil { + log.Print("marshal query failed") + } + blob := commonpb.Blob{ + Value: queryByte, + } + fn := func(n int64) *msgstream.MsgPack { + searchMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + SearchRequest: internalpb.SearchRequest{ + MsgType: internalpb.MsgType_kSearch, + ReqID: n, + ProxyID: int64(1), + Timestamp: uint64(msgLength), + ResultChannelID: int64(0), + Query: &blob, + }, + } + return &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{searchMsg}, + } + } + searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream.SetPulsarClient(Params.PulsarAddress) + searchStream.CreatePulsarProducers(newSearchChannelNames) + searchStream.Start() + err = searchStream.Produce(fn(1)) + assert.NoError(t, err) + + //get search result + searchResultStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchResultStream.SetPulsarClient(Params.PulsarAddress) + unmarshalDispatcher := msgstream.NewUnmarshalDispatcher() + searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult2", unmarshalDispatcher, receiveBufSize) + searchResultStream.Start() + searchResult := searchResultStream.Consume() + assert.NotNil(t, searchResult) + unMarshaledHit := servicepb.Hits{} + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + + // gen load index message pack + indexParams := make(map[string]string) + indexParams["index_type"] = "BIN_IVF_FLAT" + indexParams["index_mode"] = "cpu" + indexParams["dim"] = "128" + indexParams["k"] = "10" + indexParams["nlist"] = "100" + indexParams["nprobe"] = "10" + indexParams["m"] = "4" + indexParams["nbits"] = "8" + indexParams["metric_type"] = "JACCARD" + indexParams["SLICE_SIZE"] = "4" + + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + // generator index + typeParams := make(map[string]string) + typeParams["dim"] = "128" + index, err := indexbuilder.NewCIndex(typeParams, indexParams) + assert.Nil(t, err) + err = index.BuildBinaryVecIndexWithoutIds(indexRowData) + assert.Equal(t, err, nil) + + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } + + minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) + assert.Equal(t, err, nil) + //save index to minio + binarySet, err := index.Serialize() + assert.Equal(t, err, nil) + indexPaths := make([]string, 0) + for _, index := range binarySet { + path := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, path) + minioKV.Save(path, string(index.Value)) + } + + //test index search result + indexResult, err := index.QueryOnBinaryVecIndexWithParam(searchRowData, indexParams) + assert.Equal(t, err, nil) + + // create loadIndexClient + fieldID := UniqueID(100) + loadIndexChannelNames := Params.LoadIndexChannelNames + client := client.NewLoadIndexClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) + client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) + + // init message stream consumer and do checks + statsMs := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) + statsMs.SetPulsarClient(Params.PulsarAddress) + statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) + statsMs.Start() + + findFiledStats := false + for { + receiveMsg := msgstream.MsgStream(statsMs).Consume() + assert.NotNil(t, receiveMsg) + assert.NotEqual(t, len(receiveMsg.Msgs), 0) + + for _, msg := range receiveMsg.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { + continue + } + findFiledStats = true + assert.Equal(t, ok, true) + assert.Equal(t, len(statsMsg.FieldStats), 1) + fieldStats0 := statsMsg.FieldStats[0] + assert.Equal(t, fieldStats0.FieldID, fieldID) + assert.Equal(t, fieldStats0.CollectionID, collectionID) + assert.Equal(t, len(fieldStats0.IndexStats), 1) + indexStats0 := fieldStats0.IndexStats[0] + params := indexStats0.IndexParams + // sort index params by key + sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) + indexEqual := node.loadIndexService.indexParamsEqual(params, indexParamsKV) + assert.Equal(t, indexEqual, true) + } + + if findFiledStats { + break + } + } + + err = searchStream.Produce(fn(2)) + assert.NoError(t, err) + searchResult = searchResultStream.Consume() + assert.NotNil(t, searchResult) + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + + idsIndex := indexResult.IDs() + idsSegment := unMarshaledHit.IDs + assert.Equal(t, len(idsIndex), len(idsSegment)) + for i := 0; i < len(idsIndex); i++ { + assert.Equal(t, idsIndex[i], idsSegment[i]) + } + Params.SearchChannelNames = oldSearchChannelNames + Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName + fmt.Println("loadIndex binaryVector test Done!") + + defer assert.Equal(t, findFiledStats, true) + <-node.queryNodeLoopCtx.Done() + node.Close() +} diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 46a072b7a0ac819d7f32e84de618ebfc87287e1e..acf0f1ab12f3e9482525797ac79621dba9f2c6f1 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -123,6 +123,9 @@ func (node *QueryNode) Close() { if node.searchService != nil { node.searchService.close() } + if node.loadIndexService != nil { + node.loadIndexService.close() + } if node.statsService != nil { node.statsService.close() } diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 34ec092f5287a77247765d1f411958204d74e2f7..1217fa3da34d8069dd5a41bb08f5af2d557651bc 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -35,7 +35,7 @@ func genTestCollectionMeta(collectionName string, collectionID UniqueID, isBinar TypeParams: []*commonpb.KeyValuePair{ { Key: "dim", - Value: "16", + Value: "128", }, }, IndexParams: []*commonpb.KeyValuePair{ @@ -92,8 +92,12 @@ func genTestCollectionMeta(collectionName string, collectionID UniqueID, isBinar return &collectionMeta } -func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) { - collectionMeta := genTestCollectionMeta(collectionName, collectionID, false) +func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID, optional ...bool) { + isBinary := false + if len(optional) > 0 { + isBinary = optional[0] + } + collectionMeta := genTestCollectionMeta(collectionName, collectionID, isBinary) schemaBlob := proto.MarshalTextString(collectionMeta.Schema) assert.NotEqual(t, "", schemaBlob) diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index 43512b90192c2a49371ce43c1b292eb858c9e697..5a6ce44a55ba9d171912a0d13cd929039e97be87 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -8,6 +8,7 @@ import ( "github.com/opentracing/opentracing-go" oplog "github.com/opentracing/opentracing-go/log" "log" + "regexp" "sync" "github.com/golang/protobuf/proto" @@ -26,8 +27,8 @@ type searchService struct { replica collectionReplica tSafeWatcher *tSafeWatcher + serviceableTimeMutex sync.Mutex // guards serviceableTime serviceableTime Timestamp - serviceableTimeMutex sync.Mutex msgBuffer chan msgstream.TsMsg unsolvedMsg []msgstream.TsMsg @@ -235,7 +236,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { return errors.New("unmarshal query failed") } collectionName := query.CollectionName - partitionTags := query.PartitionTags + partitionTagsInQuery := query.PartitionTags collection, err := ss.replica.getCollectionByName(collectionName) if err != nil { span.LogFields(oplog.Error(err)) @@ -260,15 +261,28 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { searchResults := make([]*SearchResult, 0) matchedSegments := make([]*Segment, 0) - for _, partitionTag := range partitionTags { - hasPartition := ss.replica.hasPartition(collectionID, partitionTag) - if !hasPartition { - span.LogFields(oplog.Error(errors.New("search Failed, invalid partitionTag"))) - return errors.New("search Failed, invalid partitionTag") + fmt.Println("search msg's partitionTag = ", partitionTagsInQuery) + + var partitionTagsInCol []string + for _, partition := range collection.partitions { + partitionTag := partition.partitionTag + partitionTagsInCol = append(partitionTagsInCol, partitionTag) + } + var searchPartitionTag []string + if len(partitionTagsInQuery) == 0 { + searchPartitionTag = partitionTagsInCol + } else { + for _, tag := range partitionTagsInCol { + for _, toMatchTag := range partitionTagsInQuery { + re := regexp.MustCompile("^" + toMatchTag + "$") + if re.MatchString(tag) { + searchPartitionTag = append(searchPartitionTag, tag) + } + } } } - for _, partitionTag := range partitionTags { + for _, partitionTag := range searchPartitionTag { partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag) for _, segment := range partition.segments { //fmt.Println("dsl = ", dsl) @@ -285,30 +299,39 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } if len(searchResults) <= 0 { - var results = internalpb.SearchResult{ - MsgType: internalpb.MsgType_kSearchResult, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, - ReqID: searchMsg.ReqID, - ProxyID: searchMsg.ProxyID, - QueryNodeID: ss.queryNodeID, - Timestamp: searchTimestamp, - ResultChannelID: searchMsg.ResultChannelID, - Hits: nil, - } - searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{ - MsgCtx: searchMsg.MsgCtx, - HashValues: []uint32{uint32(searchMsg.ResultChannelID)}, - }, - SearchResult: results, - } - err = ss.publishSearchResult(searchResultMsg) - if err != nil { - span.LogFields(oplog.Error(err)) - return err + for _, group := range placeholderGroups { + nq := group.getNumOfQuery() + nilHits := make([][]byte, nq) + hit := &servicepb.Hits{} + for i := 0; i < int(nq); i++ { + bs, err := proto.Marshal(hit) + if err != nil { + span.LogFields(oplog.Error(err)) + return err + } + nilHits[i] = bs + } + var results = internalpb.SearchResult{ + MsgType: internalpb.MsgType_kSearchResult, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, + ReqID: searchMsg.ReqID, + ProxyID: searchMsg.ProxyID, + QueryNodeID: ss.queryNodeID, + Timestamp: searchTimestamp, + ResultChannelID: searchMsg.ResultChannelID, + Hits: nilHits, + } + searchResultMsg := &msgstream.SearchResultMsg{ + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, + SearchResult: results, + } + err = ss.publishSearchResult(searchResultMsg) + if err != nil { + span.LogFields(oplog.Error(err)) + return err + } + return nil } - span.LogFields(oplog.String("publish search research success", "publish search research success")) - return nil } inReduced := make([]bool, len(searchResults)) @@ -385,9 +408,9 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result") - defer span.Finish() - msg.SetMsgContext(ctx) + // span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result") + // defer span.Finish() + // msg.SetMsgContext(ctx) fmt.Println("Public SearchResult", msg.HashKeys()) msgPack := msgstream.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, msg) @@ -396,9 +419,10 @@ func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { } func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg string) error { - span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg") - defer span.Finish() - msg.SetMsgContext(ctx) + // span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg") + // defer span.Finish() + // msg.SetMsgContext(ctx) + fmt.Println("Public fail SearchResult!") msgPack := msgstream.MsgPack{} searchMsg, ok := msg.(*msgstream.SearchMsg) if !ok { diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 037909286549cf3408c94df2c3db9bb938417d97..2e00200a2fb6f5af96e5886ffb80cd8b531dd9bc 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -13,6 +13,7 @@ package querynode import "C" import ( "strconv" + "sync" "unsafe" "github.com/stretchr/testify/assert" @@ -21,32 +22,53 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" ) +type indexParam = map[string]string + type Segment struct { - segmentPtr C.CSegmentBase - segmentID UniqueID - partitionTag string // TODO: use partitionID - collectionID UniqueID - lastMemSize int64 - lastRowCount int64 + segmentPtr C.CSegmentBase + segmentID UniqueID + partitionTag string // TODO: use partitionID + collectionID UniqueID + lastMemSize int64 + lastRowCount int64 + + rmMutex sync.Mutex // guards recentlyModified recentlyModified bool + + paramMutex sync.RWMutex // guards indexParam + indexParam map[int64]indexParam } func (s *Segment) ID() UniqueID { return s.segmentID } +func (s *Segment) SetRecentlyModified(modify bool) { + s.rmMutex.Lock() + defer s.rmMutex.Unlock() + s.recentlyModified = modify +} + +func (s *Segment) GetRecentlyModified() bool { + s.rmMutex.Lock() + defer s.rmMutex.Unlock() + return s.recentlyModified +} + //-------------------------------------------------------------------------------------- constructor and destructor func newSegment(collection *Collection, segmentID int64, partitionTag string, collectionID UniqueID) *Segment { /* CSegmentBase newSegment(CPartition partition, unsigned long segment_id); */ + initIndexParam := make(map[int64]indexParam) segmentPtr := C.NewSegment(collection.collectionPtr, C.ulong(segmentID)) var newSegment = &Segment{ segmentPtr: segmentPtr, segmentID: segmentID, partitionTag: partitionTag, collectionID: collectionID, + indexParam: initIndexParam, } return newSegment @@ -161,7 +183,7 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps return errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) } - s.recentlyModified = true + s.SetRecentlyModified(true) return nil } @@ -256,3 +278,39 @@ func (s *Segment) updateSegmentIndex(loadIndexInfo *LoadIndexInfo) error { return nil } + +func (s *Segment) setIndexParam(fieldID int64, indexParamKv []*commonpb.KeyValuePair) error { + s.paramMutex.Lock() + defer s.paramMutex.Unlock() + indexParamMap := make(indexParam) + if indexParamKv == nil { + return errors.New("loadIndexMsg's indexParam empty") + } + for _, param := range indexParamKv { + indexParamMap[param.Key] = param.Value + } + s.indexParam[fieldID] = indexParamMap + return nil +} + +func (s *Segment) matchIndexParam(fieldID int64, indexParamKv []*commonpb.KeyValuePair) bool { + s.paramMutex.RLock() + defer s.paramMutex.RUnlock() + fieldIndexParam := s.indexParam[fieldID] + if fieldIndexParam == nil { + return false + } + paramSize := len(s.indexParam) + matchCount := 0 + for _, param := range indexParamKv { + value, ok := fieldIndexParam[param.Key] + if !ok { + return false + } + if param.Value != value { + return false + } + matchCount++ + } + return paramSize == matchCount +} diff --git a/internal/querynode/tsafe.go b/internal/querynode/tsafe.go index 27a1b640046a420a5b0aa2701fb4246b936ebc38..60529a3c9868837c3e651dc057b8d407281edeb7 100644 --- a/internal/querynode/tsafe.go +++ b/internal/querynode/tsafe.go @@ -31,7 +31,7 @@ type tSafe interface { } type tSafeImpl struct { - tSafeMu sync.Mutex + tSafeMu sync.Mutex // guards all fields tSafe Timestamp watcherList []*tSafeWatcher } @@ -44,6 +44,8 @@ func newTSafe() tSafe { } func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) { + ts.tSafeMu.Lock() + defer ts.tSafeMu.Unlock() ts.watcherList = append(ts.watcherList, t) } @@ -55,8 +57,9 @@ func (ts *tSafeImpl) get() Timestamp { func (ts *tSafeImpl) set(t Timestamp) { ts.tSafeMu.Lock() + defer ts.tSafeMu.Unlock() + ts.tSafe = t - ts.tSafeMu.Unlock() for _, watcher := range ts.watcherList { watcher.notify() } diff --git a/internal/writenode/flow_graph_dd_node.go b/internal/writenode/flow_graph_dd_node.go index 8bd71886ebabe7213c0ed29313027e9c3b31839a..7cb5ef2ab659d832e7233976054d04a6b5c2775a 100644 --- a/internal/writenode/flow_graph_dd_node.go +++ b/internal/writenode/flow_graph_dd_node.go @@ -103,8 +103,6 @@ func (ddNode *ddNode) Operate(in []*Msg) []*Msg { return tsMessages[i].BeginTs() < tsMessages[j].BeginTs() }) - var flush bool = false - var flushSegID UniqueID // do dd tasks for _, msg := range tsMessages { switch msg.Type() { @@ -118,98 +116,100 @@ func (ddNode *ddNode) Operate(in []*Msg) []*Msg { ddNode.dropPartition(msg.(*msgstream.DropPartitionMsg)) case internalPb.MsgType_kFlush: fMsg := msg.(*msgstream.FlushMsg) - flush = true - flushSegID = fMsg.SegmentID + flushSegID := fMsg.SegmentID ddMsg.flushMessages = append(ddMsg.flushMessages, fMsg) + ddNode.flush() + + log.Println(".. manual flush completed ...") + ddlFlushMsg := &ddlFlushSyncMsg{ + flushCompleted: true, + ddlBinlogPathMsg: ddlBinlogPathMsg{ + segID: flushSegID, + }, + } + + ddNode.outCh <- ddlFlushMsg + default: log.Println("Non supporting message type:", msg.Type()) } } // generate binlog - if ddNode.ddBuffer.full() || flush { - log.Println(". dd buffer full or receive Flush msg ...") - ddCodec := &storage.DataDefinitionCodec{} - for collectionID, data := range ddNode.ddBuffer.ddData { - // buffer data to binlog - binLogs, err := ddCodec.Serialize(data.timestamps, data.ddRequestString, data.eventTypes) + if ddNode.ddBuffer.full() { + ddNode.flush() + } + + var res Msg = ddNode.ddMsg + return []*Msg{&res} +} + +func (ddNode *ddNode) flush() { + // generate binlog + log.Println(". dd buffer full or receive Flush msg ...") + ddCodec := &storage.DataDefinitionCodec{} + for collectionID, data := range ddNode.ddBuffer.ddData { + // buffer data to binlog + binLogs, err := ddCodec.Serialize(data.timestamps, data.ddRequestString, data.eventTypes) + if err != nil { + log.Println(err) + continue + } + if len(binLogs) != 2 { + log.Println("illegal binLogs") + continue + } + + // binLogs -> minIO/S3 + if len(data.ddRequestString) != len(data.timestamps) || + len(data.timestamps) != len(data.eventTypes) { + log.Println("illegal ddBuffer, failed to save binlog") + continue + } else { + log.Println(".. dd buffer flushing ...") + // Blob key example: + // ${tenant}/data_definition_log/${collection_id}/ts/${log_idx} + // ${tenant}/data_definition_log/${collection_id}/ddl/${log_idx} + keyCommon := path.Join(Params.DdLogRootPath, strconv.FormatInt(collectionID, 10)) + + // save ts binlog + timestampLogIdx, err := ddNode.idAllocator.AllocOne() if err != nil { log.Println(err) - continue } - if len(binLogs) != 2 { - log.Println("illegal binLogs") - continue + timestampKey := path.Join(keyCommon, binLogs[0].GetKey(), strconv.FormatInt(timestampLogIdx, 10)) + err = ddNode.kv.Save(timestampKey, string(binLogs[0].GetValue())) + if err != nil { + log.Println(err) } + log.Println("save ts binlog, key = ", timestampKey) - // binLogs -> minIO/S3 - if len(data.ddRequestString) != len(data.timestamps) || - len(data.timestamps) != len(data.eventTypes) { - log.Println("illegal ddBuffer, failed to save binlog") - continue - } else { - log.Println(".. dd buffer flushing ...") - // Blob key example: - // ${tenant}/data_definition_log/${collection_id}/ts/${log_idx} - // ${tenant}/data_definition_log/${collection_id}/ddl/${log_idx} - keyCommon := path.Join(Params.DdLogRootPath, strconv.FormatInt(collectionID, 10)) - - // save ts binlog - timestampLogIdx, err := ddNode.idAllocator.AllocOne() - if err != nil { - log.Println(err) - } - timestampKey := path.Join(keyCommon, binLogs[0].GetKey(), strconv.FormatInt(timestampLogIdx, 10)) - err = ddNode.kv.Save(timestampKey, string(binLogs[0].GetValue())) - if err != nil { - log.Println(err) - } - log.Println("save ts binlog, key = ", timestampKey) - - // save dd binlog - ddLogIdx, err := ddNode.idAllocator.AllocOne() - if err != nil { - log.Println(err) - } - ddKey := path.Join(keyCommon, binLogs[1].GetKey(), strconv.FormatInt(ddLogIdx, 10)) - err = ddNode.kv.Save(ddKey, string(binLogs[1].GetValue())) - if err != nil { - log.Println(err) - } - log.Println("save dd binlog, key = ", ddKey) - - ddlFlushMsg := &ddlFlushSyncMsg{ - flushCompleted: false, - ddlBinlogPathMsg: ddlBinlogPathMsg{ - collID: collectionID, - paths: []string{timestampKey, ddKey}, - }, - } - - ddNode.outCh <- ddlFlushMsg + // save dd binlog + ddLogIdx, err := ddNode.idAllocator.AllocOne() + if err != nil { + log.Println(err) + } + ddKey := path.Join(keyCommon, binLogs[1].GetKey(), strconv.FormatInt(ddLogIdx, 10)) + err = ddNode.kv.Save(ddKey, string(binLogs[1].GetValue())) + if err != nil { + log.Println(err) + } + log.Println("save dd binlog, key = ", ddKey) + + ddlFlushMsg := &ddlFlushSyncMsg{ + flushCompleted: false, + ddlBinlogPathMsg: ddlBinlogPathMsg{ + collID: collectionID, + paths: []string{timestampKey, ddKey}, + }, } + ddNode.outCh <- ddlFlushMsg } - // clear buffer - ddNode.ddBuffer.ddData = make(map[UniqueID]*ddData) - } - - if flush { - - log.Println(".. manual flush completed ...") - ddlFlushMsg := &ddlFlushSyncMsg{ - flushCompleted: true, - ddlBinlogPathMsg: ddlBinlogPathMsg{ - segID: flushSegID, - }, - } - - ddNode.outCh <- ddlFlushMsg } - - var res Msg = ddNode.ddMsg - return []*Msg{&res} + // clear buffer + ddNode.ddBuffer.ddData = make(map[UniqueID]*ddData) } func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) { diff --git a/internal/writenode/flush_sync_service.go b/internal/writenode/flush_sync_service.go index 7b6587b711c5d79c3bde04573e889d2024c7b287..77e17bda86b2e7db6e4bea3add0d73e34e4f35e2 100644 --- a/internal/writenode/flush_sync_service.go +++ b/internal/writenode/flush_sync_service.go @@ -112,6 +112,7 @@ func (fService *flushSyncService) start() { fService.completeInsertFlush(insertFlushMsg.segID) if fService.FlushCompleted(insertFlushMsg.segID) { + log.Printf("Seg(%d) flush completed.", insertFlushMsg.segID) fService.metaTable.CompleteFlush(insertFlushMsg.ts, insertFlushMsg.segID) } } diff --git a/internal/writenode/flush_sync_service_test.go b/internal/writenode/flush_sync_service_test.go index 7da80503d6191fcbe38aa299829c2c8e69181e16..59e0442f4244808fda0832dfcff0df51aa02375c 100644 --- a/internal/writenode/flush_sync_service_test.go +++ b/internal/writenode/flush_sync_service_test.go @@ -90,7 +90,7 @@ func TestFlushSyncService_Start(t *testing.T) { } for { - if len(ddChan) == 0 && len(insertChan) == 0 { + if len(ddChan) == 0 && len(insertChan) == 0 && fService.FlushCompleted(SegID) { break } } @@ -117,10 +117,6 @@ func TestFlushSyncService_Start(t *testing.T) { assert.NoError(t, err) assert.Equal(t, true, cp) - cp, err = fService.metaTable.checkFlushComplete(SegID) - assert.NoError(t, err) - assert.Equal(t, true, cp) - }) } diff --git a/internal/writenode/meta_table.go b/internal/writenode/meta_table.go index ab5ac04461587a53ec3a4c0ea6ec4a7bbee1d96c..ea7828874f935da9e41880c46b87d55595f222b5 100644 --- a/internal/writenode/meta_table.go +++ b/internal/writenode/meta_table.go @@ -171,8 +171,8 @@ func (mt *metaTable) addSegmentFlush(segmentID UniqueID, timestamp Timestamp) er } func (mt *metaTable) getFlushCloseTime(segmentID UniqueID) (Timestamp, error) { - mt.lock.Lock() - defer mt.lock.Unlock() + mt.lock.RLock() + defer mt.lock.RUnlock() meta, ok := mt.segID2FlushMeta[segmentID] if !ok { return typeutil.ZeroTimestamp, errors.Errorf("segment not exists with ID = " + strconv.FormatInt(segmentID, 10)) @@ -181,8 +181,8 @@ func (mt *metaTable) getFlushCloseTime(segmentID UniqueID) (Timestamp, error) { } func (mt *metaTable) getFlushOpenTime(segmentID UniqueID) (Timestamp, error) { - mt.lock.Lock() - defer mt.lock.Unlock() + mt.lock.RLock() + defer mt.lock.RUnlock() meta, ok := mt.segID2FlushMeta[segmentID] if !ok { return typeutil.ZeroTimestamp, errors.Errorf("segment not exists with ID = " + strconv.FormatInt(segmentID, 10)) @@ -191,8 +191,8 @@ func (mt *metaTable) getFlushOpenTime(segmentID UniqueID) (Timestamp, error) { } func (mt *metaTable) checkFlushComplete(segmentID UniqueID) (bool, error) { - mt.lock.Lock() - defer mt.lock.Unlock() + mt.lock.RLock() + defer mt.lock.RUnlock() meta, ok := mt.segID2FlushMeta[segmentID] if !ok { return false, errors.Errorf("segment not exists with ID = " + strconv.FormatInt(segmentID, 10)) @@ -201,9 +201,8 @@ func (mt *metaTable) checkFlushComplete(segmentID UniqueID) (bool, error) { } func (mt *metaTable) getSegBinlogPaths(segmentID UniqueID) (map[int64][]string, error) { - mt.lock.Lock() - defer mt.lock.Unlock() - + mt.lock.RLock() + defer mt.lock.RUnlock() meta, ok := mt.segID2FlushMeta[segmentID] if !ok { return nil, errors.Errorf("segment not exists with ID = " + strconv.FormatInt(segmentID, 10)) @@ -216,9 +215,8 @@ func (mt *metaTable) getSegBinlogPaths(segmentID UniqueID) (map[int64][]string, } func (mt *metaTable) getDDLBinlogPaths(collID UniqueID) (map[UniqueID][]string, error) { - mt.lock.Lock() - defer mt.lock.Unlock() - + mt.lock.RLock() + defer mt.lock.RUnlock() meta, ok := mt.collID2DdlMeta[collID] if !ok { return nil, errors.Errorf("collection not exists with ID = " + strconv.FormatInt(collID, 10)) diff --git a/scripts/init_devcontainer.sh b/scripts/init_devcontainer.sh index 6b30db984ee3e073987a8e473f4589ac16ac196e..80321670b9968754eca42575da89b7cbd78327a7 100755 --- a/scripts/init_devcontainer.sh +++ b/scripts/init_devcontainer.sh @@ -8,6 +8,15 @@ while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symli done ROOT_DIR="$( cd -P "$( dirname "$SOURCE" )/.." && pwd )" +unameOut="$(uname -s)" +case "${unameOut}" in + Linux*) machine=Linux;; + Darwin*) machine=Mac;; + CYGWIN*) machine=Cygwin;; + MINGW*) machine=MinGw;; + *) machine="UNKNOWN:${unameOut}" +esac + # Attempt to run in the container with the same UID/GID as we have on the host, # as this results in the correct permissions on files created in the shared # volumes. This isn't always possible, however, as IDs less than 100 are @@ -21,8 +30,14 @@ gid=$(id -g) [ "$uid" -lt 500 ] && uid=501 [ "$gid" -lt 500 ] && gid=$uid -awk 'c&&c--{sub(/^/,"#")} /# Build devcontainer/{c=5} 1' $ROOT_DIR/docker-compose.yml > $ROOT_DIR/docker-compose-vscode.yml.bak +awk 'c&&c--{sub(/^/,"#")} /# Build devcontainer/{c=5} 1' $ROOT_DIR/docker-compose.yml > $ROOT_DIR/docker-compose-vscode.yml.tmp + +awk 'c&&c--{sub(/^/,"#")} /# Command/{c=3} 1' $ROOT_DIR/docker-compose-vscode.yml.tmp > $ROOT_DIR/docker-compose-vscode.yml -awk 'c&&c--{sub(/^/,"#")} /# Command/{c=3} 1' $ROOT_DIR/docker-compose-vscode.yml.bak > $ROOT_DIR/docker-compose-vscode.yml +rm $ROOT_DIR/docker-compose-vscode.yml.tmp -sed -i '.bak' "s/# user: {{ CURRENT_ID }}/user: \"$uid:$gid\"/g" $ROOT_DIR/docker-compose-vscode.yml +if [ "${machine}" == "Mac" ];then + sed -i '' "s/# user: {{ CURRENT_ID }}/user: \"$uid:$gid\"/g" $ROOT_DIR/docker-compose-vscode.yml +else + sed -i "s/# user: {{ CURRENT_ID }}/user: \"$uid:$gid\"/g" $ROOT_DIR/docker-compose-vscode.yml +fi \ No newline at end of file diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index f4b4bac81f072b6c8353e675479c7c7fd149a450..b48b9f6717d353b2f7de2f31af7381a571122d39 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -8,13 +8,15 @@ while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symli SOURCE="$(readlink "$SOURCE")" [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located done -SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" +ROOT_DIR="$( cd -P "$( dirname "$SOURCE" )/.." && pwd )" # ignore Minio,S3 unittes -MILVUS_DIR="${SCRIPTS_DIR}/../internal/" +MILVUS_DIR="${ROOT_DIR}/internal/" echo $MILVUS_DIR -go test -cover "${MILVUS_DIR}/kv/..." -failfast -go test -cover "${MILVUS_DIR}/proxy/..." -failfast -go test -cover "${MILVUS_DIR}/writenode/..." -failfast -go test -cover "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/storage" "${MILVUS_DIR}/util/..." -failfast -#go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast + +go test -race -cover "${MILVUS_DIR}/kv/..." -failfast +go test -race -cover "${MILVUS_DIR}/proxy/..." -failfast +go test -race -cover "${MILVUS_DIR}/writenode/..." -failfast +go test -race -cover "${MILVUS_DIR}/master/..." -failfast +go test -cover "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/storage" "${MILVUS_DIR}/util/..." -failfast +#go test -race -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index 9bee462aee885e91e3b5735d9a8c2ef4f58a45df..98dd4e02c717761d411870a5bc05d3a0a68e2c74 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -4,5 +4,5 @@ numpy==1.18.1 pytest==5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4 -pymilvus-distributed==0.0.10 +pymilvus-distributed==0.0.14 sklearn==0.0 diff --git a/tests/python/test_bulk_insert.py b/tests/python/test_bulk_insert.py index d8ff5cbd7abd807492311669b9dd4cd159965b21..97f3724ec03411fd7d509fb03556d27bf403c9b1 100644 --- a/tests/python/test_bulk_insert.py +++ b/tests/python/test_bulk_insert.py @@ -101,7 +101,6 @@ class TestInsertBase: connect.flush([collection]) connect.drop_collection(collection) - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_create_index(self, connect, collection, get_simple_index): ''' @@ -119,7 +118,6 @@ class TestInsertBase: if field["name"] == field_name: assert field["indexes"][0] == get_simple_index - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_create_index_new(self, connect, collection, get_simple_index): ''' @@ -137,7 +135,6 @@ class TestInsertBase: if field["name"] == field_name: assert field["indexes"][0] == get_simple_index - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_after_create_index(self, connect, collection, get_simple_index): ''' @@ -154,7 +151,6 @@ class TestInsertBase: if field["name"] == field_name: assert field["indexes"][0] == get_simple_index - # @pytest.mark.skip(" later ") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_search(self, connect, collection): ''' @@ -645,7 +641,6 @@ class TestInsertBinary: connect.flush([binary_collection]) assert connect.count_entities(binary_collection) == default_nb - @pytest.mark.skip("create index") def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index): ''' target: test insert binary entities after build index @@ -662,7 +657,6 @@ class TestInsertBinary: if field["name"] == binary_field_name: assert field["indexes"][0] == get_binary_index - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index): ''' @@ -863,7 +857,6 @@ class TestInsertMultiCollections: connect.flush([collection_name]) assert len(ids) == 1 - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_create_index_insert_vector_another(self, connect, collection, get_simple_index): ''' @@ -877,7 +870,7 @@ class TestInsertMultiCollections: ids = connect.bulk_insert(collection, default_entity) connect.drop_collection(collection_name) - @pytest.mark.skip("create index") + @pytest.mark.skip("count entities") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_vector_create_index_another(self, connect, collection, get_simple_index): ''' @@ -892,7 +885,7 @@ class TestInsertMultiCollections: count = connect.count_entities(collection_name) assert count == 0 - @pytest.mark.skip("create index") + @pytest.mark.skip("count entities") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_vector_sleep_create_index_another(self, connect, collection, get_simple_index): ''' diff --git a/tests/python/test_index.py b/tests/python/test_index.py index 687e6573a68070eaee449f7d0cf425cf971fb82b..8162fa37fe82c30a55a1a423d89b45b1b128cc6f 100644 --- a/tests/python/test_index.py +++ b/tests/python/test_index.py @@ -17,19 +17,19 @@ query, query_vecs = gen_query_vectors(field_name, default_entities, default_top_ default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} -@pytest.mark.skip("wait for debugging...") +# @pytest.mark.skip("wait for debugging...") class TestIndexBase: @pytest.fixture( scope="function", params=gen_simple_index() ) def get_simple_index(self, request, connect): + import copy logging.getLogger().info(request.param) - # TODO: Determine the service mode - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param + if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return copy.deepcopy(request.param) @pytest.fixture( scope="function", @@ -132,7 +132,7 @@ class TestIndexBase: ''' ids = connect.bulk_insert(collection, default_entities) connect.create_index(collection, field_name, get_simple_index) - logging.getLogger().info(connect.get_collection_stats(collection)) + # logging.getLogger().info(connect.get_collection_stats(collection)) nq = get_nq index_type = get_simple_index["index_type"] search_param = get_search_param(index_type) @@ -140,6 +140,7 @@ class TestIndexBase: res = connect.search(collection, query) assert len(res) == nq + @pytest.mark.skip("can't_pass_ci") @pytest.mark.timeout(BUILD_TIMEOUT) @pytest.mark.level(2) def test_create_index_multithread(self, connect, collection, args): @@ -175,6 +176,7 @@ class TestIndexBase: with pytest.raises(Exception) as e: connect.create_index(collection_name, field_name, default_index) + @pytest.mark.skip("count_entries") @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_insert_flush(self, connect, collection, get_simple_index): @@ -201,6 +203,7 @@ class TestIndexBase: connect.create_index(collection, field_name, get_simple_index) # TODO: + @pytest.mark.skip("get_collection_stats") @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_different_index_repeatedly(self, connect, collection): @@ -275,7 +278,7 @@ class TestIndexBase: ids = connect.bulk_insert(collection, default_entities) get_simple_index["metric_type"] = metric_type connect.create_index(collection, field_name, get_simple_index) - logging.getLogger().info(connect.get_collection_stats(collection)) + # logging.getLogger().info(connect.get_collection_stats(collection)) nq = get_nq index_type = get_simple_index["index_type"] search_param = get_search_param(index_type) @@ -320,6 +323,7 @@ class TestIndexBase: with pytest.raises(Exception) as e: connect.create_index(collection_name, field_name, default_index) + @pytest.mark.skip("count_entries") @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_no_vectors_insert_ip(self, connect, collection, get_simple_index): ''' @@ -347,6 +351,8 @@ class TestIndexBase: connect.create_index(collection, field_name, get_simple_index) # TODO: + + @pytest.mark.skip("get_collection_stats") @pytest.mark.level(2) @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_different_index_repeatedly_ip(self, connect, collection): @@ -369,6 +375,7 @@ class TestIndexBase: ****************************************************************** """ + @pytest.mark.skip("drop_index") def test_drop_index(self, connect, collection, get_simple_index): ''' target: test drop index interface @@ -382,6 +389,7 @@ class TestIndexBase: # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type assert not stats["partitions"][0]["segments"] + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_drop_index_repeatly(self, connect, collection, get_simple_index): ''' @@ -398,6 +406,7 @@ class TestIndexBase: # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type assert not stats["partitions"][0]["segments"] + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_drop_index_without_connect(self, dis_connect, collection): ''' @@ -408,6 +417,7 @@ class TestIndexBase: with pytest.raises(Exception) as e: dis_connect.drop_index(collection, field_name) + @pytest.mark.skip("drop_index") def test_drop_index_collection_not_existed(self, connect): ''' target: test drop index interface when collection name not existed @@ -419,6 +429,7 @@ class TestIndexBase: with pytest.raises(Exception) as e: connect.drop_index(collection_name, field_name) + @pytest.mark.skip("drop_index") def test_drop_index_collection_not_create(self, connect, collection): ''' target: test drop index interface when index not created @@ -429,6 +440,7 @@ class TestIndexBase: # no create index connect.drop_index(collection, field_name) + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_create_drop_index_repeatly(self, connect, collection, get_simple_index): ''' @@ -440,6 +452,7 @@ class TestIndexBase: connect.create_index(collection, field_name, get_simple_index) connect.drop_index(collection, field_name) + @pytest.mark.skip("drop_index") def test_drop_index_ip(self, connect, collection, get_simple_index): ''' target: test drop index interface @@ -454,6 +467,7 @@ class TestIndexBase: # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type assert not stats["partitions"][0]["segments"] + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_drop_index_repeatly_ip(self, connect, collection, get_simple_index): ''' @@ -471,6 +485,7 @@ class TestIndexBase: # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type assert not stats["partitions"][0]["segments"] + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_drop_index_without_connect_ip(self, dis_connect, collection): ''' @@ -481,6 +496,7 @@ class TestIndexBase: with pytest.raises(Exception) as e: dis_connect.drop_index(collection, field_name) + @pytest.mark.skip("drop_index") def test_drop_index_collection_not_create_ip(self, connect, collection): ''' target: test drop index interface when index not created @@ -491,6 +507,7 @@ class TestIndexBase: # no create index connect.drop_index(collection, field_name) + @pytest.mark.skip("drop_index") @pytest.mark.level(2) def test_create_drop_index_repeatly_ip(self, connect, collection, get_simple_index): ''' @@ -504,7 +521,6 @@ class TestIndexBase: connect.drop_index(collection, field_name) -@pytest.mark.skip("binary") class TestIndexBinary: @pytest.fixture( scope="function", @@ -590,6 +606,7 @@ class TestIndexBinary: res = connect.search(binary_collection, query, search_params=search_param) assert len(res) == nq + @pytest.mark.skip("get status for build index failed") @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_invalid_metric_type_binary(self, connect, binary_collection, get_l2_index): ''' @@ -613,6 +630,7 @@ class TestIndexBinary: ****************************************************************** """ + @pytest.mark.skip("get_collection_stats does not impl") def test_get_index_info(self, connect, binary_collection, get_jaccard_index): ''' target: test describe index interface @@ -632,6 +650,7 @@ class TestIndexBinary: if "index_type" in file: assert file["index_type"] == get_jaccard_index["index_type"] + @pytest.mark.skip("get_collection_stats does not impl") def test_get_index_info_partition(self, connect, binary_collection, get_jaccard_index): ''' target: test describe index interface @@ -660,6 +679,7 @@ class TestIndexBinary: ****************************************************************** """ + @pytest.mark.skip("get_collection_stats and drop_index do not impl") def test_drop_index(self, connect, binary_collection, get_jaccard_index): ''' target: test drop index interface @@ -674,6 +694,7 @@ class TestIndexBinary: # assert stats["partitions"][0]["segments"][0]["index_name"] == default_index_type assert not stats["partitions"][0]["segments"] + @pytest.mark.skip("get_collection_stats does not impl") def test_drop_index_partition(self, connect, binary_collection, get_jaccard_index): ''' target: test drop index interface @@ -699,7 +720,6 @@ class TestIndexBinary: assert False -@pytest.mark.skip("wait for debugging...") class TestIndexInvalid(object): """ Test create / describe / drop index interfaces with invalid collection names @@ -738,7 +758,6 @@ class TestIndexInvalid(object): connect.create_index(collection, field_name, get_simple_index) -@pytest.mark.skip("wait for debugging...") class TestIndexAsync: @pytest.fixture(scope="function", autouse=True) def skip_http_check(self, args): diff --git a/tests/python/test_insert.py b/tests/python/test_insert.py index 14948a4c1f1eb42c6d76344bd245f0a09306975b..3c0420aa1057cd0fd6ff3329f84402d04f7043c0 100644 --- a/tests/python/test_insert.py +++ b/tests/python/test_insert.py @@ -101,7 +101,6 @@ class TestInsertBase: connect.flush([collection]) connect.drop_collection(collection) - @pytest.mark.skip("create_index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_create_index(self, connect, collection, get_simple_index): ''' @@ -119,7 +118,6 @@ class TestInsertBase: if field["name"] == field_name: assert field["indexes"][0] == get_simple_index - @pytest.mark.skip("create_index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_after_create_index(self, connect, collection, get_simple_index): ''' @@ -136,7 +134,6 @@ class TestInsertBase: if field["name"] == field_name: assert field["indexes"][0] == get_simple_index - @pytest.mark.skip(" todo fix search") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_search(self, connect, collection): ''' @@ -313,7 +310,6 @@ class TestInsertBinary: connect.flush([binary_collection]) assert connect.count_entities(binary_collection) == default_nb - @pytest.mark.skip("create index") def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index): ''' target: test insert binary entities after build index @@ -330,7 +326,6 @@ class TestInsertBinary: if field["name"] == binary_field_name: assert field["indexes"][0] == get_binary_index - @pytest.mark.skip("create index") @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index): ''' diff --git a/tests/python/test_search.py b/tests/python/test_search.py index 7ecdcb54958fd2aa9bbba18ed73cbda6ba5cdb4d..d23e4ff0f884cb4f222fc3cebe1f69fc93d8a46b 100644 --- a/tests/python/test_search.py +++ b/tests/python/test_search.py @@ -89,10 +89,11 @@ class TestSearchBase: params=gen_simple_index() ) def get_simple_index(self, request, connect): + import copy if str(connect._cmd("mode")) == "CPU": if request.param["index_type"] in index_cpu_not_support(): pytest.skip("sq8h not support in CPU mode") - return request.param + return copy.deepcopy(request.param) @pytest.fixture( scope="function", @@ -255,7 +256,7 @@ class TestSearchBase: assert res2[0][0].id == res[0][1].id assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") - @pytest.mark.skip("search_after_index") + # Pass @pytest.mark.level(2) def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -302,7 +303,7 @@ class TestSearchBase: assert len(res) == nq assert len(res[0]) == default_top_k - @pytest.mark.skip("search_index_partition") + # pass @pytest.mark.level(2) def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -333,7 +334,7 @@ class TestSearchBase: res = connect.search(collection, query, partition_tags=[default_tag]) assert len(res) == nq - @pytest.mark.skip("search_index_partition_B") + # PASS @pytest.mark.level(2) def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -383,7 +384,7 @@ class TestSearchBase: assert len(res) == nq assert len(res[0]) == 0 - @pytest.mark.skip("search_index_partitions") + # PASS @pytest.mark.level(2) def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' @@ -417,7 +418,7 @@ class TestSearchBase: assert res[0]._distances[0] > epsilon assert res[1]._distances[0] > epsilon - @pytest.mark.skip("search_index_partitions_B") + # Pass @pytest.mark.level(2) def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): ''' @@ -451,7 +452,7 @@ class TestSearchBase: assert res[0]._distances[0] < epsilon assert res[1]._distances[0] < epsilon - # + # pass # test for ip metric # # TODO: reopen after we supporting ip flat @@ -477,7 +478,7 @@ class TestSearchBase: with pytest.raises(Exception) as e: res = connect.search(collection, query) - @pytest.mark.skip("search_ip_after_index") + # PASS @pytest.mark.level(2) def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -506,7 +507,6 @@ class TestSearchBase: assert check_id_result(res[0], ids[0]) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - @pytest.mark.skip("search_ip_index_partition") @pytest.mark.level(2) def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -539,7 +539,7 @@ class TestSearchBase: res = connect.search(collection, query, partition_tags=[default_tag]) assert len(res) == nq - @pytest.mark.skip("search_ip_index_partitions") + # PASS @pytest.mark.level(2) def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' @@ -618,7 +618,7 @@ class TestSearchBase: res = connect.search(collection, query) assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) - @pytest.mark.skip("search_distance_l2_after_index") + # Pass def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -672,7 +672,7 @@ class TestSearchBase: res = connect.search(collection, query) assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon - @pytest.mark.skip("search_distance_ip_after_index") + # Pass def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -942,7 +942,7 @@ class TestSearchBase: assert res[i]._distances[0] < epsilon assert res[i]._distances[1] > epsilon - @pytest.mark.skip("query_entities_with_field_less_than_top_k") + @pytest.mark.skip("test_query_entities_with_field_less_than_top_k") def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): """ target: test search with field, and let return entities less than topk @@ -1741,8 +1741,7 @@ class TestSearchInvalid(object): def get_search_params(self, request): yield request.param - # TODO: reopen after we supporting create index - @pytest.mark.skip("search_with_invalid_params") + # Pass @pytest.mark.level(2) def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): ''' @@ -1763,8 +1762,7 @@ class TestSearchInvalid(object): with pytest.raises(Exception) as e: res = connect.search(collection, query) - # TODO: reopen after we supporting binary type - @pytest.mark.skip("search_with_invalid_params_binary") + # pass @pytest.mark.level(2) def test_search_with_invalid_params_binary(self, connect, binary_collection): ''' @@ -1783,7 +1781,7 @@ class TestSearchInvalid(object): with pytest.raises(Exception) as e: res = connect.search(binary_collection, query) - @pytest.mark.skip("search_with_empty_params") + # Pass @pytest.mark.level(2) def test_search_with_empty_params(self, connect, collection, args, get_simple_index): ''' diff --git a/tests/python/utils.py b/tests/python/utils.py index dfc99eb37efd8f8c2f6d48b8627e5aac9db0d07d..282f7e36f765def05538aa19a9591698389dfcdd 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -55,7 +55,7 @@ default_index_params = [ {"nlist": 128}, {"nlist": 128}, {"nlist": 128}, - {"nlist": 128, "m": 16}, + {"nlist": 128, "m": 16, "nbits": 8}, {"M": 48, "efConstruction": 500}, # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50}, {"n_trees": 50}, @@ -281,9 +281,9 @@ def gen_entities(nb, is_normal=False): def gen_entities_new(nb, is_normal=False): vectors = gen_vectors(nb, default_dim, is_normal) entities = [ - {"name": "int64", "values": [i for i in range(nb)]}, - {"name": "float", "values": [float(i) for i in range(nb)]}, - {"name": default_float_vec_field_name, "values": vectors} + {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]}, + {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]}, + {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors} ] return entities