From f0ad282708e876624aa76fde188bb3ca5e7bb3df Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 10:30:21 +0100 Subject: [PATCH] CLI auth migration, document embeddings core lifecycle (#913) Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows. --- trustgraph-base/trustgraph/api/knowledge.py | 31 + .../trustgraph/api/socket_client.py | 52 ++ .../messaging/translators/knowledge.py | 56 +- .../trustgraph/schema/knowledge/knowledge.py | 6 +- trustgraph-cli/pyproject.toml | 2 + trustgraph-cli/trustgraph/cli/get_de_core.py | 111 ++++ trustgraph-cli/trustgraph/cli/get_kg_core.py | 77 +-- .../trustgraph/cli/invoke_graph_rag.py | 604 ------------------ trustgraph-cli/trustgraph/cli/put_de_core.py | 119 ++++ trustgraph-cli/trustgraph/cli/put_kg_core.py | 99 +-- trustgraph-flow/trustgraph/cores/knowledge.py | 325 +++++++--- trustgraph-flow/trustgraph/cores/service.py | 5 + .../trustgraph/gateway/registry.py | 6 + .../trustgraph/tables/knowledge.py | 94 +++ 14 files changed, 762 insertions(+), 825 deletions(-) create mode 100644 trustgraph-cli/trustgraph/cli/get_de_core.py create mode 100644 trustgraph-cli/trustgraph/cli/put_de_core.py diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index c3ec2308..06357d70 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -132,3 +132,34 @@ class Knowledge: self.request(request = input) + def list_de_cores(self): + + input = { + "operation": "list-de-cores", + "workspace": self.api.workspace, + } + + return self.request(request = input)["ids"] + + def delete_de_core(self, id): + + input = { + "operation": "delete-de-core", + "workspace": self.api.workspace, + "id": id, + } + + self.request(request = input) + + def load_de_core(self, id, flow="default", collection="default"): + + input = { + "operation": "load-de-core", + "workspace": self.api.workspace, + "id": id, + "flow": flow, + "collection": collection, + } + + self.request(request = input) + diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index aeb15f85..75a7be9a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -491,6 +491,58 @@ class SocketClient: triples=raw_triples, ) + def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-kg-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_kg_core( + self, id: str, triples=None, graph_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-kg-core", + "workspace": self.workspace, + "id": id, + } + if triples is not None: + request["triples"] = triples + if graph_embeddings is not None: + request["graph-embeddings"] = graph_embeddings + return self._send_request_sync("knowledge", None, request) + + def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-de-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_de_core( + self, id: str, document_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-de-core", + "workspace": self.workspace, + "id": id, + } + if document_embeddings is not None: + request["document-embeddings"] = document_embeddings + return self._send_request_sync("knowledge", None, request) + def close(self) -> None: """Close the persistent WebSocket connection.""" if self._loop and not self._loop.is_closed(): diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f2cc8e46..3830bf59 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + DocumentEmbeddings, ChunkEmbeddings, Metadata, EntityEmbeddings ) from .base import MessageTranslator @@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator): ] ) + document_embeddings = None + if "document-embeddings" in data: + document_embeddings = DocumentEmbeddings( + metadata=Metadata( + id=data["document-embeddings"]["metadata"]["id"], + root=data["document-embeddings"]["metadata"].get("root", ""), + collection=data["document-embeddings"]["metadata"]["collection"] + ), + chunks=[ + ChunkEmbeddings( + chunk_id=ch["chunk_id"], + vector=ch["vector"], + ) + for ch in data["document-embeddings"]["chunks"] + ] + ) + return KnowledgeRequest( operation=data.get("operation"), id=data.get("id"), @@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator): collection=data.get("collection"), triples=triples, graph_embeddings=graph_embeddings, + document_embeddings=document_embeddings, ) def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: @@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator): ], } + if obj.document_embeddings: + result["document-embeddings"] = { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + return result @@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator): } } + # Streaming document embeddings response + if obj.document_embeddings: + return { + "document-embeddings": { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + } + # End of stream marker if obj.eos is True: return {"eos": True} @@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator): is_final = ( obj.ids is not None or # List response obj.eos is True or # End of stream - (not obj.triples and not obj.graph_embeddings) # Empty response + (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response ) return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 64cb7082..a3879103 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -4,7 +4,7 @@ from ..core.topic import queue from ..core.metadata import Metadata from .document import Document, TextDocument from .graph import Triples -from .embeddings import GraphEmbeddings +from .embeddings import GraphEmbeddings, DocumentEmbeddings # get-kg-core # -> (???) @@ -41,6 +41,9 @@ class KnowledgeRequest: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + # put-de-core + document_embeddings: DocumentEmbeddings | None = None + @dataclass class KnowledgeResponse: error: Error | None = None @@ -48,6 +51,7 @@ class KnowledgeResponse: eos: bool = False # Indicates end of knowledge core stream triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + document_embeddings: DocumentEmbeddings | None = None knowledge_request_queue = queue('knowledge', cls='request') knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index e8062fba..10dca2e8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" +tg-get-de-core = "trustgraph.cli.get_de_core:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" @@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main" tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main" +tg-put-de-core = "trustgraph.cli.put_de_core:main" tg-put-kg-core = "trustgraph.cli.put_kg_core:main" tg-remove-library-document = "trustgraph.cli.remove_library_document:main" tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main" diff --git a/trustgraph-cli/trustgraph/cli/get_de_core.py b/trustgraph-cli/trustgraph/cli/get_de_core.py new file mode 100644 index 00000000..caf74ba9 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/get_de_core.py @@ -0,0 +1,111 @@ +""" +Uses the knowledge service to fetch a document embeddings core which is +saved to a local file in msgpack format. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def write_de(f, data): + msg = ( + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["root"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "i": ch["chunk_id"], + "v": ch["vector"], + } + for ch in data["chunks"] + ] + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + +def fetch(url, workspace, id, output, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(output, "wb") as f: + + for response in socket.get_de_core(id): + + if "document-embeddings" in response: + de += 1 + write_de(f, response["document-embeddings"]) + + print(f"Got: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-get-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help=f'Output file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index 8bee4115..b4f37b81 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,13 +5,11 @@ to a local file in msgpack format. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,7 +19,7 @@ def write_triple(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -35,13 +33,13 @@ def write_ge(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -49,54 +47,18 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, workspace, id, output, token=None): +def fetch(url, workspace, id, output, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "get-kg-core", - "workspace": workspace, - "id": id, - } - }) - - await ws.send(req) + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 with open(output, "wb") as f: - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "response" not in obj: - raise RuntimeError("No response?") - - response = obj["response"] - - if "error" in response: - raise RuntimeError(obj["error"]) - - if "eos" in response: - if response["eos"]: break + for response in socket.get_kg_core(id): if "triples" in response: t += 1 @@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None): print(f"Got: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -151,14 +114,12 @@ def main(): try: - asyncio.run( - fetch( - url=args.url, - workspace=args.workspace, - id=args.id, - output=args.output, - token=args.token, - ) + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 23d6bcac..f39cdab0 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question """ import argparse -import json import os import sys -import websockets -import asyncio from trustgraph.api import ( Api, ExplainabilityClient, @@ -31,607 +28,6 @@ default_max_path_length = 2 default_edge_score_limit = 30 default_edge_limit = 25 -# Provenance predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_CONCEPT = TG + "concept" -TG_ENTITY = TG + "entity" -TG_EDGE_COUNT = TG + "edgeCount" -TG_SELECTED_EDGE = TG + "selectedEdge" -TG_EDGE = TG + "edge" -TG_REASONING = TG + "reasoning" -TG_DOCUMENT = TG + "document" -TG_CONTAINS = TG + "contains" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - - -def _get_event_type(prov_id): - """Extract event type from provenance_id""" - if "question" in prov_id: - return "question" - elif "grounding" in prov_id: - return "grounding" - elif "exploration" in prov_id: - return "exploration" - elif "focus" in prov_id: - return "focus" - elif "synthesis" in prov_id: - return "synthesis" - return "provenance" - - -def _format_provenance_details(event_type, triples): - """Format provenance details based on event type and triples""" - lines = [] - - if event_type == "question": - # Show query and timestamp - for s, p, o in triples: - if p == TG_QUERY: - lines.append(f" Query: {o}") - elif p == PROV_STARTED_AT_TIME: - lines.append(f" Time: {o}") - - elif event_type == "grounding": - # Show extracted concepts - concepts = [o for s, p, o in triples if p == TG_CONCEPT] - if concepts: - lines.append(f" Concepts: {len(concepts)}") - for concept in concepts: - lines.append(f" - {concept}") - - elif event_type == "exploration": - # Show edge count (seed entities resolved separately with labels) - for s, p, o in triples: - if p == TG_EDGE_COUNT: - lines.append(f" Edges explored: {o}") - - elif event_type == "focus": - # For focus, just count edge selection URIs - # The actual edge details are fetched separately via edge_selections parameter - edge_sel_uris = [] - for s, p, o in triples: - if p == TG_SELECTED_EDGE: - edge_sel_uris.append(o) - if edge_sel_uris: - lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") - - elif event_type == "synthesis": - # Show document reference (content already streamed) - for s, p, o in triples: - if p == TG_DOCUMENT: - lines.append(f" Document: {o}") - - return lines - - -async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): - """Query triples for a provenance node (single attempt)""" - request = { - "id": "triples-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": prov_id}, - "collection": collection, - "limit": 100 - } - } - # Add graph filter if specified (for named graph queries) - if graph is not None: - request["request"]["g"] = graph - - if debug: - print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) - - triples = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if debug: - print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) - - if response.get("id") != "triples-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - # Handle triples response - # Response format: {"response": [triples...]} - # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) - p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) - # Handle quoted triples (type "t") and regular values - o_term = t.get("o", {}) - if o_term.get("t") == "t": - # Quoted triple - extract s, p, o from nested structure - tr = o_term.get("tr", {}) - o = { - "s": tr.get("s", {}).get("i", ""), - "p": tr.get("p", {}).get("i", ""), - "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), - } - else: - o = o_term.get("i", o_term.get("v", "")) - triples.append((s, p, o)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception: {e}", file=sys.stderr) - - if debug: - print(f" [debug] got {len(triples)} triples", file=sys.stderr) - - return triples - - -async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): - """Query triples for a provenance node with retries for race condition""" - for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) - if triples: - return triples - # Wait before retry if empty (triples may not be stored yet) - if attempt < max_retries - 1: - if debug: - print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) - await asyncio.sleep(retry_delay) - return [] - - -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): - """ - Query for provenance of an edge (s, p, o) in the knowledge graph. - - Finds subgraphs that contain the edge via tg:contains, then follows - prov:wasDerivedFrom to find source documents. - - Returns list of source URIs (chunks, pages, documents). - """ - # Query for subgraphs that contain this edge: ?subgraph tg:contains <> - request = { - "id": "edge-prov-request", - "service": "triples", - "flow": flow_id, - "request": { - "p": {"t": "i", "i": TG_CONTAINS}, - "o": { - "t": "t", # Quoted triple type - "tr": { - "s": {"t": "i", "i": edge_s}, - "p": {"t": "i", "i": edge_p}, - "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, - } - }, - "collection": collection, - "limit": 10 - } - } - - if debug: - print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) - - stmt_uris = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "edge-prov-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", "") - if s: - stmt_uris.append(s) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) - - # For each statement, query wasDerivedFrom to find sources - sources = [] - for stmt_uri in stmt_uris: - # Query: stmt_uri prov:wasDerivedFrom ?source - request = { - "id": "derived-from-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": stmt_uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 10 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "derived-from-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - o = t.get("o", {}).get("i", "") - if o: - sources.append(o) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) - - return sources - - -async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): - """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" - request = { - "id": "parent-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 1 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "parent-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - return triple_list[0].get("o", {}).get("i", None) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying parent: {e}", file=sys.stderr) - - return None - - -async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): - """ - Trace the full provenance chain from a source URI up to the root document. - Returns a list of (uri, label) tuples from leaf to root. - """ - chain = [] - current = source_uri - max_depth = 10 # Prevent infinite loops - - for _ in range(max_depth): - if not current: - break - - # Get label for current entity - label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) - chain.append((current, label)) - - # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) - if not parent or parent == current: - break - current = parent - - return chain - - -def _format_provenance_chain(chain): - """ - Format a provenance chain as a human-readable string. - Chain is [(uri, label), ...] from leaf to root. - """ - if not chain: - return "" - - # Show labels, from leaf to root - labels = [label for uri, label in chain] - return " → ".join(labels) - - -def _is_iri(value): - """Check if a value looks like an IRI.""" - if not isinstance(value, str): - return False - return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") - - -async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): - """ - Query for the rdfs:label of an IRI. - Uses label_cache to avoid repeated queries. - Returns the label if found, otherwise returns the IRI. - """ - if not _is_iri(iri): - return iri - - # Check cache first - if iri in label_cache: - return label_cache[iri] - - request = { - "id": "label-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": iri}, - "p": {"t": "i", "i": RDFS_LABEL}, - "collection": collection, - "limit": 1 - } - } - - label = iri # Default to IRI if no label found - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "label-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - # Get the label value - o = triple_list[0].get("o", {}) - label = o.get("v", o.get("i", iri)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) - - # Cache the result - label_cache[iri] = label - return label - - -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): - """ - Resolve labels for all IRI components of an edge triple. - Returns (s_label, p_label, o_label). - """ - s = edge_triple.get("s", "?") - p = edge_triple.get("p", "?") - o = edge_triple.get("o", "?") - - s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) - - return s_label, p_label, o_label - - -async def _question_explainable( - url, flow_id, question, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, token=None, debug=False -): - """Execute graph RAG with explainability - shows provenance events with details""" - # Convert HTTP URL to WebSocket URL - if url.startswith("http://"): - ws_url = url.replace("http://", "ws://", 1) - elif url.startswith("https://"): - ws_url = url.replace("https://", "wss://", 1) - else: - ws_url = f"ws://{url}" - - ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" - if token: - ws_url = f"{ws_url}?token={token}" - - # Cache for label lookups to avoid repeated queries - label_cache = {} - - request = { - "id": "cli-request", - "service": "graph-rag", - "flow": flow_id, - "request": { - "query": question, - "collection": collection, - "entity-limit": entity_limit, - "triple-limit": triple_limit, - "max-subgraph-size": max_subgraph_size, - "max-path-length": max_path_length, - "streaming": True - } - } - - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "cli-request": - continue - - if "error" in response: - print(f"\nError: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - - # Check for errors in response - if "error" in resp and resp["error"]: - err = resp["error"] - print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) - break - - message_type = resp.get("message_type", "") - - if debug: - print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) - - if message_type == "explain": - # Display explain event with details - explain_id = resp.get("explain_id", "") - explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval) - if explain_id: - event_type = _get_event_type(explain_id) - print(f"\n [{event_type}] {explain_id}", file=sys.stderr) - - # Query triples for this explain node (using named graph filter) - triples = await _query_triples( - ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug - ) - - # Format and display details - details = _format_provenance_details(event_type, triples) - for line in details: - print(line, file=sys.stderr) - - # For exploration events, resolve entity labels - if event_type == "exploration": - entity_iris = [o for s, p, o in triples if p == TG_ENTITY] - if entity_iris: - print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) - for iri in entity_iris: - label = await _query_label( - ws_url, flow_id, iri, collection, - label_cache, debug=debug - ) - print(f" - {label}", file=sys.stderr) - - # For focus events, query each edge selection for details - if event_type == "focus": - for s, p, o in triples: - if debug: - print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) - if p == TG_SELECTED_EDGE and isinstance(o, str): - if debug: - print(f" [debug] querying edge selection: {o}", file=sys.stderr) - # Query the edge selection entity (using named graph filter) - edge_triples = await _query_triples( - ws_url, flow_id, o, collection, graph=explain_graph, debug=debug - ) - if debug: - print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) - # Extract edge and reasoning - edge_triple = None # Store the actual triple for provenance lookup - reasoning = None - for es, ep, eo in edge_triples: - if debug: - print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) - if ep == TG_EDGE and isinstance(eo, dict): - # eo is a quoted triple dict - edge_triple = eo - elif ep == TG_REASONING: - reasoning = eo - if edge_triple: - # Resolve labels for edge components - s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, collection, - label_cache, debug=debug - ) - print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if reasoning: - r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning - print(f" Reason: {r_short}", file=sys.stderr) - - # Trace edge provenance in the workspace collection (not explainability) - if edge_triple: - sources = await _query_edge_provenance( - ws_url, flow_id, - edge_triple.get("s", ""), - edge_triple.get("p", ""), - edge_triple.get("o", ""), - collection, # Use the query collection, not explainability - debug=debug - ) - if sources: - for src in sources: - # Trace full chain from source to root document - chain = await _trace_provenance_chain( - ws_url, flow_id, src, collection, - label_cache, debug=debug - ) - chain_str = _format_provenance_chain(chain) - print(f" Source: {chain_str}", file=sys.stderr) - - elif message_type == "chunk" or not message_type: - # Display response chunk - chunk = resp.get("response", "") - if chunk: - print(chunk, end="", flush=True) - - # Check if session is complete - if resp.get("end_of_session"): - break - - print() # Final newline - - def _question_explainable_api( url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, diff --git a/trustgraph-cli/trustgraph/cli/put_de_core.py b/trustgraph-cli/trustgraph/cli/put_de_core.py new file mode 100644 index 00000000..1d6589af --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/put_de_core.py @@ -0,0 +1,119 @@ +""" +Puts a document embeddings core into the knowledge manager via the API +socket. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): + + if unpacked[0] == "de": + msg = unpacked[1] + return { + "metadata": { + "id": id, + "root": msg["m"]["m"], + "collection": "default", + }, + "chunks": [ + { + "chunk_id": ch["i"], + "vector": ch["v"], + } + for ch in msg["c"] + ], + } + else: + raise RuntimeError("Unexpected message type", unpacked[0]) + +def put(url, workspace, id, input, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(input, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while True: + + try: + unpacked = unpacker.unpack() + except msgpack.OutOfData: + break + + msg = read_message(unpacked, id) + de += 1 + socket.put_de_core(id, document_embeddings=msg) + + print(f"Put: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-put-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help=f'Input file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index bd3169c8..fe0981a5 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,13 +19,13 @@ def read_message(unpacked, id): return "ge", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used? + "root": msg["m"]["m"], + "collection": "default", }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ], @@ -37,26 +35,20 @@ def read_message(unpacked, id): return "t", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used by receiver? + "root": msg["m"]["m"], + "collection": "default", }, "triples": msg["t"], } else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, workspace, id, input, token=None): +def put(url, workspace, id, input, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - async with connect(url) as ws: + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 @@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None): try: unpacked = unpacker.unpack() - except: + except msgpack.OutOfData: break kind, msg = read_message(unpacked, id) - mid = str(uuid.uuid4()) - if kind == "ge": - ge += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "graph-embeddings": msg - } - }) + socket.put_kg_core(id, graph_embeddings=msg) elif kind == "t": - t += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "triples": msg - } - }) + socket.put_kg_core(id, triples=msg) else: - raise RuntimeError("Unexpected message kind", kind) - await ws.send(req) - - # Retry loop, wait for right response to come back - while True: - - msg = await ws.recv() - msg = json.loads(msg) - - if msg["id"] != mid: - continue - - if "response" in msg: - if "error" in msg["response"]: - raise RuntimeError(msg["response"]["error"]) - - break - print(f"Put: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -173,14 +122,12 @@ def main(): try: - asyncio.run( - put( - url=args.url, - workspace=args.workspace, - id=args.id, - input=args.input, - token=args.token, - ) + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index 09c6137d..f1fa53f5 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -1,5 +1,6 @@ from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings +from .. schema import DocumentEmbeddings from .. knowledge import hash from .. exceptions import RequestError from .. tables.knowledge import KnowledgeTableStore @@ -157,6 +158,98 @@ class KnowledgeManager: ) ) + async def list_de_cores(self, request, respond, workspace): + + ids = await self.table_store.list_de_cores(workspace) + + await respond( + KnowledgeResponse( + error = None, + ids = ids, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def get_de_core(self, request, respond, workspace): + + logger.info("Getting document embeddings core...") + + async def publish_de(de): + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + document_embeddings = de, + ) + ) + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core retrieval complete") + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = True, + triples = None, + graph_embeddings = None, + ) + ) + + async def put_de_core(self, request, respond, workspace): + + if request.document_embeddings: + await self.table_store.add_document_embeddings( + workspace, request.document_embeddings + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def delete_de_core(self, request, respond, workspace): + + logger.info("Deleting document embeddings core...") + + await self.table_store.delete_document_embeddings( + workspace, request.id + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def load_de_core(self, request, respond, workspace): + + if self.background_task is None: + self.background_task = asyncio.create_task( + self.core_loader() + ) + + await self.loader_queue.put((request, respond, workspace)) + async def core_loader(self): logger.info("Knowledge background processor running...") @@ -165,7 +258,7 @@ class KnowledgeManager: logger.debug("Waiting for next load...") request, respond, workspace = await self.loader_queue.get() - logger.info(f"Loading knowledge: {request.id}") + logger.info(f"Loading: {request.operation} {request.id}") try: @@ -187,25 +280,14 @@ class KnowledgeManager: if "interfaces" not in flow: raise RuntimeError("No defined interfaces") - if "triples-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no triples-store") - - if "graph-embeddings-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no graph-embeddings-store") - - t_q = flow["interfaces"]["triples-store"]["flow"] - ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] - - # Got this far, it should all work - await respond( - KnowledgeResponse( - error = None, - ids = None, - eos = False, - triples = None, - graph_embeddings = None + if request.operation == "load-de-core": + await self._load_de_core( + request, respond, workspace, flow, + ) + else: + await self._load_kg_core( + request, respond, workspace, flow, ) - ) except Exception as e: @@ -223,72 +305,145 @@ class KnowledgeManager: ) ) - - logger.debug("Starting knowledge loading process...") - - try: - - t_pub = None - ge_pub = None - - logger.debug(f"Triples queue: {t_q}") - logger.debug(f"Graph embeddings queue: {ge_q}") - - t_pub = Publisher( - self.flow_config.pubsub, t_q, - schema=Triples, - ) - ge_pub = Publisher( - self.flow_config.pubsub, ge_q, - schema=GraphEmbeddings - ) - - logger.debug("Starting publishers...") - - await t_pub.start() - await ge_pub.start() - - async def publish_triples(t): - # Override collection with request collection - if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): - t.metadata.collection = request.collection or "default" - await t_pub.send(None, t) - - logger.debug("Publishing triples...") - - await self.table_store.get_triples( - workspace, - request.id, - publish_triples, - ) - - async def publish_ge(g): - # Override collection with request collection - if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): - g.metadata.collection = request.collection or "default" - await ge_pub.send(None, g) - - logger.debug("Publishing graph embeddings...") - - await self.table_store.get_graph_embeddings( - workspace, - request.id, - publish_ge, - ) - - logger.debug("Knowledge loading completed") - - except Exception as e: - - logger.error(f"Knowledge exception: {e}", exc_info=True) - - finally: - - logger.debug("Stopping publishers...") - - if t_pub: await t_pub.stop() - if ge_pub: await ge_pub.stop() - logger.debug("Knowledge processing done") continue + + async def _load_kg_core(self, request, respond, workspace, flow): + + if "triples-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no triples-store") + + if "graph-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no graph-embeddings-store") + + t_q = flow["interfaces"]["triples-store"]["flow"] + ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + t_pub = None + ge_pub = None + + try: + + logger.debug(f"Triples queue: {t_q}") + logger.debug(f"Graph embeddings queue: {ge_q}") + + t_pub = Publisher( + self.flow_config.pubsub, t_q, + schema=Triples, + ) + ge_pub = Publisher( + self.flow_config.pubsub, ge_q, + schema=GraphEmbeddings + ) + + logger.debug("Starting publishers...") + + await t_pub.start() + await ge_pub.start() + + async def publish_triples(t): + if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): + t.metadata.collection = request.collection or "default" + await t_pub.send(None, t) + + logger.debug("Publishing triples...") + + await self.table_store.get_triples( + workspace, + request.id, + publish_triples, + ) + + async def publish_ge(g): + if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): + g.metadata.collection = request.collection or "default" + await ge_pub.send(None, g) + + logger.debug("Publishing graph embeddings...") + + await self.table_store.get_graph_embeddings( + workspace, + request.id, + publish_ge, + ) + + logger.debug("Knowledge core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publishers...") + + if t_pub: await t_pub.stop() + if ge_pub: await ge_pub.stop() + + async def _load_de_core(self, request, respond, workspace, flow): + + if "document-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no document-embeddings-store") + + de_q = flow["interfaces"]["document-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + de_pub = None + + try: + + logger.debug(f"Document embeddings queue: {de_q}") + + de_pub = Publisher( + self.flow_config.pubsub, de_q, + schema=DocumentEmbeddings, + ) + + logger.debug("Starting publisher...") + + await de_pub.start() + + async def publish_de(de): + if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'): + de.metadata.collection = request.collection or "default" + await de_pub.send(None, de) + + logger.debug("Publishing document embeddings...") + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publisher...") + + if de_pub: await de_pub.stop() diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index c84b536c..a04e42ca 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor): "put-kg-core": self.knowledge.put_kg_core, "load-kg-core": self.knowledge.load_kg_core, "unload-kg-core": self.knowledge.unload_kg_core, + "list-de-cores": self.knowledge.list_de_cores, + "get-de-core": self.knowledge.get_de_core, + "delete-de-core": self.knowledge.delete_de_core, + "put-de-core": self.knowledge.put_de_core, + "load-de-core": self.knowledge.load_de_core, } if v.operation not in impls: diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index 5e3344f4..4d439097 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core", "load-kg-core", "unload-kg-core"): _register_kind_op("knowledge", _op, "knowledge:write") +# knowledge: document-embeddings core service. +for _op in ("get-de-core", "list-de-cores"): + _register_kind_op("knowledge", _op, "knowledge:read") +for _op in ("put-de-core", "delete-de-core", "load-de-core"): + _register_kind_op("knowledge", _op, "knowledge:write") + # collection-management: workspace collection lifecycle. _register_kind_op("collection-management", "list-collections", "collections:read") diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 5d45358d..cf085fdd 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -1,6 +1,7 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings +from .. schema import DocumentEmbeddings, ChunkEmbeddings from cassandra.cluster import Cluster @@ -217,6 +218,16 @@ class KnowledgeTableStore: WHERE workspace = ? AND document_id = ? """) + self.delete_document_embeddings_stmt = self.cassandra.prepare(""" + DELETE FROM document_embeddings + WHERE workspace = ? AND document_id = ? + """) + + self.list_de_cores_stmt = self.cassandra.prepare(""" + SELECT DISTINCT workspace, document_id FROM document_embeddings + WHERE workspace = ? + """) + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -338,6 +349,50 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def delete_document_embeddings(self, workspace, document_id): + + logger.debug("Delete document embeddings...") + + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def list_de_cores(self, workspace): + + logger.debug("List DE cores...") + + try: + rows = await async_execute( + self.cassandra, + self.list_de_cores_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + lst = [row[1] for row in rows] + + logger.debug("Done") + + return lst + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -417,3 +472,42 @@ class KnowledgeTableStore: logger.debug("Done") + async def get_document_embeddings(self, workspace, document_id, receiver): + + logger.debug("Get DE...") + + try: + rows = await async_execute( + self.cassandra, + self.get_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + for row in rows: + + if row[3]: + chunks = [ + ChunkEmbeddings( + chunk_id=ch[0], + vector=ch[1], + ) + for ch in row[3] + ] + else: + chunks = [] + + await receiver( + DocumentEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + chunks = chunks + ) + ) + + logger.debug("Done") +