From 87c6c97af663a24576891758297ac15db629c2f7 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 21 Apr 2026 13:26:14 +0100 Subject: [PATCH] Fix tests --- .../test_metadata_preservation.py | 24 ++++++------- .../test_null_embedding_protection.py | 36 +++++++------------ .../unit/test_retrieval/test_document_rag.py | 12 ++++++- .../test_document_rag_service.py | 1 + .../test_knowledge_translator_roundtrip.py | 2 +- 5 files changed, 37 insertions(+), 38 deletions(-) diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py index aded7253..2170c763 100644 --- a/tests/unit/test_reliability/test_metadata_preservation.py +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -30,7 +30,7 @@ class TestDocumentMetadataTranslator: "title": "Test Document", "comments": "No comments", "metadata": [], - "user": "alice", + "workspace": "alice", "tags": ["finance", "q4"], "parent-id": "doc-100", "document-type": "page", @@ -40,14 +40,14 @@ class TestDocumentMetadataTranslator: assert obj.time == 1710000000 assert obj.kind == "application/pdf" assert obj.title == "Test Document" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.tags == ["finance", "q4"] assert obj.parent_id == "doc-100" assert obj.document_type == "page" wire = self.tx.encode(obj) assert wire["id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["parent-id"] == "doc-100" assert wire["document-type"] == "page" @@ -80,10 +80,10 @@ class TestDocumentMetadataTranslator: def test_falsy_fields_omitted_from_wire(self): """Empty string fields should be omitted from wire format.""" - obj = DocumentMetadata(id="", time=0, user="") + obj = DocumentMetadata(id="", time=0, workspace="") wire = self.tx.encode(obj) assert "id" not in wire - assert "user" not in wire + assert "workspace" not in wire # --------------------------------------------------------------------------- @@ -101,7 +101,7 @@ class TestProcessingMetadataTranslator: "document-id": "doc-123", "time": 1710000000, "flow": "default", - "user": "alice", + "workspace": "alice", "collection": "my-collection", "tags": ["tag1"], } @@ -109,20 +109,20 @@ class TestProcessingMetadataTranslator: assert obj.id == "proc-1" assert obj.document_id == "doc-123" assert obj.flow == "default" - assert obj.user == "alice" + assert obj.workspace == "alice" assert obj.collection == "my-collection" assert obj.tags == ["tag1"] wire = self.tx.encode(obj) assert wire["id"] == "proc-1" assert wire["document-id"] == "doc-123" - assert wire["user"] == "alice" + assert wire["workspace"] == "alice" assert wire["collection"] == "my-collection" def test_missing_fields_use_defaults(self): obj = self.tx.decode({}) assert obj.id is None - assert obj.user is None + assert obj.workspace is None assert obj.collection is None def test_tags_none_omitted(self): @@ -135,10 +135,10 @@ class TestProcessingMetadataTranslator: wire = self.tx.encode(obj) assert wire["tags"] == [] - def test_user_and_collection_preserved(self): + def test_workspace_and_collection_preserved(self): """Core pipeline routing fields must survive round-trip.""" - data = {"user": "bob", "collection": "research"} + data = {"workspace": "bob", "collection": "research"} obj = self.tx.decode(data) wire = self.tx.encode(obj) - assert wire["user"] == "bob" + assert wire["workspace"] == "bob" assert wire["collection"] == "research" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 41a5c621..2296e961 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -61,7 +61,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -69,7 +68,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [] # Empty vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) # No upsert should be called proc.qdrant.upsert.assert_not_called() @@ -83,7 +82,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -91,7 +89,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = None # None vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -103,7 +101,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -111,7 +108,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -124,7 +121,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" emb = MagicMock() @@ -132,7 +128,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.1, 0.2, 0.3] msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -146,7 +142,6 @@ class TestDocEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "docs" emb = MagicMock() @@ -154,7 +149,7 @@ class TestDocEmbeddingsNullProtection: emb.vector = [0.0] * 384 # 384-dim vector msg.chunks = [emb] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("alice", msg) call_args = proc.qdrant.upsert.call_args assert "d_alice_docs_384" in call_args[1]["collection_name"] @@ -175,7 +170,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -183,7 +177,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -195,7 +189,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -203,7 +196,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [0.1, 0.2, 0.3] msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -215,7 +208,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -223,7 +215,7 @@ class TestGraphEmbeddingsNullProtection: entity.vector = [] # Empty vector msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -236,7 +228,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "col1" entity = MagicMock() @@ -245,7 +236,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "c1" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_called_once() @pytest.mark.asyncio @@ -258,7 +249,6 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) msg = MagicMock() - msg.metadata.user = "alice" msg.metadata.collection = "graphs" entity = MagicMock() @@ -267,7 +257,7 @@ class TestGraphEmbeddingsNullProtection: entity.chunk_id = "" msg.entities = [entity] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("alice", msg) # Collection should be created with correct dimension proc.qdrant.create_collection.assert_called_once() @@ -290,11 +280,10 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.chunks = [MagicMock()] - await proc.store_document_embeddings(msg) + await proc.store_document_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() @pytest.mark.asyncio @@ -306,9 +295,8 @@ class TestCollectionValidation: proc.collection_exists = MagicMock(return_value=False) msg = MagicMock() - msg.metadata.user = "user1" msg.metadata.collection = "deleted-col" msg.entities = [MagicMock()] - await proc.store_graph_embeddings(msg) + await proc.store_graph_embeddings("user1", msg) proc.qdrant.upsert.assert_not_called() diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 1ff85f5a..d96d7cd4 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -92,6 +92,7 @@ class TestQuery: # Initialize Query with defaults query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False @@ -112,6 +113,7 @@ class TestQuery: # Initialize Query with custom doc_limit query = Query( rag=mock_rag, + workspace="test_workspace", user="custom_user", collection="custom_collection", verbose=True, @@ -137,6 +139,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False @@ -162,6 +165,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False @@ -184,6 +188,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False @@ -223,6 +228,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False, @@ -350,7 +356,7 @@ class TestQuery: mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2]], limit=20, # Default doc_limit - user="trustgraph", # Default user + user="", # Default user (empty passthrough) collection="default" # Default collection ) @@ -380,6 +386,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=True, @@ -453,6 +460,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False @@ -509,6 +517,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=True @@ -619,6 +628,7 @@ class TestQuery: query = Query( rag=mock_rag, + workspace="test_workspace", user="test_user", collection="test_collection", verbose=False, diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index a5d42f3a..1469d075 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -64,6 +64,7 @@ class TestDocumentRagService: # Verify: DocumentRag.query was called with correct parameters mock_rag_instance.query.assert_called_once_with( "test query", + workspace=ANY, # Workspace comes from flow.workspace (mock) user="my_user", # Must be from message, not hardcoded default collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py index 47c864c7..6b55c420 100644 --- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -109,7 +109,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(decoded, KnowledgeRequest) assert decoded.operation == "put-kg-core" - assert decoded.user == "alice" + assert decoded.workspace == "alice" assert decoded.id == "doc-1" assert decoded.flow == "default" assert decoded.collection == "testcoll"