From 861576f77a6ebf00b16bff55c410701b08430914 Mon Sep 17 00:00:00 2001 From: xige-16 <xi.ge@zilliz.com> Date: Mon, 22 Feb 2021 10:44:38 +0800 Subject: [PATCH] Checkout field ids when load segment in query node Signed-off-by: xige-16 <xi.ge@zilliz.com> --- internal/proxynode/paramtable.go | 1 + internal/querynode/load_service.go | 7 ++++++- internal/querynode/load_service_test.go | 2 +- internal/querynode/segment_loader.go | 16 +++++++++------- tests/python/test_index.py | 2 +- 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/internal/proxynode/paramtable.go b/internal/proxynode/paramtable.go index a3ca31df1..307c47713 100644 --- a/internal/proxynode/paramtable.go +++ b/internal/proxynode/paramtable.go @@ -177,6 +177,7 @@ func (pt *ParamTable) initQueryNodeIDList() []UniqueID { } ret = append(ret, UniqueID(v)) } + pt.QueryNodeIDList = ret return ret } diff --git a/internal/querynode/load_service.go b/internal/querynode/load_service.go index 003904226..73043849f 100644 --- a/internal/querynode/load_service.go +++ b/internal/querynode/load_service.go @@ -114,7 +114,12 @@ func (s *loadService) loadSegmentInternal(collectionID UniqueID, partitionID Uni return err } - targetFields := s.segLoader.getTargetFields(paths, srcFieldIDs, fieldIDs) + //fmt.Println("srcFieldIDs in internal:", srcFieldIDs) + //fmt.Println("dstFieldIDs in internal:", fieldIDs) + targetFields, err := s.segLoader.checkTargetFields(paths, srcFieldIDs, fieldIDs) + if err != nil { + return err + } err = s.segLoader.loadSegmentFieldsData(segment, targetFields) if err != nil { return err diff --git a/internal/querynode/load_service_test.go b/internal/querynode/load_service_test.go index a8f2145da..8ec4ee440 100644 --- a/internal/querynode/load_service_test.go +++ b/internal/querynode/load_service_test.go @@ -1142,7 +1142,7 @@ func TestSegmentLoad_Search_Vector(t *testing.T) { paths, srcFieldIDs, err := generateInsertBinLog(collectionID, partitionID, segmentID, keyPrefix) assert.NoError(t, err) - fieldsMap := node.loadService.segLoader.getTargetFields(paths, srcFieldIDs, fieldIDs) + fieldsMap, _ := node.loadService.segLoader.checkTargetFields(paths, srcFieldIDs, fieldIDs) assert.Equal(t, len(fieldsMap), 2) segment, err := node.replica.getSegmentByID(segmentID) diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 883c96e17..b2a18961c 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -8,6 +8,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/kv" minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/datapb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/zilliztech/milvus-distributed/internal/storage" @@ -53,11 +54,11 @@ func (loader *segmentLoader) getInsertBinlogPaths(segmentID UniqueID) ([]*intern } pathResponse, err := loader.dataClient.GetInsertBinlogPaths(insertBinlogPathRequest) - if err != nil { + if err != nil || pathResponse.Status.ErrorCode != commonpb.ErrorCode_SUCCESS { return nil, nil, err } - if len(pathResponse.FieldIDs) != len(pathResponse.Paths) { + if len(pathResponse.FieldIDs) != len(pathResponse.Paths) || len(pathResponse.FieldIDs) <= 0 { return nil, nil, errors.New("illegal InsertBinlogPathsResponse") } @@ -82,7 +83,7 @@ func (loader *segmentLoader) filterOutVectorFields(fieldIDs []int64, vectorField return targetFields } -func (loader *segmentLoader) getTargetFields(paths []*internalpb2.StringList, srcFieldIDS []int64, dstFields []int64) map[int64]*internalpb2.StringList { +func (loader *segmentLoader) checkTargetFields(paths []*internalpb2.StringList, srcFieldIDs []int64, dstFieldIDs []int64) (map[int64]*internalpb2.StringList, error) { targetFields := make(map[int64]*internalpb2.StringList) containsFunc := func(s []int64, e int64) bool { @@ -94,13 +95,14 @@ func (loader *segmentLoader) getTargetFields(paths []*internalpb2.StringList, sr return false } - for i, fieldID := range srcFieldIDS { - if containsFunc(dstFields, fieldID) { - targetFields[fieldID] = paths[i] + for i, fieldID := range dstFieldIDs { + if !containsFunc(srcFieldIDs, fieldID) { + return nil, errors.New("uncompleted fields") } + targetFields[fieldID] = paths[i] } - return targetFields + return targetFields, nil } func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, targetFields map[int64]*internalpb2.StringList) error { diff --git a/tests/python/test_index.py b/tests/python/test_index.py index ee7b36c50..cad7695c9 100644 --- a/tests/python/test_index.py +++ b/tests/python/test_index.py @@ -524,7 +524,6 @@ class TestIndexBase: connect.drop_index(collection, field_name) -@pytest.mark.skip("r0.3-test") class TestIndexBinary: @pytest.fixture( scope="function", @@ -594,6 +593,7 @@ class TestIndexBinary: ids = connect.bulk_insert(binary_collection, default_binary_entities, partition_tag=default_tag) connect.create_index(binary_collection, binary_field_name, get_jaccard_index) + @pytest.mark.skip("r0.3-test") @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_search_with_query_vectors(self, connect, binary_collection, get_jaccard_index, get_nq): ''' -- GitLab