CLI auth migration, document embeddings core lifecycle (#913)

Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient
with first-frame auth (fixes broken raw websocket path). Fix wire
format field names (root/vector). Remove ~600 lines of dead raw
websocket code from invoke_graph_rag.py.

Add document embeddings core lifecycle to the knowledge service:
list/get/put/delete/load operations across schema, translator,
Cassandra table store, knowledge manager, gateway registry, REST API,
socket client, and CLI (tg-get-de-core, tg-put-de-core).

Fix delete_kg_core to also clean up document embeddings rows.
This commit is contained in:
cybermaggedon 2026-05-14 10:30:21 +01:00 committed by GitHub
parent dd974b0cac
commit f0ad282708
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 762 additions and 825 deletions

View file

@ -132,3 +132,34 @@ class Knowledge:
self.request(request = input) self.request(request = input)
def list_de_cores(self):
input = {
"operation": "list-de-cores",
"workspace": self.api.workspace,
}
return self.request(request = input)["ids"]
def delete_de_core(self, id):
input = {
"operation": "delete-de-core",
"workspace": self.api.workspace,
"id": id,
}
self.request(request = input)
def load_de_core(self, id, flow="default", collection="default"):
input = {
"operation": "load-de-core",
"workspace": self.api.workspace,
"id": id,
"flow": flow,
"collection": collection,
}
self.request(request = input)

View file

@ -491,6 +491,58 @@ class SocketClient:
triples=raw_triples, triples=raw_triples,
) )
def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]:
request = {
"operation": "get-kg-core",
"workspace": self.workspace,
"id": id,
}
for response in self._send_request_sync(
"knowledge", None, request, streaming_raw=True,
):
if response.get("eos"):
break
yield response
def put_kg_core(
self, id: str, triples=None, graph_embeddings=None,
) -> Dict[str, Any]:
request = {
"operation": "put-kg-core",
"workspace": self.workspace,
"id": id,
}
if triples is not None:
request["triples"] = triples
if graph_embeddings is not None:
request["graph-embeddings"] = graph_embeddings
return self._send_request_sync("knowledge", None, request)
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
request = {
"operation": "get-de-core",
"workspace": self.workspace,
"id": id,
}
for response in self._send_request_sync(
"knowledge", None, request, streaming_raw=True,
):
if response.get("eos"):
break
yield response
def put_de_core(
self, id: str, document_embeddings=None,
) -> Dict[str, Any]:
request = {
"operation": "put-de-core",
"workspace": self.workspace,
"id": id,
}
if document_embeddings is not None:
request["document-embeddings"] = document_embeddings
return self._send_request_sync("knowledge", None, request)
def close(self) -> None: def close(self) -> None:
"""Close the persistent WebSocket connection.""" """Close the persistent WebSocket connection."""
if self._loop and not self._loop.is_closed(): if self._loop and not self._loop.is_closed():

View file

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple, Optional from typing import Dict, Any, Tuple, Optional
from ...schema import ( from ...schema import (
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
DocumentEmbeddings, ChunkEmbeddings,
Metadata, EntityEmbeddings Metadata, EntityEmbeddings
) )
from .base import MessageTranslator from .base import MessageTranslator
@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator):
] ]
) )
document_embeddings = None
if "document-embeddings" in data:
document_embeddings = DocumentEmbeddings(
metadata=Metadata(
id=data["document-embeddings"]["metadata"]["id"],
root=data["document-embeddings"]["metadata"].get("root", ""),
collection=data["document-embeddings"]["metadata"]["collection"]
),
chunks=[
ChunkEmbeddings(
chunk_id=ch["chunk_id"],
vector=ch["vector"],
)
for ch in data["document-embeddings"]["chunks"]
]
)
return KnowledgeRequest( return KnowledgeRequest(
operation=data.get("operation"), operation=data.get("operation"),
id=data.get("id"), id=data.get("id"),
@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
collection=data.get("collection"), collection=data.get("collection"),
triples=triples, triples=triples,
graph_embeddings=graph_embeddings, graph_embeddings=graph_embeddings,
document_embeddings=document_embeddings,
) )
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator):
], ],
} }
if obj.document_embeddings:
result["document-embeddings"] = {
"metadata": {
"id": obj.document_embeddings.metadata.id,
"root": obj.document_embeddings.metadata.root,
"collection": obj.document_embeddings.metadata.collection,
},
"chunks": [
{
"chunk_id": ch.chunk_id,
"vector": ch.vector,
}
for ch in obj.document_embeddings.chunks
],
}
return result return result
@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator):
} }
} }
# Streaming document embeddings response
if obj.document_embeddings:
return {
"document-embeddings": {
"metadata": {
"id": obj.document_embeddings.metadata.id,
"root": obj.document_embeddings.metadata.root,
"collection": obj.document_embeddings.metadata.collection,
},
"chunks": [
{
"chunk_id": ch.chunk_id,
"vector": ch.vector,
}
for ch in obj.document_embeddings.chunks
],
}
}
# End of stream marker # End of stream marker
if obj.eos is True: if obj.eos is True:
return {"eos": True} return {"eos": True}
@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator):
is_final = ( is_final = (
obj.ids is not None or # List response obj.ids is not None or # List response
obj.eos is True or # End of stream obj.eos is True or # End of stream
(not obj.triples and not obj.graph_embeddings) # Empty response (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
) )
return response, is_final return response, is_final

View file

@ -4,7 +4,7 @@ from ..core.topic import queue
from ..core.metadata import Metadata from ..core.metadata import Metadata
from .document import Document, TextDocument from .document import Document, TextDocument
from .graph import Triples from .graph import Triples
from .embeddings import GraphEmbeddings from .embeddings import GraphEmbeddings, DocumentEmbeddings
# get-kg-core # get-kg-core
# -> (???) # -> (???)
@ -41,6 +41,9 @@ class KnowledgeRequest:
triples: Triples | None = None triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None graph_embeddings: GraphEmbeddings | None = None
# put-de-core
document_embeddings: DocumentEmbeddings | None = None
@dataclass @dataclass
class KnowledgeResponse: class KnowledgeResponse:
error: Error | None = None error: Error | None = None
@ -48,6 +51,7 @@ class KnowledgeResponse:
eos: bool = False # Indicates end of knowledge core stream eos: bool = False # Indicates end of knowledge core stream
triples: Triples | None = None triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None graph_embeddings: GraphEmbeddings | None = None
document_embeddings: DocumentEmbeddings | None = None
knowledge_request_queue = queue('knowledge', cls='request') knowledge_request_queue = queue('knowledge', cls='request')
knowledge_response_queue = queue('knowledge', cls='response') knowledge_response_queue = queue('knowledge', cls='response')

View file

@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-dump-queues = "trustgraph.cli.dump_queues:main"
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
tg-get-de-core = "trustgraph.cli.get_de_core:main"
tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-get-document-content = "trustgraph.cli.get_document_content:main"
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main"
tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
tg-load-structured-data = "trustgraph.cli.load_structured_data:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main" tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
tg-put-de-core = "trustgraph.cli.put_de_core:main"
tg-put-kg-core = "trustgraph.cli.put_kg_core:main" tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
tg-remove-library-document = "trustgraph.cli.remove_library_document:main" tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main" tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"

View file

@ -0,0 +1,111 @@
"""
Uses the knowledge service to fetch a document embeddings core which is
saved to a local file in msgpack format.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def write_de(f, data):
msg = (
"de",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"c": [
{
"i": ch["chunk_id"],
"v": ch["vector"],
}
for ch in data["chunks"]
]
}
)
f.write(msgpack.packb(msg, use_bin_type=True))
def fetch(url, workspace, id, output, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(output, "wb") as f:
for response in socket.get_de_core(id):
if "document-embeddings" in response:
de += 1
write_de(f, response["document-embeddings"])
print(f"Got: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-get-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-o', '--output',
required=True,
help=f'Output file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
fetch(
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -5,13 +5,11 @@ to a local file in msgpack format.
import argparse import argparse
import os import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,7 +19,7 @@ def write_triple(f, data):
{ {
"m": { "m": {
"i": data["metadata"]["id"], "i": data["metadata"]["id"],
"m": data["metadata"]["metadata"], "m": data["metadata"]["root"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },
"t": data["triples"], "t": data["triples"],
@ -35,13 +33,13 @@ def write_ge(f, data):
{ {
"m": { "m": {
"i": data["metadata"]["id"], "i": data["metadata"]["id"],
"m": data["metadata"]["metadata"], "m": data["metadata"]["root"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },
"e": [ "e": [
{ {
"e": ent["entity"], "e": ent["entity"],
"v": ent["vectors"], "v": ent["vector"],
} }
for ent in data["entities"] for ent in data["entities"]
] ]
@ -49,54 +47,18 @@ def write_ge(f, data):
) )
f.write(msgpack.packb(msg, use_bin_type=True)) f.write(msgpack.packb(msg, use_bin_type=True))
async def fetch(url, workspace, id, output, token=None): def fetch(url, workspace, id, output, token=None):
if not url.endswith("/"): api = Api(url=url, token=token, workspace=workspace)
url += "/" socket = api.socket()
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
mid = str(uuid.uuid4())
async with connect(url) as ws:
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "get-kg-core",
"workspace": workspace,
"id": id,
}
})
await ws.send(req)
try:
ge = 0 ge = 0
t = 0 t = 0
with open(output, "wb") as f: with open(output, "wb") as f:
while True: for response in socket.get_kg_core(id):
msg = await ws.recv()
obj = json.loads(msg)
if "response" not in obj:
raise RuntimeError("No response?")
response = obj["response"]
if "error" in response:
raise RuntimeError(obj["error"])
if "eos" in response:
if response["eos"]: break
if "triples" in response: if "triples" in response:
t += 1 t += 1
@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None):
print(f"Got: {t} triple, {ge} GE messages.") print(f"Got: {t} triple, {ge} GE messages.")
await ws.close() finally:
socket.close()
def main(): def main():
@ -151,7 +114,6 @@ def main():
try: try:
asyncio.run(
fetch( fetch(
url=args.url, url=args.url,
workspace=args.workspace, workspace=args.workspace,
@ -159,7 +121,6 @@ def main():
output=args.output, output=args.output,
token=args.token, token=args.token,
) )
)
except Exception as e: except Exception as e:

View file

@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question
""" """
import argparse import argparse
import json
import os import os
import sys import sys
import websockets
import asyncio
from trustgraph.api import ( from trustgraph.api import (
Api, Api,
ExplainabilityClient, ExplainabilityClient,
@ -31,607 +28,6 @@ default_max_path_length = 2
default_edge_score_limit = 30 default_edge_score_limit = 30
default_edge_limit = 25 default_edge_limit = 25
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_DOCUMENT = TG + "document"
TG_CONTAINS = TG + "contains"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _get_event_type(prov_id):
"""Extract event type from provenance_id"""
if "question" in prov_id:
return "question"
elif "grounding" in prov_id:
return "grounding"
elif "exploration" in prov_id:
return "exploration"
elif "focus" in prov_id:
return "focus"
elif "synthesis" in prov_id:
return "synthesis"
return "provenance"
def _format_provenance_details(event_type, triples):
"""Format provenance details based on event type and triples"""
lines = []
if event_type == "question":
# Show query and timestamp
for s, p, o in triples:
if p == TG_QUERY:
lines.append(f" Query: {o}")
elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}")
elif event_type == "grounding":
# Show extracted concepts
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
if concepts:
lines.append(f" Concepts: {len(concepts)}")
for concept in concepts:
lines.append(f" - {concept}")
elif event_type == "exploration":
# Show edge count (seed entities resolved separately with labels)
for s, p, o in triples:
if p == TG_EDGE_COUNT:
lines.append(f" Edges explored: {o}")
elif event_type == "focus":
# For focus, just count edge selection URIs
# The actual edge details are fetched separately via edge_selections parameter
edge_sel_uris = []
for s, p, o in triples:
if p == TG_SELECTED_EDGE:
edge_sel_uris.append(o)
if edge_sel_uris:
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
elif event_type == "synthesis":
# Show document reference (content already streamed)
for s, p, o in triples:
if p == TG_DOCUMENT:
lines.append(f" Document: {o}")
return lines
async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False):
"""Query triples for a provenance node (single attempt)"""
request = {
"id": "triples-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": prov_id},
"collection": collection,
"limit": 100
}
}
# Add graph filter if specified (for named graph queries)
if graph is not None:
request["request"]["g"] = graph
if debug:
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
triples = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if debug:
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
if response.get("id") != "triples-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Handle triples response
# Response format: {"response": [triples...]}
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
# Handle quoted triples (type "t") and regular values
o_term = t.get("o", {})
if o_term.get("t") == "t":
# Quoted triple - extract s, p, o from nested structure
tr = o_term.get("tr", {})
o = {
"s": tr.get("s", {}).get("i", ""),
"p": tr.get("p", {}).get("i", ""),
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
}
else:
o = o_term.get("i", o_term.get("v", ""))
triples.append((s, p, o))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception: {e}", file=sys.stderr)
if debug:
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
return triples
async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False):
"""Query triples for a provenance node with retries for race condition"""
for attempt in range(max_retries):
triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug)
if triples:
return triples
# Wait before retry if empty (triples may not be stored yet)
if attempt < max_retries - 1:
if debug:
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
await asyncio.sleep(retry_delay)
return []
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False):
"""
Query for provenance of an edge (s, p, o) in the knowledge graph.
Finds subgraphs that contain the edge via tg:contains, then follows
prov:wasDerivedFrom to find source documents.
Returns list of source URIs (chunks, pages, documents).
"""
# Query for subgraphs that contain this edge: ?subgraph tg:contains <<s p o>>
request = {
"id": "edge-prov-request",
"service": "triples",
"flow": flow_id,
"request": {
"p": {"t": "i", "i": TG_CONTAINS},
"o": {
"t": "t", # Quoted triple type
"tr": {
"s": {"t": "i", "i": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
}
},
"collection": collection,
"limit": 10
}
}
if debug:
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
stmt_uris = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "edge-prov-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", "")
if s:
stmt_uris.append(s)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
# For each statement, query wasDerivedFrom to find sources
sources = []
for stmt_uri in stmt_uris:
# Query: stmt_uri prov:wasDerivedFrom ?source
request = {
"id": "derived-from-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": stmt_uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 10
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "derived-from-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
o = t.get("o", {}).get("i", "")
if o:
sources.append(o)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
return sources
async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False):
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
request = {
"id": "parent-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 1
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "parent-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
return triple_list[0].get("o", {}).get("i", None)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
return None
async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False):
"""
Trace the full provenance chain from a source URI up to the root document.
Returns a list of (uri, label) tuples from leaf to root.
"""
chain = []
current = source_uri
max_depth = 10 # Prevent infinite loops
for _ in range(max_depth):
if not current:
break
# Get label for current entity
label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug)
chain.append((current, label))
# Get parent
parent = await _query_derived_from(ws_url, flow_id, current, collection, debug)
if not parent or parent == current:
break
current = parent
return chain
def _format_provenance_chain(chain):
"""
Format a provenance chain as a human-readable string.
Chain is [(uri, label), ...] from leaf to root.
"""
if not chain:
return ""
# Show labels, from leaf to root
labels = [label for uri, label in chain]
return "".join(labels)
def _is_iri(value):
"""Check if a value looks like an IRI."""
if not isinstance(value, str):
return False
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False):
"""
Query for the rdfs:label of an IRI.
Uses label_cache to avoid repeated queries.
Returns the label if found, otherwise returns the IRI.
"""
if not _is_iri(iri):
return iri
# Check cache first
if iri in label_cache:
return label_cache[iri]
request = {
"id": "label-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": iri},
"p": {"t": "i", "i": RDFS_LABEL},
"collection": collection,
"limit": 1
}
}
label = iri # Default to IRI if no label found
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "label-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
# Get the label value
o = triple_list[0].get("o", {})
label = o.get("v", o.get("i", iri))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
# Cache the result
label_cache[iri] = label
return label
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False):
"""
Resolve labels for all IRI components of an edge triple.
Returns (s_label, p_label, o_label).
"""
s = edge_triple.get("s", "?")
p = edge_triple.get("p", "?")
o = edge_triple.get("o", "?")
s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug)
p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug)
o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug)
return s_label, p_label, o_label
async def _question_explainable(
url, flow_id, question, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability - shows provenance events with details"""
# Convert HTTP URL to WebSocket URL
if url.startswith("http://"):
ws_url = url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
ws_url = url.replace("https://", "wss://", 1)
else:
ws_url = f"ws://{url}"
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
if token:
ws_url = f"{ws_url}?token={token}"
# Cache for label lookups to avoid repeated queries
label_cache = {}
request = {
"id": "cli-request",
"service": "graph-rag",
"flow": flow_id,
"request": {
"query": question,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"streaming": True
}
}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "cli-request":
continue
if "error" in response:
print(f"\nError: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Check for errors in response
if "error" in resp and resp["error"]:
err = resp["error"]
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
break
message_type = resp.get("message_type", "")
if debug:
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
if message_type == "explain":
# Display explain event with details
explain_id = resp.get("explain_id", "")
explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval)
if explain_id:
event_type = _get_event_type(explain_id)
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
# Query triples for this explain node (using named graph filter)
triples = await _query_triples(
ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug
)
# Format and display details
details = _format_provenance_details(event_type, triples)
for line in details:
print(line, file=sys.stderr)
# For exploration events, resolve entity labels
if event_type == "exploration":
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
if entity_iris:
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
for iri in entity_iris:
label = await _query_label(
ws_url, flow_id, iri, collection,
label_cache, debug=debug
)
print(f" - {label}", file=sys.stderr)
# For focus events, query each edge selection for details
if event_type == "focus":
for s, p, o in triples:
if debug:
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
if p == TG_SELECTED_EDGE and isinstance(o, str):
if debug:
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
# Query the edge selection entity (using named graph filter)
edge_triples = await _query_triples(
ws_url, flow_id, o, collection, graph=explain_graph, debug=debug
)
if debug:
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
# Extract edge and reasoning
edge_triple = None # Store the actual triple for provenance lookup
reasoning = None
for es, ep, eo in edge_triples:
if debug:
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
if ep == TG_EDGE and isinstance(eo, dict):
# eo is a quoted triple dict
edge_triple = eo
elif ep == TG_REASONING:
reasoning = eo
if edge_triple:
# Resolve labels for edge components
s_label, p_label, o_label = await _resolve_edge_labels(
ws_url, flow_id, edge_triple, collection,
label_cache, debug=debug
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
print(f" Reason: {r_short}", file=sys.stderr)
# Trace edge provenance in the workspace collection (not explainability)
if edge_triple:
sources = await _query_edge_provenance(
ws_url, flow_id,
edge_triple.get("s", ""),
edge_triple.get("p", ""),
edge_triple.get("o", ""),
collection, # Use the query collection, not explainability
debug=debug
)
if sources:
for src in sources:
# Trace full chain from source to root document
chain = await _trace_provenance_chain(
ws_url, flow_id, src, collection,
label_cache, debug=debug
)
chain_str = _format_provenance_chain(chain)
print(f" Source: {chain_str}", file=sys.stderr)
elif message_type == "chunk" or not message_type:
# Display response chunk
chunk = resp.get("response", "")
if chunk:
print(chunk, end="", flush=True)
# Check if session is complete
if resp.get("end_of_session"):
break
print() # Final newline
def _question_explainable_api( def _question_explainable_api(
url, flow_id, question_text, collection, entity_limit, triple_limit, url, flow_id, question_text, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=30, max_subgraph_size, max_path_length, edge_score_limit=30,

View file

@ -0,0 +1,119 @@
"""
Puts a document embeddings core into the knowledge manager via the API
socket.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def read_message(unpacked, id):
if unpacked[0] == "de":
msg = unpacked[1]
return {
"metadata": {
"id": id,
"root": msg["m"]["m"],
"collection": "default",
},
"chunks": [
{
"chunk_id": ch["i"],
"vector": ch["v"],
}
for ch in msg["c"]
],
}
else:
raise RuntimeError("Unexpected message type", unpacked[0])
def put(url, workspace, id, input, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(input, "rb") as f:
unpacker = msgpack.Unpacker(f, raw=False)
while True:
try:
unpacked = unpacker.unpack()
except msgpack.OutOfData:
break
msg = read_message(unpacked, id)
de += 1
socket.put_de_core(id, document_embeddings=msg)
print(f"Put: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-put-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-i', '--input',
required=True,
help=f'Input file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
put(
url=args.url,
workspace=args.workspace,
id=args.id,
input=args.input,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket.
import argparse import argparse
import os import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,13 +19,13 @@ def read_message(unpacked, id):
return "ge", { return "ge", {
"metadata": { "metadata": {
"id": id, "id": id,
"metadata": msg["m"]["m"], "root": msg["m"]["m"],
"collection": "default", # Not used? "collection": "default",
}, },
"entities": [ "entities": [
{ {
"entity": ent["e"], "entity": ent["e"],
"vectors": ent["v"], "vector": ent["v"],
} }
for ent in msg["e"] for ent in msg["e"]
], ],
@ -37,26 +35,20 @@ def read_message(unpacked, id):
return "t", { return "t", {
"metadata": { "metadata": {
"id": id, "id": id,
"metadata": msg["m"]["m"], "root": msg["m"]["m"],
"collection": "default", # Not used by receiver? "collection": "default",
}, },
"triples": msg["t"], "triples": msg["t"],
} }
else: else:
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
async def put(url, workspace, id, input, token=None): def put(url, workspace, id, input, token=None):
if not url.endswith("/"): api = Api(url=url, token=token, workspace=workspace)
url += "/" socket = api.socket()
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
async with connect(url) as ws:
try:
ge = 0 ge = 0
t = 0 t = 0
@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None):
try: try:
unpacked = unpacker.unpack() unpacked = unpacker.unpack()
except: except msgpack.OutOfData:
break break
kind, msg = read_message(unpacked, id) kind, msg = read_message(unpacked, id)
mid = str(uuid.uuid4())
if kind == "ge": if kind == "ge":
ge += 1 ge += 1
socket.put_kg_core(id, graph_embeddings=msg)
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"graph-embeddings": msg
}
})
elif kind == "t": elif kind == "t":
t += 1 t += 1
socket.put_kg_core(id, triples=msg)
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"triples": msg
}
})
else: else:
raise RuntimeError("Unexpected message kind", kind) raise RuntimeError("Unexpected message kind", kind)
await ws.send(req)
# Retry loop, wait for right response to come back
while True:
msg = await ws.recv()
msg = json.loads(msg)
if msg["id"] != mid:
continue
if "response" in msg:
if "error" in msg["response"]:
raise RuntimeError(msg["response"]["error"])
break
print(f"Put: {t} triple, {ge} GE messages.") print(f"Put: {t} triple, {ge} GE messages.")
await ws.close() finally:
socket.close()
def main(): def main():
@ -173,7 +122,6 @@ def main():
try: try:
asyncio.run(
put( put(
url=args.url, url=args.url,
workspace=args.workspace, workspace=args.workspace,
@ -181,7 +129,6 @@ def main():
input=args.input, input=args.input,
token=args.token, token=args.token,
) )
)
except Exception as e: except Exception as e:

View file

@ -1,5 +1,6 @@
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
from .. schema import DocumentEmbeddings
from .. knowledge import hash from .. knowledge import hash
from .. exceptions import RequestError from .. exceptions import RequestError
from .. tables.knowledge import KnowledgeTableStore from .. tables.knowledge import KnowledgeTableStore
@ -157,6 +158,98 @@ class KnowledgeManager:
) )
) )
async def list_de_cores(self, request, respond, workspace):
ids = await self.table_store.list_de_cores(workspace)
await respond(
KnowledgeResponse(
error = None,
ids = ids,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def get_de_core(self, request, respond, workspace):
logger.info("Getting document embeddings core...")
async def publish_de(de):
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
document_embeddings = de,
)
)
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core retrieval complete")
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = True,
triples = None,
graph_embeddings = None,
)
)
async def put_de_core(self, request, respond, workspace):
if request.document_embeddings:
await self.table_store.add_document_embeddings(
workspace, request.document_embeddings
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def delete_de_core(self, request, respond, workspace):
logger.info("Deleting document embeddings core...")
await self.table_store.delete_document_embeddings(
workspace, request.id
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def load_de_core(self, request, respond, workspace):
if self.background_task is None:
self.background_task = asyncio.create_task(
self.core_loader()
)
await self.loader_queue.put((request, respond, workspace))
async def core_loader(self): async def core_loader(self):
logger.info("Knowledge background processor running...") logger.info("Knowledge background processor running...")
@ -165,7 +258,7 @@ class KnowledgeManager:
logger.debug("Waiting for next load...") logger.debug("Waiting for next load...")
request, respond, workspace = await self.loader_queue.get() request, respond, workspace = await self.loader_queue.get()
logger.info(f"Loading knowledge: {request.id}") logger.info(f"Loading: {request.operation} {request.id}")
try: try:
@ -187,24 +280,13 @@ class KnowledgeManager:
if "interfaces" not in flow: if "interfaces" not in flow:
raise RuntimeError("No defined interfaces") raise RuntimeError("No defined interfaces")
if "triples-store" not in flow["interfaces"]: if request.operation == "load-de-core":
raise RuntimeError("Flow has no triples-store") await self._load_de_core(
request, respond, workspace, flow,
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
# Got this far, it should all work
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
) )
else:
await self._load_kg_core(
request, respond, workspace, flow,
) )
except Exception as e: except Exception as e:
@ -223,14 +305,36 @@ class KnowledgeManager:
) )
) )
logger.debug("Knowledge processing done")
logger.debug("Starting knowledge loading process...") continue
try: async def _load_kg_core(self, request, respond, workspace, flow):
if "triples-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no triples-store")
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
t_pub = None t_pub = None
ge_pub = None ge_pub = None
try:
logger.debug(f"Triples queue: {t_q}") logger.debug(f"Triples queue: {t_q}")
logger.debug(f"Graph embeddings queue: {ge_q}") logger.debug(f"Graph embeddings queue: {ge_q}")
@ -249,7 +353,6 @@ class KnowledgeManager:
await ge_pub.start() await ge_pub.start()
async def publish_triples(t): async def publish_triples(t):
# Override collection with request collection
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
t.metadata.collection = request.collection or "default" t.metadata.collection = request.collection or "default"
await t_pub.send(None, t) await t_pub.send(None, t)
@ -263,7 +366,6 @@ class KnowledgeManager:
) )
async def publish_ge(g): async def publish_ge(g):
# Override collection with request collection
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
g.metadata.collection = request.collection or "default" g.metadata.collection = request.collection or "default"
await ge_pub.send(None, g) await ge_pub.send(None, g)
@ -276,7 +378,7 @@ class KnowledgeManager:
publish_ge, publish_ge,
) )
logger.debug("Knowledge loading completed") logger.debug("Knowledge core loading completed")
except Exception as e: except Exception as e:
@ -289,6 +391,59 @@ class KnowledgeManager:
if t_pub: await t_pub.stop() if t_pub: await t_pub.stop()
if ge_pub: await ge_pub.stop() if ge_pub: await ge_pub.stop()
logger.debug("Knowledge processing done") async def _load_de_core(self, request, respond, workspace, flow):
continue if "document-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no document-embeddings-store")
de_q = flow["interfaces"]["document-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
de_pub = None
try:
logger.debug(f"Document embeddings queue: {de_q}")
de_pub = Publisher(
self.flow_config.pubsub, de_q,
schema=DocumentEmbeddings,
)
logger.debug("Starting publisher...")
await de_pub.start()
async def publish_de(de):
if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'):
de.metadata.collection = request.collection or "default"
await de_pub.send(None, de)
logger.debug("Publishing document embeddings...")
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core loading completed")
except Exception as e:
logger.error(f"Knowledge exception: {e}", exc_info=True)
finally:
logger.debug("Stopping publisher...")
if de_pub: await de_pub.stop()

View file

@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor):
"put-kg-core": self.knowledge.put_kg_core, "put-kg-core": self.knowledge.put_kg_core,
"load-kg-core": self.knowledge.load_kg_core, "load-kg-core": self.knowledge.load_kg_core,
"unload-kg-core": self.knowledge.unload_kg_core, "unload-kg-core": self.knowledge.unload_kg_core,
"list-de-cores": self.knowledge.list_de_cores,
"get-de-core": self.knowledge.get_de_core,
"delete-de-core": self.knowledge.delete_de_core,
"put-de-core": self.knowledge.put_de_core,
"load-de-core": self.knowledge.load_de_core,
} }
if v.operation not in impls: if v.operation not in impls:

View file

@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core",
"load-kg-core", "unload-kg-core"): "load-kg-core", "unload-kg-core"):
_register_kind_op("knowledge", _op, "knowledge:write") _register_kind_op("knowledge", _op, "knowledge:write")
# knowledge: document-embeddings core service.
for _op in ("get-de-core", "list-de-cores"):
_register_kind_op("knowledge", _op, "knowledge:read")
for _op in ("put-de-core", "delete-de-core", "load-de-core"):
_register_kind_op("knowledge", _op, "knowledge:write")
# collection-management: workspace collection lifecycle. # collection-management: workspace collection lifecycle.
_register_kind_op("collection-management", "list-collections", "collections:read") _register_kind_op("collection-management", "list-collections", "collections:read")

View file

@ -1,6 +1,7 @@
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
@ -217,6 +218,16 @@ class KnowledgeTableStore:
WHERE workspace = ? AND document_id = ? WHERE workspace = ? AND document_id = ?
""") """)
self.delete_document_embeddings_stmt = self.cassandra.prepare("""
DELETE FROM document_embeddings
WHERE workspace = ? AND document_id = ?
""")
self.list_de_cores_stmt = self.cassandra.prepare("""
SELECT DISTINCT workspace, document_id FROM document_embeddings
WHERE workspace = ?
""")
async def add_triples(self, workspace, m): async def add_triples(self, workspace, m):
when = int(time.time() * 1000) when = int(time.time() * 1000)
@ -338,6 +349,50 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
raise raise
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def delete_document_embeddings(self, workspace, document_id):
logger.debug("Delete document embeddings...")
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def list_de_cores(self, workspace):
logger.debug("List DE cores...")
try:
rows = await async_execute(
self.cassandra,
self.list_de_cores_stmt,
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
lst = [row[1] for row in rows]
logger.debug("Done")
return lst
async def get_triples(self, workspace, document_id, receiver): async def get_triples(self, workspace, document_id, receiver):
logger.debug("Get triples...") logger.debug("Get triples...")
@ -417,3 +472,42 @@ class KnowledgeTableStore:
logger.debug("Done") logger.debug("Done")
async def get_document_embeddings(self, workspace, document_id, receiver):
logger.debug("Get DE...")
try:
rows = await async_execute(
self.cassandra,
self.get_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
for row in rows:
if row[3]:
chunks = [
ChunkEmbeddings(
chunk_id=ch[0],
vector=ch[1],
)
for ch in row[3]
]
else:
chunks = []
await receiver(
DocumentEmbeddings(
metadata = Metadata(
id = document_id,
collection = "default",
),
chunks = chunks
)
)
logger.debug("Done")