Skip to content
Snippets Groups Projects
Commit 067c30c4 authored by Yihao Dai's avatar Yihao Dai Committed by yefu.chen
Browse files

Add sealedSegment cgo unittest, fix growingSegment field id check


Signed-off-by: default avatarbigsheeper <yihao.dai@zilliz.com>
parent 511aa419
No related branches found
No related tags found
No related merge requests found
Showing with 735 additions and 357 deletions
......@@ -187,64 +187,55 @@ func NewUnmarshalDispatcher() *UnmarshalDispatcher
RocksMQ is a RocksDB-based messaging/streaming library.
```go
```GO
// All the following UniqueIDs are 64-bit integer, which is combined with timestamp and increasing number
type ProducerMessage struct {
payload []byte
}
```
```go
type ConsumerMessage struct {
msgID MessageID
msgID UniqueID
payload []byte
}
```
```GO
type Channel struct {
beginOffset MessageID
endOffset MessageID
}
type ConsumerGroupContext struct {
currentOffset MessageID
type IDAllocator interface {
Alloc(count uint32) (UniqueID, UniqueID, error)
AllocOne() (UniqueID, error)
UpdateID() error
}
// Every collection has its RocksMQ
type RocksMQ struct {
channels map[string]Channel
cgCtxs map[string]ConsumerGroupContext
mu sync.Mutex
}
func (rmq *RocksMQ) CreateChannel(channelName string) error // create channel, add record in meta-store
func (rmq *RocksMQ) DestroyChannel(channelName string) error // drop channel, delete record in meta-store
func (rmq *RocksMQ) CreateConsumerGroup(groupName string) error // create consumer group, add record in meta-store
func (rmq *RocksMQ) DestroyConsumerGroup(groupName string) error // drop consumer group, delete record in meta-store
func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) error // produce a batch of message, insert into rocksdb
func (rmq *RocksMQ) Consume(groupName string, channelName string, n int) ([]ConsumerMessage, error) // comsume up to n messages, modify current_id in Etcd
func (rmq *RocksMQ) Seek(groupName string, channelName string, msgID MessageID) error // modify current_id in Etcd
store *gorocksdb.DB
kv kv.Base
idAllocator IDAllocator
produceMu sync.Mutex
consumeMu sync.Mutex
}
func (rmq *RocksMQ) CreateChannel(channelName string) error
func (rmq *RocksMQ) DestroyChannel(channelName string) error
func (rmq *RocksMQ) CreateConsumerGroup(groupName string) error
func (rmq *RocksMQ) DestroyConsumerGroup(groupName string) error
func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) error
func (rmq *RocksMQ) Consume(groupName string, channelName string, n int) ([]ConsumerMessage, error)
func (rmq *RocksMQ) Seek(groupName string, channelName string, msgID MessageID) error
func NewRocksMQ(name string, idAllocator IDAllocator) (*RocksMQ, error)
```
##### A.4.1 Meta (stored in Etcd)
* channel meta
```go
"$(channel_name)/begin_id", MessageID
"$(channel_name)/end_id", MessageID
```
// channel meta
"$(channel_name)/begin_id", UniqueID
"$(channel_name)/end_id", UniqueID
* consumer group meta
```go
"$(group_name)/$(channel_name)/current_id", MessageID
// consumer group meta
"$(group_name)/$(channel_name)/current_id", UniqueID
```
......
This diff is collapsed.
......@@ -1317,6 +1317,7 @@ class DescribeCollectionResponse :
enum : int {
kStatusFieldNumber = 1,
kSchemaFieldNumber = 2,
kCollectionIDFieldNumber = 3,
};
// .milvus.proto.common.Status status = 1;
bool has_status() const;
......@@ -1334,6 +1335,11 @@ class DescribeCollectionResponse :
::milvus::proto::schema::CollectionSchema* mutable_schema();
void set_allocated_schema(::milvus::proto::schema::CollectionSchema* schema);
 
// int64 collectionID = 3;
void clear_collectionid();
::PROTOBUF_NAMESPACE_ID::int64 collectionid() const;
void set_collectionid(::PROTOBUF_NAMESPACE_ID::int64 value);
// @@protoc_insertion_point(class_scope:milvus.proto.milvus.DescribeCollectionResponse)
private:
class _Internal;
......@@ -1341,6 +1347,7 @@ class DescribeCollectionResponse :
::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_;
::milvus::proto::common::Status* status_;
::milvus::proto::schema::CollectionSchema* schema_;
::PROTOBUF_NAMESPACE_ID::int64 collectionid_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
};
......@@ -7783,6 +7790,20 @@ inline void DescribeCollectionResponse::set_allocated_schema(::milvus::proto::sc
// @@protoc_insertion_point(field_set_allocated:milvus.proto.milvus.DescribeCollectionResponse.schema)
}
 
// int64 collectionID = 3;
inline void DescribeCollectionResponse::clear_collectionid() {
collectionid_ = PROTOBUF_LONGLONG(0);
}
inline ::PROTOBUF_NAMESPACE_ID::int64 DescribeCollectionResponse::collectionid() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.DescribeCollectionResponse.collectionID)
return collectionid_;
}
inline void DescribeCollectionResponse::set_collectionid(::PROTOBUF_NAMESPACE_ID::int64 value) {
collectionid_ = value;
// @@protoc_insertion_point(field_set:milvus.proto.milvus.DescribeCollectionResponse.collectionID)
}
// -------------------------------------------------------------------
 
// LoadCollectionRequest
......
......@@ -240,7 +240,7 @@ SegmentGrowingImpl::GetMemoryUsageInBytes() const {
Status
SegmentGrowingImpl::LoadIndexing(const LoadIndexInfo& info) {
auto field_offset = schema_->get_offset(FieldName(info.field_name));
auto field_offset = schema_->get_offset(FieldId(info.field_id));
Assert(info.index_params.count("metric_type"));
auto metric_type_str = info.index_params.at("metric_type");
......
......@@ -32,7 +32,7 @@ namespace chrono = std::chrono;
using namespace milvus;
using namespace milvus::segcore;
//using namespace milvus::proto;
// using namespace milvus::proto;
using namespace milvus::knowhere;
TEST(CApiTest, CollectionTest) {
......@@ -937,7 +937,7 @@ TEST(CApiTest, UpdateSegmentIndex_Without_Predicate) {
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
......@@ -1074,7 +1074,7 @@ TEST(CApiTest, UpdateSegmentIndex_With_float_Predicate_Range) {
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
......@@ -1211,7 +1211,7 @@ TEST(CApiTest, UpdateSegmentIndex_With_float_Predicate_Term) {
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
......@@ -1350,7 +1350,7 @@ TEST(CApiTest, UpdateSegmentIndex_With_binary_Predicate_Range) {
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
......@@ -1488,7 +1488,7 @@ TEST(CApiTest, UpdateSegmentIndex_With_binary_Predicate_Term) {
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
......@@ -1559,12 +1559,157 @@ TEST(CApiTest, SealedSegmentTest) {
auto load_info = CLoadFieldDataInfo{101, blob, N};
// TODO: open load test
// auto res = LoadFieldData(segment, load_info);
// assert(res.error_code == Success);
// auto count = GetRowCount(segment);
// assert(count == N);
auto res = LoadFieldData(segment, load_info);
assert(res.error_code == Success);
auto count = GetRowCount(segment);
assert(count == N);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, SealedSegment_search_float_Predicate_Range) {
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("L2", "16", false);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0, Sealed);
auto N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(0);
auto counter_col = dataset.get_col<int64_t>(1);
auto query_ptr = vec_col.data() + 420000 * DIM;
const char* dsl_string = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 420000,
"LT": 420010
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
// create place_holder_group
int num_queries = 10;
auto raw_group = CreatePlaceholderGroupFromBlob(num_queries, DIM, query_ptr);
auto blob = raw_group.SerializeAsString();
// search on segment's small index
void* plan = nullptr;
auto status = CreatePlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
assert(status.error_code == Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
Timestamp time = 10000000;
// load index to segment
auto conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 10},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0}};
auto indexing = generate_index(vec_col.data(), conf, DIM, K, N, IndexEnum::INDEX_FAISS_IVFPQ);
// gen query dataset
auto query_dataset = milvus::knowhere::GenDataset(num_queries, DIM, query_ptr);
auto result_on_index = indexing->Query(query_dataset, conf, nullptr);
auto ids = result_on_index->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dis = result_on_index->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::vector<int64_t> vec_ids(ids, ids + K * num_queries);
std::vector<float> vec_dis;
for (int j = 0; j < K * num_queries; ++j) {
vec_dis.push_back(dis[j] * -1);
}
auto binary_set = indexing->Serialize(conf);
void* c_load_index_info = nullptr;
status = NewLoadIndexInfo(&c_load_index_info);
assert(status.error_code == Success);
std::string index_type_key = "index_type";
std::string index_type_value = "IVF_PQ";
std::string index_mode_key = "index_mode";
std::string index_mode_value = "cpu";
std::string metric_type_key = "metric_type";
std::string metric_type_value = "L2";
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 100);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
auto query_dataset2 = milvus::knowhere::GenDataset(num_queries, DIM, query_ptr);
auto fuck2 = load_index_info->index;
auto result_on_index2 = fuck2->Query(query_dataset2, conf, nullptr);
auto ids2 = result_on_index2->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dis2 = result_on_index2->Get<float*>(milvus::knowhere::meta::DISTANCE);
int i = 1 + 1;
++i;
auto c_counter_field_data = CLoadFieldDataInfo{
101,
counter_col.data(),
N,
};
status = LoadFieldData(segment, c_counter_field_data);
assert(status.error_code == Success);
auto c_id_field_data = CLoadFieldDataInfo{
0,
counter_col.data(),
N,
};
status = LoadFieldData(segment, c_id_field_data);
assert(status.error_code == Success);
status = UpdateSealedSegmentIndex(segment, c_load_index_info);
assert(status.error_code == Success);
CQueryResult c_search_result_on_bigIndex;
auto res_after_load_index = Search(segment, plan, placeholderGroups.data(), &time, 1, &c_search_result_on_bigIndex);
assert(res_after_load_index.error_code == Success);
auto search_result_on_bigIndex = (*(QueryResult*)c_search_result_on_bigIndex);
for (int i = 0; i < num_queries; ++i) {
auto offset = i * K;
ASSERT_EQ(search_result_on_bigIndex.internal_seg_offsets_[offset], 420000 + i);
}
DeleteLoadIndexInfo(c_load_index_info);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(c_search_result_on_bigIndex);
DeleteCollection(collection);
DeleteSegment(segment);
}
......@@ -31,7 +31,7 @@ TEST(Sealed, without_predicate) {
auto dim = 16;
auto topK = 5;
auto metric_type = MetricType::METRIC_L2;
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto fake_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
schema->AddDebugField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
......@@ -106,7 +106,7 @@ TEST(Sealed, without_predicate) {
LoadIndexInfo load_info;
load_info.field_name = "fakevec";
load_info.field_id = 42;
load_info.field_id = fake_id.get();
load_info.index = indexing;
load_info.index_params["metric_type"] = "L2";
......@@ -128,7 +128,7 @@ TEST(Sealed, with_predicate) {
auto dim = 16;
auto topK = 5;
auto metric_type = MetricType::METRIC_L2;
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto fake_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
schema->AddDebugField("counter", DataType::INT64);
std::string dsl = R"({
"bool": {
......@@ -199,7 +199,7 @@ TEST(Sealed, with_predicate) {
LoadIndexInfo load_info;
load_info.field_name = "fakevec";
load_info.field_id = 42;
load_info.field_id = fake_id.get();
load_info.index = indexing;
load_info.index_params["metric_type"] = "L2";
......
......@@ -51,6 +51,7 @@ func TestGrpcService(t *testing.T) {
assert.Nil(t, err)
core.ProxyTimeTickChan = make(chan typeutil.Timestamp, 8)
core.DataNodeSegmentFlushCompletedChan = make(chan typeutil.UniqueID, 8)
timeTickArray := make([]typeutil.Timestamp, 0, 16)
core.SendTimeTick = func(ts typeutil.Timestamp) error {
......@@ -199,6 +200,8 @@ func TestGrpcService(t *testing.T) {
})
t.Run("describe collection", func(t *testing.T) {
collMeta, err := core.MetaTable.GetCollectionByName("testColl")
assert.Nil(t, err)
req := &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kDescribeCollection,
......@@ -213,6 +216,7 @@ func TestGrpcService(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, rsp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS)
assert.Equal(t, rsp.Schema.Name, "testColl")
assert.Equal(t, rsp.CollectionID, collMeta.ID)
})
t.Run("show collection", func(t *testing.T) {
......@@ -275,6 +279,8 @@ func TestGrpcService(t *testing.T) {
})
t.Run("show partition", func(t *testing.T) {
coll, err := core.MetaTable.GetCollectionByName("testColl")
assert.Nil(t, err)
req := &milvuspb.ShowPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kShowPartitions,
......@@ -284,6 +290,7 @@ func TestGrpcService(t *testing.T) {
},
DbName: "testDb",
CollectionName: "testColl",
CollectionID: coll.ID,
}
rsp, err := cli.ShowPartitions(req)
assert.Nil(t, err)
......@@ -312,7 +319,7 @@ func TestGrpcService(t *testing.T) {
req := &milvuspb.ShowSegmentRequest{
Base: &commonpb.MsgBase{
MsgType: 111, //TODO show segment request msg type
MsgType: commonpb.MsgType_kShowSegment,
MsgID: 111,
Timestamp: 111,
SourceID: 111,
......@@ -358,7 +365,7 @@ func TestGrpcService(t *testing.T) {
req := &milvuspb.DescribeSegmentRequest{
Base: &commonpb.MsgBase{
MsgType: 113, //TODO, describe segment request msg type
MsgType: commonpb.MsgType_kDescribeSegment,
MsgID: 113,
Timestamp: 113,
SourceID: 113,
......@@ -392,6 +399,47 @@ func TestGrpcService(t *testing.T) {
assert.Equal(t, rsp.IndexDescriptions[0].IndexName, cms.Params.DefaultIndexName)
})
t.Run("flush segment", func(t *testing.T) {
coll, err := core.MetaTable.GetCollectionByName("testColl")
assert.Nil(t, err)
partID := coll.PartitionIDs[1]
part, err := core.MetaTable.GetPartitionByID(partID)
assert.Nil(t, err)
assert.Equal(t, len(part.SegmentIDs), 1)
seg := &datapb.SegmentInfo{
SegmentID: 1001,
CollectionID: coll.ID,
PartitionID: part.PartitionID,
}
core.DataServiceSegmentChan <- seg
time.Sleep(time.Millisecond * 100)
part, err = core.MetaTable.GetPartitionByID(partID)
assert.Nil(t, err)
assert.Equal(t, len(part.SegmentIDs), 2)
core.DataNodeSegmentFlushCompletedChan <- 1001
time.Sleep(time.Millisecond * 100)
req := &milvuspb.DescribeIndexRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_kDescribeIndex,
MsgID: 115,
Timestamp: 115,
SourceID: 115,
},
DbName: "",
CollectionName: "testColl",
FieldName: "vector",
IndexName: "",
}
rsp, err := cli.DescribeIndex(req)
assert.Nil(t, err)
assert.Equal(t, rsp.Status.ErrorCode, commonpb.ErrorCode_SUCCESS)
assert.Equal(t, len(rsp.IndexDescriptions), 2)
assert.Equal(t, rsp.IndexDescriptions[0].IndexName, cms.Params.DefaultIndexName)
assert.Equal(t, rsp.IndexDescriptions[1].IndexName, "index_100")
})
t.Run("drop partition", func(t *testing.T) {
req := &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{
......
......@@ -4,6 +4,7 @@ import (
"context"
"log"
"math/rand"
"strconv"
"sync"
"sync/atomic"
"time"
......@@ -138,6 +139,9 @@ type Core struct {
//setMsgStreams segment channel, receive segment info from data service, if master create segment
DataServiceSegmentChan chan *datapb.SegmentInfo
//setMsgStreams ,if segment flush completed, data node would put segment id into msg stream
DataNodeSegmentFlushCompletedChan chan typeutil.UniqueID
//TODO,get binlog file path from data service,
GetBinlogFilePathsFromDataServiceReq func(segID typeutil.UniqueID, fieldID typeutil.UniqueID) ([]string, error)
......@@ -227,6 +231,9 @@ func (c *Core) checkInit() error {
if c.indexTaskQueue == nil {
return errors.Errorf("indexTaskQueue is nil")
}
if c.DataNodeSegmentFlushCompletedChan == nil {
return errors.Errorf("DataNodeSegmentFlushCompletedChan is nil")
}
log.Printf("master node id = %d\n", Params.NodeID)
return nil
}
......@@ -311,7 +318,7 @@ func (c *Core) startCreateIndexLoop() {
return
case t, ok := <-c.indexTaskQueue:
if !ok {
log.Printf("index task chan is close, exit loop")
log.Printf("index task chan has closed, exit loop")
return
}
if err := t.BuildIndex(); err != nil {
......@@ -321,6 +328,34 @@ func (c *Core) startCreateIndexLoop() {
}
}
func (c *Core) startSegmentFlushCompletedLoop() {
for {
select {
case <-c.ctx.Done():
log.Printf("close segment flush completed loop")
return
case seg, ok := <-c.DataNodeSegmentFlushCompletedChan:
if !ok {
log.Printf("data node segment flush completed chan has colsed, exit loop")
}
fields, err := c.MetaTable.GetSegmentVectorFields(seg)
if err != nil {
log.Printf("GetSegmentVectorFields, error = %s ", err.Error())
}
for _, f := range fields {
t := &CreateIndexTask{
core: c,
segmentID: seg,
indexName: "index_" + strconv.FormatInt(f.FieldID, 10),
fieldSchema: f,
indexParams: nil,
}
c.indexTaskQueue <- t
}
}
}
}
func (c *Core) setMsgStreams() error {
//proxy time tick stream,
proxyTimeTickStream := pulsarms.NewPulsarMsgStream(c.ctx, 1024)
......@@ -542,6 +577,7 @@ func (c *Core) Start() error {
go c.startTimeTickLoop()
go c.startDataServiceSegmentLoop()
go c.startCreateIndexLoop()
go c.startSegmentFlushCompletedLoop()
c.stateCode.Store(internalpb2.StateCode_HEALTHY)
})
return nil
......
......@@ -182,6 +182,10 @@ func (mt *metaTable) AddCollection(coll *pb.CollectionInfo, part *pb.PartitionIn
if len(coll.PartitionIDs) != 0 {
return errors.Errorf("partitions should be empty when creating collection")
}
if _, ok := mt.collName2ID[coll.Schema.Name]; ok {
return errors.Errorf("collection %s exist", coll.Schema.Name)
}
coll.PartitionIDs = append(coll.PartitionIDs, part.PartitionID)
mt.collID2Meta[coll.ID] = *coll
mt.collName2ID[coll.Schema.Name] = coll.ID
......@@ -405,7 +409,7 @@ func (mt *metaTable) DeletePartition(collID typeutil.UniqueID, partitionName str
for _, segID := range partMeta.SegmentIDs {
segIndexMeta, ok := mt.segID2IndexMeta[segID]
if !ok {
log.Printf("segment id = %d not exist", segID)
log.Printf("segment id = %d has no index meta", segID)
continue
}
delete(mt.segID2IndexMeta, segID)
......@@ -630,6 +634,27 @@ func (mt *metaTable) GetNotIndexedSegments(collName string, fieldName string, in
return rstID, fieldSchema, nil
}
func (mt *metaTable) GetSegmentVectorFields(segID typeutil.UniqueID) ([]*schemapb.FieldSchema, error) {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()
collID, ok := mt.segID2CollID[segID]
if !ok {
return nil, errors.Errorf("segment id %d not belong to any collection", segID)
}
collMeta, ok := mt.collID2Meta[collID]
if !ok {
return nil, errors.Errorf("segment id %d not belong to any collection which has dropped", segID)
}
rst := make([]*schemapb.FieldSchema, 0, 2)
for _, f := range collMeta.Schema.Fields {
if f.DataType == schemapb.DataType_VECTOR_BINARY || f.DataType == schemapb.DataType_VECTOR_FLOAT {
field := proto.Clone(f)
rst = append(rst, field.(*schemapb.FieldSchema))
}
}
return rst, nil
}
func (mt *metaTable) GetIndexByName(collName string, fieldName string, indexName string) ([]pb.IndexInfo, error) {
mt.ddLock.RLock()
mt.ddLock.RUnlock()
......
......@@ -212,6 +212,7 @@ func (t *DescribeCollectionReqTask) Execute() error {
return err
}
t.Rsp.Schema = proto.Clone(coll.Schema).(*schemapb.CollectionSchema)
t.Rsp.CollectionID = coll.ID
var newField []*schemapb.FieldSchema
for _, field := range t.Rsp.Schema.Fields {
if field.FieldID >= StartOfUserFieldID {
......@@ -368,10 +369,13 @@ func (t *ShowPartitionReqTask) Ts() (typeutil.Timestamp, error) {
}
func (t *ShowPartitionReqTask) Execute() error {
coll, err := t.core.MetaTable.GetCollectionByName(t.Req.CollectionName)
coll, err := t.core.MetaTable.GetCollectionByID(t.Req.CollectionID)
if err != nil {
return err
}
if coll.Schema.Name != t.Req.CollectionName {
return errors.Errorf("collection %s not exist", t.Req.CollectionName)
}
for _, partID := range coll.PartitionIDs {
partMeta, err := t.core.MetaTable.GetPartitionByID(partID)
if err != nil {
......@@ -477,6 +481,9 @@ func (t *CreateIndexReqTask) Execute() error {
if err != nil {
return err
}
if field.DataType != schemapb.DataType_VECTOR_FLOAT && field.DataType != schemapb.DataType_VECTOR_BINARY {
return errors.Errorf("field name = %s, data type = %s", t.Req.FieldName, schemapb.DataType_name[int32(field.DataType)])
}
for _, seg := range segIDs {
task := CreateIndexTask{
core: t.core,
......
......@@ -46,6 +46,7 @@ message DescribeCollectionRequest {
message DescribeCollectionResponse {
common.Status status = 1;
schema.CollectionSchema schema = 2;
int64 collectionID = 3;
}
message LoadCollectionRequest {
......
This diff is collapsed.
......@@ -2,18 +2,23 @@ package querynode
import (
"context"
"encoding/binary"
"fmt"
"math"
"math/rand"
"path"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/indexnode"
minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2"
"github.com/zilliztech/milvus-distributed/internal/proto/milvuspb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/storage"
)
......@@ -209,11 +214,11 @@ func generateIndex(segmentID UniqueID) ([]string, indexParam, error) {
return indexPaths, indexParams, nil
}
func TestSegmentManager_load_and_release(t *testing.T) {
func TestSegmentManager_load_release_and_search(t *testing.T) {
collectionID := UniqueID(0)
partitionID := UniqueID(1)
segmentID := UniqueID(2)
fieldIDs := []int64{101}
fieldIDs := []int64{0, 101}
node := newQueryNodeMock()
defer node.Stop()
......@@ -236,7 +241,7 @@ func TestSegmentManager_load_and_release(t *testing.T) {
assert.NoError(t, err)
fieldsMap := node.segManager.filterOutNeedlessFields(paths, srcFieldIDs, fieldIDs)
assert.Equal(t, len(fieldsMap), 1)
assert.Equal(t, len(fieldsMap), 2)
err = node.segManager.loadSegmentFieldsData(segmentID, fieldsMap)
assert.NoError(t, err)
......@@ -247,5 +252,50 @@ func TestSegmentManager_load_and_release(t *testing.T) {
err = node.segManager.loadIndex(segmentID, indexPaths, indexParams)
assert.NoError(t, err)
// do search
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
const DIM = 16
var searchRawData []byte
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
searchRawData = append(searchRawData, buf...)
}
placeholderValue := milvuspb.PlaceholderValue{
Tag: "$0",
Type: milvuspb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRawData},
}
placeholderGroup := milvuspb.PlaceholderGroup{
Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
}
placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup)
assert.NoError(t, err)
searchTimestamp := Timestamp(1020)
collection, err := node.replica.getCollectionByID(collectionID)
assert.NoError(t, err)
plan, err := createPlan(*collection, dslString)
assert.NoError(t, err)
holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob)
assert.NoError(t, err)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups = append(placeholderGroups, holder)
// wait for segment building index
time.Sleep(3 * time.Second)
segment, err := node.replica.getSegmentByID(segmentID)
assert.NoError(t, err)
_, err = segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp})
assert.Nil(t, err)
plan.delete()
holder.delete()
<-ctx.Done()
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment