CLI auth migration, document embeddings core lifecycle (#913)

Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient
with first-frame auth (fixes broken raw websocket path). Fix wire
format field names (root/vector). Remove ~600 lines of dead raw
websocket code from invoke_graph_rag.py.

Add document embeddings core lifecycle to the knowledge service:
list/get/put/delete/load operations across schema, translator,
Cassandra table store, knowledge manager, gateway registry, REST API,
socket client, and CLI (tg-get-de-core, tg-put-de-core).

Fix delete_kg_core to also clean up document embeddings rows.
This commit is contained in:
cybermaggedon 2026-05-14 10:30:21 +01:00 committed by GitHub
parent dd974b0cac
commit f0ad282708
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 762 additions and 825 deletions

View file

@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
tg-dump-queues = "trustgraph.cli.dump_queues:main"
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
tg-get-de-core = "trustgraph.cli.get_de_core:main"
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
tg-get-document-content = "trustgraph.cli.get_document_content:main"
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main"
tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
tg-put-de-core = "trustgraph.cli.put_de_core:main"
tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"

View file

@ -0,0 +1,111 @@
"""
Uses the knowledge service to fetch a document embeddings core which is
saved to a local file in msgpack format.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def write_de(f, data):
msg = (
"de",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"c": [
{
"i": ch["chunk_id"],
"v": ch["vector"],
}
for ch in data["chunks"]
]
}
)
f.write(msgpack.packb(msg, use_bin_type=True))
def fetch(url, workspace, id, output, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(output, "wb") as f:
for response in socket.get_de_core(id):
if "document-embeddings" in response:
de += 1
write_de(f, response["document-embeddings"])
print(f"Got: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-get-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-o', '--output',
required=True,
help=f'Output file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
fetch(
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -5,13 +5,11 @@ to a local file in msgpack format.
import argparse
import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,7 +19,7 @@ def write_triple(f, data):
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
@ -35,13 +33,13 @@ def write_ge(f, data):
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"e": [
{
"e": ent["entity"],
"v": ent["vectors"],
"v": ent["vector"],
}
for ent in data["entities"]
]
@ -49,54 +47,18 @@ def write_ge(f, data):
)
f.write(msgpack.packb(msg, use_bin_type=True))
async def fetch(url, workspace, id, output, token=None):
def fetch(url, workspace, id, output, token=None):
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
mid = str(uuid.uuid4())
async with connect(url) as ws:
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "get-kg-core",
"workspace": workspace,
"id": id,
}
})
await ws.send(req)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
ge = 0
t = 0
with open(output, "wb") as f:
while True:
msg = await ws.recv()
obj = json.loads(msg)
if "response" not in obj:
raise RuntimeError("No response?")
response = obj["response"]
if "error" in response:
raise RuntimeError(obj["error"])
if "eos" in response:
if response["eos"]: break
for response in socket.get_kg_core(id):
if "triples" in response:
t += 1
@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None):
print(f"Got: {t} triple, {ge} GE messages.")
await ws.close()
finally:
socket.close()
def main():
@ -151,14 +114,12 @@ def main():
try:
asyncio.run(
fetch(
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
fetch(
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
except Exception as e:

View file

@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question
"""
import argparse
import json
import os
import sys
import websockets
import asyncio
from trustgraph.api import (
Api,
ExplainabilityClient,
@ -31,607 +28,6 @@ default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_DOCUMENT = TG + "document"
TG_CONTAINS = TG + "contains"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _get_event_type(prov_id):
"""Extract event type from provenance_id"""
if "question" in prov_id:
return "question"
elif "grounding" in prov_id:
return "grounding"
elif "exploration" in prov_id:
return "exploration"
elif "focus" in prov_id:
return "focus"
elif "synthesis" in prov_id:
return "synthesis"
return "provenance"
def _format_provenance_details(event_type, triples):
"""Format provenance details based on event type and triples"""
lines = []
if event_type == "question":
# Show query and timestamp
for s, p, o in triples:
if p == TG_QUERY:
lines.append(f" Query: {o}")
elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}")
elif event_type == "grounding":
# Show extracted concepts
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
if concepts:
lines.append(f" Concepts: {len(concepts)}")
for concept in concepts:
lines.append(f" - {concept}")
elif event_type == "exploration":
# Show edge count (seed entities resolved separately with labels)
for s, p, o in triples:
if p == TG_EDGE_COUNT:
lines.append(f" Edges explored: {o}")
elif event_type == "focus":
# For focus, just count edge selection URIs
# The actual edge details are fetched separately via edge_selections parameter
edge_sel_uris = []
for s, p, o in triples:
if p == TG_SELECTED_EDGE:
edge_sel_uris.append(o)
if edge_sel_uris:
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
elif event_type == "synthesis":
# Show document reference (content already streamed)
for s, p, o in triples:
if p == TG_DOCUMENT:
lines.append(f" Document: {o}")
return lines
async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False):
"""Query triples for a provenance node (single attempt)"""
request = {
"id": "triples-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": prov_id},
"collection": collection,
"limit": 100
}
}
# Add graph filter if specified (for named graph queries)
if graph is not None:
request["request"]["g"] = graph
if debug:
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
triples = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if debug:
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
if response.get("id") != "triples-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Handle triples response
# Response format: {"response": [triples...]}
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
# Handle quoted triples (type "t") and regular values
o_term = t.get("o", {})
if o_term.get("t") == "t":
# Quoted triple - extract s, p, o from nested structure
tr = o_term.get("tr", {})
o = {
"s": tr.get("s", {}).get("i", ""),
"p": tr.get("p", {}).get("i", ""),
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
}
else:
o = o_term.get("i", o_term.get("v", ""))
triples.append((s, p, o))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception: {e}", file=sys.stderr)
if debug:
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
return triples
async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False):
"""Query triples for a provenance node with retries for race condition"""
for attempt in range(max_retries):
triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug)
if triples:
return triples
# Wait before retry if empty (triples may not be stored yet)
if attempt < max_retries - 1:
if debug:
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
await asyncio.sleep(retry_delay)
return []
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False):
"""
Query for provenance of an edge (s, p, o) in the knowledge graph.
Finds subgraphs that contain the edge via tg:contains, then follows
prov:wasDerivedFrom to find source documents.
Returns list of source URIs (chunks, pages, documents).
"""
# Query for subgraphs that contain this edge: ?subgraph tg:contains <<s p o>>
request = {
"id": "edge-prov-request",
"service": "triples",
"flow": flow_id,
"request": {
"p": {"t": "i", "i": TG_CONTAINS},
"o": {
"t": "t", # Quoted triple type
"tr": {
"s": {"t": "i", "i": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
}
},
"collection": collection,
"limit": 10
}
}
if debug:
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
stmt_uris = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "edge-prov-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", "")
if s:
stmt_uris.append(s)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
# For each statement, query wasDerivedFrom to find sources
sources = []
for stmt_uri in stmt_uris:
# Query: stmt_uri prov:wasDerivedFrom ?source
request = {
"id": "derived-from-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": stmt_uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 10
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "derived-from-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
o = t.get("o", {}).get("i", "")
if o:
sources.append(o)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
return sources
async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False):
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
request = {
"id": "parent-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 1
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "parent-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
return triple_list[0].get("o", {}).get("i", None)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
return None
async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False):
"""
Trace the full provenance chain from a source URI up to the root document.
Returns a list of (uri, label) tuples from leaf to root.
"""
chain = []
current = source_uri
max_depth = 10 # Prevent infinite loops
for _ in range(max_depth):
if not current:
break
# Get label for current entity
label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug)
chain.append((current, label))
# Get parent
parent = await _query_derived_from(ws_url, flow_id, current, collection, debug)
if not parent or parent == current:
break
current = parent
return chain
def _format_provenance_chain(chain):
"""
Format a provenance chain as a human-readable string.
Chain is [(uri, label), ...] from leaf to root.
"""
if not chain:
return ""
# Show labels, from leaf to root
labels = [label for uri, label in chain]
return "".join(labels)
def _is_iri(value):
"""Check if a value looks like an IRI."""
if not isinstance(value, str):
return False
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False):
"""
Query for the rdfs:label of an IRI.
Uses label_cache to avoid repeated queries.
Returns the label if found, otherwise returns the IRI.
"""
if not _is_iri(iri):
return iri
# Check cache first
if iri in label_cache:
return label_cache[iri]
request = {
"id": "label-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": iri},
"p": {"t": "i", "i": RDFS_LABEL},
"collection": collection,
"limit": 1
}
}
label = iri # Default to IRI if no label found
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "label-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
# Get the label value
o = triple_list[0].get("o", {})
label = o.get("v", o.get("i", iri))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
# Cache the result
label_cache[iri] = label
return label
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False):
"""
Resolve labels for all IRI components of an edge triple.
Returns (s_label, p_label, o_label).
"""
s = edge_triple.get("s", "?")
p = edge_triple.get("p", "?")
o = edge_triple.get("o", "?")
s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug)
p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug)
o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug)
return s_label, p_label, o_label
async def _question_explainable(
url, flow_id, question, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability - shows provenance events with details"""
# Convert HTTP URL to WebSocket URL
if url.startswith("http://"):
ws_url = url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
ws_url = url.replace("https://", "wss://", 1)
else:
ws_url = f"ws://{url}"
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
if token:
ws_url = f"{ws_url}?token={token}"
# Cache for label lookups to avoid repeated queries
label_cache = {}
request = {
"id": "cli-request",
"service": "graph-rag",
"flow": flow_id,
"request": {
"query": question,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"streaming": True
}
}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "cli-request":
continue
if "error" in response:
print(f"\nError: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Check for errors in response
if "error" in resp and resp["error"]:
err = resp["error"]
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
break
message_type = resp.get("message_type", "")
if debug:
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
if message_type == "explain":
# Display explain event with details
explain_id = resp.get("explain_id", "")
explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval)
if explain_id:
event_type = _get_event_type(explain_id)
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
# Query triples for this explain node (using named graph filter)
triples = await _query_triples(
ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug
)
# Format and display details
details = _format_provenance_details(event_type, triples)
for line in details:
print(line, file=sys.stderr)
# For exploration events, resolve entity labels
if event_type == "exploration":
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
if entity_iris:
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
for iri in entity_iris:
label = await _query_label(
ws_url, flow_id, iri, collection,
label_cache, debug=debug
)
print(f" - {label}", file=sys.stderr)
# For focus events, query each edge selection for details
if event_type == "focus":
for s, p, o in triples:
if debug:
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
if p == TG_SELECTED_EDGE and isinstance(o, str):
if debug:
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
# Query the edge selection entity (using named graph filter)
edge_triples = await _query_triples(
ws_url, flow_id, o, collection, graph=explain_graph, debug=debug
)
if debug:
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
# Extract edge and reasoning
edge_triple = None # Store the actual triple for provenance lookup
reasoning = None
for es, ep, eo in edge_triples:
if debug:
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
if ep == TG_EDGE and isinstance(eo, dict):
# eo is a quoted triple dict
edge_triple = eo
elif ep == TG_REASONING:
reasoning = eo
if edge_triple:
# Resolve labels for edge components
s_label, p_label, o_label = await _resolve_edge_labels(
ws_url, flow_id, edge_triple, collection,
label_cache, debug=debug
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
print(f" Reason: {r_short}", file=sys.stderr)
# Trace edge provenance in the workspace collection (not explainability)
if edge_triple:
sources = await _query_edge_provenance(
ws_url, flow_id,
edge_triple.get("s", ""),
edge_triple.get("p", ""),
edge_triple.get("o", ""),
collection, # Use the query collection, not explainability
debug=debug
)
if sources:
for src in sources:
# Trace full chain from source to root document
chain = await _trace_provenance_chain(
ws_url, flow_id, src, collection,
label_cache, debug=debug
)
chain_str = _format_provenance_chain(chain)
print(f" Source: {chain_str}", file=sys.stderr)
elif message_type == "chunk" or not message_type:
# Display response chunk
chunk = resp.get("response", "")
if chunk:
print(chunk, end="", flush=True)
# Check if session is complete
if resp.get("end_of_session"):
break
print() # Final newline
def _question_explainable_api(
url, flow_id, question_text, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=30,

View file

@ -0,0 +1,119 @@
"""
Puts a document embeddings core into the knowledge manager via the API
socket.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def read_message(unpacked, id):
if unpacked[0] == "de":
msg = unpacked[1]
return {
"metadata": {
"id": id,
"root": msg["m"]["m"],
"collection": "default",
},
"chunks": [
{
"chunk_id": ch["i"],
"vector": ch["v"],
}
for ch in msg["c"]
],
}
else:
raise RuntimeError("Unexpected message type", unpacked[0])
def put(url, workspace, id, input, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(input, "rb") as f:
unpacker = msgpack.Unpacker(f, raw=False)
while True:
try:
unpacked = unpacker.unpack()
except msgpack.OutOfData:
break
msg = read_message(unpacked, id)
de += 1
socket.put_de_core(id, document_embeddings=msg)
print(f"Put: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-put-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-i', '--input',
required=True,
help=f'Input file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
put(
url=args.url,
workspace=args.workspace,
id=args.id,
input=args.input,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket.
import argparse
import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,13 +19,13 @@ def read_message(unpacked, id):
return "ge", {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"collection": "default", # Not used?
"root": msg["m"]["m"],
"collection": "default",
},
"entities": [
{
"entity": ent["e"],
"vectors": ent["v"],
"vector": ent["v"],
}
for ent in msg["e"]
],
@ -37,26 +35,20 @@ def read_message(unpacked, id):
return "t", {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"collection": "default", # Not used by receiver?
"root": msg["m"]["m"],
"collection": "default",
},
"triples": msg["t"],
}
else:
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
async def put(url, workspace, id, input, token=None):
def put(url, workspace, id, input, token=None):
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
async with connect(url) as ws:
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
ge = 0
t = 0
@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None):
try:
unpacked = unpacker.unpack()
except:
except msgpack.OutOfData:
break
kind, msg = read_message(unpacked, id)
mid = str(uuid.uuid4())
if kind == "ge":
ge += 1
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"graph-embeddings": msg
}
})
socket.put_kg_core(id, graph_embeddings=msg)
elif kind == "t":
t += 1
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"triples": msg
}
})
socket.put_kg_core(id, triples=msg)
else:
raise RuntimeError("Unexpected message kind", kind)
await ws.send(req)
# Retry loop, wait for right response to come back
while True:
msg = await ws.recv()
msg = json.loads(msg)
if msg["id"] != mid:
continue
if "response" in msg:
if "error" in msg["response"]:
raise RuntimeError(msg["response"]["error"])
break
print(f"Put: {t} triple, {ge} GE messages.")
await ws.close()
finally:
socket.close()
def main():
@ -173,14 +122,12 @@ def main():
try:
asyncio.run(
put(
url=args.url,
workspace=args.workspace,
id=args.id,
input=args.input,
token=args.token,
)
put(
url=args.url,
workspace=args.workspace,
id=args.id,
input=args.input,
token=args.token,
)
except Exception as e: