mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
release/v2.4 -> master (#844)
This commit is contained in:
parent
a24df8e990
commit
89cabee1b4
386 changed files with 7202 additions and 5741 deletions
|
|
@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
|
|||
return proc
|
||||
|
||||
|
||||
def _make_request(vector=None, user="test-user", collection="test-col",
|
||||
def _make_request(vector=None, collection="test-col",
|
||||
schema_name="customers", limit=10, index_name=None):
|
||||
return RowEmbeddingsRequest(
|
||||
vector=vector or [0.1, 0.2, 0.3],
|
||||
user=user,
|
||||
collection=collection,
|
||||
schema_name=schema_name,
|
||||
limit=limit,
|
||||
|
|
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
|
|||
)
|
||||
|
||||
|
||||
def _make_flow(workspace="test-workspace", pub=None):
|
||||
"""Make a mock flow object that is callable and has .workspace."""
|
||||
flow = MagicMock()
|
||||
flow.return_value = pub if pub is not None else AsyncMock()
|
||||
flow.workspace = workspace
|
||||
return flow
|
||||
|
||||
|
||||
def _make_search_point(index_name, index_value, text, score):
|
||||
point = MagicMock()
|
||||
point.payload = {
|
||||
|
|
@ -85,34 +92,33 @@ class TestFindCollection:
|
|||
def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_user_test_col_customers_384"
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-user", "test-col", "customers")
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
# Prefix: rows_test_user_test_col_customers_
|
||||
assert result == "rows_test_user_test_col_customers_384"
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_user_other_col_schema_768"
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-user", "test-col", "customers")
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("user", "col", "schema")
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
|
|||
proc = _make_processor()
|
||||
request = _make_request(vector=[])
|
||||
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
|
|||
proc.find_collection = MagicMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
|
|||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request()
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], RowIndexMatch)
|
||||
|
|
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request(index_name="address")
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
call_kwargs = proc.qdrant.query_points.call_args[1]
|
||||
assert call_kwargs["query_filter"] is not None
|
||||
|
|
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request(index_name="")
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
call_kwargs = proc.qdrant.query_points.call_args[1]
|
||||
assert call_kwargs["query_filter"] is None
|
||||
|
|
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
|
|||
proc.qdrant.query_points.return_value = mock_result
|
||||
|
||||
request = _make_request()
|
||||
result = await proc.query_row_embeddings(request)
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].index_name == ""
|
||||
|
|
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
||||
with pytest.raises(Exception, match="qdrant down"):
|
||||
await proc.query_row_embeddings(request)
|
||||
await proc.query_row_embeddings("test-workspace", request)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -243,7 +249,7 @@ class TestOnMessage:
|
|||
])
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
@ -264,7 +270,7 @@ class TestOnMessage:
|
|||
)
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
@ -284,7 +290,7 @@ class TestOnMessage:
|
|||
proc.query_row_embeddings = AsyncMock(return_value=[])
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow = _make_flow(pub=mock_pub)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = _make_request()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue