Skip to content
Snippets Groups Projects
test_search.py 75.18 KiB
import time
import pdb
import copy
import logging
from multiprocessing import Pool, Process
import pytest
import numpy as np

from milvus import DataType
from .utils import *
from .constants import *

uid = "test_search"
nq = 1
epsilon = 0.001
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
search_param = {"nprobe": 1}

entity = gen_entities(1, is_normal=True)
entities = gen_entities(default_nb, is_normal=True)
raw_vectors, binary_entities = gen_binary_entities(default_nb)
default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq)
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k,
                                                                    nq)


def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True):
    '''
    Generate entities and add it in collection
    '''
    global entities
    if nb == 1200:
        insert_entities = entities
    else:
        insert_entities = gen_entities(nb, is_normal=True)
    if partition_tags is None:
        if auto_id:
            ids = connect.insert(collection, insert_entities)
        else:
            ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
    else:
        if auto_id:
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
        else:
            ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
    # connect.flush([collection])
    return insert_entities, ids


def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None):
    '''
    Generate entities and add it in collection
    '''
    ids = []
    global binary_entities
    global raw_vectors
    if nb == 1200:
        insert_entities = binary_entities
        insert_raw_vectors = raw_vectors
    else:
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
    if insert is True:
        if partition_tags is None:
            ids = connect.insert(collection, insert_entities)
        else:
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
        connect.flush([collection])
    return insert_raw_vectors, insert_entities, ids


class TestSearchBase:
    """
    generate valid create_index params
    """

    @pytest.fixture(
        scope="function",
        params=gen_index()
    )
    def get_index(self, request, connect):
        # if str(connect._cmd("mode")) == "CPU":
        if request.param["index_type"] in index_cpu_not_support():
            pytest.skip("sq8h not support in CPU mode")
        return request.param

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        import copy
        # if str(connect._cmd("mode")) == "CPU":
        if request.param["index_type"] in index_cpu_not_support():
            pytest.skip("sq8h not support in CPU mode")
        return copy.deepcopy(request.param)

    @pytest.fixture(
        scope="function",
        params=gen_binary_index()
    )
    def get_jaccard_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] in binary_support():
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    @pytest.fixture(
        scope="function",
        params=gen_binary_index()
    )
    def get_hamming_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] in binary_support():
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    @pytest.fixture(
        scope="function",
        params=gen_binary_index()
    )
    def get_structure_index(self, request, connect):
        logging.getLogger().info(request.param)
        if request.param["index_type"] == "FLAT":
            return request.param
        else:
            pytest.skip("Skip index Temporary")

    """
    generate top-k params
    """

    @pytest.fixture(
        scope="function",
        params=[1, 10]
    )
    def get_top_k(self, request):
        yield request.param

    @pytest.fixture(
        scope="function",
        params=[1, 10, 1100]
    )
    def get_nq(self, request):
        yield request.param

    # PASS
    @pytest.mark.skip("r0.3-test")
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq
        entities, ids = init_data(connect, collection)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
        if top_k <= max_top_k:
            res = connect.search(collection, query)
            assert len(res[0]) == top_k
            assert res[0]._distances[0] <= epsilon
            assert check_id_result(res[0], ids[0])
        else:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)

    # milvus-distributed dose not have the limitation of top_k
    def test_search_flat_top_k(self, connect, collection, get_nq):
        '''
        target: test basic search function, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = 16385
        nq = get_nq
        entities, ids = init_data(connect, collection)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
        if top_k <= max_top_k:
            res = connect.search(collection, query)
            assert len(res[0]) == top_k
            assert res[0]._distances[0] <= epsilon
            assert check_id_result(res[0], ids[0])
        else:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)

    # TODO: reopen after we supporting targetEntry
    @pytest.mark.skip("search_field")
    def test_search_field(self, connect, collection, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq
        entities, ids = init_data(connect, collection)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
        if top_k <= max_top_k:
            res = connect.search(collection, query, fields=["float_vector"])
            assert len(res[0]) == top_k
            assert res[0]._distances[0] <= epsilon
            assert check_id_result(res[0], ids[0])
            res = connect.search(collection, query, fields=["float"])
            for i in range(nq):
                assert entities[1]["values"][:nq][i] in [r.entity.get('float') for r in res[i]]
        else:
            with pytest.raises(Exception):
                connect.search(collection, query)

    @pytest.mark.skip("search_after_delete")
    def test_search_after_delete(self, connect, collection, get_top_k, get_nq):
        '''
        target: test basic search function before and after deletion, all the search params is
                corrent, change top-k value.
                check issue <a href="https://github.com/milvus-io/milvus/issues/4200">#4200</a>
        method: search with the given vectors, check the result
        expected: the deleted entities do not exist in the result.
        '''
        top_k = get_top_k
        nq = get_nq

        entities, ids = init_data(connect, collection, nb=10000)
        first_int64_value = entities[0]["values"][0]
        first_vector = entities[2]["values"][0]

        search_param = get_search_param("FLAT")
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
        vecs[:] = []
        vecs.append(first_vector)

        res = None
        if top_k > max_top_k:
            with pytest.raises(Exception):
                connect.search(collection, query, fields=['int64'])
            pytest.skip("top_k value is larger than max_topp_k")
        else:
            res = connect.search(collection, query, fields=['int64'])
            assert len(res) == 1
            assert len(res[0]) >= top_k
            assert res[0][0].id == ids[0]
            assert res[0][0].entity.get("int64") == first_int64_value
            assert res[0]._distances[0] < epsilon
            assert check_id_result(res[0], ids[0])

        connect.delete_entity_by_id(collection, ids[:1])
        connect.flush([collection])

        res2 = connect.search(collection, query, fields=['int64'])
        assert len(res2) == 1
        assert len(res2[0]) >= top_k
        assert res2[0][0].id != ids[0]
        if top_k > 1:
            assert res2[0][0].id == res[0][1].id
            assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")

    # Pass
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq

        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        entities, ids = init_data(connect, collection)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert len(res) == nq
            assert len(res[0]) >= top_k
            assert res[0]._distances[0] < epsilon
            assert check_id_result(res[0], ids[0])

    # DOG: TODO INVALID TYPE UNKNOWN
    @pytest.mark.skip("search_after_index_different_metric_type")
    def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index):
        '''
        target: test search with different metric_type
        method: build index with L2, and search using IP
        expected: search ok
        '''
        search_metric_type = "IP"
        index_type = get_simple_index["index_type"]
        entities, ids = init_data(connect, collection)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type,
                                        search_params=search_param)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # pass
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: add vectors into collection, search with the given vectors, check the result
        expected: the length of the result is top_k, search collection with partition tag return empty
        '''
        top_k = get_top_k
        nq = get_nq

        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, default_tag)
        entities, ids = init_data(connect, collection)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert len(res) == nq
            assert len(res[0]) >= top_k
            assert res[0]._distances[0] < epsilon
            assert check_id_result(res[0], ids[0])
            res = connect.search(collection, query, partition_tags=[default_tag])
            assert len(res) == nq

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq

        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, default_tag)
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
        for tags in [[default_tag], [default_tag, "new_tag"]]:
            if top_k > max_top_k:
                with pytest.raises(Exception) as e:
                    res = connect.search(collection, query, partition_tags=tags)
            else:
                res = connect.search(collection, query, partition_tags=tags)
                assert len(res) == nq
                assert len(res[0]) >= top_k
                assert res[0]._distances[0] < epsilon
                assert check_id_result(res[0], ids[0])

    @pytest.mark.skip("search_index_partition_C")
    @pytest.mark.level(2)
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search with the given vectors and tag (tag name not existed in collection), check the result
        expected: error raised
        '''
        top_k = get_top_k
        nq = get_nq
        entities, ids = init_data(connect, collection)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query, partition_tags=["new_tag"])
        else:
            res = connect.search(collection, query, partition_tags=["new_tag"])
            assert len(res) == nq
            assert len(res[0]) == 0

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search collection with the given vectors and tags, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = 2
        new_tag = "new_tag"
        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, default_tag)
        connect.create_partition(collection, new_tag)
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert check_id_result(res[0], ids[0])
            assert not check_id_result(res[1], new_ids[0])
            assert res[0]._distances[0] < epsilon
            assert res[1]._distances[0] < epsilon
            res = connect.search(collection, query, partition_tags=["new_tag"])
            assert res[0]._distances[0] > epsilon
            assert res[1]._distances[0] > epsilon

    # Pass
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search collection with the given vectors and tags, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = 2
        tag = "tag"
        new_tag = "new_tag"
        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, tag)
        connect.create_partition(collection, new_tag)
        entities, ids = init_data(connect, collection, partition_tags=tag)
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
            assert not check_id_result(res[0], ids[0])
            assert res[0]._distances[0] < epsilon
            assert res[1]._distances[0] < epsilon
            res = connect.search(collection, query, partition_tags=["new(.*)"])
            assert res[0]._distances[0] < epsilon
            assert res[1]._distances[0] < epsilon

    # pass
    # test for ip metric
    #
    # TODO: reopen after we supporting ip flat
    # DOG: TODO REDUCE
    @pytest.mark.skip("search_ip_flat")
    @pytest.mark.level(2)
    def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, change top-k value
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq
        entities, ids = init_data(connect, collection)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
        if top_k <= max_top_k:
            res = connect.search(collection, query)
            assert len(res[0]) == top_k
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
            assert check_id_result(res[0], ids[0])
        else:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search with the given vectors, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = get_nq

        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        entities, ids = init_data(connect, collection)
        get_simple_index["metric_type"] = "IP"
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert len(res) == nq
            assert len(res[0]) >= top_k
            assert check_id_result(res[0], ids[0])
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])

    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: add vectors into collection, search with the given vectors, check the result
        expected: the length of the result is top_k, search collection with partition tag return empty
        '''
        top_k = get_top_k
        nq = get_nq
        metric_type = "IP"
        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, default_tag)
        entities, ids = init_data(connect, collection)
        get_simple_index["metric_type"] = metric_type
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type,
                                        search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert len(res) == nq
            assert len(res[0]) >= top_k
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
            assert check_id_result(res[0], ids[0])
            res = connect.search(collection, query, partition_tags=[default_tag])
            assert len(res) == nq

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
        '''
        target: test basic search function, all the search params is corrent, test all index params, and build
        method: search collection with the given vectors and tags, check the result
        expected: the length of the result is top_k
        '''
        top_k = get_top_k
        nq = 2
        metric_type = "IP"
        new_tag = "new_tag"
        index_type = get_simple_index["index_type"]
        if index_type in skip_pq():
            pytest.skip("Skip PQ")
        connect.create_partition(collection, default_tag)
        connect.create_partition(collection, new_tag)
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
        get_simple_index["metric_type"] = metric_type
        connect.create_index(collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
        if top_k > max_top_k:
            with pytest.raises(Exception) as e:
                res = connect.search(collection, query)
        else:
            res = connect.search(collection, query)
            assert check_id_result(res[0], ids[0])
            assert not check_id_result(res[1], new_ids[0])
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
            res = connect.search(collection, query, partition_tags=["new_tag"])
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
            # TODO:
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])

    # PASS
    @pytest.mark.level(2)
    def test_search_without_connect(self, dis_connect, collection):
        '''
        target: test search vectors without connection
        method: use dis connected instance, call search method and check if search successfully
        expected: raise exception
        '''
        with pytest.raises(Exception) as e:
            res = dis_connect.search(collection, default_query)

    # PASS
    # TODO: proxy or SDK checks if collection exists
    def test_search_collection_name_not_existed(self, connect):
        '''
        target: search collection not existed
        method: search with the random collection_name, which is not in db
        expected: status not ok
        '''
        collection_name = gen_unique_str(uid)
        with pytest.raises(Exception) as e:
            res = connect.search(collection_name, default_query)

    # PASS
    @pytest.mark.skip("r0.3-test")
    def test_search_distance_l2(self, connect, collection):
        '''
        target: search collection, and check the result: distance
        method: compare the return distance value with value computed with Euclidean
        expected: the return distance equals to the computed value
        '''
        nq = 2
        search_param = {"nprobe": 1}
        entities, ids = init_data(connect, collection, nb=nq)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True,
                                        search_params=search_param)
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq,
                                                      search_params=search_param)
        distance_0 = l2(vecs[0], inside_vecs[0])
        distance_1 = l2(vecs[0], inside_vecs[1])
        res = connect.search(collection, query)
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])

    # Pass
    @pytest.mark.skip("r0.3-test")
    def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
        '''
        target: search collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        index_type = get_simple_index["index_type"]
        nq = 2
        entities, ids = init_data(connect, id_collection, auto_id=False)
        connect.create_index(id_collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True,
                                        search_params=search_param)
        inside_vecs = entities[-1]["values"]
        min_distance = 1.0
        min_id = None
        for i in range(default_nb):
            tmp_dis = l2(vecs[0], inside_vecs[i])
            if min_distance > tmp_dis:
                min_distance = tmp_dis
                min_id = ids[i]
        res = connect.search(id_collection, query)
        tmp_epsilon = epsilon
        check_id_result(res[0], min_id)
        # if index_type in ["ANNOY", "IVF_PQ"]:
        #     tmp_epsilon = 0.1
        # TODO:
        # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon

    # DOG: TODO REDUCE
    # TODO: reopen after we supporting ip flat
    @pytest.mark.skip("search_distance_ip")
    @pytest.mark.level(2)
    def test_search_distance_ip(self, connect, collection):
        '''
        target: search collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nq = 2
        metirc_type = "IP"
        search_param = {"nprobe": 1}
        entities, ids = init_data(connect, collection, nb=nq)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True,
                                        metric_type=metirc_type,
                                        search_params=search_param)
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq,
                                                      search_params=search_param)
        distance_0 = ip(vecs[0], inside_vecs[0])
        distance_1 = ip(vecs[0], inside_vecs[1])
        res = connect.search(collection, query)
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon

    # Pass
    @pytest.mark.skip("r0.3-test")
    def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
        '''
        target: search collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        index_type = get_simple_index["index_type"]
        nq = 2
        metirc_type = "IP"
        entities, ids = init_data(connect, id_collection, auto_id=False)
        get_simple_index["metric_type"] = metirc_type
        connect.create_index(id_collection, field_name, get_simple_index)
        search_param = get_search_param(index_type)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True,
                                        metric_type=metirc_type,
                                        search_params=search_param)
        inside_vecs = entities[-1]["values"]
        max_distance = 0
        max_id = None
        for i in range(default_nb):
            tmp_dis = ip(vecs[0], inside_vecs[i])
            if max_distance < tmp_dis:
                max_distance = tmp_dis
                max_id = ids[i]
        res = connect.search(id_collection, query)
        tmp_epsilon = epsilon
        check_id_result(res[0], max_id)
        # if index_type in ["ANNOY", "IVF_PQ"]:
        #     tmp_epsilon = 0.1
        # TODO:
        # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon

    # PASS
    @pytest.mark.skip("r0.3-test")
    def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with L2
        expected: the return distance equals to the computed value
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD")
        res = connect.search(binary_collection, query)
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon

    # DOG: TODO INVALID TYPE
    @pytest.mark.skip("search_distance_jaccard_flat_index_L2")
    @pytest.mark.level(2)
    def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with L2
        expected: throw error of mismatched metric type
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2")
        with pytest.raises(Exception) as e:
            res = connect.search(binary_collection, query)

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_hamming_flat_index(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING")
        res = connect.search(binary_collection, query)
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_substructure_flat_index(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq,
                                        metric_type="SUBSTRUCTURE")
        res = connect.search(binary_collection, query)
        assert len(res[0]) == 0

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with SUB
        expected: the return distance equals to the computed value
        '''
        top_k = 3
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
        query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUBSTRUCTURE",
                                        replace_vecs=query_vecs)
        res = connect.search(binary_collection, query)
        assert res[0][0].distance <= epsilon
        assert res[0][0].id == ids[0]
        assert res[1][0].distance <= epsilon
        assert res[1][0].id == ids[1]

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq,
                                        metric_type="SUPERSTRUCTURE")
        res = connect.search(binary_collection, query)
        assert len(res[0]) == 0

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with SUPER
        expected: the return distance equals to the computed value
        '''
        top_k = 3
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
        query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUPERSTRUCTURE",
                                        replace_vecs=query_vecs)
        res = connect.search(binary_collection, query)
        assert len(res[0]) == 2
        assert len(res[1]) == 2
        assert res[0][0].id in ids
        assert res[0][0].distance <= epsilon
        assert res[1][0].id in ids
        assert res[1][0].distance <= epsilon

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
        '''
        target: search binary_collection, and check the result: distance
        method: compare the return distance value with value computed with Inner product
        expected: the return distance equals to the computed value
        '''
        nq = 1
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO")
        res = connect.search(binary_collection, query)
        assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    @pytest.mark.timeout(30)
    def test_search_concurrent_multithreads(self, connect, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        threads_num = 4
        threads = []
        collection = gen_unique_str(uid)
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
        # create collection
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
        milvus.create_collection(collection, default_fields)
        entities, ids = init_data(milvus, collection)

        def search(milvus):
            res = milvus.search(collection, default_query)
            assert len(res) == 1
            assert res[0]._entities[0].id in ids
            assert res[0]._distances[0] < epsilon

        for i in range(threads_num):
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
            t = MilvusTestThread(target=search, args=(milvus,))
            threads.append(t)
            t.start()
            time.sleep(0.2)
        for t in threads:
            t.join()

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    @pytest.mark.timeout(30)
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        threads_num = 4
        threads = []
        collection = gen_unique_str(uid)
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
        # create collection
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
        milvus.create_collection(collection, default_fields)
        entities, ids = init_data(milvus, collection)

        def search(milvus):
            res = milvus.search(collection, default_query)
            assert len(res) == 1
            assert res[0]._entities[0].id in ids
            assert res[0]._distances[0] < epsilon

        for i in range(threads_num):
            t = MilvusTestThread(target=search, args=(milvus,))
            threads.append(t)
            t.start()
            time.sleep(0.2)
        for t in threads:
            t.join()

    # PASS
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_multi_collections(self, connect, args):
        '''
        target: test search multi collections of L2
        method: add vectors into 10 collections, and search
        expected: search status ok, the length of result
        '''
        num = 10
        top_k = 10
        nq = 20
        for i in range(num):
            collection = gen_unique_str(uid + str(i))
            connect.create_collection(collection, default_fields)
            entities, ids = init_data(connect, collection)
            assert len(ids) == default_nb
            query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
            res = connect.search(collection, query)
            assert len(res) == nq
            for i in range(nq):
                assert check_id_result(res[i], ids[i])
                assert res[i]._distances[0] < epsilon
                assert res[i]._distances[1] > epsilon

    @pytest.mark.skip("test_query_entities_with_field_less_than_top_k")
    def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
        """
        target: test search with field, and let return entities less than topk
        method: insert entities and build ivf_ index, and search with field, n_probe=1
        expected:
        """
        entities, ids = init_data(connect, id_collection, auto_id=False)
        simple_index = {"index_type": "IVF_FLAT", "params": {"nlist": 200}, "metric_type": "L2"}
        connect.create_index(id_collection, field_name, simple_index)
        # logging.getLogger().info(connect.get_collection_info(id_collection))
        top_k = 300
        default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 1})
        expr = {"must": [gen_default_vector_expr(default_query)]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(id_collection, query, fields=["int64"])
        assert len(res) == nq
        for r in res[0]:
            assert getattr(r.entity, "int64") == getattr(r.entity, "id")


@pytest.mark.skip("r0.3-test")
class TestSearchDSL(object):
    """
    ******************************************************************
    #  The following cases are used to build invalid query expr
    ******************************************************************
    """

    # PASS
    def test_query_no_must(self, connect, collection):
        '''
        method: build query without must expr
        expected: error raised
        '''
        # entities, ids = init_data(connect, collection)
        query = update_query_expr(default_query, keep_old=False)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_no_vector_term_only(self, connect, collection):
        '''
        method: build query without vector only term
        expected: error raised
        '''
        # entities, ids = init_data(connect, collection)
        expr = {
            "must": [gen_default_term_expr]
        }
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_no_vector_range_only(self, connect, collection):
        '''
        method: build query without vector only range
        expected: error raised
        '''
        # entities, ids = init_data(connect, collection)
        expr = {
            "must": [gen_default_range_expr]
        }
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_vector_only(self, connect, collection):
        entities, ids = init_data(connect, collection)
        res = connect.search(collection, default_query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # PASS
    def test_query_wrong_format(self, connect, collection):
        '''
        method: build query without must expr, with wrong expr name
        expected: error raised
        '''
        # entities, ids = init_data(connect, collection)
        expr = {
            "must1": [gen_default_term_expr]
        }
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_empty(self, connect, collection):
        '''
        method: search with empty query
        expected: error raised
        '''
        query = {}
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    """
    ******************************************************************
    #  The following cases are used to build valid query expr
    ******************************************************************
    """

    # PASS
    @pytest.mark.level(2)
    def test_query_term_value_not_in(self, connect, collection):
        '''
        method: build query with vector and term expr, with no term can be filtered
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {
            "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0
        # TODO:

    # PASS
    @pytest.mark.level(2)
    def test_query_term_value_all_in(self, connect, collection):
        '''
        method: build query with vector and term expr, with all term can be filtered
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 1
        # TODO:

    # PASS
    @pytest.mark.level(2)
    def test_query_term_values_not_in(self, connect, collection):
        '''
        method: build query with vector and term expr, with no term can be filtered
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {"must": [gen_default_vector_expr(default_query),
                         gen_default_term_expr(values=[i for i in range(100000, 100010)])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0
        # TODO:

    # PASS
    def test_query_term_values_all_in(self, connect, collection):
        '''
        method: build query with vector and term expr, with all term can be filtered
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k
        limit = default_nb // 2
        for i in range(nq):
            for result in res[i]:
                logging.getLogger().info(result.id)
                assert result.id in ids[:limit]
        # TODO:

    # PASS
    def test_query_term_values_parts_in(self, connect, collection):
        '''
        method: build query with vector and term expr, with parts of term can be filtered
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {"must": [gen_default_vector_expr(default_query),
                         gen_default_term_expr(
                             values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k
        # TODO:

    # PASS
    @pytest.mark.level(2)
    def test_query_term_values_repeat(self, connect, collection):
        '''
        method: build query with vector and term expr, with the same values
        expected: filter pass
        '''
        entities, ids = init_data(connect, collection)
        expr = {
            "must": [gen_default_vector_expr(default_query),
                     gen_default_term_expr(values=[1 for i in range(1, default_nb)])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 1
        # TODO:

    # DOG: BUG, please fix
    @pytest.mark.skip("query_term_value_empty")
    def test_query_term_value_empty(self, connect, collection):
        '''
        method: build query with term value empty
        expected: return null
        '''
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    # PASS
    def test_query_complex_dsl(self, connect, collection):
        '''
        method: query with complicated dsl
        expected: no error raised
        '''
        expr = {"must": [
            {"must": [{"should": [gen_default_term_expr(values=[1]), gen_default_range_expr()]}]},
            {"must": [gen_default_vector_expr(default_query)]}
            ]}
        logging.getLogger().info(expr)
        query = update_query_expr(default_query, expr=expr)
        logging.getLogger().info(query)
        res = connect.search(collection, query)
        logging.getLogger().info(res)

    """
    ******************************************************************
    #  The following cases are used to build invalid term query expr
    ******************************************************************
    """

    # PASS
    @pytest.mark.level(2)
    def test_query_term_key_error(self, connect, collection):
        '''
        method: build query with term key error
        expected: Exception raised
        '''
        expr = {"must": [gen_default_vector_expr(default_query),
                         gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    @pytest.fixture(
        scope="function",
        params=gen_invalid_term()
    )
    def get_invalid_term(self, request):
        return request.param

    # PASS
    @pytest.mark.level(2)
    def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
        '''
        method: build query with wrong format term
        expected: Exception raised
        '''
        entities, ids = init_data(connect, collection)
        term = get_invalid_term
        expr = {"must": [gen_default_vector_expr(default_query), term]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # DOG: PLEASE IMPLEMENT connect.count_entities
    # TODO
    @pytest.mark.skip("query_term_field_named_term")
    @pytest.mark.level(2)
    def test_query_term_field_named_term(self, connect, collection):
        '''
        method: build query with field named "term"
        expected: error raised
        '''
        term_fields = add_field_default(default_fields, field_name="term")
        collection_term = gen_unique_str("term")
        connect.create_collection(collection_term, term_fields)
        term_entities = add_field(entities, field_name="term")
        ids = connect.insert(collection_term, term_entities)
        assert len(ids) == default_nb
        connect.flush([collection_term])
        count = connect.count_entities(collection_term) # count_entities is not impelmented
        assert count == default_nb                      # removing these two lines, this test passed
        term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
        expr = {"must": [gen_default_vector_expr(default_query),
                         term_param]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection_term, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k
        connect.drop_collection(collection_term)

    # PASS
    @pytest.mark.level(2)
    def test_query_term_one_field_not_existed(self, connect, collection):
        '''
        method: build query with two fields term, one of it not existed
        expected: exception raised
        '''
        entities, ids = init_data(connect, collection)
        term = gen_default_term_expr()
        term["term"].update({"a": [0]})
        expr = {"must": [gen_default_vector_expr(default_query), term]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    """
    ******************************************************************
    #  The following cases are used to build valid range query expr
    ******************************************************************
    """

    # PASS
    def test_query_range_key_error(self, connect, collection):
        '''
        method: build query with range key error
        expected: Exception raised
        '''
        range = gen_default_range_expr(keyword="ranges")
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    @pytest.fixture(
        scope="function",
        params=gen_invalid_range()
    )
    def get_invalid_range(self, request):
        return request.param

    # PASS
    @pytest.mark.level(2)
    def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
        '''
        method: build query with wrong format range
        expected: Exception raised
        '''
        entities, ids = init_data(connect, collection)
        range = get_invalid_range
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    @pytest.mark.level(2)
    def test_query_range_string_ranges(self, connect, collection):
        '''
        method: build query with invalid ranges
        expected: raise Exception
        '''
        entities, ids = init_data(connect, collection)
        ranges = {"GT": "0", "LT": "1000"}
        range = gen_default_range_expr(ranges=ranges)
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    @pytest.mark.level(2)
    def test_query_range_invalid_ranges(self, connect, collection):
        '''
        method: build query with invalid ranges
        expected: 0
        '''
        entities, ids = init_data(connect, collection)
        ranges = {"GT": default_nb, "LT": 0}
        range = gen_default_range_expr(ranges=ranges)
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception):
            res = connect.search(collection, query)
            assert len(res[0]) == 0

    @pytest.fixture(
        scope="function",
        params=gen_valid_ranges()
    )
    def get_valid_ranges(self, request):
        return request.param

    # PASS
    @pytest.mark.level(2)
    def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
        '''
        method: build query with valid ranges
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        ranges = get_valid_ranges
        range = gen_default_range_expr(ranges=ranges)
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # PASS
    def test_query_range_one_field_not_existed(self, connect, collection):
        '''
        method: build query with two fields ranges, one of fields not existed
        expected: exception raised
        '''
        entities, ids = init_data(connect, collection)
        range = gen_default_range_expr()
        range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}})
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    """
    ************************************************************************
    #  The following cases are used to build query expr multi range and term
    ************************************************************************
    """

    # PASS
    def test_query_multi_term_has_common(self, connect, collection):
        '''
        method: build query with multi term with same field, and values has common
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term_first = gen_default_term_expr()
        term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)])
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # PASS
    @pytest.mark.level(2)
    def test_query_multi_term_no_common(self, connect, collection):
        '''
         method: build query with multi range with same field, and ranges no common
         expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term_first = gen_default_term_expr()
        term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    # PASS
    def test_query_multi_term_different_fields(self, connect, collection):
        '''
         method: build query with multi range with same field, and ranges no common
         expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term_first = gen_default_term_expr()
        term_second = gen_default_term_expr(field="float",
                                            values=[float(i) for i in range(default_nb // 2, default_nb)])
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    # PASS
    @pytest.mark.level(2)
    def test_query_single_term_multi_fields(self, connect, collection):
        '''
        method: build query with multi term, different field each term
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}}
        term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}}
        term = update_term_expr({"term": {}}, [term_first, term_second])
        expr = {"must": [gen_default_vector_expr(default_query), term]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    @pytest.mark.level(2)
    def test_query_multi_range_has_common(self, connect, collection):
        '''
        method: build query with multi range with same field, and ranges has common
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        range_one = gen_default_range_expr()
        range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3})
        expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # PASS
    @pytest.mark.level(2)
    def test_query_multi_range_no_common(self, connect, collection):
        '''
         method: build query with multi range with same field, and ranges no common
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        range_one = gen_default_range_expr()
        range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
        expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    # PASS
    @pytest.mark.level(2)
    def test_query_multi_range_different_fields(self, connect, collection):
        '''
        method: build query with multi range, different field each range
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        range_first = gen_default_range_expr()
        range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb})
        expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    # PASS
    @pytest.mark.level(2)
    def test_query_single_range_multi_fields(self, connect, collection):
        '''
        method: build query with multi range, different field each range
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        range_first = {"int64": {"GT": 0, "LT": default_nb // 2}}
        range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}}
        range = update_range_expr({"range": {}}, [range_first, range_second])
        expr = {"must": [gen_default_vector_expr(default_query), range]}
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    """
    ******************************************************************
    #  The following cases are used to build query expr both term and range
    ******************************************************************
    """

    # PASS
    @pytest.mark.level(2)
    def test_query_single_term_range_has_common(self, connect, collection):
        '''
        method: build query with single term single range
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term = gen_default_term_expr()
        range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2})
        expr = {"must": [gen_default_vector_expr(default_query), term, range]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == default_top_k

    # PASS
    def test_query_single_term_range_no_common(self, connect, collection):
        '''
        method: build query with single term single range
        expected: pass
        '''
        entities, ids = init_data(connect, collection)
        term = gen_default_term_expr()
        range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
        expr = {"must": [gen_default_vector_expr(default_query), term, range]}
        query = update_query_expr(default_query, expr=expr)
        res = connect.search(collection, query)
        assert len(res) == nq
        assert len(res[0]) == 0

    """
    ******************************************************************
    #  The following cases are used to build multi vectors query expr
    ******************************************************************
    """

    # PASS
    def test_query_multi_vectors_same_field(self, connect, collection):
        '''
        method: build query with two vectors same field
        expected: error raised
        '''
        entities, ids = init_data(connect, collection)
        vector1 = default_query
        vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2)
        expr = {
            "must": [vector1, vector2]
        }
        query = update_query_expr(default_query, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)


@pytest.mark.skip("r0.3-test")
class TestSearchDSLBools(object):
    """
    ******************************************************************
    #  The following cases are used to build invalid query expr
    ******************************************************************
    """

    # PASS
    @pytest.mark.level(2)
    def test_query_no_bool(self, connect, collection):
        '''
        method: build query without bool expr
        expected: error raised
        '''
        entities, ids = init_data(connect, collection)
        expr = {"bool1": {}}
        query = expr
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_should_only_term(self, connect, collection):
        '''
        method: build query without must, with should.term instead
        expected: error raised
        '''
        expr = {"should": gen_default_term_expr}
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_should_only_vector(self, connect, collection):
        '''
        method: build query without must, with should.vector instead
        expected: error raised
        '''
        expr = {"should": default_query["bool"]["must"]}
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_must_not_only_term(self, connect, collection):
        '''
        method: build query without must, with must_not.term instead
        expected: error raised
        '''
        expr = {"must_not": gen_default_term_expr}
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_must_not_vector(self, connect, collection):
        '''
        method: build query without must, with must_not.vector instead
        expected: error raised
        '''
        expr = {"must_not": default_query["bool"]["must"]}
        query = update_query_expr(default_query, keep_old=False, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # PASS
    def test_query_must_should(self, connect, collection):
        '''
        method: build query must, and with should.term
        expected: error raised
        '''
        expr = {"should": gen_default_term_expr}
        query = update_query_expr(default_query, keep_old=True, expr=expr)
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)


"""
******************************************************************
#  The following cases are used to test `search` function
#  with invalid collection_name, or invalid query expr
******************************************************************
"""


class TestSearchInvalid(object):
    """
    Test search collection with invalid collection names
    """

    @pytest.fixture(
        scope="function",
        params=gen_invalid_strs()
    )
    def get_collection_name(self, request):
        yield request.param

    @pytest.fixture(
        scope="function",
        params=gen_invalid_strs()
    )
    def get_invalid_tag(self, request):
        yield request.param

    @pytest.fixture(
        scope="function",
        params=gen_invalid_strs()
    )
    def get_invalid_field(self, request):
        yield request.param

    @pytest.fixture(
        scope="function",
        params=gen_simple_index()
    )
    def get_simple_index(self, request, connect):
        # if str(connect._cmd("mode")) == "CPU":
        if request.param["index_type"] in index_cpu_not_support():
            pytest.skip("sq8h not support in CPU mode")
        return request.param

    # PASS
    @pytest.mark.level(2)
    def test_search_with_invalid_collection(self, connect, get_collection_name):
        collection_name = get_collection_name
        with pytest.raises(Exception) as e:
            res = connect.search(collection_name, default_query)

    # PASS
    # TODO(yukun)
    @pytest.mark.level(2)
    def test_search_with_invalid_tag(self, connect, collection):
        tag = " "
        with pytest.raises(Exception) as e:
            res = connect.search(collection, default_query, partition_tags=tag)

    # TODO: reopen after we supporting targetEntry
    @pytest.mark.skip("search_with_invalid_field_name")
    @pytest.mark.level(2)
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
        fields = [get_invalid_field]
        with pytest.raises(Exception) as e:
            res = connect.search(collection, default_query, fields=fields)

    # TODO: reopen after we supporting targetEntry
    @pytest.mark.skip("search_with_not_existed_field_name")
    @pytest.mark.level(1)
    def test_search_with_not_existed_field_name(self, connect, collection):
        fields = [gen_unique_str("field_name")]
        with pytest.raises(Exception) as e:
            res = connect.search(collection, default_query, fields=fields)

    """
    Test search collection with invalid query
    """

    @pytest.fixture(
        scope="function",
        params=gen_invalid_ints()
    )
    def get_top_k(self, request):
        yield request.param

    @pytest.mark.level(1)
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
        '''
        target: test search function, with the wrong top_k
        method: search with top_k
        expected: raise an error, and the connection is normal
        '''
        top_k = get_top_k
        default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
        with pytest.raises(Exception) as e:
            res = connect.search(collection, default_query)

    """
    Test search collection with invalid search params
    """

    @pytest.fixture(
        scope="function",
        params=gen_invaild_search_params()
    )
    def get_search_params(self, request):
        yield request.param

    # Pass
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
        '''
        target: test search function, with the wrong nprobe
        method: search with nprobe
        expected: raise an error, and the connection is normal
        '''
        search_params = get_search_params
        index_type = get_simple_index["index_type"]
        if index_type in ["FLAT"]:
            pytest.skip("skip in FLAT index")
        if index_type != search_params["index_type"]:
            pytest.skip("skip if index_type not matched")
        entities, ids = init_data(connect, collection)
        connect.create_index(collection, field_name, get_simple_index)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1,
                                        search_params=search_params["search_params"])
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)

    # pass
    @pytest.mark.skip("r0.3-test")
    @pytest.mark.level(2)
    def test_search_with_invalid_params_binary(self, connect, binary_collection):
        '''
        target: test search function, with the wrong nprobe
        method: search with nprobe
        expected: raise an error, and the connection is normal
        '''
        nq = 1
        index_type = "BIN_IVF_FLAT"
        int_vectors, entities, ids = init_binary_data(connect, binary_collection)
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
        connect.create_index(binary_collection, binary_field_name,
                             {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}})
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq,
                                        search_params={"nprobe": 0}, metric_type="JACCARD")
        with pytest.raises(Exception) as e:
            res = connect.search(binary_collection, query)

    # Pass
    @pytest.mark.level(2)
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
        '''
        target: test search function, with empty search params
        method: search with params
        expected: raise an error, and the connection is normal
        '''
        index_type = get_simple_index["index_type"]
        if args["handler"] == "HTTP":
            pytest.skip("skip in http mode")
        if index_type == "FLAT":
            pytest.skip("skip in FLAT index")
        entities, ids = init_data(connect, collection)
        connect.create_index(collection, field_name, get_simple_index)
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={})
        with pytest.raises(Exception) as e:
            res = connect.search(collection, query)


def check_id_result(result, id):
    limit_in = 5
    ids = [entity.id for entity in result]
    if len(result) >= limit_in:
        return id in ids[:limit_in]
    else:
        return id in ids