Skip to content
Snippets Groups Projects
Unverified Commit 1f05c9ca authored by David Buchaca Prats's avatar David Buchaca Prats Committed by GitHub
Browse files

fix(array): consider singleton to return singleton (#286)

parent d6deafdc
No related branches found
Tags v0.12.9
No related merge requests found
......@@ -136,8 +136,7 @@ class FindMixin:
isinstance(query, list) and isinstance(query[0], str)
):
result = self._find_by_text(query, index=index, limit=limit, **kwargs)
if len(result) == 1:
if isinstance(query, str):
return result[0]
else:
return result
......@@ -206,6 +205,7 @@ class FindMixin:
if len(matches) >= _limit:
break
result.append(matches)
else:
raise TypeError(
f'unsupported type `{type(_result)}` returned from `._find()`'
......@@ -219,10 +219,11 @@ class FindMixin:
else:
result[i] = matches
if len(result) == 1:
return result[0]
else:
return result
# ensures query=np.array([1,2,3]) returns DocumentArray not list with 1 DocumentArray
if n_dim == 1:
result = result[0]
return result
@abc.abstractmethod
def _find(
......
......@@ -31,18 +31,23 @@ def test_find(storage, config, limit, query, start_storage):
da.extend([Document(embedding=v) for v in embeddings])
result = da.find(query, limit=limit)
n_rows_query, _ = ndarray.get_array_rows(query)
n_rows_query, n_dim = ndarray.get_array_rows(query)
# check for each row on the query a DocumentArray is returned
if n_rows_query == 1:
if n_rows_query == 1 and n_dim == 1:
# we expect a result to be DocumentArray
assert len(result) == limit
elif n_rows_query == 1 and n_dim == 2:
# we expect a result to be a list with 1 DocumentArray
assert len(result) == 1
assert len(result[0]) == limit
else:
# check for each row on the query a DocumentArray is returned
assert len(result) == n_rows_query
# check returned objects are sorted according to the storage backend metric
# weaviate uses cosine similarity by default
# annlite uses cosine distance by default
if n_rows_query == 1:
if n_dim == 1:
if storage == 'weaviate':
cosine_similarities = [
t['cosine_similarity'].value for t in result[:, 'scores']
......@@ -144,6 +149,7 @@ def test_find_by_tag(storage, config, start_storage):
)
results = da.find('token1 token2', index='attr1')
assert isinstance(results, DocumentArray)
assert len(results) == 2
assert results[0].id == '1'
assert results[1].id == '2'
......@@ -176,3 +182,13 @@ def test_find_by_tag(storage, config, start_storage):
assert len(results) == 1
assert results[0].id == '3'
assert all(['token1' in r.tags['attr3'] for r in results]) == True
results = da.find(['token1 token2'], index='attr1')
assert isinstance(results, list)
assert len(results) == 1
assert isinstance(results[0], DocumentArray)
results = da.find(['token1 token2', 'token1'], index='attr1')
assert isinstance(results, list)
assert len(results) == 2
assert all([isinstance(result, DocumentArray) for result in results]) == True
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment