From f0ad282708e876624aa76fde188bb3ca5e7bb3df Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 10:30:21 +0100 Subject: [PATCH 01/16] CLI auth migration, document embeddings core lifecycle (#913) Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows. --- trustgraph-base/trustgraph/api/knowledge.py | 31 + .../trustgraph/api/socket_client.py | 52 ++ .../messaging/translators/knowledge.py | 56 +- .../trustgraph/schema/knowledge/knowledge.py | 6 +- trustgraph-cli/pyproject.toml | 2 + trustgraph-cli/trustgraph/cli/get_de_core.py | 111 ++++ trustgraph-cli/trustgraph/cli/get_kg_core.py | 77 +-- .../trustgraph/cli/invoke_graph_rag.py | 604 ------------------ trustgraph-cli/trustgraph/cli/put_de_core.py | 119 ++++ trustgraph-cli/trustgraph/cli/put_kg_core.py | 99 +-- trustgraph-flow/trustgraph/cores/knowledge.py | 325 +++++++--- trustgraph-flow/trustgraph/cores/service.py | 5 + .../trustgraph/gateway/registry.py | 6 + .../trustgraph/tables/knowledge.py | 94 +++ 14 files changed, 762 insertions(+), 825 deletions(-) create mode 100644 trustgraph-cli/trustgraph/cli/get_de_core.py create mode 100644 trustgraph-cli/trustgraph/cli/put_de_core.py diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index c3ec2308..06357d70 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -132,3 +132,34 @@ class Knowledge: self.request(request = input) + def list_de_cores(self): + + input = { + "operation": "list-de-cores", + "workspace": self.api.workspace, + } + + return self.request(request = input)["ids"] + + def delete_de_core(self, id): + + input = { + "operation": "delete-de-core", + "workspace": self.api.workspace, + "id": id, + } + + self.request(request = input) + + def load_de_core(self, id, flow="default", collection="default"): + + input = { + "operation": "load-de-core", + "workspace": self.api.workspace, + "id": id, + "flow": flow, + "collection": collection, + } + + self.request(request = input) + diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index aeb15f85..75a7be9a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -491,6 +491,58 @@ class SocketClient: triples=raw_triples, ) + def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-kg-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_kg_core( + self, id: str, triples=None, graph_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-kg-core", + "workspace": self.workspace, + "id": id, + } + if triples is not None: + request["triples"] = triples + if graph_embeddings is not None: + request["graph-embeddings"] = graph_embeddings + return self._send_request_sync("knowledge", None, request) + + def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-de-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_de_core( + self, id: str, document_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-de-core", + "workspace": self.workspace, + "id": id, + } + if document_embeddings is not None: + request["document-embeddings"] = document_embeddings + return self._send_request_sync("knowledge", None, request) + def close(self) -> None: """Close the persistent WebSocket connection.""" if self._loop and not self._loop.is_closed(): diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f2cc8e46..3830bf59 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + DocumentEmbeddings, ChunkEmbeddings, Metadata, EntityEmbeddings ) from .base import MessageTranslator @@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator): ] ) + document_embeddings = None + if "document-embeddings" in data: + document_embeddings = DocumentEmbeddings( + metadata=Metadata( + id=data["document-embeddings"]["metadata"]["id"], + root=data["document-embeddings"]["metadata"].get("root", ""), + collection=data["document-embeddings"]["metadata"]["collection"] + ), + chunks=[ + ChunkEmbeddings( + chunk_id=ch["chunk_id"], + vector=ch["vector"], + ) + for ch in data["document-embeddings"]["chunks"] + ] + ) + return KnowledgeRequest( operation=data.get("operation"), id=data.get("id"), @@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator): collection=data.get("collection"), triples=triples, graph_embeddings=graph_embeddings, + document_embeddings=document_embeddings, ) def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: @@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator): ], } + if obj.document_embeddings: + result["document-embeddings"] = { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + return result @@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator): } } + # Streaming document embeddings response + if obj.document_embeddings: + return { + "document-embeddings": { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + } + # End of stream marker if obj.eos is True: return {"eos": True} @@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator): is_final = ( obj.ids is not None or # List response obj.eos is True or # End of stream - (not obj.triples and not obj.graph_embeddings) # Empty response + (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response ) return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 64cb7082..a3879103 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -4,7 +4,7 @@ from ..core.topic import queue from ..core.metadata import Metadata from .document import Document, TextDocument from .graph import Triples -from .embeddings import GraphEmbeddings +from .embeddings import GraphEmbeddings, DocumentEmbeddings # get-kg-core # -> (???) @@ -41,6 +41,9 @@ class KnowledgeRequest: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + # put-de-core + document_embeddings: DocumentEmbeddings | None = None + @dataclass class KnowledgeResponse: error: Error | None = None @@ -48,6 +51,7 @@ class KnowledgeResponse: eos: bool = False # Indicates end of knowledge core stream triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + document_embeddings: DocumentEmbeddings | None = None knowledge_request_queue = queue('knowledge', cls='request') knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index e8062fba..10dca2e8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" +tg-get-de-core = "trustgraph.cli.get_de_core:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" @@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main" tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main" +tg-put-de-core = "trustgraph.cli.put_de_core:main" tg-put-kg-core = "trustgraph.cli.put_kg_core:main" tg-remove-library-document = "trustgraph.cli.remove_library_document:main" tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main" diff --git a/trustgraph-cli/trustgraph/cli/get_de_core.py b/trustgraph-cli/trustgraph/cli/get_de_core.py new file mode 100644 index 00000000..caf74ba9 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/get_de_core.py @@ -0,0 +1,111 @@ +""" +Uses the knowledge service to fetch a document embeddings core which is +saved to a local file in msgpack format. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def write_de(f, data): + msg = ( + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["root"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "i": ch["chunk_id"], + "v": ch["vector"], + } + for ch in data["chunks"] + ] + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + +def fetch(url, workspace, id, output, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(output, "wb") as f: + + for response in socket.get_de_core(id): + + if "document-embeddings" in response: + de += 1 + write_de(f, response["document-embeddings"]) + + print(f"Got: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-get-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help=f'Output file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index 8bee4115..b4f37b81 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,13 +5,11 @@ to a local file in msgpack format. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,7 +19,7 @@ def write_triple(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -35,13 +33,13 @@ def write_ge(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -49,54 +47,18 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, workspace, id, output, token=None): +def fetch(url, workspace, id, output, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "get-kg-core", - "workspace": workspace, - "id": id, - } - }) - - await ws.send(req) + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 with open(output, "wb") as f: - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "response" not in obj: - raise RuntimeError("No response?") - - response = obj["response"] - - if "error" in response: - raise RuntimeError(obj["error"]) - - if "eos" in response: - if response["eos"]: break + for response in socket.get_kg_core(id): if "triples" in response: t += 1 @@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None): print(f"Got: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -151,14 +114,12 @@ def main(): try: - asyncio.run( - fetch( - url=args.url, - workspace=args.workspace, - id=args.id, - output=args.output, - token=args.token, - ) + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 23d6bcac..f39cdab0 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question """ import argparse -import json import os import sys -import websockets -import asyncio from trustgraph.api import ( Api, ExplainabilityClient, @@ -31,607 +28,6 @@ default_max_path_length = 2 default_edge_score_limit = 30 default_edge_limit = 25 -# Provenance predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_CONCEPT = TG + "concept" -TG_ENTITY = TG + "entity" -TG_EDGE_COUNT = TG + "edgeCount" -TG_SELECTED_EDGE = TG + "selectedEdge" -TG_EDGE = TG + "edge" -TG_REASONING = TG + "reasoning" -TG_DOCUMENT = TG + "document" -TG_CONTAINS = TG + "contains" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - - -def _get_event_type(prov_id): - """Extract event type from provenance_id""" - if "question" in prov_id: - return "question" - elif "grounding" in prov_id: - return "grounding" - elif "exploration" in prov_id: - return "exploration" - elif "focus" in prov_id: - return "focus" - elif "synthesis" in prov_id: - return "synthesis" - return "provenance" - - -def _format_provenance_details(event_type, triples): - """Format provenance details based on event type and triples""" - lines = [] - - if event_type == "question": - # Show query and timestamp - for s, p, o in triples: - if p == TG_QUERY: - lines.append(f" Query: {o}") - elif p == PROV_STARTED_AT_TIME: - lines.append(f" Time: {o}") - - elif event_type == "grounding": - # Show extracted concepts - concepts = [o for s, p, o in triples if p == TG_CONCEPT] - if concepts: - lines.append(f" Concepts: {len(concepts)}") - for concept in concepts: - lines.append(f" - {concept}") - - elif event_type == "exploration": - # Show edge count (seed entities resolved separately with labels) - for s, p, o in triples: - if p == TG_EDGE_COUNT: - lines.append(f" Edges explored: {o}") - - elif event_type == "focus": - # For focus, just count edge selection URIs - # The actual edge details are fetched separately via edge_selections parameter - edge_sel_uris = [] - for s, p, o in triples: - if p == TG_SELECTED_EDGE: - edge_sel_uris.append(o) - if edge_sel_uris: - lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") - - elif event_type == "synthesis": - # Show document reference (content already streamed) - for s, p, o in triples: - if p == TG_DOCUMENT: - lines.append(f" Document: {o}") - - return lines - - -async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): - """Query triples for a provenance node (single attempt)""" - request = { - "id": "triples-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": prov_id}, - "collection": collection, - "limit": 100 - } - } - # Add graph filter if specified (for named graph queries) - if graph is not None: - request["request"]["g"] = graph - - if debug: - print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) - - triples = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if debug: - print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) - - if response.get("id") != "triples-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - # Handle triples response - # Response format: {"response": [triples...]} - # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) - p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) - # Handle quoted triples (type "t") and regular values - o_term = t.get("o", {}) - if o_term.get("t") == "t": - # Quoted triple - extract s, p, o from nested structure - tr = o_term.get("tr", {}) - o = { - "s": tr.get("s", {}).get("i", ""), - "p": tr.get("p", {}).get("i", ""), - "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), - } - else: - o = o_term.get("i", o_term.get("v", "")) - triples.append((s, p, o)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception: {e}", file=sys.stderr) - - if debug: - print(f" [debug] got {len(triples)} triples", file=sys.stderr) - - return triples - - -async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): - """Query triples for a provenance node with retries for race condition""" - for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) - if triples: - return triples - # Wait before retry if empty (triples may not be stored yet) - if attempt < max_retries - 1: - if debug: - print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) - await asyncio.sleep(retry_delay) - return [] - - -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): - """ - Query for provenance of an edge (s, p, o) in the knowledge graph. - - Finds subgraphs that contain the edge via tg:contains, then follows - prov:wasDerivedFrom to find source documents. - - Returns list of source URIs (chunks, pages, documents). - """ - # Query for subgraphs that contain this edge: ?subgraph tg:contains <> - request = { - "id": "edge-prov-request", - "service": "triples", - "flow": flow_id, - "request": { - "p": {"t": "i", "i": TG_CONTAINS}, - "o": { - "t": "t", # Quoted triple type - "tr": { - "s": {"t": "i", "i": edge_s}, - "p": {"t": "i", "i": edge_p}, - "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, - } - }, - "collection": collection, - "limit": 10 - } - } - - if debug: - print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) - - stmt_uris = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "edge-prov-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", "") - if s: - stmt_uris.append(s) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) - - # For each statement, query wasDerivedFrom to find sources - sources = [] - for stmt_uri in stmt_uris: - # Query: stmt_uri prov:wasDerivedFrom ?source - request = { - "id": "derived-from-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": stmt_uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 10 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "derived-from-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - o = t.get("o", {}).get("i", "") - if o: - sources.append(o) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) - - return sources - - -async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): - """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" - request = { - "id": "parent-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 1 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "parent-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - return triple_list[0].get("o", {}).get("i", None) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying parent: {e}", file=sys.stderr) - - return None - - -async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): - """ - Trace the full provenance chain from a source URI up to the root document. - Returns a list of (uri, label) tuples from leaf to root. - """ - chain = [] - current = source_uri - max_depth = 10 # Prevent infinite loops - - for _ in range(max_depth): - if not current: - break - - # Get label for current entity - label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) - chain.append((current, label)) - - # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) - if not parent or parent == current: - break - current = parent - - return chain - - -def _format_provenance_chain(chain): - """ - Format a provenance chain as a human-readable string. - Chain is [(uri, label), ...] from leaf to root. - """ - if not chain: - return "" - - # Show labels, from leaf to root - labels = [label for uri, label in chain] - return " → ".join(labels) - - -def _is_iri(value): - """Check if a value looks like an IRI.""" - if not isinstance(value, str): - return False - return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") - - -async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): - """ - Query for the rdfs:label of an IRI. - Uses label_cache to avoid repeated queries. - Returns the label if found, otherwise returns the IRI. - """ - if not _is_iri(iri): - return iri - - # Check cache first - if iri in label_cache: - return label_cache[iri] - - request = { - "id": "label-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": iri}, - "p": {"t": "i", "i": RDFS_LABEL}, - "collection": collection, - "limit": 1 - } - } - - label = iri # Default to IRI if no label found - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "label-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - # Get the label value - o = triple_list[0].get("o", {}) - label = o.get("v", o.get("i", iri)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) - - # Cache the result - label_cache[iri] = label - return label - - -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): - """ - Resolve labels for all IRI components of an edge triple. - Returns (s_label, p_label, o_label). - """ - s = edge_triple.get("s", "?") - p = edge_triple.get("p", "?") - o = edge_triple.get("o", "?") - - s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) - - return s_label, p_label, o_label - - -async def _question_explainable( - url, flow_id, question, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, token=None, debug=False -): - """Execute graph RAG with explainability - shows provenance events with details""" - # Convert HTTP URL to WebSocket URL - if url.startswith("http://"): - ws_url = url.replace("http://", "ws://", 1) - elif url.startswith("https://"): - ws_url = url.replace("https://", "wss://", 1) - else: - ws_url = f"ws://{url}" - - ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" - if token: - ws_url = f"{ws_url}?token={token}" - - # Cache for label lookups to avoid repeated queries - label_cache = {} - - request = { - "id": "cli-request", - "service": "graph-rag", - "flow": flow_id, - "request": { - "query": question, - "collection": collection, - "entity-limit": entity_limit, - "triple-limit": triple_limit, - "max-subgraph-size": max_subgraph_size, - "max-path-length": max_path_length, - "streaming": True - } - } - - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "cli-request": - continue - - if "error" in response: - print(f"\nError: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - - # Check for errors in response - if "error" in resp and resp["error"]: - err = resp["error"] - print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) - break - - message_type = resp.get("message_type", "") - - if debug: - print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) - - if message_type == "explain": - # Display explain event with details - explain_id = resp.get("explain_id", "") - explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval) - if explain_id: - event_type = _get_event_type(explain_id) - print(f"\n [{event_type}] {explain_id}", file=sys.stderr) - - # Query triples for this explain node (using named graph filter) - triples = await _query_triples( - ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug - ) - - # Format and display details - details = _format_provenance_details(event_type, triples) - for line in details: - print(line, file=sys.stderr) - - # For exploration events, resolve entity labels - if event_type == "exploration": - entity_iris = [o for s, p, o in triples if p == TG_ENTITY] - if entity_iris: - print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) - for iri in entity_iris: - label = await _query_label( - ws_url, flow_id, iri, collection, - label_cache, debug=debug - ) - print(f" - {label}", file=sys.stderr) - - # For focus events, query each edge selection for details - if event_type == "focus": - for s, p, o in triples: - if debug: - print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) - if p == TG_SELECTED_EDGE and isinstance(o, str): - if debug: - print(f" [debug] querying edge selection: {o}", file=sys.stderr) - # Query the edge selection entity (using named graph filter) - edge_triples = await _query_triples( - ws_url, flow_id, o, collection, graph=explain_graph, debug=debug - ) - if debug: - print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) - # Extract edge and reasoning - edge_triple = None # Store the actual triple for provenance lookup - reasoning = None - for es, ep, eo in edge_triples: - if debug: - print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) - if ep == TG_EDGE and isinstance(eo, dict): - # eo is a quoted triple dict - edge_triple = eo - elif ep == TG_REASONING: - reasoning = eo - if edge_triple: - # Resolve labels for edge components - s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, collection, - label_cache, debug=debug - ) - print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if reasoning: - r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning - print(f" Reason: {r_short}", file=sys.stderr) - - # Trace edge provenance in the workspace collection (not explainability) - if edge_triple: - sources = await _query_edge_provenance( - ws_url, flow_id, - edge_triple.get("s", ""), - edge_triple.get("p", ""), - edge_triple.get("o", ""), - collection, # Use the query collection, not explainability - debug=debug - ) - if sources: - for src in sources: - # Trace full chain from source to root document - chain = await _trace_provenance_chain( - ws_url, flow_id, src, collection, - label_cache, debug=debug - ) - chain_str = _format_provenance_chain(chain) - print(f" Source: {chain_str}", file=sys.stderr) - - elif message_type == "chunk" or not message_type: - # Display response chunk - chunk = resp.get("response", "") - if chunk: - print(chunk, end="", flush=True) - - # Check if session is complete - if resp.get("end_of_session"): - break - - print() # Final newline - - def _question_explainable_api( url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, diff --git a/trustgraph-cli/trustgraph/cli/put_de_core.py b/trustgraph-cli/trustgraph/cli/put_de_core.py new file mode 100644 index 00000000..1d6589af --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/put_de_core.py @@ -0,0 +1,119 @@ +""" +Puts a document embeddings core into the knowledge manager via the API +socket. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): + + if unpacked[0] == "de": + msg = unpacked[1] + return { + "metadata": { + "id": id, + "root": msg["m"]["m"], + "collection": "default", + }, + "chunks": [ + { + "chunk_id": ch["i"], + "vector": ch["v"], + } + for ch in msg["c"] + ], + } + else: + raise RuntimeError("Unexpected message type", unpacked[0]) + +def put(url, workspace, id, input, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(input, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while True: + + try: + unpacked = unpacker.unpack() + except msgpack.OutOfData: + break + + msg = read_message(unpacked, id) + de += 1 + socket.put_de_core(id, document_embeddings=msg) + + print(f"Put: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-put-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help=f'Input file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index bd3169c8..fe0981a5 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,13 +19,13 @@ def read_message(unpacked, id): return "ge", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used? + "root": msg["m"]["m"], + "collection": "default", }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ], @@ -37,26 +35,20 @@ def read_message(unpacked, id): return "t", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used by receiver? + "root": msg["m"]["m"], + "collection": "default", }, "triples": msg["t"], } else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, workspace, id, input, token=None): +def put(url, workspace, id, input, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - async with connect(url) as ws: + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 @@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None): try: unpacked = unpacker.unpack() - except: + except msgpack.OutOfData: break kind, msg = read_message(unpacked, id) - mid = str(uuid.uuid4()) - if kind == "ge": - ge += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "graph-embeddings": msg - } - }) + socket.put_kg_core(id, graph_embeddings=msg) elif kind == "t": - t += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "triples": msg - } - }) + socket.put_kg_core(id, triples=msg) else: - raise RuntimeError("Unexpected message kind", kind) - await ws.send(req) - - # Retry loop, wait for right response to come back - while True: - - msg = await ws.recv() - msg = json.loads(msg) - - if msg["id"] != mid: - continue - - if "response" in msg: - if "error" in msg["response"]: - raise RuntimeError(msg["response"]["error"]) - - break - print(f"Put: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -173,14 +122,12 @@ def main(): try: - asyncio.run( - put( - url=args.url, - workspace=args.workspace, - id=args.id, - input=args.input, - token=args.token, - ) + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index 09c6137d..f1fa53f5 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -1,5 +1,6 @@ from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings +from .. schema import DocumentEmbeddings from .. knowledge import hash from .. exceptions import RequestError from .. tables.knowledge import KnowledgeTableStore @@ -157,6 +158,98 @@ class KnowledgeManager: ) ) + async def list_de_cores(self, request, respond, workspace): + + ids = await self.table_store.list_de_cores(workspace) + + await respond( + KnowledgeResponse( + error = None, + ids = ids, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def get_de_core(self, request, respond, workspace): + + logger.info("Getting document embeddings core...") + + async def publish_de(de): + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + document_embeddings = de, + ) + ) + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core retrieval complete") + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = True, + triples = None, + graph_embeddings = None, + ) + ) + + async def put_de_core(self, request, respond, workspace): + + if request.document_embeddings: + await self.table_store.add_document_embeddings( + workspace, request.document_embeddings + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def delete_de_core(self, request, respond, workspace): + + logger.info("Deleting document embeddings core...") + + await self.table_store.delete_document_embeddings( + workspace, request.id + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def load_de_core(self, request, respond, workspace): + + if self.background_task is None: + self.background_task = asyncio.create_task( + self.core_loader() + ) + + await self.loader_queue.put((request, respond, workspace)) + async def core_loader(self): logger.info("Knowledge background processor running...") @@ -165,7 +258,7 @@ class KnowledgeManager: logger.debug("Waiting for next load...") request, respond, workspace = await self.loader_queue.get() - logger.info(f"Loading knowledge: {request.id}") + logger.info(f"Loading: {request.operation} {request.id}") try: @@ -187,25 +280,14 @@ class KnowledgeManager: if "interfaces" not in flow: raise RuntimeError("No defined interfaces") - if "triples-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no triples-store") - - if "graph-embeddings-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no graph-embeddings-store") - - t_q = flow["interfaces"]["triples-store"]["flow"] - ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] - - # Got this far, it should all work - await respond( - KnowledgeResponse( - error = None, - ids = None, - eos = False, - triples = None, - graph_embeddings = None + if request.operation == "load-de-core": + await self._load_de_core( + request, respond, workspace, flow, + ) + else: + await self._load_kg_core( + request, respond, workspace, flow, ) - ) except Exception as e: @@ -223,72 +305,145 @@ class KnowledgeManager: ) ) - - logger.debug("Starting knowledge loading process...") - - try: - - t_pub = None - ge_pub = None - - logger.debug(f"Triples queue: {t_q}") - logger.debug(f"Graph embeddings queue: {ge_q}") - - t_pub = Publisher( - self.flow_config.pubsub, t_q, - schema=Triples, - ) - ge_pub = Publisher( - self.flow_config.pubsub, ge_q, - schema=GraphEmbeddings - ) - - logger.debug("Starting publishers...") - - await t_pub.start() - await ge_pub.start() - - async def publish_triples(t): - # Override collection with request collection - if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): - t.metadata.collection = request.collection or "default" - await t_pub.send(None, t) - - logger.debug("Publishing triples...") - - await self.table_store.get_triples( - workspace, - request.id, - publish_triples, - ) - - async def publish_ge(g): - # Override collection with request collection - if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): - g.metadata.collection = request.collection or "default" - await ge_pub.send(None, g) - - logger.debug("Publishing graph embeddings...") - - await self.table_store.get_graph_embeddings( - workspace, - request.id, - publish_ge, - ) - - logger.debug("Knowledge loading completed") - - except Exception as e: - - logger.error(f"Knowledge exception: {e}", exc_info=True) - - finally: - - logger.debug("Stopping publishers...") - - if t_pub: await t_pub.stop() - if ge_pub: await ge_pub.stop() - logger.debug("Knowledge processing done") continue + + async def _load_kg_core(self, request, respond, workspace, flow): + + if "triples-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no triples-store") + + if "graph-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no graph-embeddings-store") + + t_q = flow["interfaces"]["triples-store"]["flow"] + ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + t_pub = None + ge_pub = None + + try: + + logger.debug(f"Triples queue: {t_q}") + logger.debug(f"Graph embeddings queue: {ge_q}") + + t_pub = Publisher( + self.flow_config.pubsub, t_q, + schema=Triples, + ) + ge_pub = Publisher( + self.flow_config.pubsub, ge_q, + schema=GraphEmbeddings + ) + + logger.debug("Starting publishers...") + + await t_pub.start() + await ge_pub.start() + + async def publish_triples(t): + if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): + t.metadata.collection = request.collection or "default" + await t_pub.send(None, t) + + logger.debug("Publishing triples...") + + await self.table_store.get_triples( + workspace, + request.id, + publish_triples, + ) + + async def publish_ge(g): + if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): + g.metadata.collection = request.collection or "default" + await ge_pub.send(None, g) + + logger.debug("Publishing graph embeddings...") + + await self.table_store.get_graph_embeddings( + workspace, + request.id, + publish_ge, + ) + + logger.debug("Knowledge core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publishers...") + + if t_pub: await t_pub.stop() + if ge_pub: await ge_pub.stop() + + async def _load_de_core(self, request, respond, workspace, flow): + + if "document-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no document-embeddings-store") + + de_q = flow["interfaces"]["document-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + de_pub = None + + try: + + logger.debug(f"Document embeddings queue: {de_q}") + + de_pub = Publisher( + self.flow_config.pubsub, de_q, + schema=DocumentEmbeddings, + ) + + logger.debug("Starting publisher...") + + await de_pub.start() + + async def publish_de(de): + if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'): + de.metadata.collection = request.collection or "default" + await de_pub.send(None, de) + + logger.debug("Publishing document embeddings...") + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publisher...") + + if de_pub: await de_pub.stop() diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index c84b536c..a04e42ca 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor): "put-kg-core": self.knowledge.put_kg_core, "load-kg-core": self.knowledge.load_kg_core, "unload-kg-core": self.knowledge.unload_kg_core, + "list-de-cores": self.knowledge.list_de_cores, + "get-de-core": self.knowledge.get_de_core, + "delete-de-core": self.knowledge.delete_de_core, + "put-de-core": self.knowledge.put_de_core, + "load-de-core": self.knowledge.load_de_core, } if v.operation not in impls: diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index 5e3344f4..4d439097 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core", "load-kg-core", "unload-kg-core"): _register_kind_op("knowledge", _op, "knowledge:write") +# knowledge: document-embeddings core service. +for _op in ("get-de-core", "list-de-cores"): + _register_kind_op("knowledge", _op, "knowledge:read") +for _op in ("put-de-core", "delete-de-core", "load-de-core"): + _register_kind_op("knowledge", _op, "knowledge:write") + # collection-management: workspace collection lifecycle. _register_kind_op("collection-management", "list-collections", "collections:read") diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 5d45358d..cf085fdd 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -1,6 +1,7 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings +from .. schema import DocumentEmbeddings, ChunkEmbeddings from cassandra.cluster import Cluster @@ -217,6 +218,16 @@ class KnowledgeTableStore: WHERE workspace = ? AND document_id = ? """) + self.delete_document_embeddings_stmt = self.cassandra.prepare(""" + DELETE FROM document_embeddings + WHERE workspace = ? AND document_id = ? + """) + + self.list_de_cores_stmt = self.cassandra.prepare(""" + SELECT DISTINCT workspace, document_id FROM document_embeddings + WHERE workspace = ? + """) + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -338,6 +349,50 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def delete_document_embeddings(self, workspace, document_id): + + logger.debug("Delete document embeddings...") + + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def list_de_cores(self, workspace): + + logger.debug("List DE cores...") + + try: + rows = await async_execute( + self.cassandra, + self.list_de_cores_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + lst = [row[1] for row in rows] + + logger.debug("Done") + + return lst + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -417,3 +472,42 @@ class KnowledgeTableStore: logger.debug("Done") + async def get_document_embeddings(self, workspace, document_id, receiver): + + logger.debug("Get DE...") + + try: + rows = await async_execute( + self.cassandra, + self.get_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + for row in rows: + + if row[3]: + chunks = [ + ChunkEmbeddings( + chunk_id=ch[0], + vector=ch[1], + ) + for ch in row[3] + ] + else: + chunks = [] + + await receiver( + DocumentEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + chunks = chunks + ) + ) + + logger.debug("Done") + From bb1109963c1ffc5e7dc8e091251db730546748d7 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 12:03:43 +0100 Subject: [PATCH 02/16] Remove spurious workspace parameter from SPARQL algebra evaluator (#915) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix threading of workspace paramater: - The SPARQL algebra evaluator was threading a workspace parameter through every function and passing it to TriplesClient.query(), which doesn't accept it. Workspace isolation is handled by pub/sub topic routing — the TriplesClient is already scoped to a workspace-specific flow, same as GraphRAG. Passing workspace explicitly was both incorrect and unnecessary. Update tests: - tests/unit/test_query/test_sparql_algebra.py (new) — Tests _query_pattern, _eval_bgp, and evaluate() with various algebra nodes. Key tests assert workspace is never in tc.query() kwargs, plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and edge cases. - tests/unit/test_retrieval/test_graph_rag.py — Added test_triples_query_never_passes_workspace (checks query()) and test_follow_edges_never_passes_workspace (checks query_stream()). --- tests/unit/test_query/test_sparql_algebra.py | 302 ++++++++++++++++++ tests/unit/test_retrieval/test_graph_rag.py | 51 +++ .../trustgraph/query/sparql/algebra.py | 84 +++-- .../trustgraph/query/sparql/service.py | 1 - 4 files changed, 394 insertions(+), 44 deletions(-) create mode 100644 tests/unit/test_query/test_sparql_algebra.py diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py new file mode 100644 index 00000000..9827b2de --- /dev/null +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -0,0 +1,302 @@ +""" +Tests for the SPARQL algebra evaluator. + +Verifies that evaluate() and _query_pattern() call TriplesClient.query() +with the correct arguments, and in particular that workspace is never +passed — workspace isolation is handled by pub/sub topic routing. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + +from rdflib.term import Variable, URIRef, Literal +from rdflib.plugins.sparql.parserutils import CompValue + +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.algebra import ( + evaluate, _query_pattern, _eval_bgp, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + + +def lit(v): + return Term(type=LITERAL, value=v) + + +def make_triple(s, p, o): + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def make_bgp(*patterns): + """Build a CompValue BGP node from (s, p, o) tuples of rdflib terms.""" + node = CompValue("BGP") + node.triples = list(patterns) + return node + + +def make_project(inner, variables): + node = CompValue("Project") + node.p = inner + node.PV = [Variable(v) for v in variables] + return node + + +def make_select(inner): + node = CompValue("SelectQuery") + node.p = inner + return node + + +def make_join(left, right): + node = CompValue("Join") + node.p1 = left + node.p2 = right + return node + + +def make_union(left, right): + node = CompValue("Union") + node.p1 = left + node.p2 = right + return node + + +def make_slice(inner, start, length): + node = CompValue("Slice") + node.p = inner + node.start = start + node.length = length + return node + + +def make_distinct(inner): + node = CompValue("Distinct") + node.p = inner + return node + + +class TestQueryPattern: + """Tests for _query_pattern — the leaf that calls TriplesClient.""" + + @pytest.mark.asyncio + async def test_passes_correct_args(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern( + tc, + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + collection="my-collection", + limit=100, + ) + + tc.query.assert_called_once_with( + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + limit=100, + collection="my-collection", + ) + + @pytest.mark.asyncio + async def test_workspace_not_passed(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern(tc, None, None, None, "default", 10) + + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + + @pytest.mark.asyncio + async def test_returns_query_results(self): + tc = AsyncMock() + triple = make_triple(iri("http://a"), iri("http://b"), lit("c")) + tc.query.return_value = [triple] + + results = await _query_pattern(tc, None, None, None, "default", 10) + + assert len(results) == 1 + assert results[0].s.iri == "http://a" + + +class TestEvalBgp: + """Tests for BGP evaluation — triple pattern queries.""" + + @pytest.mark.asyncio + async def test_single_pattern_all_variables(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default", limit=100) + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://s" + assert solutions[0]["p"].iri == "http://p" + assert solutions[0]["o"].value == "o" + + @pytest.mark.asyncio + async def test_single_pattern_bound_subject(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("val")), + ] + + bgp = make_bgp( + (URIRef("http://s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + tc.query.assert_called_once() + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + assert kwargs["collection"] == "default" + + @pytest.mark.asyncio + async def test_empty_bgp_returns_empty_solution(self): + tc = AsyncMock() + + bgp = make_bgp() + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_no_results_returns_empty(self): + tc = AsyncMock() + tc.query.return_value = [] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [] + + +class TestEvaluate: + """Tests for the top-level evaluate() dispatcher.""" + + @pytest.mark.asyncio + async def test_select_query_node(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + select = make_select(make_project(bgp, ["s", "p"])) + + solutions = await evaluate(select, tc, collection="default") + + assert len(solutions) == 1 + assert "s" in solutions[0] + assert "p" in solutions[0] + assert "o" not in solutions[0] + + @pytest.mark.asyncio + async def test_workspace_never_in_query_calls(self): + """Verify that no matter the algebra structure, workspace is never + passed to TriplesClient.query().""" + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) + tree = make_select(make_project( + make_union(bgp1, bgp2), ["s", "p", "o"] + )) + + await evaluate(tree, tc, collection="test-coll") + + for c in tc.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_join(self): + tc = AsyncMock() + tc.query.side_effect = [ + [make_triple(iri("http://a"), iri("http://p"), lit("v"))], + [make_triple(iri("http://a"), iri("http://q"), lit("w"))], + ] + + bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) + bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) + tree = make_join(bgp1, bgp2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://a" + + @pytest.mark.asyncio + async def test_slice(self): + tc = AsyncMock() + triples = [ + make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) + for i in range(5) + ] + tc.query.return_value = triples + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_slice(bgp, start=1, length=2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 2 + + @pytest.mark.asyncio + async def test_distinct(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple, triple] + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_distinct(bgp) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_unsupported_node_returns_empty_solution(self): + tc = AsyncMock() + + node = CompValue("SomethingUnknown") + + solutions = await evaluate(node, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_non_compvalue_returns_empty_solution(self): + tc = AsyncMock() + + solutions = await evaluate("not a node", tc, collection="default") + + assert solutions == [{}] diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index e0f41357..d1979211 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -337,6 +337,57 @@ class TestQuery: cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") + @pytest.mark.asyncio + async def test_triples_query_never_passes_workspace(self): + """Workspace isolation is handled by pub/sub topic routing, not + by passing workspace to TriplesClient.query(). Verify that + GraphRAG never passes workspace as a keyword argument.""" + mock_rag = MagicMock() + mock_cache = MagicMock() + mock_cache.get.return_value = None + mock_rag.label_cache = mock_cache + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.o = "Label" + mock_triples_client.query.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False + ) + + await query.maybe_label("http://example.com/entity") + + for c in mock_triples_client.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_follow_edges_never_passes_workspace(self): + """Verify follow_edges never passes workspace to query_stream.""" + mock_rag = MagicMock() + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triples_client.query_stream.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + triple_limit=10 + ) + + subgraph = set() + await query.follow_edges("e1", subgraph, path_length=1) + + for c in mock_triples_client.query_stream.call_args_list: + assert "workspace" not in c.kwargs + @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index bff9a336..76b1ad8e 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -30,14 +30,13 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, workspace, collection, limit=10000): +async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. Args: node: rdflib CompValue algebra node triples_client: TriplesClient instance for triple pattern queries - workspace: workspace/keyspace identifier collection: collection identifier limit: safety limit on results @@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000): logger.warning(f"Unsupported algebra node: {name}") return [{}] - return await handler(node, triples_client, workspace, collection, limit) + return await handler(node, triples_client, collection, limit) # --- Node handlers --- -async def _eval_select_query(node, tc, workspace, collection, limit): +async def _eval_select_query(node, tc, collection, limit): """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_project(node, tc, workspace, collection, limit): +async def _eval_project(node, tc, collection, limit): """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] return project(solutions, variables) -async def _eval_bgp(node, tc, workspace, collection, limit): +async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. @@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit): # Query the triples store results = await _query_pattern( - tc, s_val, p_val, o_val, workspace, collection, limit + tc, s_val, p_val, o_val, collection, limit ) # Map results back to variable bindings, @@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit): return solutions[:limit] -async def _eval_join(node, tc, workspace, collection, limit): +async def _eval_join(node, tc, collection, limit): """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return hash_join(left, right)[:limit] -async def _eval_left_join(node, tc, workspace, collection, limit): +async def _eval_left_join(node, tc, collection, limit): """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, workspace, collection, limit) - right_sols = await evaluate(node.p2, tc, workspace, collection, limit) + left_sols = await evaluate(node.p1, tc, collection, limit) + right_sols = await evaluate(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit): return left_join(left_sols, right_sols, filter_fn)[:limit] -async def _eval_union(node, tc, workspace, collection, limit): +async def _eval_union(node, tc, collection, limit): """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return union(left, right)[:limit] -async def _eval_filter(node, tc, workspace, collection, limit): +async def _eval_filter(node, tc, collection, limit): """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr return [ sol for sol in solutions @@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit): ] -async def _eval_distinct(node, tc, workspace, collection, limit): +async def _eval_distinct(node, tc, collection, limit): """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_reduced(node, tc, workspace, collection, limit): +async def _eval_reduced(node, tc, collection, limit): """Evaluate a Reduced node (like Distinct but implementation-defined).""" # Treat same as Distinct - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_order_by(node, tc, workspace, collection, limit): +async def _eval_order_by(node, tc, collection, limit): """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) key_fns = [] for cond in node.expr: @@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit): return order_by(solutions, key_fns) -async def _eval_slice(node, tc, workspace, collection, limit): +async def _eval_slice(node, tc, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" # Pass tighter limit downstream if possible inner_limit = limit @@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit): offset = node.start or 0 inner_limit = min(limit, offset + node.length) - solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) + solutions = await evaluate(node.p, tc, collection, inner_limit) return slice_solutions(solutions, node.start or 0, node.length) -async def _eval_extend(node, tc, workspace, collection, limit): +async def _eval_extend(node, tc, collection, limit): """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr @@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit): return result -async def _eval_group(node, tc, workspace, collection, limit): +async def _eval_group(node, tc, collection, limit): """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) # Extract grouping expressions group_exprs = [] @@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit): return result -async def _eval_aggregate_join(node, tc, workspace, collection, limit): +async def _eval_aggregate_join(node, tc, collection, limit): """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) result = [] for sol in solutions: @@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit): return result -async def _eval_graph(node, tc, workspace, collection, limit): +async def _eval_graph(node, tc, collection, limit): """Evaluate a Graph node (GRAPH clause).""" term = node.term @@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit): # We'd need to pass graph to triples queries # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) else: - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_values(node, tc, workspace, collection, limit): +async def _eval_values(node, tc, collection, limit): """Evaluate a VALUES clause (inline data).""" variables = [str(v) for v in node.var] solutions = [] @@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit): return solutions -async def _eval_to_multiset(node, tc, workspace, collection, limit): +async def _eval_to_multiset(node, tc, collection, limit): """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) # --- Aggregate computation --- @@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, workspace, collection, limit): +async def _query_pattern(tc, s, p, o, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - workspace=workspace, collection=collection, ) return results diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 983cd4f6..75c00dba 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,6 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) From a2dde9cafbdea5e6c0f48c5b6ef52c0b6b30c2b1 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 16:00:54 +0100 Subject: [PATCH 03/16] Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916) Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns --- .../test_cassandra_config_end_to_end.py | 79 +++--- .../test_rows_cassandra_integration.py | 3 + .../test_rows_graphql_query_integration.py | 12 +- .../test_query/test_rows_cassandra_query.py | 7 +- .../test_triples_cassandra_query.py | 112 ++++----- .../test_null_embedding_protection.py | 12 + .../test_doc_embeddings_qdrant_storage.py | 4 +- .../test_row_embeddings_qdrant_storage.py | 14 +- .../test_rows_cassandra_storage.py | 3 + .../test_triples_cassandra_storage.py | 78 +++--- .../test_row_embeddings_query.py | 27 ++- .../trustgraph/direct/cassandra_kg.py | 226 +++++++++++++++++- .../query/doc_embeddings/qdrant/service.py | 42 +--- .../query/graph_embeddings/qdrant/service.py | 42 +--- .../query/row_embeddings/qdrant/service.py | 20 +- .../query/rows/cassandra/service.py | 60 ++--- .../query/triples/cassandra/service.py | 167 +++++-------- .../storage/doc_embeddings/qdrant/write.py | 58 +++-- .../storage/graph_embeddings/qdrant/write.py | 58 +++-- .../storage/row_embeddings/qdrant/write.py | 76 +++--- .../storage/rows/cassandra/write.py | 78 +++--- .../storage/triples/cassandra/write.py | 179 ++++---------- 22 files changed, 736 insertions(+), 621 deletions(-) diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 514a5dbf..290e1348 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -63,26 +63,26 @@ class TestEndToEndConfigurationFlow: 'CASSANDRA_USERNAME': 'obj-user', 'CASSANDRA_PASSWORD': 'obj-pass' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) - + # Trigger Cassandra connection processor.connect_cassandra() - + # Verify auth provider was created with env vars mock_auth_provider.assert_called_once_with( username='obj-user', password='obj-pass' ) - + # Verify cluster was created with hosts from env and auth mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -188,37 +188,34 @@ class TestConfigurationPriorityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra_kg.Cluster') - async def test_no_config_defaults_end_to_end(self, mock_cluster): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_no_config_defaults_end_to_end(self, mock_kg_class): """Test that defaults are used when no configuration provided end-to-end.""" - mock_cluster_instance = MagicMock() - mock_session = MagicMock() - mock_cluster_instance.connect.return_value = mock_session - mock_cluster.return_value = mock_cluster_instance - + from unittest.mock import AsyncMock + + mock_tg_instance = MagicMock() + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_kg_class.return_value = mock_tg_instance + with patch.dict(os.environ, {}, clear=True): processor = TriplesQuery(taskgroup=MagicMock()) - + # Mock query to trigger TrustGraph creation mock_query = MagicMock() mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None mock_query.o = None + mock_query.g = None mock_query.limit = 100 - - # Mock the get_all method to return empty list - mock_tg_instance = MagicMock() - mock_tg_instance.get_all.return_value = [] - processor.tg = mock_tg_instance - + await processor.query_triples('default_user', mock_query) - + # Should use defaults - mock_cluster.assert_called_once() - call_args = mock_cluster.call_args - assert call_args.args[0] == ['cassandra'] # Default host - assert 'auth_provider' not in call_args.kwargs # No auth with default config + mock_kg_class.assert_called_once_with( + hosts=['cassandra'], + keyspace='default_user' + ) class TestNoBackwardCompatibilityEndToEnd: @@ -324,16 +321,16 @@ class TestMultipleHostsHandling: env_vars = { 'CASSANDRA_HOST': 'host1,host2,host3,host4,host5' } - + mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Verify all hosts were passed to Cluster mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -392,27 +389,27 @@ class TestAuthenticationFlow: 'CASSANDRA_USERNAME': 'auth-user', 'CASSANDRA_PASSWORD': 'auth-secret' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should be created mock_auth_provider.assert_called_once_with( username='auth-user', password='auth-secret' ) - + # Cluster should be created with auth provider call_args = mock_cluster.call_args assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): @@ -421,21 +418,21 @@ class TestAuthenticationFlow: 'CASSANDRA_HOST': 'no-auth-host' # No username/password } - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should not be created mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): @@ -446,15 +443,15 @@ class TestAuthenticationFlow: cassandra_username='partial-user' # No password ) - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + processor.connect_cassandra() - + # Auth provider should not be created (needs both username AND password) mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index 1358d420..d668600c 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -101,6 +101,8 @@ class TestRowsCassandraIntegration: processor.session = None # Bind actual methods from the new unified table implementation + import asyncio + processor._setup_lock = asyncio.Lock() processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) @@ -108,6 +110,7 @@ class TestRowsCassandraIntegration: processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_object = Processor.on_object.__get__(processor, Processor) processor.collection_exists = MagicMock(return_value=True) diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index 29b4464d..a455accd 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration: await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra - processor.connect_cassandra() + await processor.connect_cassandra() assert processor.session is not None # Create test keyspace and table @@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration: """Test inserting data and querying via GraphQL""" # Load schema and connect await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Setup test data keyspace = "test_user" @@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration: """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "test_user" collection = "filter_test" @@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration: """Test full message processing workflow""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create mock message request = RowsQueryRequest( @@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling multiple concurrent GraphQL queries""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create multiple query tasks queries = [ @@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling of large query result sets""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "large_test_user" collection = "large_collection" diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index bb6bbe84..b61500a4 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" + import asyncio processor = MagicMock() processor.schemas = {} processor.schema_builders = {} processor.graphql_schemas = {} processor.config_key = "schema" processor.query_cassandra = MagicMock() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -335,7 +338,7 @@ class TestUnifiedTableQueries: """Test query execution with matching index""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) @@ -396,7 +399,7 @@ class TestUnifiedTableQueries: """Test query execution without matching index (scan mode)""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 09681214..980fa904 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -2,8 +2,10 @@ Tests for Cassandra triples query service """ +import asyncio + import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.query.triples.cassandra.service import Processor, create_term from trustgraph.schema import Term, IRI, LITERAL @@ -18,7 +20,7 @@ class TestCassandraQueryProcessor: return Processor( taskgroup=MagicMock(), id='test-cassandra-query', - graph_host='localhost' + cassandra_host='localhost' ) def test_create_term_with_http_uri(self, processor): @@ -85,7 +87,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -110,8 +112,8 @@ class TestCassandraQueryProcessor: keyspace='test_user' ) - # Verify get_spo was called with correct parameters - mock_tg_instance.get_spo.assert_called_once_with( + # Verify async_get_spo was called with correct parameters + mock_tg_instance.async_get_spo.assert_called_once_with( 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 ) @@ -130,23 +132,25 @@ class TestCassandraQueryProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, cassandra_host='cassandra.example.com', cassandra_username='queryuser', cassandra_password='querypass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'queryuser' assert processor.cassandra_password == 'querypass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -164,7 +168,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_sp.return_value = [mock_result] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -178,7 +182,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) + mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'test_predicate' @@ -200,7 +204,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_s.return_value = [mock_result] + mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -214,7 +218,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) + mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'result_predicate' @@ -236,7 +240,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_p.return_value = [mock_result] + mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -250,7 +254,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) + mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'test_predicate' @@ -272,7 +276,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_o.return_value = [mock_result] + mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -286,7 +290,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) + mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'result_predicate' @@ -305,11 +309,11 @@ class TestCassandraQueryProcessor: mock_result.s = 'all_subject' mock_result.p = 'all_predicate' mock_result.o = 'all_object' - mock_result.g = '' + mock_result.d = '' mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_all.return_value = [mock_result] + mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -323,7 +327,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) + mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.iri == 'all_subject' assert result[0].p.iri == 'all_predicate' @@ -410,7 +414,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -451,7 +455,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -489,8 +493,8 @@ class TestCassandraQueryProcessor: mock_result.lang = None mock_result.p = 'p' mock_result.o = 'o' - mock_tg_instance1.get_s.return_value = [mock_result] - mock_tg_instance2.get_s.return_value = [mock_result] + mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result]) + mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -504,7 +508,6 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user1', query1) - assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( @@ -516,10 +519,11 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user2', query2) - assert processor.table == 'user2' - # Verify TrustGraph was created twice + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -529,7 +533,7 @@ class TestCassandraQueryProcessor: mock_tg_instance = MagicMock() mock_kg_class.return_value = mock_tg_instance - mock_tg_instance.get_spo.side_effect = Exception("Query failed") + mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed")) processor = Processor(taskgroup=MagicMock()) @@ -566,7 +570,7 @@ class TestCassandraQueryProcessor: mock_result2.otype = None mock_result2.dtype = None mock_result2.lang = None - mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2]) processor = Processor(taskgroup=MagicMock()) @@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_po.return_value = [mock_result] + mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_po was called (should use optimized po_table) - mock_tg_instance.get_po.assert_called_once_with( + # Verify async_get_po was called (should use optimized po_table) + mock_tg_instance.async_get_po.assert_called_once_with( 'test_collection', 'test_predicate', 'test_object', g=None, limit=50 ) @@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_os.return_value = [mock_result] + mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_os was called (should use optimized subject_table with clustering) - mock_tg_instance.get_os.assert_called_once_with( + # Verify async_get_os was called (should use optimized subject_table with clustering) + mock_tg_instance.async_get_os.assert_called_once_with( 'test_collection', 'test_object', 'test_subject', g=None, limit=25 ) @@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations: mock_kg_class.return_value = mock_tg_instance # Mock empty results for all queries - mock_tg_instance.get_all.return_value = [] - mock_tg_instance.get_s.return_value = [] - mock_tg_instance.get_p.return_value = [] - mock_tg_instance.get_o.return_value = [] - mock_tg_instance.get_sp.return_value = [] - mock_tg_instance.get_po.return_value = [] - mock_tg_instance.get_os.return_value = [] - mock_tg_instance.get_spo.return_value = [] + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_tg_instance.async_get_s = AsyncMock(return_value=[]) + mock_tg_instance.async_get_p = AsyncMock(return_value=[]) + mock_tg_instance.async_get_o = AsyncMock(return_value=[]) + mock_tg_instance.async_get_sp = AsyncMock(return_value=[]) + mock_tg_instance.async_get_po = AsyncMock(return_value=[]) + mock_tg_instance.async_get_os = AsyncMock(return_value=[]) + mock_tg_instance.async_get_spo = AsyncMock(return_value=[]) processor = Processor(taskgroup=MagicMock()) # Test each query pattern test_patterns = [ # (s, p, o, expected_method) - (None, None, None, 'get_all'), # All triples - ('s1', None, None, 'get_s'), # Subject only - (None, 'p1', None, 'get_p'), # Predicate only - (None, None, 'o1', 'get_o'), # Object only - ('s1', 'p1', None, 'get_sp'), # Subject + Predicate - (None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) - ('s1', None, 'o1', 'get_os'), # Object + Subject - ('s1', 'p1', 'o1', 'get_spo'), # All three + (None, None, None, 'async_get_all'), # All triples + ('s1', None, None, 'async_get_s'), # Subject only + (None, 'p1', None, 'async_get_p'), # Predicate only + (None, None, 'o1', 'async_get_o'), # Object only + ('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate + (None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) + ('s1', None, 'o1', 'async_get_os'), # Object + Subject + ('s1', 'p1', 'o1', 'async_get_spo'), # All three ] for s, p, o, expected_method in test_patterns: @@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.lang = None mock_results.append(mock_result) - mock_tg_instance.get_po.return_value = mock_results + mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results) processor = Processor(taskgroup=MagicMock()) @@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('large_dataset_user', query) - # Verify optimized get_po was used (no ALLOW FILTERING needed!) - mock_tg_instance.get_po.assert_called_once_with( + # Verify optimized async_get_po was used (no ALLOW FILTERING needed!) + mock_tg_instance.async_get_po.assert_called_once_with( 'massive_collection', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://example.com/Person', diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 2296e961..dbe06b40 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_embedding_upserted(self): + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_dimension_in_collection_name(self): """Collection name should include vector dimension.""" + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "docs" @@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_entity_and_vector_upserted(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_lazy_collection_creation_on_new_dimension(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = False proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "graphs" diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index ce6e6b3d..360ac3dc 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions - # Verify collection existence is checked on each write - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + # Second write uses cached collection state — no collection_exists check + mock_qdrant_instance.collection_exists.assert_not_called() # But upsert should still be called mock_qdrant_instance.upsert.assert_called_once() diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index 8754f47c..44fdf516 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("test_collection", 384) + await processor.ensure_collection("test_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") mock_qdrant_instance.create_collection.assert_called_once() # Verify the collection is cached - assert "test_collection" in processor.created_collections + assert "test_collection" in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_skips_existing(self, mock_qdrant_client): @@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("existing_collection", 384) + await processor.ensure_collection("existing_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once() mock_qdrant_instance.create_collection.assert_not_called() @@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add("cached_collection") + processor._known_collections.add("cached_collection") - processor.ensure_collection("cached_collection", 384) + await processor.ensure_collection("cached_collection", 384) # Should not check or create - just return mock_qdrant_instance.collection_exists.assert_not_called() @@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') + processor._known_collections.add('rows_test_workspace_test_collection_schema1_384') await processor.delete_collection('test_workspace', 'test_collection') @@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index 852f01a1..3e5664ea 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configurations""" + import asyncio processor = MagicMock() processor.schemas = {} processor.config_key = "schema" processor.registered_partitions = set() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test configuration diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 04acbb16..394f0e54 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -2,6 +2,8 @@ Tests for Cassandra triples storage service """ +import asyncio + import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -24,12 +26,13 @@ class TestCassandraStorageProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters (new cassandra_* names)""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, id='custom-storage', @@ -37,11 +40,12 @@ class TestCassandraStorageProcessor: cassandra_username='testuser', cassandra_password='testpass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'testuser' assert processor.cassandra_password == 'testpass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_partial_auth(self): """Test processor initialization with only username (no password)""" @@ -92,6 +96,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor( @@ -114,7 +119,6 @@ class TestCassandraStorageProcessor: username='testuser', password='testpass' ) - assert processor.table == 'user1' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -122,6 +126,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -138,7 +143,6 @@ class TestCassandraStorageProcessor: hosts=['cassandra'], # Updated default keyspace='user2' ) - assert processor.table == 'user2' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -146,6 +150,7 @@ class TestCassandraStorageProcessor: """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -169,6 +174,7 @@ class TestCassandraStorageProcessor: """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -208,12 +214,12 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) - assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call( + assert mock_tg_instance.async_insert.call_count == 2 + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject1', 'predicate1', 'object1', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) - mock_tg_instance.insert.assert_any_call( + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject2', 'predicate2', 'object2', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) @@ -224,6 +230,7 @@ class TestCassandraStorageProcessor: """Test behavior when message has no triples""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -236,19 +243,17 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify no triples were inserted - mock_tg_instance.insert.assert_not_called() + mock_tg_instance.async_insert.assert_not_called() @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - @patch('trustgraph.storage.triples.cassandra.write.time.sleep') - async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class): + async def test_exception_handling_on_connection_failure(self, mock_kg_class): """Test exception handling during TrustGraph creation""" taskgroup_mock = MagicMock() mock_kg_class.side_effect = Exception("Connection failed") processor = Processor(taskgroup=taskgroup_mock) - # Create mock message mock_message = MagicMock() mock_message.metadata.collection = 'collection1' mock_message.triples = [] @@ -256,9 +261,6 @@ class TestCassandraStorageProcessor: with pytest.raises(Exception, match="Connection failed"): await processor.store_triples('user1', mock_message) - # Verify sleep was called before re-raising - mock_sleep.assert_called_once_with(1) - def test_add_args_method(self): """Test that add_args properly configures argument parser""" from argparse import ArgumentParser @@ -359,8 +361,6 @@ class TestCassandraStorageProcessor: mock_message1.triples = [] await processor.store_triples('user1', mock_message1) - assert processor.table == 'user1' - assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() @@ -368,11 +368,11 @@ class TestCassandraStorageProcessor: mock_message2.triples = [] await processor.store_triples('user2', mock_message2) - assert processor.table == 'user2' - assert processor.tg == mock_tg_instance2 - # Verify TrustGraph was created twice for different tables + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -380,6 +380,7 @@ class TestCassandraStorageProcessor: """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -405,7 +406,7 @@ class TestCassandraStorageProcessor: await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', @@ -418,29 +419,29 @@ class TestCassandraStorageProcessor: @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): - """Test that table remains unchanged when TrustGraph creation fails""" + async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class): + """Test that a failed connection doesn't leave stale cached state""" taskgroup_mock = MagicMock() + mock_good_instance = MagicMock() processor = Processor(taskgroup=taskgroup_mock) - # Set an initial table - processor.table = ('old_user', 'old_collection') - - # Mock TrustGraph to raise exception - mock_kg_class.side_effect = Exception("Connection failed") - mock_message = MagicMock() - mock_message.metadata.collection = 'new_collection' + mock_message.metadata.collection = 'collection1' mock_message.triples = [] + # First call fails + mock_kg_class.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples('new_user', mock_message) + await processor.store_triples('user1', mock_message) - # Table should remain unchanged since self.table = table happens after try/except - assert processor.table == ('old_user', 'old_collection') - # TrustGraph should be set to None though - assert processor.tg is None + # Second call succeeds — should retry connection, not use stale state + mock_kg_class.side_effect = None + mock_kg_class.return_value = mock_good_instance + await processor.store_triples('user1', mock_message) + + # Connection was attempted twice (failed + succeeded) + assert mock_kg_class.call_count == 2 class TestCassandraPerformanceOptimizations: @@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations: """Test that legacy mode still works with single table""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): @@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations: """Test that optimized mode uses multi-table schema""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): @@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations: """Test that all tables stay consistent during batch writes""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'collection1', 'test_subject', 'test_predicate', 'test_object', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 51cf834f..f1297e1c 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -89,7 +89,8 @@ class TestSanitizeName: class TestFindCollection: - def test_finds_matching_collection(self): + @pytest.mark.asyncio + async def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_test_workspace_test_col_customers_384" @@ -98,11 +99,12 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result == "rows_test_workspace_test_col_customers_384" - def test_returns_none_when_no_match(self): + @pytest.mark.asyncio + async def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_other_workspace_other_col_schema_768" @@ -111,14 +113,15 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result is None - def test_returns_none_on_error(self): + @pytest.mark.asyncio + async def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("workspace", "col", "schema") + result = await proc.find_collection("workspace", "col", "schema") assert result is None @@ -139,7 +142,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_no_collection_returns_empty(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value=None) + proc.find_collection = AsyncMock(return_value=None) request = _make_request() result = await proc.query_row_embeddings("test-workspace", request) @@ -148,7 +151,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -172,7 +175,7 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -188,7 +191,7 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -204,7 +207,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -225,7 +228,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 59d2a2a1..d7abd1a9 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -1,10 +1,14 @@ +import datetime +import os +import logging + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement, SimpleStatement from ssl import SSLContext, PROTOCOL_TLSv1_2 -import os -import logging + +from ..tables.cassandra_async import async_execute # Global list to track clusters for cleanup _active_clusters = [] @@ -461,7 +465,6 @@ class KnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph: logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + # ======================================================================== + # Async methods — use cassandra driver's native async API via async_execute + # ======================================================================== + + async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""): + if g is None: + g = DEFAULT_GRAPH + if otype is None: + if o.startswith("http://") or o.startswith("https://"): + otype = "u" + else: + otype = "l" + + batch = BatchStatement() + batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang)) + if otype == 'u' or otype == 't': + batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang)) + if g != DEFAULT_GRAPH: + batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang)) + + await async_execute(self.session, batch) + + async def async_get_all(self, collection, limit=50): + return await async_execute( + self.session, self.get_collection_all_stmt, (collection, limit) + ) + + async def async_get_s(self, collection, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_p(self, collection, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_p_stmt, (collection, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_o(self, collection, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_stmt, (collection, o, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_sp(self, collection, s, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_po(self, collection, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_os(self, collection, o, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=row.p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_spo(self, collection, s, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_g(self, collection, g, limit=50): + if g is None: + g = DEFAULT_GRAPH + return await async_execute( + self.session, self.get_collection_by_graph_stmt, (collection, g, limit) + ) + + async def async_collection_exists(self, collection): + try: + result = await async_execute( + self.session, + f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", + (collection,) + ) + return bool(result) + except Exception as e: + logger.error(f"Error checking collection existence: {e}") + return False + + async def async_create_collection(self, collection): + await async_execute( + self.session, + f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", + (collection, datetime.datetime.now()) + ) + logger.info(f"Created collection metadata for {collection}") + + async def async_delete_collection(self, collection): + rows = await async_execute( + self.session, + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s", + (collection,) + ) + + entities = set() + quads = [] + for row in rows: + d, s, p, o = row.d, row.s, row.p, row.o + otype = row.otype + dtype = row.dtype if hasattr(row, 'dtype') else '' + lang = row.lang if hasattr(row, 'lang') else '' + quads.append((d, s, p, o, otype, dtype, lang)) + entities.add(s) + entities.add(p) + if otype == 'u' or otype == 't': + entities.add(o) + if d != DEFAULT_GRAPH: + entities.add(d) + + batch = BatchStatement() + count = 0 + for entity in entities: + batch.add(self.delete_entity_partition_stmt, (collection, entity)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + batch = BatchStatement() + count = 0 + for d, s, p, o, otype, dtype, lang in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + await async_execute( + self.session, + f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", + (collection,) + ) + logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + def close(self): """Close connections""" if hasattr(self, 'session') and self.session: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 1d59c835..f6770744 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array of chunk_ids """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error @@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) async def query_document_embeddings(self, workspace, msg): @@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"d_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist, returning empty results") return [] - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit, with_payload=True, - ).points + ) + search_result = result.points chunks = [] for r in search_result: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index b8fb1361..167130c9 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of entities """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL @@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"t_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist") return [] # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) unique entities - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit * 2, with_payload=True, - ).points + ) + search_result = result.points entity_set = set() entities = [] diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index dd89a8d8..1534c044 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ +import asyncio import logging import re from typing import Optional @@ -70,7 +71,7 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( f"rows_{self.sanitize_name(workspace)}_" @@ -78,14 +79,15 @@ class Processor(FlowProcessor): ) try: - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching = [ coll.name for coll in all_collections if coll.name.startswith(prefix) ] if matching: - # Return first match (there should typically be only one per dimension) return matching[0] except Exception as e: @@ -100,8 +102,7 @@ class Processor(FlowProcessor): if not vec: return [] - # Find the collection for this workspace/collection/schema - qdrant_collection = self.find_collection( + qdrant_collection = await self.find_collection( workspace, request.collection, request.schema_name ) @@ -113,7 +114,6 @@ class Processor(FlowProcessor): return [] try: - # Build optional filter for index_name query_filter = None if request.index_name: query_filter = Filter( @@ -125,16 +125,16 @@ class Processor(FlowProcessor): ] ) - # Query Qdrant - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=qdrant_collection, query=vec, limit=request.limit, with_payload=True, query_filter=query_filter, - ).points + ) + search_result = result.points - # Convert to RowIndexMatch objects matches = [] for point in search_result: payload = point.payload or {} diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 73cfcd83..7157daae 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema: - source: text """ +import asyncio import json import logging import re @@ -97,34 +98,38 @@ class Processor(FlowProcessor): # Cassandra session self.cluster = None self.session = None + self._setup_lock = asyncio.Lock() # Known keyspaces self.known_keyspaces: Set[str] = set() - def connect_cassandra(self): + async def connect_cassandra(self): """Connect to Cassandra cluster""" - if self.session: - return + async with self._setup_lock: + if self.session: + return - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + cluster = Cluster(contact_points=self.cassandra_host) - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + session = await asyncio.to_thread(cluster.connect) + self.cluster = cluster + self.session = session + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -140,14 +145,17 @@ class Processor(FlowProcessor): f"for workspace {workspace}" ) - # Replace existing schemas for this workspace + async with self._setup_lock: + await self._apply_schema_config(workspace, config) + + async def _apply_schema_config(self, workspace, config): + ws_schemas: Dict[str, RowSchema] = {} self.schemas[workspace] = ws_schemas builder = GraphQLSchemaBuilder() self.schema_builders[workspace] = builder - # Check if our config type exists if self.config_key not in config: logger.warning( f"No '{self.config_key}' type in configuration " @@ -156,16 +164,12 @@ class Processor(FlowProcessor): self.graphql_schemas[workspace] = None return - # Get the schemas dictionary for our type schemas_config = config[self.config_key] - # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: - # Parse the JSON schema definition schema_def = json.loads(schema_json) - # Create Field objects fields = [] for field_def in schema_def.get("fields", []): field = SchemaField( @@ -180,7 +184,6 @@ class Processor(FlowProcessor): ) fields.append(field) - # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), @@ -202,7 +205,6 @@ class Processor(FlowProcessor): f"{len(ws_schemas)} schemas" ) - # Regenerate GraphQL schema for this workspace self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: @@ -254,7 +256,7 @@ class Processor(FlowProcessor): For other queries, we need to scan and post-filter. """ # Connect if needed - self.connect_cassandra() + await self.connect_cassandra() safe_keyspace = self.sanitize_name(workspace) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index a9bdbbac..1fadaab3 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,8 +6,8 @@ null. Output is a list of quads. import asyncio import logging - import json + from cassandra.query import SimpleStatement from .... direct.cassandra_kg import ( @@ -17,6 +17,7 @@ from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... tables.cassandra_async import async_execute # Module logger logger = logging.getLogger(__name__) @@ -176,45 +177,42 @@ class Processor(TriplesQueryService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - def ensure_connection(self, workspace): - """Ensure we have a connection to the correct keyspace.""" - if workspace != self.table: - KGClass = EntityCentricKnowledgeGraph + self._connections = {} + self._conn_lock = asyncio.Lock() - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - self.table = workspace + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] async def query_triples(self, workspace, query): try: - # ensure_connection may construct a fresh - # EntityCentricKnowledgeGraph which does sync schema - # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per workspace. - await asyncio.to_thread(self.ensure_connection, workspace) - - # Extract values from query s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) - g_val = query.g # Already a string or None + g_val = query.g + + tg = await self._get_connection(workspace) def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), @@ -223,33 +221,21 @@ class Processor(TriplesQueryService): quads = [] - # All self.tg.get_* calls below are sync wrappers around - # cassandra session.execute. Materialise inside a worker - # thread so iteration never triggers sync paging back on - # the event loop. - - # Route to appropriate query method based on which fields are specified if s_val is not None: if p_val is not None: if o_val is not None: - # SPO specified - find matching graphs - resp = await asyncio.to_thread( - lambda: list(self.tg.get_spo( - query.collection, s_val, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_spo( + query.collection, s_val, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: - # SP specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_sp( - query.collection, s_val, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_sp( + query.collection, s_val, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -257,24 +243,18 @@ class Processor(TriplesQueryService): quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # SO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_os( - query.collection, o_val, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_os( + query.collection, o_val, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: - # S only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_s( - query.collection, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_s( + query.collection, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -283,24 +263,18 @@ class Processor(TriplesQueryService): else: if p_val is not None: if o_val is not None: - # PO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_po( - query.collection, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_po( + query.collection, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: - # P only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_p( - query.collection, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_p( + query.collection, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -308,40 +282,26 @@ class Processor(TriplesQueryService): quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # O only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_o( - query.collection, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_o( + query.collection, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: - # Nothing specified - get all - resp = await asyncio.to_thread( - lambda: list(self.tg.get_all( - query.collection, limit=query.limit, - )) + resp = await tg.async_get_all( + query.collection, limit=query.limit, ) for t in resp: - # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - # Filter by graph - # g_val=None means all graphs (no filter) - # g_val="" means default graph only - # otherwise filter to specific named graph if g_val is not None: if g != g_val: continue term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) - # Convert to Triple objects (with g field) - # s and p are always IRIs in RDF - # Object uses term_type/datatype/language metadata from database triples = [ Triple( s=create_term(q[0], term_type='u'), @@ -365,51 +325,36 @@ class Processor(TriplesQueryService): Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 - # Extract query pattern s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) g_val = query.g def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), getattr(row, 'lang', None), ) - # For streaming, we need to execute with fetch_size - # Use the collection table for get_all queries (most common streaming case) - - # Determine which query to use based on pattern if s_val is None and p_val is None and o_val is None: - # Get all - use collection table with paging - cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + + tg = await self._get_connection(workspace) + + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s" params = [query.collection] + statement = SimpleStatement(cql, fetch_size=batch_size) + result_set = await async_execute(tg.session, statement, params) + else: - # For specific patterns, fall back to non-streaming - # (these typically return small result sets anyway) async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return - # Materialise in a worker thread. We lose true streaming - # paging (the driver fetches all pages eagerly inside the - # thread) but the event loop stays responsive, and result - # sets at this layer are typically small enough that this - # is acceptable. If true async paging is needed later, - # revisit using ResponseFuture page callbacks. - statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = await asyncio.to_thread( - lambda: list(self.tg.session.execute(statement, params)) - ) - batch = [] count = 0 diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index fb7166b5..2bfef99c 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_document_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"d_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "chunk_id": chunk_id, } ) - ] + ], ) @staticmethod @@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): try: prefix = f"d_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 391c2a04..13dcdba8 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_graph_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"t_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) payload = { "entity": entity_value, @@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity.chunk_id: payload["chunk_id"] = entity.chunk_id - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): vector=vec, payload=payload, ) - ] + ], ) @staticmethod @@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): try: prefix = f"t_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 32d87871..a01629c5 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -16,10 +16,10 @@ Payload structure: - text: The text that was embedded (for debugging/display) """ +import asyncio import logging import re import uuid -from typing import Set, Tuple from qdrant_client import QdrantClient from qdrant_client.models import PointStruct, Distance, VectorParams @@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register config handler for collection management self.register_config_handler(self.on_collection_config, types=["collection"]) - # Cache of created Qdrant collections - self.created_collections: Set[str] = set() - - # Qdrant client self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() def sanitize_name(self, name: str) -> str: """Sanitize names for Qdrant collection naming""" @@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor): safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" - def ensure_collection(self, collection_name: str, dimension: int): + async def ensure_collection(self, collection_name: str, dimension: int): """Create Qdrant collection if it doesn't exist""" - if collection_name in self.created_collections: - return - - if not self.qdrant.collection_exists(collection_name): - logger.info( - f"Creating Qdrant collection {collection_name} " - f"with dimension {dimension}" + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name ) - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=dimension, - distance=Distance.COSINE + if not exists: + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" ) - ) - - self.created_collections.add(collection_name) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) async def on_embeddings(self, msg, consumer, flow): """Process incoming RowEmbeddings and write to Qdrant""" @@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): dimension = len(vector) - # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( workspace, collection, schema_name, dimension ) - self.ensure_collection(qdrant_collection, dimension) + await self.ensure_collection(qdrant_collection, dimension) - # Write to Qdrant - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=qdrant_collection, points=[ PointStruct( @@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): "text": row_emb.text } ) - ] + ], ) embeddings_written += 1 @@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): try: prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " @@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") except Exception as e: diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index a5dad748..65eeee06 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Cache of known keyspaces and whether tables exist self.known_keyspaces: Set[str] = set() - self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables + self.tables_initialized: Set[str] = set() # Cache of registered (collection, schema_name) pairs self.registered_partitions: Set[Tuple[str, str]] = set() @@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.cluster = None self.session = None + # Protects connection setup and cache mutations + self._setup_lock = asyncio.Lock() + def connect_cassandra(self): """Connect to Cassandra cluster""" if self.session: @@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"for workspace {workspace}" ) + async with self._setup_lock: + return await self._apply_schema_config(workspace, config, version) + + async def _apply_schema_config(self, workspace, config, version): + # Track which schemas changed in this workspace old_schemas = self.schemas.get(workspace, {}) old_schema_names = set(old_schemas.keys()) @@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' - # Ensure tables exist (sync DDL — push to a worker thread - # so the event loop stays responsive when running in a - # processor group sharing the loop with siblings). - await asyncio.to_thread(self.ensure_tables, keyspace) - - # Register partitions if first time seeing this (collection, schema_name) - await asyncio.to_thread( - self.register_partitions, - keyspace, collection, schema_name, workspace, - ) + async with self._setup_lock: + await asyncio.to_thread(self.ensure_tables, keyspace) + await asyncio.to_thread( + self.register_partitions, + keyspace, collection, schema_name, workspace, + ) safe_keyspace = self.sanitize_name(keyspace) @@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" - # Connect if not already connected (sync, push to thread) - await asyncio.to_thread(self.connect_cassandra) - - # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, workspace) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + await asyncio.to_thread(self.ensure_tables, workspace) logger.info(f"Collection {collection} ready for workspace {workspace}") async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + if workspace not in self.known_keyspaces: + safe_ks = self.sanitize_name(workspace) + check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s" + result = await async_execute(self.session, check_cql, (safe_ks,)) + if not result: + logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete") + return + self.known_keyspaces.add(workspace) safe_keyspace = self.sanitize_name(workspace) - # Check if keyspace exists - if workspace not in self.known_keyspaces: - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = await async_execute( - self.session, check_keyspace_cql, (safe_keyspace,) - ) - if not result: - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(workspace) - # Discover all partitions for this collection select_partitions_cql = f""" SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions @@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to clean up row_partitions for {collection}: {e}") raise - # Clear from local cache - self.registered_partitions = { - (col, sch) for col, sch in self.registered_partitions - if col != collection - } + async with self._setup_lock: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } logger.info( f"Deleted collection {collection}: {partitions_deleted} partitions " @@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(workspace) @@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) raise - # Clear from local cache - self.registered_partitions.discard((collection, schema_name)) + async with self._setup_lock: + self.registered_partitions.discard((collection, schema_name)) logger.info( f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 0774153b..79d6c549 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ import asyncio -import base64 -import os -import argparse -import time import logging -import json from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, DEFAULT_GRAPH @@ -28,6 +23,8 @@ default_ident = "triples-write" def serialize_triple(triple): """Serialize a Triple object to JSON for storage.""" + import json + if triple is None: return None @@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - self.tg = None + + self._connections = {} + self._conn_lock = asyncio.Lock() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] + async def store_triples(self, workspace, message): - # The cassandra-driver work below — connection, schema - # setup, and per-triple inserts — is all synchronous. - # Wrap the whole batch in a worker thread so the event - # loop stays responsive for sibling processors when - # running in a processor group. + tg = await self._get_connection(workspace) - def _do_store(): + for t in message.triples: + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + g_val = t.g if t.g is not None else DEFAULT_GRAPH - if self.table is None or self.table != workspace: + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Exception: {e}", exc_info=True) - time.sleep(1) - raise e - - self.table = workspace - - for t in message.triples: - # Extract values from Term objects - s_val = get_term_value(t.s) - p_val = get_term_value(t.p) - o_val = get_term_value(t.o) - # t.g is None for default graph, or a graph IRI - g_val = t.g if t.g is not None else DEFAULT_GRAPH - - # Extract object type metadata for entity-centric storage - otype = get_term_otype(t.o) - dtype = get_term_dtype(t.o) - lang = get_term_lang(t.o) - - self.tg.insert( - message.metadata.collection, - s_val, - p_val, - o_val, - g=g_val, - otype=otype, - dtype=dtype, - lang=lang, - ) - - await asyncio.to_thread(_do_store) + await tg.async_insert( + message.metadata.collection, + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang, + ) async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" + try: + tg = await self._get_connection(workspace) - def _do_create(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Create collection using the built-in method logger.info(f"Creating collection {collection} for workspace {workspace}") - if self.tg.collection_exists(collection): + exists = await tg.async_collection_exists(collection) + if exists: logger.info(f"Collection {collection} already exists") else: - self.tg.create_collection(collection) + await tg.async_create_collection(collection) logger.info(f"Created collection {collection}") - try: - await asyncio.to_thread(_do_create) except Exception as e: logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" + try: + tg = await self._get_connection(workspace) - def _do_delete(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Delete all triples for this collection using the built-in method - self.tg.delete_collection(collection) + await tg.async_delete_collection(collection) logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") - try: - await asyncio.to_thread(_do_delete) except Exception as e: logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise From 846282c37521fd75fe1849658710c37d33aa43b1 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 21:03:09 +0100 Subject: [PATCH 04/16] Fixed error only returning a page of results (#921) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The root cause: async_execute only materialises the first result page (by design — it says so in its docstring). The streaming query set fetch_size=20 and expected to iterate all results, but only got the first 20 rows back. The fix uses asyncio.to_thread(lambda: list(tg.session.execute(...))) which lets the sync driver iterate all pages in a worker thread — exactly what the pre-async code did. --- .../trustgraph/query/triples/cassandra/service.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 1fadaab3..822dba25 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -17,7 +17,6 @@ from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK from .... base import TriplesQueryService from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config -from .... tables.cassandra_async import async_execute # Module logger logger = logging.getLogger(__name__) @@ -348,7 +347,12 @@ class Processor(TriplesQueryService): cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s" params = [query.collection] statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = await async_execute(tg.session, statement, params) + # async_execute only materialises the first page; + # this query needs all pages, so use sync execute + # in a worker thread where page iteration can block. + result_set = await asyncio.to_thread( + lambda: list(tg.session.execute(statement, params)) + ) else: async for batch, is_final in self._fallback_stream(workspace, query, batch_size): From 01b1fd849d2e7b1621320846ca42719cf25431c4 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 15 May 2026 12:58:12 +0100 Subject: [PATCH 05/16] Optional test warning suppression (#923) * Fix test collection module errors & silence upstream Pytest warnings (#823) * chore: add virtual environment and .env directories to gitignore * test: filter upstream DeprecationWarning and UserWarning messages * fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages * Revert __init__.py deletions * Add .ini changes but commented out, will be useful at times --------- Co-authored-by: Salil M --- .gitignore | 5 ++++- tests/pytest.ini | 11 ++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 32942156..366edb4a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py trustgraph/trustgraph/trustgraph_version.py -vertexai/ \ No newline at end of file +vertexai/ +venv/ +.venv/ +.env diff --git a/tests/pytest.ini b/tests/pytest.ini index 5dcc095c..a89759ab 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -16,4 +16,13 @@ markers = unit: marks tests as unit tests contract: marks tests as contract tests (service interface validation) vertexai: marks tests as vertex ai specific tests - asyncio: marks tests that use asyncio \ No newline at end of file + asyncio: marks tests that use asyncio +# This is helpful if you're bored with deprecationwarnings. I prefer to +# keep the warnings for now, it avoids masking problems. +# +# filterwarnings = +# ignore:Core Pydantic V1 functionality isn't compatible with Python 3.14.*:UserWarning +# ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning +# ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning +# ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning +# ignore:.*_UnionGenericAlias.*is deprecated and slated for removal in Python 3.17:DeprecationWarning From 58b5c5c8d571252362e1df566754047ec94e8c84 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 15 May 2026 13:32:30 +0100 Subject: [PATCH 06/16] fix(openai): fail fast on unrecoverable RateLimitError codes (#901) (#904) (#925) Co-authored-by: Sahil Yadav --- .../model/text_completion/openai/llm.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 1cefcbe9..df1bfdd1 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -104,7 +104,15 @@ class Processor(LlmService): return resp - except RateLimitError: + except RateLimitError as e: + try: + body = getattr(e, 'body', {}) + if isinstance(body, dict): + code = body.get('error', {}).get('code') + if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): + raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") + except Exception: + pass # Leave rate limit retries to the base handler raise TooManyRequests() @@ -188,7 +196,16 @@ class Processor(LlmService): logger.debug("Streaming complete") - except RateLimitError: + except RateLimitError as e: + try: + body = getattr(e, 'body', {}) + if isinstance(body, dict): + code = body.get('error', {}).get('code') + if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): + logger.warning(f"Hit unrecoverable rate limit error during streaming: {code}") + raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") + except Exception: + pass logger.warning("Hit rate limit during streaming") raise TooManyRequests() From 913f610db5ab46ef5df9c537693aea07defb01c9 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 15 May 2026 13:35:04 +0100 Subject: [PATCH 07/16] Ensure retry exception is properly raised (#926) --- .../trustgraph/model/text_completion/openai/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index df1bfdd1..c8ab9c36 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -111,7 +111,7 @@ class Processor(LlmService): code = body.get('error', {}).get('code') if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") - except Exception: + except (ValueError, KeyError, TypeError, AttributeError): pass # Leave rate limit retries to the base handler raise TooManyRequests() @@ -204,7 +204,7 @@ class Processor(LlmService): if code in ('insufficient_quota', 'invalid_api_key', 'account_deactivated'): logger.warning(f"Hit unrecoverable rate limit error during streaming: {code}") raise RuntimeError(f"OpenAI unrecoverable error: {code} - {body['error'].get('message', '')}") - except Exception: + except (ValueError, KeyError, TypeError, AttributeError): pass logger.warning("Hit rate limit during streaming") raise TooManyRequests() From aea4c2df8e27faf9c5a2c2881d0e4b71222b16b5 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 16 May 2026 11:32:51 +0100 Subject: [PATCH 08/16] fix: library API get/update document round-trip bugs (#893) (#928) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix 5 cascading bugs in the Library API wrapper that prevented the get_documents → update_document round-trip from working: - Tolerate missing title field in document metadata (use .get()) - Use attribute access on Triple objects instead of subscript - Serialize datetime to int seconds for JSON compatibility - Handle empty server response on successful update - Send both id and document-id keys in update request Added library API tests --- tests/unit/test_api/test_library_api.py | 296 ++++++++++++++++++++++ trustgraph-base/trustgraph/api/library.py | 27 +- 2 files changed, 312 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_api/test_library_api.py diff --git a/tests/unit/test_api/test_library_api.py b/tests/unit/test_api/test_library_api.py new file mode 100644 index 00000000..086ecd63 --- /dev/null +++ b/tests/unit/test_api/test_library_api.py @@ -0,0 +1,296 @@ +""" +Tests for the Library API wrapper round-trip behavior. +Covers the get_documents → update_document path and edge cases +from issue #893. +""" + +import datetime +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.api.library import Library, to_value, from_value +from trustgraph.api.types import DocumentMetadata, Triple +from trustgraph.knowledge import Uri, Literal + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_library(response=None): + api = MagicMock() + api.workspace = "default" + api.request.return_value = response or {} + lib = Library(api) + return lib, api + + +def _wire_triple(s_iri, p_iri, o_val): + return { + "s": {"t": "i", "i": s_iri}, + "p": {"t": "i", "i": p_iri}, + "o": {"t": "l", "v": o_val}, + } + + +def _doc_wire(id="doc-1", time=1700000000, title="Test Doc", + kind="text/plain", comments="", tags=None, + metadata=None, parent_id="", document_type="source", + include_title=True): + doc = { + "id": id, + "time": time, + "kind": kind, + "comments": comments, + "metadata": metadata or [], + "tags": tags or [], + "parent-id": parent_id, + "document-type": document_type, + } + if include_title: + doc["title"] = title + return doc + + +# --------------------------------------------------------------------------- +# Bug 1: get_documents tolerates missing title +# --------------------------------------------------------------------------- + +class TestGetDocumentsMissingTitle: + + def test_missing_title_defaults_to_empty(self): + doc = _doc_wire(include_title=False) + lib, api = _make_library({"document-metadatas": [doc]}) + + result = lib.get_documents() + + assert len(result) == 1 + assert result[0].title == "" + + def test_present_title_preserved(self): + doc = _doc_wire(title="My Title") + lib, api = _make_library({"document-metadatas": [doc]}) + + result = lib.get_documents() + + assert result[0].title == "My Title" + + +# --------------------------------------------------------------------------- +# Bug 2: update_document handles Triple objects (attribute access) +# --------------------------------------------------------------------------- + +class TestUpdateDocumentTripleAccess: + + def test_triple_objects_serialized_correctly(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[ + Triple( + s=Uri("http://example.org/entity/alice"), + p=Uri("http://example.org/rel/knows"), + o=Literal("Bob"), + ), + ], + tags=["test"], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + triples = call_args["document-metadata"]["metadata"] + + assert len(triples) == 1 + assert triples[0]["s"]["i"] == "http://example.org/entity/alice" + assert triples[0]["p"]["i"] == "http://example.org/rel/knows" + assert triples[0]["o"]["v"] == "Bob" + + def test_empty_metadata_list(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + assert call_args["document-metadata"]["metadata"] == [] + + +# --------------------------------------------------------------------------- +# Bug 3: update_document serializes datetime to int seconds +# --------------------------------------------------------------------------- + +class TestUpdateDocumentTimeSerialization: + + def test_datetime_serialized_to_int(self): + lib, api = _make_library({}) + + ts = 1700000000 + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(ts), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + wire_time = call_args["document-metadata"]["time"] + + assert isinstance(wire_time, int) + assert wire_time == ts + + def test_int_time_passed_through(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=1700000000, + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + assert call_args["document-metadata"]["time"] == 1700000000 + + +# --------------------------------------------------------------------------- +# Bug 4: update_document handles empty server response +# --------------------------------------------------------------------------- + +class TestUpdateDocumentEmptyResponse: + + def test_empty_response_returns_input_metadata(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Updated Title", + comments="notes", + metadata=[], + tags=["a"], + ) + + result = lib.update_document(id="doc-1", metadata=metadata) + + assert result is metadata + + def test_full_response_parsed(self): + response_doc = _doc_wire( + id="doc-1", title="Server Title", tags=["b"], + ) + lib, api = _make_library({"document-metadata": response_doc}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Client Title", + comments="", + metadata=[], + tags=["a"], + ) + + result = lib.update_document(id="doc-1", metadata=metadata) + + assert result.title == "Server Title" + assert result.tags == ["b"] + + +# --------------------------------------------------------------------------- +# Bug 5: update_document sends both id and document-id +# --------------------------------------------------------------------------- + +class TestUpdateDocumentIdKeys: + + def test_both_id_keys_sent(self): + lib, api = _make_library({}) + + metadata = DocumentMetadata( + id="doc-1", + time=datetime.datetime.fromtimestamp(1700000000), + kind="text/plain", + title="Test", + comments="", + metadata=[], + tags=[], + ) + + lib.update_document(id="doc-1", metadata=metadata) + + call_args = api.request.call_args[0][1] + doc_meta = call_args["document-metadata"] + + assert doc_meta["id"] == "doc-1" + assert doc_meta["document-id"] == "doc-1" + + +# --------------------------------------------------------------------------- +# Round-trip: get_documents → update_document +# --------------------------------------------------------------------------- + +class TestGetUpdateRoundTrip: + + def test_full_round_trip(self): + wire_doc = _doc_wire( + id="doc-42", + title="Original", + tags=["v1"], + metadata=[_wire_triple( + "http://example.org/e/1", + "http://example.org/r/type", + "report", + )], + ) + + lib, api = _make_library({"document-metadatas": [wire_doc]}) + + docs = lib.get_documents() + assert len(docs) == 1 + + doc = docs[0] + doc.title = "Updated" + doc.tags.append("v2") + + # Server returns empty on update + api.request.return_value = {} + result = lib.update_document(id=doc.id, metadata=doc) + + # Should not raise, should return the input metadata + assert result.title == "Updated" + assert "v2" in result.tags + + # Verify the wire format sent + call_args = api.request.call_args[0][1] + doc_meta = call_args["document-metadata"] + + assert doc_meta["id"] == "doc-42" + assert doc_meta["title"] == "Updated" + assert isinstance(doc_meta["time"], int) + assert len(doc_meta["metadata"]) == 1 + assert doc_meta["metadata"][0]["o"]["v"] == "report" diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index 024e933d..b3506bb7 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -365,7 +365,7 @@ class Library: id = v["id"], time = datetime.datetime.fromtimestamp(v["time"]), kind = v["kind"], - title = v["title"], + title = v.get("title", ""), comments = v.get("comments", ""), metadata = [ Triple( @@ -482,14 +482,15 @@ class Library: "workspace": self.api.workspace, "document-metadata": { "document-id": id, - "time": metadata.time, + "id": id, + "time": int(metadata.time.timestamp()) if hasattr(metadata.time, 'timestamp') else metadata.time, "title": metadata.title, "comments": metadata.comments, "metadata": [ { - "s": from_value(t["s"]), - "p": from_value(t["p"]), - "o": from_value(t["o"]), + "s": from_value(t.s), + "p": from_value(t.p), + "o": from_value(t.o), } for t in metadata.metadata ], @@ -498,14 +499,17 @@ class Library: } object = self.request(input) - doc = object["document-metadata"] + doc = object.get("document-metadata") if isinstance(object, dict) else None + + if not doc: + return metadata try: - DocumentMetadata( + return DocumentMetadata( id = doc["id"], time = datetime.datetime.fromtimestamp(doc["time"]), kind = doc["kind"], - title = doc["title"], + title = doc.get("title", ""), comments = doc.get("comments", ""), metadata = [ Triple( @@ -513,10 +517,11 @@ class Library: p = to_value(w["p"]), o = to_value(w["o"]) ) - for w in doc["metadata"] + for w in doc.get("metadata", []) ], - workspace = doc.get("workspace", ""), - tags = doc["tags"] + tags = doc.get("tags", []), + parent_id = doc.get("parent-id", ""), + document_type = doc.get("document-type", "source"), ) except Exception as e: logger.error("Failed to parse document update response", exc_info=True) From 38d9c746a8fff98d462d5c8662f2c6ed286b0479 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 16 May 2026 15:13:38 +0100 Subject: [PATCH 09/16] Fix ontology selector defaults, add bypass mode, enforce domain/range (#929) - Align similarity_threshold default to 0.3 everywhere (class signature had stale 0.7). Fix matching contradiction in tech-spec. - Add bypass_selector_below parameter (default 5) to skip vector similarity selection when ontology element count is small enough. - Enforce domain/range constraints in TripleConverter for object properties and datatype properties, with subclass hierarchy support. Properties with no declared domain/range pass through unchanged. - Add unit tests for domain/range validation, subclass acceptance, polymorphic pass-through, and selector bypass. Fixes #908, #920 --- docs/tech-specs/ontorag.md | 2 +- .../test_triple_converter_validation.py | 389 ++++++++++++++++++ .../trustgraph/extract/kg/ontology/extract.py | 11 +- .../extract/kg/ontology/ontology_selector.py | 52 ++- .../extract/kg/ontology/triple_converter.py | 60 ++- 5 files changed, 501 insertions(+), 13 deletions(-) create mode 100644 tests/unit/test_extract/test_ontology/test_triple_converter_validation.py diff --git a/docs/tech-specs/ontorag.md b/docs/tech-specs/ontorag.md index 86a3cd19..460e72ba 100644 --- a/docs/tech-specs/ontorag.md +++ b/docs/tech-specs/ontorag.md @@ -278,7 +278,7 @@ The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for e 3. **Similarity Search**: - For each text segment embedding, search the vector store - Retrieve top-k (e.g., 10) most similar ontology elements - - Apply similarity threshold (e.g., 0.7) to filter weak matches + - Apply similarity threshold (e.g., 0.3) to filter weak matches - Aggregate results across all segments, tracking match frequencies 4. **Dependency Resolution**: diff --git a/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py new file mode 100644 index 00000000..195e8adf --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py @@ -0,0 +1,389 @@ +""" +Tests for TripleConverter domain/range enforcement and +OntologySelector bypass for small ontologies. + +Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation). +""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from trustgraph.extract.kg.ontology.triple_converter import TripleConverter +from trustgraph.extract.kg.ontology.ontology_selector import ( + OntologySelector, + OntologySubset, +) +from trustgraph.extract.kg.ontology.ontology_loader import ( + Ontology, + OntologyClass, + OntologyProperty, +) +from trustgraph.extract.kg.ontology.simplified_parser import ( + Relationship, + Attribute, +) +from trustgraph.extract.kg.ontology.text_processor import TextSegment + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def ontology_subset(): + """Ontology subset with classes, hierarchy, and constrained properties.""" + return OntologySubset( + ontology_id="test", + classes={ + "Person": { + "uri": "http://example.org/Person", + "type": "owl:Class", + "labels": [{"value": "Person"}], + "subclass_of": None, + }, + "Employee": { + "uri": "http://example.org/Employee", + "type": "owl:Class", + "labels": [{"value": "Employee"}], + "subclass_of": "Person", + }, + "Manager": { + "uri": "http://example.org/Manager", + "type": "owl:Class", + "labels": [{"value": "Manager"}], + "subclass_of": "Employee", + }, + "Company": { + "uri": "http://example.org/Company", + "type": "owl:Class", + "labels": [{"value": "Company"}], + "subclass_of": None, + }, + "Product": { + "uri": "http://example.org/Product", + "type": "owl:Class", + "labels": [{"value": "Product"}], + "subclass_of": None, + }, + }, + object_properties={ + "worksFor": { + "uri": "http://example.org/worksFor", + "type": "owl:ObjectProperty", + "labels": [{"value": "works for"}], + "domain": "Person", + "range": "Company", + }, + "manages": { + "uri": "http://example.org/manages", + "type": "owl:ObjectProperty", + "labels": [{"value": "manages"}], + "domain": "Manager", + "range": "Employee", + }, + "relatedTo": { + "uri": "http://example.org/relatedTo", + "type": "owl:ObjectProperty", + "labels": [{"value": "related to"}], + "domain": None, + "range": None, + }, + }, + datatype_properties={ + "employeeId": { + "uri": "http://example.org/employeeId", + "type": "owl:DatatypeProperty", + "labels": [{"value": "employee ID"}], + "domain": "Employee", + }, + "description": { + "uri": "http://example.org/description", + "type": "owl:DatatypeProperty", + "labels": [{"value": "description"}], + "domain": None, + }, + }, + metadata={"name": "Test Ontology"}, + ) + + +@pytest.fixture +def converter(ontology_subset): + return TripleConverter(ontology_subset=ontology_subset, ontology_id="test") + + +# --------------------------------------------------------------------------- +# Domain/range enforcement — relationships +# --------------------------------------------------------------------------- + +class TestRelationshipDomainRange: + + def test_valid_domain_and_range(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + triple = converter.convert_relationship(rel) + assert triple is not None + + def test_domain_violation_rejected(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is None + + def test_range_violation_rejected(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Widget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + def test_both_domain_and_range_violated(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Gadget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Subclass acceptance +# --------------------------------------------------------------------------- + +class TestSubclassAcceptance: + + def test_direct_subclass_matches_domain(self, converter): + """Employee is subclass of Person; worksFor domain is Person.""" + rel = Relationship( + subject="Bob", subject_type="Employee", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_transitive_subclass_matches_domain(self, converter): + """Manager → Employee → Person; worksFor domain is Person.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_subclass_matches_range(self, converter): + """manages range is Employee; Manager is subclass of Employee.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="manages", + object="Dave", object_type="Manager", + ) + assert converter.convert_relationship(rel) is not None + + def test_superclass_does_not_match_subclass_constraint(self, converter): + """manages domain is Manager; Person is NOT a subclass of Manager.""" + rel = Relationship( + subject="Alice", subject_type="Person", + relation="manages", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Polymorphic properties (no domain/range) +# --------------------------------------------------------------------------- + +class TestPolymorphicProperties: + + def test_no_domain_no_range_allows_anything(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="relatedTo", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_polymorphic_with_unrelated_types(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="relatedTo", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is not None + + +# --------------------------------------------------------------------------- +# Datatype property domain enforcement +# --------------------------------------------------------------------------- + +class TestAttributeDomainValidation: + + def test_valid_domain(self, converter): + attr = Attribute( + entity="Bob", entity_type="Employee", + attribute="employeeId", value="E-1234", + ) + assert converter.convert_attribute(attr) is not None + + def test_subclass_matches_domain(self, converter): + """Manager is subclass of Employee; employeeId domain is Employee.""" + attr = Attribute( + entity="Carol", entity_type="Manager", + attribute="employeeId", value="M-5678", + ) + assert converter.convert_attribute(attr) is not None + + def test_domain_violation_rejected(self, converter): + attr = Attribute( + entity="Acme Corp", entity_type="Company", + attribute="employeeId", value="E-0000", + ) + assert converter.convert_attribute(attr) is None + + def test_no_domain_allows_anything(self, converter): + attr = Attribute( + entity="Widget", entity_type="Product", + attribute="description", value="A useful widget", + ) + assert converter.convert_attribute(attr) is not None + + +# --------------------------------------------------------------------------- +# OntologySelector bypass for small ontologies (#908) +# --------------------------------------------------------------------------- + +def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0): + classes = { + f"C{i}": OntologyClass(uri=f"http://example.org/C{i}") + for i in range(n_classes) + } + obj_props = { + f"op{i}": OntologyProperty( + uri=f"http://example.org/op{i}", type="owl:ObjectProperty" + ) + for i in range(n_obj_props) + } + dt_props = { + f"dp{i}": OntologyProperty( + uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty" + ) + for i in range(n_dt_props) + } + return Ontology( + id="tiny", + metadata={"name": "Tiny"}, + classes=classes, + object_properties=obj_props, + datatype_properties=dt_props, + ) + + +def _make_loader(ontology): + loader = Mock() + loader.get_ontology.return_value = ontology + loader.get_all_ontologies.return_value = {"tiny": ontology} + return loader + + +class TestBypassSelectorBelow: + + async def test_bypass_returns_full_ontology(self): + """With 3 elements and bypass_selector_below=5, selector is bypassed.""" + ont = _make_ontology(2, 1, 0) + loader = _make_loader(ont) + embedder = Mock() + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + assert len(subsets) == 1 + assert subsets[0].ontology_id == "tiny" + assert len(subsets[0].classes) == 2 + assert len(subsets[0].object_properties) == 1 + assert subsets[0].relevance_score == 1.0 + # Embedder should never be called + embedder.embed_text.assert_not_called() + + async def test_no_bypass_when_above_threshold(self): + """With 10 elements and bypass_selector_below=5, selector runs normally.""" + ont = _make_ontology(6, 3, 1) + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 10 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Vector store was consulted (selector ran normally) + vector_store.size.assert_called_once() + + async def test_bypass_at_exact_threshold_not_triggered(self): + """With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=).""" + ont = _make_ontology(3, 1, 1) # total = 5 + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 5 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Should NOT bypass — 5 is not < 5 + vector_store.size.assert_called_once() + + async def test_bypass_zero_disables(self): + """bypass_selector_below=0 means bypass never triggers.""" + ont = _make_ontology(0, 0, 0) # empty ontology + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1]) + vector_store = Mock() + vector_store.size.return_value = 0 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=0, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # 0 is not < 0, so bypass doesn't trigger + vector_store.size.assert_called_once() diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 1d45d3f9..6a43e547 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -121,6 +121,7 @@ class Processor(FlowProcessor): # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) + self.bypass_selector_below = params.get("bypass_selector_below", 5) # Per-workspace ontology version tracking self.current_ontology_versions = {} # workspace -> version @@ -187,7 +188,8 @@ class Processor(FlowProcessor): ontology_embedder=ontology_embedder, ontology_loader=loader, top_k=self.top_k, - similarity_threshold=self.similarity_threshold + similarity_threshold=self.similarity_threshold, + bypass_selector_below=self.bypass_selector_below, ) # Store flow-specific components @@ -981,6 +983,13 @@ class Processor(FlowProcessor): default=0.3, help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' ) + parser.add_argument( + '--bypass-selector-below', + type=int, + default=5, + help='Bypass ontology selector when total ontology elements ' + '(classes + properties) is below this value (default: 5)' + ) parser.add_argument( '--triples-batch-size', type=int, diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py index 5111529a..5fd60a0f 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py @@ -33,19 +33,44 @@ class OntologySelector: def __init__(self, ontology_embedder: OntologyEmbedder, ontology_loader: OntologyLoader, top_k: int = 10, - similarity_threshold: float = 0.7): - """Initialize the ontology selector. - - Args: - ontology_embedder: Embedder with vector store - ontology_loader: Loader with ontology definitions - top_k: Number of top results to retrieve per segment - similarity_threshold: Minimum similarity score - """ + similarity_threshold: float = 0.3, + bypass_selector_below: int = 5): self.embedder = ontology_embedder self.loader = ontology_loader self.top_k = top_k self.similarity_threshold = similarity_threshold + self.bypass_selector_below = bypass_selector_below + + def _total_ontology_elements(self) -> int: + total = 0 + for ontology in self.loader.get_all_ontologies().values(): + total += len(ontology.classes) + total += len(ontology.object_properties) + total += len(ontology.datatype_properties) + return total + + def _build_full_subsets(self) -> List[OntologySubset]: + subsets = [] + for ont_id, ontology in self.loader.get_all_ontologies().items(): + subset = OntologySubset( + ontology_id=ont_id, + classes={ + cid: cls.__dict__ + for cid, cls in ontology.classes.items() + }, + object_properties={ + pid: prop.__dict__ + for pid, prop in ontology.object_properties.items() + }, + datatype_properties={ + pid: prop.__dict__ + for pid, prop in ontology.datatype_properties.items() + }, + metadata=ontology.metadata, + relevance_score=1.0, + ) + subsets.append(subset) + return subsets async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: """Select relevant ontology subsets for text segments. @@ -56,6 +81,15 @@ class OntologySelector: Returns: List of ontology subsets with relevant elements """ + total = self._total_ontology_elements() + if total < self.bypass_selector_below: + logger.info( + f"Ontology has {total} elements (below " + f"bypass_selector_below={self.bypass_selector_below}), " + f"using full ontology" + ) + return self._build_full_subsets() + # Collect all relevant elements relevant_elements = await self._find_relevant_elements(segments) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py index 06fff4f4..d9e6c837 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py @@ -6,7 +6,7 @@ with full URIs and correct is_uri flags. """ import logging -from typing import List, Optional +from typing import List, Optional, Set from .... schema import Triple, Term, IRI, LITERAL from .... rdf import RDF_TYPE, RDF_LABEL @@ -32,6 +32,25 @@ class TripleConverter: self.ontology_id = ontology_id self.entity_registry = EntityRegistry(ontology_id) + def _get_ancestor_classes(self, class_id: str) -> Set[str]: + ancestors = set() + current = class_id + while current: + cls_def = self.ontology_subset.classes.get(current) + if not cls_def: + break + parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None) + if not parent or parent in ancestors: + break + ancestors.add(parent) + current = parent + return ancestors + + def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool: + if actual_type == expected_type: + return True + return expected_type in self._get_ancestor_classes(actual_type) + def convert_all(self, extraction: ExtractionResult) -> List[Triple]: """Convert complete extraction result to RDF triples. @@ -129,6 +148,29 @@ class TripleConverter: logger.warning(f"Unknown relationship '{relationship.relation}', skipping") return None + # Enforce domain/range constraints when declared + prop_def = self.ontology_subset.object_properties.get( + relationship.relation, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None) + + if domain and not self._matches_class_constraint(relationship.subject_type, domain): + logger.warning( + f"Domain violation: '{relationship.relation}' expects " + f"domain '{domain}', got subject type " + f"'{relationship.subject_type}', skipping" + ) + return None + + if range_ and not self._matches_class_constraint(relationship.object_type, range_): + logger.warning( + f"Range violation: '{relationship.relation}' expects " + f"range '{range_}', got object type " + f"'{relationship.object_type}', skipping" + ) + return None + # Generate triple: subject property object return Triple( s=Term(type=IRI, iri=subject_uri), @@ -157,11 +199,25 @@ class TripleConverter: logger.warning(f"Unknown attribute '{attribute.attribute}', skipping") return None + # Enforce domain constraint when declared + prop_def = self.ontology_subset.datatype_properties.get( + attribute.attribute, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + + if domain and not self._matches_class_constraint(attribute.entity_type, domain): + logger.warning( + f"Domain violation: attribute '{attribute.attribute}' " + f"expects domain '{domain}', got entity type " + f"'{attribute.entity_type}', skipping" + ) + return None + # Generate triple: entity property "literal value" return Triple( s=Term(type=IRI, iri=entity_uri), p=Term(type=IRI, iri=property_uri), - o=Term(type=LITERAL, value=attribute.value) # Literal! + o=Term(type=LITERAL, value=attribute.value) ) def _get_class_uri(self, class_id: str) -> Optional[str]: From 2b70a1ea8e1c39a706db534fb990d192679c56dc Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 16 May 2026 16:07:16 +0100 Subject: [PATCH 10/16] Close producers on flow stop to prevent stale non-persistent topics (#930) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flow.stop() only stopped consumers, leaving response producers connected to non-persistent Pulsar topics. After flow restart, the orphaned producers held stale broker routing state, causing response messages to never reach new consumers — manifesting as 120s timeouts on document-embeddings and similar RPC paths. Fix: Flow.stop() now explicitly stops all producers. Producer.stop() closes the underlying Pulsar producer connection rather than just setting a flag. Fixes #906 --- trustgraph-base/trustgraph/base/flow.py | 2 ++ trustgraph-base/trustgraph/base/producer.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/trustgraph-base/trustgraph/base/flow.py b/trustgraph-base/trustgraph/base/flow.py index 0f42bbe2..3b928d3e 100644 --- a/trustgraph-base/trustgraph/base/flow.py +++ b/trustgraph-base/trustgraph/base/flow.py @@ -34,6 +34,8 @@ class Flow: async def stop(self): for c in self.consumer.values(): await c.stop() + for p in self.producer.values(): + await p.stop() if self.librarian: await self.librarian.stop() diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py index 20b4b0d6..9af9d22e 100644 --- a/trustgraph-base/trustgraph/base/producer.py +++ b/trustgraph-base/trustgraph/base/producer.py @@ -34,6 +34,9 @@ class Producer: async def stop(self): self.running = False + if self.producer: + self.producer.close() + self.producer = None async def send(self, msg, properties={}): From ab83c81d8a06470745e3ac836a601706ea2842c6 Mon Sep 17 00:00:00 2001 From: Mister Lobster Date: Mon, 18 May 2026 04:43:59 -0400 Subject: [PATCH 11/16] fix(gateway): propagate --timeout flag to per-service dispatchers (#931) The api-gateway accepts a --timeout flag (default 600s) but the value was not propagated into DispatcherManager, which hard-coded timeout=120 for every per-service dispatcher (graph-rag, document-rag, text-completion, embeddings, librarian, etc.). This meant any synchronous request taking more than 120 seconds would always return a Timeout error at the 120s mark, regardless of the --timeout value set on the gateway. Changes: - Add timeout parameter to DispatcherManager.__init__ (default: 120 for backward compatibility) - Store self.timeout in DispatcherManager - Replace both hardcoded timeout=120 with self.timeout in invoke_global_service and invoke_flow_service - Pass self.timeout from Api to DispatcherManager in service.py - Document the timeout parameter in the docstring Fixes #894 --- .../trustgraph/gateway/dispatch/manager.py | 14 +++++++++++--- trustgraph-flow/trustgraph/gateway/service.py | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 51161f9b..bddb009d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -135,13 +135,19 @@ class DispatcherWrapper: class DispatcherManager: def __init__(self, backend, config_receiver, auth, - prefix="api-gateway", queue_overrides=None): + prefix="api-gateway", queue_overrides=None, timeout=120): """ ``auth`` is required. It flows into the Mux for first-frame WebSocket authentication and into downstream dispatcher construction. There is no permissive default — constructing a DispatcherManager without an authenticator would be a silent downgrade to no-auth on the socket path. + + ``timeout`` is the per-request timeout in seconds, propagated + to every dispatcher created by this manager. Must match the + gateway's ``--timeout`` flag so that long-running requests + are not prematurely cut off at the old hard-coded 120 s + ceiling. """ if auth is None: raise ValueError( @@ -149,6 +155,8 @@ class DispatcherManager: "is no no-auth mode" ) + self.timeout = timeout + self.backend = backend self.config_receiver = config_receiver self.config_receiver.add_handler(self) @@ -291,7 +299,7 @@ class DispatcherManager: dispatcher = global_dispatchers[kind]( backend = self.backend, - timeout = 120, + timeout = self.timeout, consumer = consumer_name, subscriber = consumer_name, request_queue = request_queue, @@ -448,7 +456,7 @@ class DispatcherManager: backend = self.backend, request_queue = qconfig["request"], response_queue = qconfig["response"], - timeout = 120, + timeout = self.timeout, consumer = f"{self.prefix}-{workspace}-{flow}-{kind}-request", subscriber = f"{self.prefix}-{workspace}-{flow}-{kind}-request", ) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 0f6a5070..fb51e1a2 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -119,6 +119,7 @@ class Api: prefix = "gateway", queue_overrides = queue_overrides, auth = self.auth, + timeout = self.timeout, ) self.endpoint_manager = EndpointManager( From da7d10e99540f4d60a18b22162919ae17b6955fa Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 18 May 2026 14:10:05 +0100 Subject: [PATCH 12/16] feat: add no-auth IAM regime as a drop-in replacement for iam-svc (#933) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `no-auth-svc`, a lightweight IAM service that permits all access unconditionally — no database, no bootstrap, no signing keys. Deploy it in place of `iam-svc` for development, demos, and single-user setups where authentication overhead is unwanted. The gateway no longer hard-codes a 401 on missing credentials. Instead it asks the IAM regime via a new `authenticate-anonymous` operation whether token-free access is allowed. This keeps the gateway regime-agnostic: `iam-svc` rejects anonymous auth (preserving existing security), while `no-auth-svc` grants it with a configurable default user and workspace. Includes a tech spec (docs/tech-specs/no-auth-regime.md) and tests that pin the safety boundary — malformed tokens never fall through to the anonymous path, and a contract test ensures the full iam-svc always rejects `authenticate-anonymous`. --- docs/tech-specs/no-auth-regime.md | 186 ++++++++++++++++++ tests/unit/test_gateway/test_auth.py | 151 +++++++++++++- tests/unit/test_iam/__init__.py | 0 .../test_iam/test_iam_rejects_anonymous.py | 44 +++++ tests/unit/test_iam/test_noauth_handler.py | 138 +++++++++++++ .../trustgraph/api/async_socket_client.py | 8 +- .../trustgraph/api/socket_client.py | 8 +- trustgraph-base/trustgraph/base/iam_client.py | 16 ++ trustgraph-flow/pyproject.toml | 1 + trustgraph-flow/trustgraph/gateway/auth.py | 24 ++- .../trustgraph/gateway/dispatch/mux.py | 11 +- .../trustgraph/iam/noauth/__init__.py | 1 + .../trustgraph/iam/noauth/__main__.py | 4 + .../trustgraph/iam/noauth/handler.py | 131 ++++++++++++ .../trustgraph/iam/noauth/service.py | 182 +++++++++++++++++ trustgraph-flow/trustgraph/iam/service/iam.py | 3 + 16 files changed, 876 insertions(+), 32 deletions(-) create mode 100644 docs/tech-specs/no-auth-regime.md create mode 100644 tests/unit/test_iam/__init__.py create mode 100644 tests/unit/test_iam/test_iam_rejects_anonymous.py create mode 100644 tests/unit/test_iam/test_noauth_handler.py create mode 100644 trustgraph-flow/trustgraph/iam/noauth/__init__.py create mode 100644 trustgraph-flow/trustgraph/iam/noauth/__main__.py create mode 100644 trustgraph-flow/trustgraph/iam/noauth/handler.py create mode 100644 trustgraph-flow/trustgraph/iam/noauth/service.py diff --git a/docs/tech-specs/no-auth-regime.md b/docs/tech-specs/no-auth-regime.md new file mode 100644 index 00000000..ae8b427f --- /dev/null +++ b/docs/tech-specs/no-auth-regime.md @@ -0,0 +1,186 @@ +--- +layout: default +title: "No-Auth IAM Regime" +parent: "Tech Specs" +--- + +# No-Auth IAM Regime + +## Overview + +A minimal IAM regime that permits all access unconditionally. +Implements the same Pulsar request/response protocol as `iam-svc` +(see [iam-contract.md](iam-contract.md)) so it is a drop-in +replacement: swap `iam-svc` for `no-auth-svc` in the deployment +and the gateway, bootstrapper, and all other components continue +to work without modification. + +Intended for development, testing, single-tenant self-hosted +deployments, and evaluation environments where authentication +overhead is unwanted. + +## Motivation + +The full IAM regime requires Cassandra tables, a bootstrap +sequence, API key management, and signing key rotation. For +many deployments this is unnecessary friction: + +- Local development and CI/CD pipelines. +- Single-user or small-team self-hosted instances. +- Evaluation and demo environments. +- Deployments behind an external authentication proxy + (e.g. OAuth2 reverse proxy, VPN-gated access). + +Today operators who want no auth must still deploy `iam-svc` and +complete the bootstrap ceremony. A purpose-built no-auth regime +eliminates that requirement entirely. + +## Design + +### Deployment + +Replace `iam-svc` with `no-auth-svc` in the processor group or +container configuration. No other services change. The no-auth +service listens on the standard IAM Pulsar topics: + +- Request: `request::iam` +- Response: `response::iam` + +### Dependencies + +None. No database, no config entries, no signing keys, no +bootstrap sequence. + +### Operation responses + +The service implements the IAM contract +([iam-contract.md](iam-contract.md)) with the following +behaviour for each operation: + +| Operation | Behaviour | +|---|---| +| `authenticate-anonymous` | Returns a default identity: `user_id="anonymous"`, `workspace="default"`, `roles=["admin"]`. This is the key operation that distinguishes no-auth from the full regime. | +| `resolve-api-key` | Accepts any token. Returns the same default identity as `authenticate-anonymous`. | +| `authorise` | Always allows. Returns `decision_allow=True`, `decision_ttl_seconds=3600`. | +| `authorise-many` | Always allows all checks. | +| `get-signing-key-public` | Returns an empty string. The gateway skips JWT validation when no key is available. | +| `bootstrap` | No-op. Returns empty admin user/key. | +| `bootstrap-status` | Returns `bootstrap_available=False`. | +| `whoami` | Returns a stub user record for the actor. | +| `login` | Returns empty JWT (not supported under no-auth). | +| `create-user`, `list-users`, `get-user`, `update-user`, `delete-user`, `disable-user`, `enable-user` | Return empty/stub responses. User management is meaningless without auth. | +| `create-workspace`, `list-workspaces`, `get-workspace`, `update-workspace`, `disable-workspace` | Return empty/stub responses. | +| `create-api-key`, `list-api-keys`, `revoke-api-key` | Return empty/stub responses. | +| `change-password`, `reset-password` | No-op. | +| `rotate-signing-key` | No-op. | +| Unknown operation | Returns an error response (same as `iam-svc`). | + +### Workspace resolution + +When `resolve-api-key` is called, the returned workspace +determines which workspace the request operates against. The +no-auth service defaults to `"default"`. + +A configurable `--default-workspace` flag allows operators to +change this without code changes. + +### Anonymous authentication + +A new `authenticate-anonymous` operation is added to the IAM +protocol. This is a small, backward-compatible addition to the +contract: + +**Gateway change** (`auth.py`): when `authenticate()` receives a +request with no `Authorization` header (or an empty bearer +token), instead of immediately returning 401, it sends an +`authenticate-anonymous` request to the IAM service. If the +regime returns a valid identity, the request proceeds. If the +regime returns an error, the gateway returns 401 as before. + +**`iam-svc` (full regime)**: returns `auth-failed` for +`authenticate-anonymous`. Behaviour is unchanged — unauthenticated +requests are rejected exactly as they are today. + +**`no-auth-svc`**: returns the default identity (`anonymous` / +`default` workspace). No token required. + +This keeps the policy decision ("is anonymous access allowed?") +in the IAM regime, not in the gateway. The gateway is a generic +enforcement point that asks and respects the answer. + +**Wire format**: uses the existing `IamRequest` / `IamResponse` +schema with `operation="authenticate-anonymous"`. No new fields +required — the response uses `resolved_user_id`, +`resolved_workspace`, and `resolved_roles`, same as +`resolve-api-key`. + +Requests that do carry a bearer token follow the existing +`resolve-api-key` / JWT paths unchanged. + +## Implementation + +### Service structure + +The service is a standard `AsyncProcessor` that consumes IAM +requests and produces IAM responses, identical in shape to the +existing `iam-svc` processor: + +``` +trustgraph-flow/ + trustgraph/ + iam/ + noauth/ + __init__.py + __main__.py + service.py # AsyncProcessor wiring + handler.py # Operation dispatch, always-allow logic +``` + +### Handler + +The handler is a single `handle(request) -> response` function +with a dispatch table. Each operation returns a pre-built +`IamResponse` with the appropriate fields set. No database +access, no crypto, no state. + +### Configuration + +| Flag | Default | Description | +|---|---|---| +| `--default-workspace` | `"default"` | Workspace returned by `resolve-api-key` | +| `--default-user-id` | `"anonymous"` | User ID returned by `resolve-api-key` | + +### Entry point + +``` +tg-no-auth-svc +``` + +Or via processor group: + +```yaml +- class: trustgraph.iam.noauth.Processor + params: + <<: *defaults + id: no-auth-svc +``` + +## Security considerations + +This regime provides **no security whatsoever**. Any caller with +network access to the API gateway has full admin access to all +workspaces. + +Operators must ensure that network-level controls (firewall, +VPN, private network) provide adequate protection when deploying +this regime. The regime is explicitly not suitable for multi- +tenant or internet-facing deployments. + +## Testing + +- Unit: verify each operation returns the expected stub response. +- Integration: deploy `no-auth-svc` in place of `iam-svc`, confirm + the gateway starts, accepts requests with a dummy bearer token, + and routes them to the default workspace. +- E2E: run the standard e2e test suite with `no-auth-svc` to + confirm no regressions. diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index 26e93fd9..8ffcafa1 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -165,22 +165,37 @@ class TestIamAuthDispatch: by shape of the bearer.""" @pytest.mark.asyncio - async def test_no_authorization_header_raises_401(self): + async def test_no_authorization_header_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request(None)) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) @pytest.mark.asyncio - async def test_non_bearer_header_raises_401(self): + async def test_non_bearer_header_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request("Basic whatever")) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Basic whatever")) @pytest.mark.asyncio - async def test_empty_bearer_raises_401(self): + async def test_empty_bearer_tries_anonymous(self): auth = IamAuth(backend=Mock()) - with pytest.raises(web.HTTPUnauthorized): - await auth.authenticate(make_request("Bearer ")) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer ")) @pytest.mark.asyncio async def test_unknown_format_raises_401(self): @@ -445,3 +460,121 @@ class TestAuthorise: # Different resource → different cache key → two IAM calls. assert calls["n"] == 2 + + +# -- Anonymous authentication boundary ------------------------------------ + + +class TestAnonymousAuthBoundary: + """The gateway must only attempt anonymous auth when no credential + is presented. A malformed token must NOT fall through to the + anonymous path — that would let an attacker bypass a broken token + by simply sending garbage.""" + + @pytest.mark.asyncio + async def test_no_header_attempts_anonymous(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("anon", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate(make_request(None)) + assert ident.handle == "anon" + assert ident.source == "anonymous" + + @pytest.mark.asyncio + async def test_empty_bearer_attempts_anonymous(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("anon", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate(make_request("Bearer ")) + assert ident.handle == "anon" + assert ident.source == "anonymous" + + @pytest.mark.asyncio + async def test_malformed_token_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + called = {"anonymous": False} + + original = auth._authenticate_anonymous + + async def spy_anonymous(): + called["anonymous"] = True + return await original() + + auth._authenticate_anonymous = spy_anonymous + + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer garbage")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_bad_api_key_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + called = {"anonymous": False} + + async def spy_anonymous(): + called["anonymous"] = True + + auth._authenticate_anonymous = spy_anonymous + + async def fake_with_client(op): + raise RuntimeError("auth-failed: unknown key") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer tg_bad")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_bad_jwt_does_not_fall_through_to_anonymous(self): + auth = IamAuth(backend=Mock()) + auth._signing_public_pem = "not-a-real-pem" + called = {"anonymous": False} + + async def spy_anonymous(): + called["anonymous"] = True + + auth._authenticate_anonymous = spy_anonymous + + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer a.b.c")) + assert not called["anonymous"] + + @pytest.mark.asyncio + async def test_anonymous_rejected_by_iam_raises_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: anonymous access not permitted") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) + + @pytest.mark.asyncio + async def test_anonymous_with_empty_user_id_raises_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + return await op(Mock( + authenticate_anonymous=AsyncMock( + return_value=("", "default", []), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) diff --git a/tests/unit/test_iam/__init__.py b/tests/unit/test_iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_iam/test_iam_rejects_anonymous.py b/tests/unit/test_iam/test_iam_rejects_anonymous.py new file mode 100644 index 00000000..492b570a --- /dev/null +++ b/tests/unit/test_iam/test_iam_rejects_anonymous.py @@ -0,0 +1,44 @@ +""" +Contract test: the full iam-svc MUST reject authenticate-anonymous. + +This is a safety pin — if someone accidentally adds anonymous access +to the production IAM handler, this test catches it. +""" + +import asyncio +from unittest.mock import Mock, AsyncMock + +import pytest + +from trustgraph.iam.service.iam import IamService + + +def _make_request(**kwargs): + req = Mock() + for k, v in kwargs.items(): + setattr(req, k, v) + return req + + +class TestIamRejectsAnonymous: + + @pytest.fixture + def handler(self): + svc = object.__new__(IamService) + svc.table_store = Mock(spec=[]) + svc.bootstrap_mode = "token" + svc.bootstrap_token = "tok" + svc._on_workspace_created = None + svc._on_workspace_deleted = None + svc._signing_key = None + svc._signing_key_lock = asyncio.Lock() + return svc + + @pytest.mark.asyncio + async def test_authenticate_anonymous_returns_auth_failed(self, handler): + resp = await handler.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.error is not None + assert resp.error.type == "auth-failed" + assert "anonymous" in resp.error.message.lower() diff --git a/tests/unit/test_iam/test_noauth_handler.py b/tests/unit/test_iam/test_noauth_handler.py new file mode 100644 index 00000000..38461b62 --- /dev/null +++ b/tests/unit/test_iam/test_noauth_handler.py @@ -0,0 +1,138 @@ +""" +Tests for the no-auth IAM handler. + +Verifies that NoAuthHandler returns the expected permissive responses +and that the always-allow authorise path returns the correct shape. +""" + +import json +from unittest.mock import Mock + +import pytest + +from trustgraph.iam.noauth.handler import NoAuthHandler + + +def _make_request(**kwargs): + req = Mock() + for k, v in kwargs.items(): + setattr(req, k, v) + return req + + +class TestAuthenticateAnonymous: + + @pytest.mark.asyncio + async def test_returns_default_identity(self): + h = NoAuthHandler( + default_user_id="anon", default_workspace="ws", + ) + resp = await h.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.error is None + assert resp.resolved_user_id == "anon" + assert resp.resolved_workspace == "ws" + assert "admin" in list(resp.resolved_roles) + + @pytest.mark.asyncio + async def test_custom_defaults_propagate(self): + h = NoAuthHandler( + default_user_id="dev-user", default_workspace="dev-ws", + ) + resp = await h.handle( + _make_request(operation="authenticate-anonymous") + ) + assert resp.resolved_user_id == "dev-user" + assert resp.resolved_workspace == "dev-ws" + + +class TestResolveApiKey: + + @pytest.mark.asyncio + async def test_any_key_resolves_to_default_identity(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request(operation="resolve-api-key", api_key="tg_bogus") + ) + assert resp.error is None + assert resp.resolved_user_id == "anonymous" + assert resp.resolved_workspace == "default" + + +class TestAuthorise: + + @pytest.mark.asyncio + async def test_always_allows(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request( + operation="authorise", + user_id="anyone", + capability="anything", + resource_json="{}", + parameters_json="{}", + ) + ) + assert resp.error is None + assert resp.decision_allow is True + assert resp.decision_ttl_seconds > 0 + + @pytest.mark.asyncio + async def test_authorise_many_returns_matching_count(self): + h = NoAuthHandler() + checks = [ + {"capability": "a", "resource": {}, "parameters": {}}, + {"capability": "b", "resource": {}, "parameters": {}}, + {"capability": "c", "resource": {}, "parameters": {}}, + ] + resp = await h.handle( + _make_request( + operation="authorise-many", + user_id="u", + authorise_checks=json.dumps(checks), + ) + ) + assert resp.error is None + decisions = json.loads(resp.decisions_json) + assert len(decisions) == 3 + assert all(d["allow"] is True for d in decisions) + + +class TestCreateWorkspaceCallback: + + @pytest.mark.asyncio + async def test_create_workspace_calls_callback(self): + called_with = [] + + async def on_created(ws_id): + called_with.append(ws_id) + + h = NoAuthHandler(on_workspace_created=on_created) + req = _make_request(operation="create-workspace") + req.workspace_record = Mock() + req.workspace_record.id = "test-ws" + resp = await h.handle(req) + assert resp.error is None + assert called_with == ["test-ws"] + + @pytest.mark.asyncio + async def test_create_workspace_without_callback_still_succeeds(self): + h = NoAuthHandler() + req = _make_request(operation="create-workspace") + req.workspace_record = Mock() + req.workspace_record.id = "test-ws" + resp = await h.handle(req) + assert resp.error is None + + +class TestUnknownOperation: + + @pytest.mark.asyncio + async def test_unknown_op_returns_error(self): + h = NoAuthHandler() + resp = await h.handle( + _make_request(operation="not-a-real-op") + ) + assert resp.error is not None + assert resp.error.type == "invalid-argument" diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index ca9146b9..d18bee34 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -62,12 +62,6 @@ class AsyncSocketClient: if self._connected: return - if not self.token: - raise ProtocolException( - "AsyncSocketClient requires a token for first-frame " - "auth against /api/v1/socket" - ) - ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout @@ -79,7 +73,7 @@ class AsyncSocketClient: # reader task so the response isn't consumed by the reader's # id-based routing. await self._socket.send(json.dumps({ - "type": "auth", "token": self.token, + "type": "auth", "token": self.token or "", })) try: raw = await asyncio.wait_for( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 75a7be9a..9874c8af 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -137,12 +137,6 @@ class SocketClient: if self._connected: return - if not self.token: - raise ProtocolException( - "SocketClient requires a token for first-frame auth " - "against /api/v1/socket" - ) - ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout @@ -153,7 +147,7 @@ class SocketClient: # auth-ok / auth-failed response isn't consumed by the reader # loop's id-based routing. await self._socket.send(json.dumps({ - "type": "auth", "token": self.token, + "type": "auth", "token": self.token or "", })) try: raw = await asyncio.wait_for( diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index 4be59de1..e0457d19 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -62,6 +62,22 @@ class IamClient(RequestResponse): ) return resp.user + async def authenticate_anonymous(self, timeout=IAM_TIMEOUT): + """Request anonymous access from the IAM regime. + + Returns ``(user_id, workspace, roles)`` if the regime permits + anonymous access, or raises ``RuntimeError`` with error type + ``auth-failed`` if it does not.""" + resp = await self._request( + operation="authenticate-anonymous", + timeout=timeout, + ) + return ( + resp.resolved_user_id, + resp.resolved_workspace, + list(resp.resolved_roles), + ) + async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): """Resolve a plaintext API key to its identity triple. diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index d8c690b5..8488a0a7 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -64,6 +64,7 @@ bootstrap = "trustgraph.bootstrap.bootstrapper:run" config-svc = "trustgraph.config.service:run" flow-svc = "trustgraph.flow.service:run" iam-svc = "trustgraph.iam.service:run" +no-auth-svc = "trustgraph.iam.noauth:run" doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index 1309ecfc..273fcb5a 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -233,10 +233,10 @@ class IamAuth: header = request.headers.get("Authorization", "") if not header.startswith("Bearer "): - raise _auth_failure() + return await self._authenticate_anonymous() token = header[len("Bearer "):].strip() if not token: - raise _auth_failure() + return await self._authenticate_anonymous() # API keys always start with "tg_". JWTs have two dots and # no "tg_" prefix. Discriminate cheaply. @@ -266,6 +266,26 @@ class IamAuth: handle=sub, workspace=ws, principal_id=sub, source="jwt", ) + async def _authenticate_anonymous(self): + try: + async def _call(client): + return await client.authenticate_anonymous() + user_id, workspace, _roles = await self._with_client(_call) + except Exception as e: + logger.debug( + f"Anonymous authentication rejected: " + f"{type(e).__name__}: {e}" + ) + raise _auth_failure() + + if not user_id or not workspace: + raise _auth_failure() + + return Identity( + handle=user_id, workspace=workspace, + principal_id=user_id, source="anonymous", + ) + async def _resolve_api_key(self, plaintext): h = hashlib.sha256(plaintext.encode("utf-8")).hexdigest() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 02c0eed2..bdbd18d8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -57,16 +57,13 @@ class Mux: (important for browsers, which treat a handshake-time 401 as terminal).""" token = data.get("token", "") - if not token: - await self.ws.send_json({ - "type": "auth-failed", - "error": "auth failure", - }) - return class _Shim: def __init__(self, tok): - self.headers = {"Authorization": f"Bearer {tok}"} + self.headers = ( + {"Authorization": f"Bearer {tok}"} if tok + else {} + ) try: identity = await self.auth.authenticate(_Shim(token)) diff --git a/trustgraph-flow/trustgraph/iam/noauth/__init__.py b/trustgraph-flow/trustgraph/iam/noauth/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/iam/noauth/__main__.py b/trustgraph-flow/trustgraph/iam/noauth/__main__.py new file mode 100644 index 00000000..a731dd63 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/__main__.py @@ -0,0 +1,4 @@ + +from . service import run + +run() diff --git a/trustgraph-flow/trustgraph/iam/noauth/handler.py b/trustgraph-flow/trustgraph/iam/noauth/handler.py new file mode 100644 index 00000000..d457697e --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/handler.py @@ -0,0 +1,131 @@ +""" +No-auth IAM handler. Implements the IAM contract with every operation +returning a permissive or stub response. No database, no crypto, +no state. +""" + +import json +import logging + +from trustgraph.schema import IamResponse, Error, UserRecord + +logger = logging.getLogger(__name__) + + +def _err(type, message): + return IamResponse(error=Error(type=type, message=message)) + + +class NoAuthHandler: + + def __init__(self, default_user_id="anonymous", + default_workspace="default", + on_workspace_created=None): + self.default_user_id = default_user_id + self.default_workspace = default_workspace + self._on_workspace_created = on_workspace_created + + def _default_identity_response(self): + return IamResponse( + resolved_user_id=self.default_user_id, + resolved_workspace=self.default_workspace, + resolved_roles=["admin"], + ) + + def _default_user_record(self): + return UserRecord( + id=self.default_user_id, + workspace=self.default_workspace, + username=self.default_user_id, + name="Anonymous User", + roles=["admin"], + enabled=True, + ) + + async def handle(self, v): + op = v.operation + + try: + if op == "authenticate-anonymous": + return self._default_identity_response() + + if op == "resolve-api-key": + return self._default_identity_response() + + if op == "authorise": + return IamResponse( + decision_allow=True, + decision_ttl_seconds=3600, + ) + + if op == "authorise-many": + checks = json.loads(v.authorise_checks or "[]") + decisions = [ + {"allow": True, "ttl": 3600} + for _ in checks + ] + return IamResponse( + decisions_json=json.dumps(decisions), + ) + + if op == "get-signing-key-public": + return IamResponse(signing_key_public="") + + if op == "bootstrap": + return IamResponse() + + if op == "bootstrap-status": + return IamResponse(bootstrap_available=False) + + if op == "whoami": + return IamResponse(user=self._default_user_record()) + + if op == "login": + return IamResponse() + + if op in ( + "create-user", "get-user", "update-user", + "disable-user", "enable-user", + ): + return IamResponse(user=self._default_user_record()) + + if op == "list-users": + return IamResponse(users=[self._default_user_record()]) + + if op == "delete-user": + return IamResponse() + + if op == "create-workspace": + if self._on_workspace_created and v.workspace_record: + await self._on_workspace_created(v.workspace_record.id) + return IamResponse() + + if op in ( + "get-workspace", "update-workspace", + "disable-workspace", + ): + return IamResponse() + + if op == "list-workspaces": + return IamResponse() + + if op in ("create-api-key", "list-api-keys", "revoke-api-key"): + return IamResponse() + + if op in ("change-password", "reset-password"): + return IamResponse() + + if op == "rotate-signing-key": + return IamResponse() + + return _err( + "invalid-argument", + f"unknown operation: {op!r}", + ) + + except Exception as e: + logger.error( + f"no-auth {op} failed: {type(e).__name__}: {e}", + exc_info=True, + ) + return _err("internal-error", str(e)) diff --git a/trustgraph-flow/trustgraph/iam/noauth/service.py b/trustgraph-flow/trustgraph/iam/noauth/service.py new file mode 100644 index 00000000..76d13a3c --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/noauth/service.py @@ -0,0 +1,182 @@ +""" +No-auth IAM service. Drop-in replacement for iam-svc that permits +all access unconditionally. No database, no bootstrap, no signing keys. +""" + +import logging +import uuid + +from trustgraph.schema import Error +from trustgraph.schema import IamRequest, IamResponse +from trustgraph.schema import iam_request_queue, iam_response_queue +from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigValue +from trustgraph.schema import config_request_queue, config_response_queue + +from trustgraph.base import AsyncProcessor, Consumer, Producer +from trustgraph.base import ConsumerMetrics, ProducerMetrics +from trustgraph.base.metrics import SubscriberMetrics +from trustgraph.base.request_response_spec import RequestResponse + +from . handler import NoAuthHandler + +logger = logging.getLogger(__name__) + +default_ident = "no-auth-svc" + +default_iam_request_queue = iam_request_queue +default_iam_response_queue = iam_response_queue + + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + iam_req_q = params.get( + "iam_request_queue", default_iam_request_queue, + ) + iam_resp_q = params.get( + "iam_response_queue", default_iam_response_queue, + ) + + default_user_id = params.get("default_user_id", "anonymous") + default_workspace = params.get("default_workspace", "default") + + super().__init__(**params) + + iam_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="iam-request", + ) + iam_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="iam-response", + ) + + self.iam_request_topic = iam_req_q + + self.iam_request_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=iam_req_q, + subscriber=self.id, + schema=IamRequest, + handler=self.on_iam_request, + metrics=iam_request_metrics, + ) + + self.iam_response_producer = Producer( + backend=self.pubsub, + topic=iam_resp_q, + schema=IamResponse, + metrics=iam_response_metrics, + ) + + self.handler = NoAuthHandler( + default_user_id=default_user_id, + default_workspace=default_workspace, + on_workspace_created=self._ensure_workspace_registered, + ) + + logger.info( + f"No-auth IAM service initialised " + f"(user={default_user_id}, workspace={default_workspace})" + ) + + async def start(self): + await self.pubsub.ensure_topic(self.iam_request_topic) + await self.iam_request_consumer.start() + + def _create_config_client(self): + config_rr_id = str(uuid.uuid4()) + config_req_metrics = ProducerMetrics( + processor=self.id, flow=None, name="config-request", + ) + config_resp_metrics = SubscriberMetrics( + processor=self.id, flow=None, name="config-response", + ) + return RequestResponse( + backend=self.pubsub, + subscription=f"{self.id}--config--{config_rr_id}", + consumer_name=self.id, + request_topic=config_request_queue, + request_schema=ConfigRequest, + request_metrics=config_req_metrics, + response_topic=config_response_queue, + response_schema=ConfigResponse, + response_metrics=config_resp_metrics, + ) + + async def _ensure_workspace_registered(self, workspace_id): + client = self._create_config_client() + try: + await client.start() + await client.request( + ConfigRequest( + operation="put", + workspace="__workspaces__", + values=[ConfigValue( + type="workspace", key=workspace_id, + value='{"enabled": true}', + )], + ), + timeout=10, + ) + finally: + await client.stop() + logger.info( + f"Registered workspace in config: {workspace_id}" + ) + + async def on_iam_request(self, msg, consumer, flow): + + id = None + try: + v = msg.value() + id = msg.properties()["id"] + logger.debug( + f"Handling IAM request {id} op={v.operation!r}" + ) + resp = await self.handler.handle(v) + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + except Exception as e: + logger.error( + f"IAM request failed: {type(e).__name__}: {e}", + exc_info=True, + ) + resp = IamResponse( + error=Error(type="internal-error", message=str(e)), + ) + if id is not None: + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + + @staticmethod + def add_args(parser): + AsyncProcessor.add_args(parser) + + parser.add_argument( + "--iam-request-queue", + default=default_iam_request_queue, + help=f"IAM request queue (default: {default_iam_request_queue})", + ) + parser.add_argument( + "--iam-response-queue", + default=default_iam_response_queue, + help=f"IAM response queue (default: {default_iam_response_queue})", + ) + parser.add_argument( + "--default-user-id", + default="anonymous", + help="User ID for all requests (default: anonymous)", + ) + parser.add_argument( + "--default-workspace", + default="default", + help="Workspace for all requests (default: default)", + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 755a1c5d..7beaf5ed 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -287,6 +287,9 @@ class IamService: op = v.operation try: + if op == "authenticate-anonymous": + return _err("auth-failed", "anonymous access not permitted") + if op == "bootstrap": return await self.handle_bootstrap(v) if op == "bootstrap-status": From 76e3358ed31f559efb186b0967a512fe25b0a7d8 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 18 May 2026 14:19:19 +0100 Subject: [PATCH 13/16] fix: guard against empty query in SPARQL generator (#934) Split the query once and check the parts list before indexing, preventing an IndexError if the LLM returns an empty or whitespace-only string. Fixes #870. --- .../trustgraph/query/ontology/sparql_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py index 44c7e0a1..97fc5f4d 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py @@ -202,11 +202,14 @@ ASK {{ if response and isinstance(response, dict): query = response.get('query', '').strip() - if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')): + parts = query.split() + if parts and parts[0].upper() in ( + 'SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE', + ): return SPARQLQuery( query=query, variables=self._extract_variables(query), - query_type=query.split()[0].upper(), + query_type=parts[0].upper(), explanation=response.get('explanation', 'Generated by LLM'), complexity_score=self._calculate_complexity(query) ) From 29d3100c46a971c93ac979e2c51c0a77bb4176db Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 18 May 2026 22:08:12 +0100 Subject: [PATCH 14/16] fix: IAM bootstrap atomicity and bootstrapper startup ordering (#935) IAM auto-bootstrap could get permanently stuck in a half-done state: _seed_tables wrote the workspace first, so any_workspace_exists() returned true on restart even when user/key/signing-key creation had failed. Remove workspace creation from _seed_tables (WorkspaceInit handles it) and use any_signing_key_exists() as the completion check since the signing key is the last thing written. Run pre-service initialisers (PulsarTopology) in start() before opening pub/sub connections, breaking the chicken-and-egg where the bootstrapper needed Pulsar namespaces that it was responsible for creating. Guard against empty cluster list when broker isn't ready. --- .../bootstrap/bootstrapper/service.py | 74 ++++++++++++++----- .../bootstrap/initialisers/pulsar_topology.py | 4 + trustgraph-flow/trustgraph/iam/service/iam.py | 27 +++---- trustgraph-flow/trustgraph/tables/iam.py | 4 + 4 files changed, 74 insertions(+), 35 deletions(-) diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py index 3c658fe3..81b7e98d 100644 --- a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py @@ -326,6 +326,58 @@ class Processor(AsyncProcessor): # Main loop. # ------------------------------------------------------------------ + async def _run_pre_service(self): + """Run pre-service initialisers before opening pub/sub clients. + + These bring up infrastructure that other services depend on + (e.g. Pulsar tenant/namespaces). They use out-of-band APIs + (HTTP admin), not pub/sub, so they don't need a config client. + They run without flag tracking — they must be idempotent. + """ + pre_specs = [ + s for s in self.specs + if not s.instance.wait_for_services + ] + if not pre_specs: + return + + for spec in pre_specs: + child_logger = logger.getChild(spec.name) + child_ctx = InitContext( + logger=child_logger, + config=None, + make_flow_client=self._make_flow_client, + make_iam_client=self._make_iam_client, + ) + child_logger.info(f"Running pre-service initialiser") + try: + await spec.instance.run(child_ctx, None, spec.flag) + child_logger.info(f"Pre-service initialiser completed") + except Exception as e: + child_logger.error( + f"Pre-service initialiser failed: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + raise + + async def start(self): + # Run pre-service initialisers before opening any pub/sub + # connections. They bring up infrastructure (Pulsar + # namespaces, etc.) that super().start() depends on. + while self.running: + try: + await self._run_pre_service() + break + except Exception as e: + logger.info( + f"Pre-service initialisation failed " + f"({type(e).__name__}: {e}); retry in {GATE_BACKOFF}s" + ) + await asyncio.sleep(GATE_BACKOFF) + + await super().start() + async def run(self): logger.info( @@ -347,29 +399,18 @@ class Processor(AsyncProcessor): continue try: - # Phase 1: pre-service initialisers run unconditionally. - pre_specs = [ - s for s in self.specs - if not s.instance.wait_for_services - ] - pre_results = {} - for spec in pre_specs: - pre_results[spec.name] = await self._run_spec( - spec, config, - ) - - # Phase 2: gate. + # Phase 1: gate. gate_ok = await self._gate_ready(config) - # Phase 3: post-service initialisers, if gate passed. - post_results = {} + # Phase 2: post-service initialisers, if gate passed. + results = {} if gate_ok: post_specs = [ s for s in self.specs if s.instance.wait_for_services ] for spec in post_specs: - post_results[spec.name] = await self._run_spec( + results[spec.name] = await self._run_spec( spec, config, ) @@ -377,8 +418,7 @@ class Processor(AsyncProcessor): if not gate_ok: sleep_for = GATE_BACKOFF else: - all_results = {**pre_results, **post_results} - if any(r != "skip" for r in all_results.values()): + if any(r != "skip" for r in results.values()): sleep_for = INIT_RETRY else: sleep_for = STEADY_INTERVAL diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py index 843fe056..1e4805de 100644 --- a/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/pulsar_topology.py @@ -112,6 +112,10 @@ class PulsarTopology(Initialiser): def _reconcile_sync(self, logger): if not self._tenant_exists(): clusters = self._get_clusters() + if not clusters: + raise RuntimeError( + "Pulsar cluster list is empty — broker not ready yet" + ) logger.info( f"Creating tenant {self.tenant!r} with clusters {clusters}" ) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 7beaf5ed..0335012e 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -397,8 +397,8 @@ class IamService: async def auto_bootstrap_if_token_mode(self): """Called from the service processor at startup. In - ``token`` mode, if tables are empty, seeds the default - workspace / admin / signing key using the operator-provided + ``token`` mode, if tables are empty, seeds the admin user, + API key, and signing key using the operator-provided bootstrap token. The admin's API key plaintext is *the* ``bootstrap_token`` — the operator already knows it, nothing needs to be returned or logged. @@ -408,7 +408,7 @@ class IamService: if self.bootstrap_mode != "token": return - if await self.table_store.any_workspace_exists(): + if await self.table_store.any_signing_key_exists(): logger.info( "IAM: token mode, tables already populated; skipping " "auto-bootstrap" @@ -423,22 +423,13 @@ class IamService: async def _seed_tables(self, api_key_plaintext): """Shared seeding logic used by token-mode auto-bootstrap and - bootstrap-mode handle_bootstrap. Creates the default - workspace, admin user, admin API key (using the given - plaintext), and an initial signing key. Returns the admin + bootstrap-mode handle_bootstrap. Creates the admin user, + admin API key (using the given plaintext), and an initial + signing key. The workspace is created separately by the + bootstrapper's WorkspaceInit initialiser. Returns the admin user id.""" now = _now_dt() - await self.table_store.put_workspace( - id=DEFAULT_WORKSPACE, - name="Default", - enabled=True, - created=now, - ) - - if self._on_workspace_created: - await self._on_workspace_created(DEFAULT_WORKSPACE) - admin_user_id = str(uuid.uuid4()) admin_password = secrets.token_urlsafe(32) await self.table_store.put_user( @@ -491,7 +482,7 @@ class IamService: if self.bootstrap_mode != "bootstrap": return _err("auth-failed", "auth failure") - if await self.table_store.any_workspace_exists(): + if await self.table_store.any_signing_key_exists(): return _err("auth-failed", "auth failure") plaintext = _generate_api_key() @@ -531,7 +522,7 @@ class IamService: instead of forcing callers to probe the masked-failure path.""" available = ( self.bootstrap_mode == "bootstrap" - and not await self.table_store.any_workspace_exists() + and not await self.table_store.any_signing_key_exists() ) return IamResponse(bootstrap_available=available) diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py index 8bf9c8b4..d7bf5e3d 100644 --- a/trustgraph-flow/trustgraph/tables/iam.py +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -435,3 +435,7 @@ class IamTableStore: async def any_workspace_exists(self): rows = await self.list_workspaces() return bool(rows) + + async def any_signing_key_exists(self): + rows = await self.list_signing_keys() + return bool(rows) From 47dfc30c1cc0f6c1895ebafca1200408bc526c0c Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 18 May 2026 22:08:52 +0100 Subject: [PATCH 15/16] fix: suppress Pulsar C++ client log noise (#936) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert consumer receive timeout from 100ms back to the original 2000ms. The 100ms change was based on a misunderstanding — receive() is a blocking call that returns immediately when a message arrives, so the timeout only affects how quickly a consumer checks the shutdown flag during idle periods. 100ms generated ~200 WARN lines/sec from the C++ client with no latency benefit. Also set the Pulsar C++ client logger to Error level so residual timeout warnings from the subscriber (250ms) don't produce noise. Update poll timeout test to match reverted 2000ms value --- .../test_consumer_concurrency.py | 17 ++++++++--------- trustgraph-base/trustgraph/base/consumer.py | 2 +- .../trustgraph/base/pulsar_backend.py | 4 ++++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 59c7f2b5..44d82182 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -272,23 +272,22 @@ class TestMetricsIntegration: class TestPollTimeout: @pytest.mark.asyncio - async def test_poll_timeout_is_100ms(self): - """Consumer receive timeout should be 100ms, not the original 2000ms. + async def test_poll_timeout_is_2000ms(self): + """Consumer receive timeout should be 2000ms. - A 2000ms poll timeout means every service adds up to 2s of idle - blocking between message bursts. With many sequential hops in a - query pipeline, this compounds into seconds of unnecessary latency. - 100ms keeps responsiveness high without significant CPU overhead. + receive() is a blocking call that returns immediately when a + message arrives — the timeout only governs how often the loop + checks the shutdown flag during idle periods. Lower values + (e.g. 100ms) generate excessive C++ client WARN logging with + no latency benefit. """ consumer = _make_consumer() - # Wire up a mock Pulsar consumer that records the receive kwargs mock_pulsar_consumer = MagicMock() received_kwargs = {} def capture_receive(**kwargs): received_kwargs.update(kwargs) - # Stop after one call consumer.running = False raise type('Timeout', (Exception,), {})("timeout") @@ -296,7 +295,7 @@ class TestPollTimeout: await consumer.consume_from_queue(mock_pulsar_consumer) - assert received_kwargs.get("timeout_millis") == 100 + assert received_kwargs.get("timeout_millis") == 2000 # --------------------------------------------------------------------------- diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 5c59c515..86cc4ceb 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -188,7 +188,7 @@ class Consumer: try: msg = await loop.run_in_executor( executor, - lambda: consumer.receive(timeout_millis=100), + lambda: consumer.receive(timeout_millis=2000), ) except Exception as e: # Handle timeout from any backend diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index e27d16af..dc5e4083 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -139,6 +139,10 @@ class PulsarBackend: if api_key: client_args['authentication'] = pulsar.AuthenticationToken(api_key) + client_args['logger'] = pulsar.ConsoleLogger( + _pulsar.LoggerLevel.Error + ) + self.client = pulsar.Client(**client_args) logger.info(f"Pulsar client connected to {host}") From fd6e3e1269d954cd76b5047a3ad97ff678e6e534 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 19 May 2026 13:26:39 +0100 Subject: [PATCH 16/16] fix: stop dropping messages on Pulsar flow restarts (#938) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit consumer.py called unsubscribe() on every flow stop, deleting the server-side subscription cursor. On restart, initial_position='latest' skipped any messages published during the gap — causing intermittent data loss (e.g. graph embeddings silently never reaching Qdrant). Replace unsubscribe() with close() so the cursor survives restarts. Move subscription cleanup to where it belongs: the Pulsar backend's delete_topic(), called by the flow controller on deliberate flow deletion. This was previously a no-op TODO. --- trustgraph-base/trustgraph/base/consumer.py | 26 ++-- trustgraph-base/trustgraph/base/pubsub.py | 10 ++ .../trustgraph/base/pulsar_backend.py | 136 ++++++++++++++++-- 3 files changed, 151 insertions(+), 21 deletions(-) diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 86cc4ceb..b9f2ee0b 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -76,8 +76,10 @@ class Consumer: if hasattr(self, "consumer"): if self.consumer: - self.consumer.unsubscribe() - self.consumer.close() + try: + self.consumer.close() + except Exception: + pass self.consumer = None async def stop(self): @@ -157,12 +159,14 @@ class Consumer: except Exception as e: logger.error(f"Consumer loop exception: {e}", exc_info=True) - for c in consumers: + for i, c in enumerate(consumers): try: - c.unsubscribe() c.close() - except Exception: - pass + except Exception as ce: + logger.warning( + f"Consumer {i} close failed (error path): " + f"{type(ce).__name__}: {ce}" + ) for ex in executors: ex.shutdown(wait=False) consumers = [] @@ -171,12 +175,14 @@ class Consumer: continue finally: - for c in consumers: + for i, c in enumerate(consumers): try: - c.unsubscribe() c.close() - except Exception: - pass + except Exception as ce: + logger.warning( + f"Consumer {i} close failed: " + f"{type(ce).__name__}: {ce}" + ) for ex in executors: ex.shutdown(wait=False) diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index fb4765c1..4ae8d2d0 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) # Default connection settings from environment DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) +DEFAULT_PULSAR_ADMIN_URL = os.getenv("PULSAR_ADMIN_URL", 'http://pulsar:8080') DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq') DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672')) @@ -43,6 +44,7 @@ def get_pubsub(**config: Any) -> Any: host=config.get('pulsar_host', DEFAULT_PULSAR_HOST), api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), listener=config.get('pulsar_listener'), + admin_url=config.get('pulsar_admin_url', DEFAULT_PULSAR_ADMIN_URL), ) elif backend_type == 'rabbitmq': from .rabbitmq_backend import RabbitMQBackend @@ -77,6 +79,7 @@ def get_pubsub(**config: Any) -> Any: STANDALONE_PULSAR_HOST = 'pulsar://localhost:6650' +STANDALONE_PULSAR_ADMIN_URL = 'http://localhost:8080' def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: @@ -88,6 +91,7 @@ def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: that run outside containers) """ pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST + pulsar_admin_url = STANDALONE_PULSAR_ADMIN_URL if standalone else DEFAULT_PULSAR_ADMIN_URL pulsar_listener = 'localhost' if standalone else None rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST kafka_bootstrap = 'localhost:9092' if standalone else DEFAULT_KAFKA_BOOTSTRAP @@ -105,6 +109,12 @@ def add_pubsub_args(parser: ArgumentParser, standalone: bool = False) -> None: help=f'Pulsar host (default: {pulsar_host})', ) + parser.add_argument( + '--pulsar-admin-url', + default=pulsar_admin_url, + help=f'Pulsar admin REST API URL (default: {pulsar_admin_url})', + ) + parser.add_argument( '--pulsar-api-key', default=DEFAULT_PULSAR_API_KEY, diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index dc5e4083..e85dfbef 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -7,8 +7,12 @@ handling topic mapping, serialization, and Pulsar client management. import pulsar import _pulsar +import asyncio import json import logging +import urllib.request +import urllib.error +import urllib.parse from typing import Any from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message @@ -117,7 +121,10 @@ class PulsarBackend: producers and consumers. """ - def __init__(self, host: str, api_key: str = None, listener: str = None): + def __init__( + self, host: str, api_key: str = None, listener: str = None, + admin_url: str = None, + ): """ Initialize Pulsar backend. @@ -125,10 +132,12 @@ class PulsarBackend: host: Pulsar broker URL (e.g., pulsar://localhost:6650) api_key: Optional API key for authentication listener: Optional listener name for multi-homed setups + admin_url: Pulsar admin REST API URL (e.g., http://pulsar:8080) """ self.host = host self.api_key = api_key self.listener = listener + self.admin_url = admin_url # Create Pulsar client client_args = {'service_url': host} @@ -270,24 +279,129 @@ class PulsarBackend: return PulsarBackendConsumer(pulsar_consumer, schema) + def _admin_api_path(self, pulsar_uri: str) -> str: + """ + Convert a Pulsar topic URI to an admin REST API path. + + persistent://tg/flow/triples-store:default:explain-flow + -> /admin/v2/persistent/tg/flow/triples-store%3Adefault%3Aexplain-flow + """ + scheme, rest = pulsar_uri.split('://', 1) + tenant, namespace, topic = rest.split('/', 2) + encoded_topic = urllib.parse.quote(topic, safe='') + return f"/admin/v2/{scheme}/{tenant}/{namespace}/{encoded_topic}" + + def _admin_request(self, method, path): + """ + Make a synchronous admin REST API request. + + Returns parsed JSON for GET, None for DELETE/PUT. + Raises urllib.error.HTTPError for non-404 errors. + 404 is treated as success (idempotent deletion). + """ + url = f"{self.admin_url}{path}" + req = urllib.request.Request(url, method=method) + + try: + with urllib.request.urlopen(req) as resp: + if method == 'GET': + return json.loads(resp.read().decode('utf-8')) + return None + except urllib.error.HTTPError as e: + if e.code == 404: + return None + raise + + def _delete_topic_sync(self, topic: str): + """ + Delete a persistent topic and all its subscriptions. + + Subscriptions must be removed first — Pulsar rejects topic + deletion while subscriptions exist. Force-deletes each + subscription to disconnect any lingering consumers. + """ + pulsar_uri = self.map_topic(topic) + + if pulsar_uri.startswith('non-persistent://'): + return + + api_path = self._admin_api_path(pulsar_uri) + + try: + subs = self._admin_request('GET', f"{api_path}/subscriptions") + except Exception as e: + logger.warning(f"Failed to list subscriptions for {topic}: {e}") + return + + if subs: + for sub in subs: + encoded_sub = urllib.parse.quote(sub, safe='') + try: + self._admin_request( + 'DELETE', + f"{api_path}/subscription/{encoded_sub}" + f"?force=true" + ) + logger.info( + f"Deleted subscription {sub} from {topic}" + ) + except Exception as e: + logger.warning( + f"Failed to delete subscription {sub} " + f"from {topic}: {e}" + ) + + try: + self._admin_request('DELETE', api_path) + logger.info(f"Deleted topic: {topic}") + except Exception as e: + logger.warning(f"Failed to delete topic {topic}: {e}") + + def _topic_exists_sync(self, topic: str) -> bool: + """Check topic existence via admin API.""" + pulsar_uri = self.map_topic(topic) + + if pulsar_uri.startswith('non-persistent://'): + return False + + api_path = self._admin_api_path(pulsar_uri) + + try: + result = self._admin_request('GET', f"{api_path}/stats") + return result is not None + except Exception: + return False + async def create_topic(self, topic: str) -> None: - """No-op — Pulsar auto-creates topics on first use. - TODO: Use admin REST API for explicit persistent topic creation.""" + """No-op — Pulsar auto-creates topics on first use.""" pass async def delete_topic(self, topic: str) -> None: - """No-op — to be replaced with admin REST API calls. - TODO: Delete persistent topic via admin API.""" - pass + """ + Delete a persistent topic and all its subscriptions via + the admin REST API. + + Called by the flow controller during deliberate flow deletion. + Non-persistent topics are skipped. Idempotent. + """ + if not self.admin_url: + logger.warning( + f"Cannot delete topic {topic}: " + f"no admin URL configured" + ) + return + + await asyncio.to_thread(self._delete_topic_sync, topic) async def topic_exists(self, topic: str) -> bool: - """Returns True — Pulsar auto-creates on subscribe. - TODO: Use admin REST API for actual existence check.""" - return True + """Check whether a persistent topic exists via the admin API.""" + if not self.admin_url: + return True + + return await asyncio.to_thread(self._topic_exists_sync, topic) async def ensure_topic(self, topic: str) -> None: - """No-op — Pulsar auto-creates topics on first use. - TODO: Use admin REST API for explicit creation.""" + """No-op — Pulsar auto-creates topics on first use.""" pass def close(self) -> None: