mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 10:55:13 +02:00
CLI auth migration, document embeddings core lifecycle (#913)
Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows.
This commit is contained in:
parent
dd974b0cac
commit
f0ad282708
14 changed files with 762 additions and 825 deletions
|
|
@ -132,3 +132,34 @@ class Knowledge:
|
||||||
|
|
||||||
self.request(request = input)
|
self.request(request = input)
|
||||||
|
|
||||||
|
def list_de_cores(self):
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"operation": "list-de-cores",
|
||||||
|
"workspace": self.api.workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.request(request = input)["ids"]
|
||||||
|
|
||||||
|
def delete_de_core(self, id):
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"operation": "delete-de-core",
|
||||||
|
"workspace": self.api.workspace,
|
||||||
|
"id": id,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.request(request = input)
|
||||||
|
|
||||||
|
def load_de_core(self, id, flow="default", collection="default"):
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"operation": "load-de-core",
|
||||||
|
"workspace": self.api.workspace,
|
||||||
|
"id": id,
|
||||||
|
"flow": flow,
|
||||||
|
"collection": collection,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.request(request = input)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -491,6 +491,58 @@ class SocketClient:
|
||||||
triples=raw_triples,
|
triples=raw_triples,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
request = {
|
||||||
|
"operation": "get-kg-core",
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": id,
|
||||||
|
}
|
||||||
|
for response in self._send_request_sync(
|
||||||
|
"knowledge", None, request, streaming_raw=True,
|
||||||
|
):
|
||||||
|
if response.get("eos"):
|
||||||
|
break
|
||||||
|
yield response
|
||||||
|
|
||||||
|
def put_kg_core(
|
||||||
|
self, id: str, triples=None, graph_embeddings=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
request = {
|
||||||
|
"operation": "put-kg-core",
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": id,
|
||||||
|
}
|
||||||
|
if triples is not None:
|
||||||
|
request["triples"] = triples
|
||||||
|
if graph_embeddings is not None:
|
||||||
|
request["graph-embeddings"] = graph_embeddings
|
||||||
|
return self._send_request_sync("knowledge", None, request)
|
||||||
|
|
||||||
|
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
request = {
|
||||||
|
"operation": "get-de-core",
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": id,
|
||||||
|
}
|
||||||
|
for response in self._send_request_sync(
|
||||||
|
"knowledge", None, request, streaming_raw=True,
|
||||||
|
):
|
||||||
|
if response.get("eos"):
|
||||||
|
break
|
||||||
|
yield response
|
||||||
|
|
||||||
|
def put_de_core(
|
||||||
|
self, id: str, document_embeddings=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
request = {
|
||||||
|
"operation": "put-de-core",
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": id,
|
||||||
|
}
|
||||||
|
if document_embeddings is not None:
|
||||||
|
request["document-embeddings"] = document_embeddings
|
||||||
|
return self._send_request_sync("knowledge", None, request)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the persistent WebSocket connection."""
|
"""Close the persistent WebSocket connection."""
|
||||||
if self._loop and not self._loop.is_closed():
|
if self._loop and not self._loop.is_closed():
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Dict, Any, Tuple, Optional
|
from typing import Dict, Any, Tuple, Optional
|
||||||
from ...schema import (
|
from ...schema import (
|
||||||
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
||||||
|
DocumentEmbeddings, ChunkEmbeddings,
|
||||||
Metadata, EntityEmbeddings
|
Metadata, EntityEmbeddings
|
||||||
)
|
)
|
||||||
from .base import MessageTranslator
|
from .base import MessageTranslator
|
||||||
|
|
@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
document_embeddings = None
|
||||||
|
if "document-embeddings" in data:
|
||||||
|
document_embeddings = DocumentEmbeddings(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["document-embeddings"]["metadata"]["id"],
|
||||||
|
root=data["document-embeddings"]["metadata"].get("root", ""),
|
||||||
|
collection=data["document-embeddings"]["metadata"]["collection"]
|
||||||
|
),
|
||||||
|
chunks=[
|
||||||
|
ChunkEmbeddings(
|
||||||
|
chunk_id=ch["chunk_id"],
|
||||||
|
vector=ch["vector"],
|
||||||
|
)
|
||||||
|
for ch in data["document-embeddings"]["chunks"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return KnowledgeRequest(
|
return KnowledgeRequest(
|
||||||
operation=data.get("operation"),
|
operation=data.get("operation"),
|
||||||
id=data.get("id"),
|
id=data.get("id"),
|
||||||
|
|
@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
collection=data.get("collection"),
|
collection=data.get("collection"),
|
||||||
triples=triples,
|
triples=triples,
|
||||||
graph_embeddings=graph_embeddings,
|
graph_embeddings=graph_embeddings,
|
||||||
|
document_embeddings=document_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
||||||
|
|
@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if obj.document_embeddings:
|
||||||
|
result["document-embeddings"] = {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.document_embeddings.metadata.id,
|
||||||
|
"root": obj.document_embeddings.metadata.root,
|
||||||
|
"collection": obj.document_embeddings.metadata.collection,
|
||||||
|
},
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"chunk_id": ch.chunk_id,
|
||||||
|
"vector": ch.vector,
|
||||||
|
}
|
||||||
|
for ch in obj.document_embeddings.chunks
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Streaming document embeddings response
|
||||||
|
if obj.document_embeddings:
|
||||||
|
return {
|
||||||
|
"document-embeddings": {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.document_embeddings.metadata.id,
|
||||||
|
"root": obj.document_embeddings.metadata.root,
|
||||||
|
"collection": obj.document_embeddings.metadata.collection,
|
||||||
|
},
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"chunk_id": ch.chunk_id,
|
||||||
|
"vector": ch.vector,
|
||||||
|
}
|
||||||
|
for ch in obj.document_embeddings.chunks
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# End of stream marker
|
# End of stream marker
|
||||||
if obj.eos is True:
|
if obj.eos is True:
|
||||||
return {"eos": True}
|
return {"eos": True}
|
||||||
|
|
@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
||||||
is_final = (
|
is_final = (
|
||||||
obj.ids is not None or # List response
|
obj.ids is not None or # List response
|
||||||
obj.eos is True or # End of stream
|
obj.eos is True or # End of stream
|
||||||
(not obj.triples and not obj.graph_embeddings) # Empty response
|
(not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
|
||||||
)
|
)
|
||||||
|
|
||||||
return response, is_final
|
return response, is_final
|
||||||
|
|
@ -4,7 +4,7 @@ from ..core.topic import queue
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from .document import Document, TextDocument
|
from .document import Document, TextDocument
|
||||||
from .graph import Triples
|
from .graph import Triples
|
||||||
from .embeddings import GraphEmbeddings
|
from .embeddings import GraphEmbeddings, DocumentEmbeddings
|
||||||
|
|
||||||
# get-kg-core
|
# get-kg-core
|
||||||
# -> (???)
|
# -> (???)
|
||||||
|
|
@ -41,6 +41,9 @@ class KnowledgeRequest:
|
||||||
triples: Triples | None = None
|
triples: Triples | None = None
|
||||||
graph_embeddings: GraphEmbeddings | None = None
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
|
|
||||||
|
# put-de-core
|
||||||
|
document_embeddings: DocumentEmbeddings | None = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KnowledgeResponse:
|
class KnowledgeResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
|
|
@ -48,6 +51,7 @@ class KnowledgeResponse:
|
||||||
eos: bool = False # Indicates end of knowledge core stream
|
eos: bool = False # Indicates end of knowledge core stream
|
||||||
triples: Triples | None = None
|
triples: Triples | None = None
|
||||||
graph_embeddings: GraphEmbeddings | None = None
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
|
document_embeddings: DocumentEmbeddings | None = None
|
||||||
|
|
||||||
knowledge_request_queue = queue('knowledge', cls='request')
|
knowledge_request_queue = queue('knowledge', cls='request')
|
||||||
knowledge_response_queue = queue('knowledge', cls='response')
|
knowledge_response_queue = queue('knowledge', cls='response')
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
|
||||||
tg-dump-queues = "trustgraph.cli.dump_queues:main"
|
tg-dump-queues = "trustgraph.cli.dump_queues:main"
|
||||||
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
|
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
|
||||||
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
|
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
|
||||||
|
tg-get-de-core = "trustgraph.cli.get_de_core:main"
|
||||||
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
|
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
|
||||||
tg-get-document-content = "trustgraph.cli.get_document_content:main"
|
tg-get-document-content = "trustgraph.cli.get_document_content:main"
|
||||||
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
|
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
|
||||||
|
|
@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main"
|
||||||
tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
|
tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
|
||||||
tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
|
tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
|
||||||
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
|
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
|
||||||
|
tg-put-de-core = "trustgraph.cli.put_de_core:main"
|
||||||
tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
|
tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
|
||||||
tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
|
tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
|
||||||
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"
|
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"
|
||||||
|
|
|
||||||
111
trustgraph-cli/trustgraph/cli/get_de_core.py
Normal file
111
trustgraph-cli/trustgraph/cli/get_de_core.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""
|
||||||
|
Uses the knowledge service to fetch a document embeddings core which is
|
||||||
|
saved to a local file in msgpack format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from trustgraph.api import Api
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
|
|
||||||
|
def write_de(f, data):
|
||||||
|
msg = (
|
||||||
|
"de",
|
||||||
|
{
|
||||||
|
"m": {
|
||||||
|
"i": data["metadata"]["id"],
|
||||||
|
"m": data["metadata"]["root"],
|
||||||
|
"c": data["metadata"]["collection"],
|
||||||
|
},
|
||||||
|
"c": [
|
||||||
|
{
|
||||||
|
"i": ch["chunk_id"],
|
||||||
|
"v": ch["vector"],
|
||||||
|
}
|
||||||
|
for ch in data["chunks"]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||||
|
|
||||||
|
def fetch(url, workspace, id, output, token=None):
|
||||||
|
|
||||||
|
api = Api(url=url, token=token, workspace=workspace)
|
||||||
|
socket = api.socket()
|
||||||
|
|
||||||
|
try:
|
||||||
|
de = 0
|
||||||
|
|
||||||
|
with open(output, "wb") as f:
|
||||||
|
|
||||||
|
for response in socket.get_de_core(id):
|
||||||
|
|
||||||
|
if "document-embeddings" in response:
|
||||||
|
de += 1
|
||||||
|
write_de(f, response["document-embeddings"])
|
||||||
|
|
||||||
|
print(f"Got: {de} document embeddings messages.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
socket.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='tg-get-de-core',
|
||||||
|
description=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-u', '--url',
|
||||||
|
default=default_url,
|
||||||
|
help=f'API URL (default: {default_url})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-w', '--workspace',
|
||||||
|
default=default_workspace,
|
||||||
|
help=f'Workspace (default: {default_workspace})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--id', '--identifier',
|
||||||
|
required=True,
|
||||||
|
help=f'Document embeddings core ID',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-o', '--output',
|
||||||
|
required=True,
|
||||||
|
help=f'Output file'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=default_token,
|
||||||
|
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
fetch(
|
||||||
|
url=args.url,
|
||||||
|
workspace=args.workspace,
|
||||||
|
id=args.id,
|
||||||
|
output=args.output,
|
||||||
|
token=args.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -5,13 +5,11 @@ to a local file in msgpack format.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from websockets.asyncio.client import connect
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
from trustgraph.api import Api
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
|
|
||||||
|
|
@ -21,7 +19,7 @@ def write_triple(f, data):
|
||||||
{
|
{
|
||||||
"m": {
|
"m": {
|
||||||
"i": data["metadata"]["id"],
|
"i": data["metadata"]["id"],
|
||||||
"m": data["metadata"]["metadata"],
|
"m": data["metadata"]["root"],
|
||||||
"c": data["metadata"]["collection"],
|
"c": data["metadata"]["collection"],
|
||||||
},
|
},
|
||||||
"t": data["triples"],
|
"t": data["triples"],
|
||||||
|
|
@ -35,13 +33,13 @@ def write_ge(f, data):
|
||||||
{
|
{
|
||||||
"m": {
|
"m": {
|
||||||
"i": data["metadata"]["id"],
|
"i": data["metadata"]["id"],
|
||||||
"m": data["metadata"]["metadata"],
|
"m": data["metadata"]["root"],
|
||||||
"c": data["metadata"]["collection"],
|
"c": data["metadata"]["collection"],
|
||||||
},
|
},
|
||||||
"e": [
|
"e": [
|
||||||
{
|
{
|
||||||
"e": ent["entity"],
|
"e": ent["entity"],
|
||||||
"v": ent["vectors"],
|
"v": ent["vector"],
|
||||||
}
|
}
|
||||||
for ent in data["entities"]
|
for ent in data["entities"]
|
||||||
]
|
]
|
||||||
|
|
@ -49,54 +47,18 @@ def write_ge(f, data):
|
||||||
)
|
)
|
||||||
f.write(msgpack.packb(msg, use_bin_type=True))
|
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||||
|
|
||||||
async def fetch(url, workspace, id, output, token=None):
|
def fetch(url, workspace, id, output, token=None):
|
||||||
|
|
||||||
if not url.endswith("/"):
|
api = Api(url=url, token=token, workspace=workspace)
|
||||||
url += "/"
|
socket = api.socket()
|
||||||
|
|
||||||
url = url + "api/v1/socket"
|
|
||||||
|
|
||||||
if token:
|
|
||||||
url = f"{url}?token={token}"
|
|
||||||
|
|
||||||
mid = str(uuid.uuid4())
|
|
||||||
|
|
||||||
async with connect(url) as ws:
|
|
||||||
|
|
||||||
req = json.dumps({
|
|
||||||
"id": mid,
|
|
||||||
"workspace": workspace,
|
|
||||||
"service": "knowledge",
|
|
||||||
"request": {
|
|
||||||
"operation": "get-kg-core",
|
|
||||||
"workspace": workspace,
|
|
||||||
"id": id,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
await ws.send(req)
|
|
||||||
|
|
||||||
|
try:
|
||||||
ge = 0
|
ge = 0
|
||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
with open(output, "wb") as f:
|
with open(output, "wb") as f:
|
||||||
|
|
||||||
while True:
|
for response in socket.get_kg_core(id):
|
||||||
|
|
||||||
msg = await ws.recv()
|
|
||||||
|
|
||||||
obj = json.loads(msg)
|
|
||||||
|
|
||||||
if "response" not in obj:
|
|
||||||
raise RuntimeError("No response?")
|
|
||||||
|
|
||||||
response = obj["response"]
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
raise RuntimeError(obj["error"])
|
|
||||||
|
|
||||||
if "eos" in response:
|
|
||||||
if response["eos"]: break
|
|
||||||
|
|
||||||
if "triples" in response:
|
if "triples" in response:
|
||||||
t += 1
|
t += 1
|
||||||
|
|
@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None):
|
||||||
|
|
||||||
print(f"Got: {t} triple, {ge} GE messages.")
|
print(f"Got: {t} triple, {ge} GE messages.")
|
||||||
|
|
||||||
await ws.close()
|
finally:
|
||||||
|
socket.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
|
@ -151,7 +114,6 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
asyncio.run(
|
|
||||||
fetch(
|
fetch(
|
||||||
url=args.url,
|
url=args.url,
|
||||||
workspace=args.workspace,
|
workspace=args.workspace,
|
||||||
|
|
@ -159,7 +121,6 @@ def main():
|
||||||
output=args.output,
|
output=args.output,
|
||||||
token=args.token,
|
token=args.token,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import websockets
|
|
||||||
import asyncio
|
|
||||||
from trustgraph.api import (
|
from trustgraph.api import (
|
||||||
Api,
|
Api,
|
||||||
ExplainabilityClient,
|
ExplainabilityClient,
|
||||||
|
|
@ -31,607 +28,6 @@ default_max_path_length = 2
|
||||||
default_edge_score_limit = 30
|
default_edge_score_limit = 30
|
||||||
default_edge_limit = 25
|
default_edge_limit = 25
|
||||||
|
|
||||||
# Provenance predicates
|
|
||||||
TG = "https://trustgraph.ai/ns/"
|
|
||||||
TG_QUERY = TG + "query"
|
|
||||||
TG_CONCEPT = TG + "concept"
|
|
||||||
TG_ENTITY = TG + "entity"
|
|
||||||
TG_EDGE_COUNT = TG + "edgeCount"
|
|
||||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
|
||||||
TG_EDGE = TG + "edge"
|
|
||||||
TG_REASONING = TG + "reasoning"
|
|
||||||
TG_DOCUMENT = TG + "document"
|
|
||||||
TG_CONTAINS = TG + "contains"
|
|
||||||
PROV = "http://www.w3.org/ns/prov#"
|
|
||||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
|
||||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
|
||||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_event_type(prov_id):
|
|
||||||
"""Extract event type from provenance_id"""
|
|
||||||
if "question" in prov_id:
|
|
||||||
return "question"
|
|
||||||
elif "grounding" in prov_id:
|
|
||||||
return "grounding"
|
|
||||||
elif "exploration" in prov_id:
|
|
||||||
return "exploration"
|
|
||||||
elif "focus" in prov_id:
|
|
||||||
return "focus"
|
|
||||||
elif "synthesis" in prov_id:
|
|
||||||
return "synthesis"
|
|
||||||
return "provenance"
|
|
||||||
|
|
||||||
|
|
||||||
def _format_provenance_details(event_type, triples):
|
|
||||||
"""Format provenance details based on event type and triples"""
|
|
||||||
lines = []
|
|
||||||
|
|
||||||
if event_type == "question":
|
|
||||||
# Show query and timestamp
|
|
||||||
for s, p, o in triples:
|
|
||||||
if p == TG_QUERY:
|
|
||||||
lines.append(f" Query: {o}")
|
|
||||||
elif p == PROV_STARTED_AT_TIME:
|
|
||||||
lines.append(f" Time: {o}")
|
|
||||||
|
|
||||||
elif event_type == "grounding":
|
|
||||||
# Show extracted concepts
|
|
||||||
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
|
|
||||||
if concepts:
|
|
||||||
lines.append(f" Concepts: {len(concepts)}")
|
|
||||||
for concept in concepts:
|
|
||||||
lines.append(f" - {concept}")
|
|
||||||
|
|
||||||
elif event_type == "exploration":
|
|
||||||
# Show edge count (seed entities resolved separately with labels)
|
|
||||||
for s, p, o in triples:
|
|
||||||
if p == TG_EDGE_COUNT:
|
|
||||||
lines.append(f" Edges explored: {o}")
|
|
||||||
|
|
||||||
elif event_type == "focus":
|
|
||||||
# For focus, just count edge selection URIs
|
|
||||||
# The actual edge details are fetched separately via edge_selections parameter
|
|
||||||
edge_sel_uris = []
|
|
||||||
for s, p, o in triples:
|
|
||||||
if p == TG_SELECTED_EDGE:
|
|
||||||
edge_sel_uris.append(o)
|
|
||||||
if edge_sel_uris:
|
|
||||||
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
|
|
||||||
|
|
||||||
elif event_type == "synthesis":
|
|
||||||
# Show document reference (content already streamed)
|
|
||||||
for s, p, o in triples:
|
|
||||||
if p == TG_DOCUMENT:
|
|
||||||
lines.append(f" Document: {o}")
|
|
||||||
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False):
|
|
||||||
"""Query triples for a provenance node (single attempt)"""
|
|
||||||
request = {
|
|
||||||
"id": "triples-request",
|
|
||||||
"service": "triples",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"s": {"t": "i", "i": prov_id},
|
|
||||||
"collection": collection,
|
|
||||||
"limit": 100
|
|
||||||
}
|
|
||||||
}
|
|
||||||
# Add graph filter if specified (for named graph queries)
|
|
||||||
if graph is not None:
|
|
||||||
request["request"]["g"] = graph
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
|
|
||||||
|
|
||||||
triples = []
|
|
||||||
try:
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
|
|
||||||
|
|
||||||
if response.get("id") != "triples-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
# Handle triples response
|
|
||||||
# Response format: {"response": [triples...]}
|
|
||||||
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
|
|
||||||
triple_list = resp.get("response", [])
|
|
||||||
for t in triple_list:
|
|
||||||
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
|
|
||||||
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
|
|
||||||
# Handle quoted triples (type "t") and regular values
|
|
||||||
o_term = t.get("o", {})
|
|
||||||
if o_term.get("t") == "t":
|
|
||||||
# Quoted triple - extract s, p, o from nested structure
|
|
||||||
tr = o_term.get("tr", {})
|
|
||||||
o = {
|
|
||||||
"s": tr.get("s", {}).get("i", ""),
|
|
||||||
"p": tr.get("p", {}).get("i", ""),
|
|
||||||
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
o = o_term.get("i", o_term.get("v", ""))
|
|
||||||
triples.append((s, p, o))
|
|
||||||
|
|
||||||
if resp.get("complete") or response.get("complete"):
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] exception: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
|
|
||||||
|
|
||||||
return triples
|
|
||||||
|
|
||||||
|
|
||||||
async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False):
|
|
||||||
"""Query triples for a provenance node with retries for race condition"""
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug)
|
|
||||||
if triples:
|
|
||||||
return triples
|
|
||||||
# Wait before retry if empty (triples may not be stored yet)
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False):
|
|
||||||
"""
|
|
||||||
Query for provenance of an edge (s, p, o) in the knowledge graph.
|
|
||||||
|
|
||||||
Finds subgraphs that contain the edge via tg:contains, then follows
|
|
||||||
prov:wasDerivedFrom to find source documents.
|
|
||||||
|
|
||||||
Returns list of source URIs (chunks, pages, documents).
|
|
||||||
"""
|
|
||||||
# Query for subgraphs that contain this edge: ?subgraph tg:contains <<s p o>>
|
|
||||||
request = {
|
|
||||||
"id": "edge-prov-request",
|
|
||||||
"service": "triples",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"p": {"t": "i", "i": TG_CONTAINS},
|
|
||||||
"o": {
|
|
||||||
"t": "t", # Quoted triple type
|
|
||||||
"tr": {
|
|
||||||
"s": {"t": "i", "i": edge_s},
|
|
||||||
"p": {"t": "i", "i": edge_p},
|
|
||||||
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"collection": collection,
|
|
||||||
"limit": 10
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
|
|
||||||
|
|
||||||
stmt_uris = []
|
|
||||||
try:
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != "edge-prov-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
triple_list = resp.get("response", [])
|
|
||||||
for t in triple_list:
|
|
||||||
s = t.get("s", {}).get("i", "")
|
|
||||||
if s:
|
|
||||||
stmt_uris.append(s)
|
|
||||||
|
|
||||||
if resp.get("complete") or response.get("complete"):
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
|
|
||||||
|
|
||||||
# For each statement, query wasDerivedFrom to find sources
|
|
||||||
sources = []
|
|
||||||
for stmt_uri in stmt_uris:
|
|
||||||
# Query: stmt_uri prov:wasDerivedFrom ?source
|
|
||||||
request = {
|
|
||||||
"id": "derived-from-request",
|
|
||||||
"service": "triples",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"s": {"t": "i", "i": stmt_uri},
|
|
||||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
|
||||||
"collection": collection,
|
|
||||||
"limit": 10
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != "derived-from-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
triple_list = resp.get("response", [])
|
|
||||||
for t in triple_list:
|
|
||||||
o = t.get("o", {}).get("i", "")
|
|
||||||
if o:
|
|
||||||
sources.append(o)
|
|
||||||
|
|
||||||
if resp.get("complete") or response.get("complete"):
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
|
|
||||||
|
|
||||||
return sources
|
|
||||||
|
|
||||||
|
|
||||||
async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False):
|
|
||||||
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
|
|
||||||
request = {
|
|
||||||
"id": "parent-request",
|
|
||||||
"service": "triples",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"s": {"t": "i", "i": uri},
|
|
||||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
|
||||||
"collection": collection,
|
|
||||||
"limit": 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != "parent-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
triple_list = resp.get("response", [])
|
|
||||||
if triple_list:
|
|
||||||
return triple_list[0].get("o", {}).get("i", None)
|
|
||||||
|
|
||||||
if resp.get("complete") or response.get("complete"):
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False):
|
|
||||||
"""
|
|
||||||
Trace the full provenance chain from a source URI up to the root document.
|
|
||||||
Returns a list of (uri, label) tuples from leaf to root.
|
|
||||||
"""
|
|
||||||
chain = []
|
|
||||||
current = source_uri
|
|
||||||
max_depth = 10 # Prevent infinite loops
|
|
||||||
|
|
||||||
for _ in range(max_depth):
|
|
||||||
if not current:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Get label for current entity
|
|
||||||
label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug)
|
|
||||||
chain.append((current, label))
|
|
||||||
|
|
||||||
# Get parent
|
|
||||||
parent = await _query_derived_from(ws_url, flow_id, current, collection, debug)
|
|
||||||
if not parent or parent == current:
|
|
||||||
break
|
|
||||||
current = parent
|
|
||||||
|
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
def _format_provenance_chain(chain):
|
|
||||||
"""
|
|
||||||
Format a provenance chain as a human-readable string.
|
|
||||||
Chain is [(uri, label), ...] from leaf to root.
|
|
||||||
"""
|
|
||||||
if not chain:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Show labels, from leaf to root
|
|
||||||
labels = [label for uri, label in chain]
|
|
||||||
return " → ".join(labels)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_iri(value):
|
|
||||||
"""Check if a value looks like an IRI."""
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return False
|
|
||||||
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
|
|
||||||
|
|
||||||
|
|
||||||
async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False):
|
|
||||||
"""
|
|
||||||
Query for the rdfs:label of an IRI.
|
|
||||||
Uses label_cache to avoid repeated queries.
|
|
||||||
Returns the label if found, otherwise returns the IRI.
|
|
||||||
"""
|
|
||||||
if not _is_iri(iri):
|
|
||||||
return iri
|
|
||||||
|
|
||||||
# Check cache first
|
|
||||||
if iri in label_cache:
|
|
||||||
return label_cache[iri]
|
|
||||||
|
|
||||||
request = {
|
|
||||||
"id": "label-request",
|
|
||||||
"service": "triples",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"s": {"t": "i", "i": iri},
|
|
||||||
"p": {"t": "i", "i": RDFS_LABEL},
|
|
||||||
"collection": collection,
|
|
||||||
"limit": 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
label = iri # Default to IRI if no label found
|
|
||||||
try:
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != "label-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
triple_list = resp.get("response", [])
|
|
||||||
if triple_list:
|
|
||||||
# Get the label value
|
|
||||||
o = triple_list[0].get("o", {})
|
|
||||||
label = o.get("v", o.get("i", iri))
|
|
||||||
|
|
||||||
if resp.get("complete") or response.get("complete"):
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Cache the result
|
|
||||||
label_cache[iri] = label
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False):
|
|
||||||
"""
|
|
||||||
Resolve labels for all IRI components of an edge triple.
|
|
||||||
Returns (s_label, p_label, o_label).
|
|
||||||
"""
|
|
||||||
s = edge_triple.get("s", "?")
|
|
||||||
p = edge_triple.get("p", "?")
|
|
||||||
o = edge_triple.get("o", "?")
|
|
||||||
|
|
||||||
s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug)
|
|
||||||
p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug)
|
|
||||||
o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug)
|
|
||||||
|
|
||||||
return s_label, p_label, o_label
|
|
||||||
|
|
||||||
|
|
||||||
async def _question_explainable(
|
|
||||||
url, flow_id, question, collection, entity_limit, triple_limit,
|
|
||||||
max_subgraph_size, max_path_length, token=None, debug=False
|
|
||||||
):
|
|
||||||
"""Execute graph RAG with explainability - shows provenance events with details"""
|
|
||||||
# Convert HTTP URL to WebSocket URL
|
|
||||||
if url.startswith("http://"):
|
|
||||||
ws_url = url.replace("http://", "ws://", 1)
|
|
||||||
elif url.startswith("https://"):
|
|
||||||
ws_url = url.replace("https://", "wss://", 1)
|
|
||||||
else:
|
|
||||||
ws_url = f"ws://{url}"
|
|
||||||
|
|
||||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
|
||||||
if token:
|
|
||||||
ws_url = f"{ws_url}?token={token}"
|
|
||||||
|
|
||||||
# Cache for label lookups to avoid repeated queries
|
|
||||||
label_cache = {}
|
|
||||||
|
|
||||||
request = {
|
|
||||||
"id": "cli-request",
|
|
||||||
"service": "graph-rag",
|
|
||||||
"flow": flow_id,
|
|
||||||
"request": {
|
|
||||||
"query": question,
|
|
||||||
"collection": collection,
|
|
||||||
"entity-limit": entity_limit,
|
|
||||||
"triple-limit": triple_limit,
|
|
||||||
"max-subgraph-size": max_subgraph_size,
|
|
||||||
"max-path-length": max_path_length,
|
|
||||||
"streaming": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
async for raw_message in websocket:
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != "cli-request":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
print(f"\nError: {response['error']}", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
if "response" in response:
|
|
||||||
resp = response["response"]
|
|
||||||
|
|
||||||
# Check for errors in response
|
|
||||||
if "error" in resp and resp["error"]:
|
|
||||||
err = resp["error"]
|
|
||||||
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
message_type = resp.get("message_type", "")
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
|
|
||||||
|
|
||||||
if message_type == "explain":
|
|
||||||
# Display explain event with details
|
|
||||||
explain_id = resp.get("explain_id", "")
|
|
||||||
explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval)
|
|
||||||
if explain_id:
|
|
||||||
event_type = _get_event_type(explain_id)
|
|
||||||
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Query triples for this explain node (using named graph filter)
|
|
||||||
triples = await _query_triples(
|
|
||||||
ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format and display details
|
|
||||||
details = _format_provenance_details(event_type, triples)
|
|
||||||
for line in details:
|
|
||||||
print(line, file=sys.stderr)
|
|
||||||
|
|
||||||
# For exploration events, resolve entity labels
|
|
||||||
if event_type == "exploration":
|
|
||||||
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
|
|
||||||
if entity_iris:
|
|
||||||
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
|
|
||||||
for iri in entity_iris:
|
|
||||||
label = await _query_label(
|
|
||||||
ws_url, flow_id, iri, collection,
|
|
||||||
label_cache, debug=debug
|
|
||||||
)
|
|
||||||
print(f" - {label}", file=sys.stderr)
|
|
||||||
|
|
||||||
# For focus events, query each edge selection for details
|
|
||||||
if event_type == "focus":
|
|
||||||
for s, p, o in triples:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
|
|
||||||
if p == TG_SELECTED_EDGE and isinstance(o, str):
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
|
|
||||||
# Query the edge selection entity (using named graph filter)
|
|
||||||
edge_triples = await _query_triples(
|
|
||||||
ws_url, flow_id, o, collection, graph=explain_graph, debug=debug
|
|
||||||
)
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
|
|
||||||
# Extract edge and reasoning
|
|
||||||
edge_triple = None # Store the actual triple for provenance lookup
|
|
||||||
reasoning = None
|
|
||||||
for es, ep, eo in edge_triples:
|
|
||||||
if debug:
|
|
||||||
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
|
|
||||||
if ep == TG_EDGE and isinstance(eo, dict):
|
|
||||||
# eo is a quoted triple dict
|
|
||||||
edge_triple = eo
|
|
||||||
elif ep == TG_REASONING:
|
|
||||||
reasoning = eo
|
|
||||||
if edge_triple:
|
|
||||||
# Resolve labels for edge components
|
|
||||||
s_label, p_label, o_label = await _resolve_edge_labels(
|
|
||||||
ws_url, flow_id, edge_triple, collection,
|
|
||||||
label_cache, debug=debug
|
|
||||||
)
|
|
||||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
|
||||||
if reasoning:
|
|
||||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
|
||||||
print(f" Reason: {r_short}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Trace edge provenance in the workspace collection (not explainability)
|
|
||||||
if edge_triple:
|
|
||||||
sources = await _query_edge_provenance(
|
|
||||||
ws_url, flow_id,
|
|
||||||
edge_triple.get("s", ""),
|
|
||||||
edge_triple.get("p", ""),
|
|
||||||
edge_triple.get("o", ""),
|
|
||||||
collection, # Use the query collection, not explainability
|
|
||||||
debug=debug
|
|
||||||
)
|
|
||||||
if sources:
|
|
||||||
for src in sources:
|
|
||||||
# Trace full chain from source to root document
|
|
||||||
chain = await _trace_provenance_chain(
|
|
||||||
ws_url, flow_id, src, collection,
|
|
||||||
label_cache, debug=debug
|
|
||||||
)
|
|
||||||
chain_str = _format_provenance_chain(chain)
|
|
||||||
print(f" Source: {chain_str}", file=sys.stderr)
|
|
||||||
|
|
||||||
elif message_type == "chunk" or not message_type:
|
|
||||||
# Display response chunk
|
|
||||||
chunk = resp.get("response", "")
|
|
||||||
if chunk:
|
|
||||||
print(chunk, end="", flush=True)
|
|
||||||
|
|
||||||
# Check if session is complete
|
|
||||||
if resp.get("end_of_session"):
|
|
||||||
break
|
|
||||||
|
|
||||||
print() # Final newline
|
|
||||||
|
|
||||||
|
|
||||||
def _question_explainable_api(
|
def _question_explainable_api(
|
||||||
url, flow_id, question_text, collection, entity_limit, triple_limit,
|
url, flow_id, question_text, collection, entity_limit, triple_limit,
|
||||||
max_subgraph_size, max_path_length, edge_score_limit=30,
|
max_subgraph_size, max_path_length, edge_score_limit=30,
|
||||||
|
|
|
||||||
119
trustgraph-cli/trustgraph/cli/put_de_core.py
Normal file
119
trustgraph-cli/trustgraph/cli/put_de_core.py
Normal file
|
|
@ -0,0 +1,119 @@
|
||||||
|
"""
|
||||||
|
Puts a document embeddings core into the knowledge manager via the API
|
||||||
|
socket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from trustgraph.api import Api
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
|
|
||||||
|
def read_message(unpacked, id):
|
||||||
|
|
||||||
|
if unpacked[0] == "de":
|
||||||
|
msg = unpacked[1]
|
||||||
|
return {
|
||||||
|
"metadata": {
|
||||||
|
"id": id,
|
||||||
|
"root": msg["m"]["m"],
|
||||||
|
"collection": "default",
|
||||||
|
},
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"chunk_id": ch["i"],
|
||||||
|
"vector": ch["v"],
|
||||||
|
}
|
||||||
|
for ch in msg["c"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unexpected message type", unpacked[0])
|
||||||
|
|
||||||
|
def put(url, workspace, id, input, token=None):
|
||||||
|
|
||||||
|
api = Api(url=url, token=token, workspace=workspace)
|
||||||
|
socket = api.socket()
|
||||||
|
|
||||||
|
try:
|
||||||
|
de = 0
|
||||||
|
|
||||||
|
with open(input, "rb") as f:
|
||||||
|
|
||||||
|
unpacker = msgpack.Unpacker(f, raw=False)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
try:
|
||||||
|
unpacked = unpacker.unpack()
|
||||||
|
except msgpack.OutOfData:
|
||||||
|
break
|
||||||
|
|
||||||
|
msg = read_message(unpacked, id)
|
||||||
|
de += 1
|
||||||
|
socket.put_de_core(id, document_embeddings=msg)
|
||||||
|
|
||||||
|
print(f"Put: {de} document embeddings messages.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
socket.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='tg-put-de-core',
|
||||||
|
description=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-u', '--url',
|
||||||
|
default=default_url,
|
||||||
|
help=f'API URL (default: {default_url})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-w', '--workspace',
|
||||||
|
default=default_workspace,
|
||||||
|
help=f'Workspace (default: {default_workspace})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--id', '--identifier',
|
||||||
|
required=True,
|
||||||
|
help=f'Document embeddings core ID',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-i', '--input',
|
||||||
|
required=True,
|
||||||
|
help=f'Input file'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=default_token,
|
||||||
|
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
put(
|
||||||
|
url=args.url,
|
||||||
|
workspace=args.workspace,
|
||||||
|
id=args.id,
|
||||||
|
input=args.input,
|
||||||
|
token=args.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from websockets.asyncio.client import connect
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
from trustgraph.api import Api
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
|
|
||||||
|
|
@ -21,13 +19,13 @@ def read_message(unpacked, id):
|
||||||
return "ge", {
|
return "ge", {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": id,
|
"id": id,
|
||||||
"metadata": msg["m"]["m"],
|
"root": msg["m"]["m"],
|
||||||
"collection": "default", # Not used?
|
"collection": "default",
|
||||||
},
|
},
|
||||||
"entities": [
|
"entities": [
|
||||||
{
|
{
|
||||||
"entity": ent["e"],
|
"entity": ent["e"],
|
||||||
"vectors": ent["v"],
|
"vector": ent["v"],
|
||||||
}
|
}
|
||||||
for ent in msg["e"]
|
for ent in msg["e"]
|
||||||
],
|
],
|
||||||
|
|
@ -37,26 +35,20 @@ def read_message(unpacked, id):
|
||||||
return "t", {
|
return "t", {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": id,
|
"id": id,
|
||||||
"metadata": msg["m"]["m"],
|
"root": msg["m"]["m"],
|
||||||
"collection": "default", # Not used by receiver?
|
"collection": "default",
|
||||||
},
|
},
|
||||||
"triples": msg["t"],
|
"triples": msg["t"],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
|
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
|
||||||
|
|
||||||
async def put(url, workspace, id, input, token=None):
|
def put(url, workspace, id, input, token=None):
|
||||||
|
|
||||||
if not url.endswith("/"):
|
api = Api(url=url, token=token, workspace=workspace)
|
||||||
url += "/"
|
socket = api.socket()
|
||||||
|
|
||||||
url = url + "api/v1/socket"
|
|
||||||
|
|
||||||
if token:
|
|
||||||
url = f"{url}?token={token}"
|
|
||||||
|
|
||||||
async with connect(url) as ws:
|
|
||||||
|
|
||||||
|
try:
|
||||||
ge = 0
|
ge = 0
|
||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
|
|
@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unpacked = unpacker.unpack()
|
unpacked = unpacker.unpack()
|
||||||
except:
|
except msgpack.OutOfData:
|
||||||
break
|
break
|
||||||
|
|
||||||
kind, msg = read_message(unpacked, id)
|
kind, msg = read_message(unpacked, id)
|
||||||
|
|
||||||
mid = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if kind == "ge":
|
if kind == "ge":
|
||||||
|
|
||||||
ge += 1
|
ge += 1
|
||||||
|
socket.put_kg_core(id, graph_embeddings=msg)
|
||||||
req = json.dumps({
|
|
||||||
"id": mid,
|
|
||||||
"workspace": workspace,
|
|
||||||
"service": "knowledge",
|
|
||||||
"request": {
|
|
||||||
"operation": "put-kg-core",
|
|
||||||
"workspace": workspace,
|
|
||||||
"id": id,
|
|
||||||
"graph-embeddings": msg
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
elif kind == "t":
|
elif kind == "t":
|
||||||
|
|
||||||
t += 1
|
t += 1
|
||||||
|
socket.put_kg_core(id, triples=msg)
|
||||||
req = json.dumps({
|
|
||||||
"id": mid,
|
|
||||||
"workspace": workspace,
|
|
||||||
"service": "knowledge",
|
|
||||||
"request": {
|
|
||||||
"operation": "put-kg-core",
|
|
||||||
"workspace": workspace,
|
|
||||||
"id": id,
|
|
||||||
"triples": msg
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
raise RuntimeError("Unexpected message kind", kind)
|
raise RuntimeError("Unexpected message kind", kind)
|
||||||
|
|
||||||
await ws.send(req)
|
|
||||||
|
|
||||||
# Retry loop, wait for right response to come back
|
|
||||||
while True:
|
|
||||||
|
|
||||||
msg = await ws.recv()
|
|
||||||
msg = json.loads(msg)
|
|
||||||
|
|
||||||
if msg["id"] != mid:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "response" in msg:
|
|
||||||
if "error" in msg["response"]:
|
|
||||||
raise RuntimeError(msg["response"]["error"])
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"Put: {t} triple, {ge} GE messages.")
|
print(f"Put: {t} triple, {ge} GE messages.")
|
||||||
|
|
||||||
await ws.close()
|
finally:
|
||||||
|
socket.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
|
@ -173,7 +122,6 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
asyncio.run(
|
|
||||||
put(
|
put(
|
||||||
url=args.url,
|
url=args.url,
|
||||||
workspace=args.workspace,
|
workspace=args.workspace,
|
||||||
|
|
@ -181,7 +129,6 @@ def main():
|
||||||
input=args.input,
|
input=args.input,
|
||||||
token=args.token,
|
token=args.token,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
|
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
|
||||||
|
from .. schema import DocumentEmbeddings
|
||||||
from .. knowledge import hash
|
from .. knowledge import hash
|
||||||
from .. exceptions import RequestError
|
from .. exceptions import RequestError
|
||||||
from .. tables.knowledge import KnowledgeTableStore
|
from .. tables.knowledge import KnowledgeTableStore
|
||||||
|
|
@ -157,6 +158,98 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def list_de_cores(self, request, respond, workspace):
|
||||||
|
|
||||||
|
ids = await self.table_store.list_de_cores(workspace)
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = ids,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_de_core(self, request, respond, workspace):
|
||||||
|
|
||||||
|
logger.info("Getting document embeddings core...")
|
||||||
|
|
||||||
|
async def publish_de(de):
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None,
|
||||||
|
document_embeddings = de,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.table_store.get_document_embeddings(
|
||||||
|
workspace,
|
||||||
|
request.id,
|
||||||
|
publish_de,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Document embeddings core retrieval complete")
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = True,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def put_de_core(self, request, respond, workspace):
|
||||||
|
|
||||||
|
if request.document_embeddings:
|
||||||
|
await self.table_store.add_document_embeddings(
|
||||||
|
workspace, request.document_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_de_core(self, request, respond, workspace):
|
||||||
|
|
||||||
|
logger.info("Deleting document embeddings core...")
|
||||||
|
|
||||||
|
await self.table_store.delete_document_embeddings(
|
||||||
|
workspace, request.id
|
||||||
|
)
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_de_core(self, request, respond, workspace):
|
||||||
|
|
||||||
|
if self.background_task is None:
|
||||||
|
self.background_task = asyncio.create_task(
|
||||||
|
self.core_loader()
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.loader_queue.put((request, respond, workspace))
|
||||||
|
|
||||||
async def core_loader(self):
|
async def core_loader(self):
|
||||||
|
|
||||||
logger.info("Knowledge background processor running...")
|
logger.info("Knowledge background processor running...")
|
||||||
|
|
@ -165,7 +258,7 @@ class KnowledgeManager:
|
||||||
logger.debug("Waiting for next load...")
|
logger.debug("Waiting for next load...")
|
||||||
request, respond, workspace = await self.loader_queue.get()
|
request, respond, workspace = await self.loader_queue.get()
|
||||||
|
|
||||||
logger.info(f"Loading knowledge: {request.id}")
|
logger.info(f"Loading: {request.operation} {request.id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -187,24 +280,13 @@ class KnowledgeManager:
|
||||||
if "interfaces" not in flow:
|
if "interfaces" not in flow:
|
||||||
raise RuntimeError("No defined interfaces")
|
raise RuntimeError("No defined interfaces")
|
||||||
|
|
||||||
if "triples-store" not in flow["interfaces"]:
|
if request.operation == "load-de-core":
|
||||||
raise RuntimeError("Flow has no triples-store")
|
await self._load_de_core(
|
||||||
|
request, respond, workspace, flow,
|
||||||
if "graph-embeddings-store" not in flow["interfaces"]:
|
|
||||||
raise RuntimeError("Flow has no graph-embeddings-store")
|
|
||||||
|
|
||||||
t_q = flow["interfaces"]["triples-store"]["flow"]
|
|
||||||
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
|
|
||||||
|
|
||||||
# Got this far, it should all work
|
|
||||||
await respond(
|
|
||||||
KnowledgeResponse(
|
|
||||||
error = None,
|
|
||||||
ids = None,
|
|
||||||
eos = False,
|
|
||||||
triples = None,
|
|
||||||
graph_embeddings = None
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
await self._load_kg_core(
|
||||||
|
request, respond, workspace, flow,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -223,14 +305,36 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("Knowledge processing done")
|
||||||
|
|
||||||
logger.debug("Starting knowledge loading process...")
|
continue
|
||||||
|
|
||||||
try:
|
async def _load_kg_core(self, request, respond, workspace, flow):
|
||||||
|
|
||||||
|
if "triples-store" not in flow["interfaces"]:
|
||||||
|
raise RuntimeError("Flow has no triples-store")
|
||||||
|
|
||||||
|
if "graph-embeddings-store" not in flow["interfaces"]:
|
||||||
|
raise RuntimeError("Flow has no graph-embeddings-store")
|
||||||
|
|
||||||
|
t_q = flow["interfaces"]["triples-store"]["flow"]
|
||||||
|
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
t_pub = None
|
t_pub = None
|
||||||
ge_pub = None
|
ge_pub = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
logger.debug(f"Triples queue: {t_q}")
|
logger.debug(f"Triples queue: {t_q}")
|
||||||
logger.debug(f"Graph embeddings queue: {ge_q}")
|
logger.debug(f"Graph embeddings queue: {ge_q}")
|
||||||
|
|
||||||
|
|
@ -249,7 +353,6 @@ class KnowledgeManager:
|
||||||
await ge_pub.start()
|
await ge_pub.start()
|
||||||
|
|
||||||
async def publish_triples(t):
|
async def publish_triples(t):
|
||||||
# Override collection with request collection
|
|
||||||
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
|
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
|
||||||
t.metadata.collection = request.collection or "default"
|
t.metadata.collection = request.collection or "default"
|
||||||
await t_pub.send(None, t)
|
await t_pub.send(None, t)
|
||||||
|
|
@ -263,7 +366,6 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def publish_ge(g):
|
async def publish_ge(g):
|
||||||
# Override collection with request collection
|
|
||||||
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
|
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
|
||||||
g.metadata.collection = request.collection or "default"
|
g.metadata.collection = request.collection or "default"
|
||||||
await ge_pub.send(None, g)
|
await ge_pub.send(None, g)
|
||||||
|
|
@ -276,7 +378,7 @@ class KnowledgeManager:
|
||||||
publish_ge,
|
publish_ge,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Knowledge loading completed")
|
logger.debug("Knowledge core loading completed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
@ -289,6 +391,59 @@ class KnowledgeManager:
|
||||||
if t_pub: await t_pub.stop()
|
if t_pub: await t_pub.stop()
|
||||||
if ge_pub: await ge_pub.stop()
|
if ge_pub: await ge_pub.stop()
|
||||||
|
|
||||||
logger.debug("Knowledge processing done")
|
async def _load_de_core(self, request, respond, workspace, flow):
|
||||||
|
|
||||||
continue
|
if "document-embeddings-store" not in flow["interfaces"]:
|
||||||
|
raise RuntimeError("Flow has no document-embeddings-store")
|
||||||
|
|
||||||
|
de_q = flow["interfaces"]["document-embeddings-store"]["flow"]
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
error = None,
|
||||||
|
ids = None,
|
||||||
|
eos = False,
|
||||||
|
triples = None,
|
||||||
|
graph_embeddings = None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
de_pub = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
logger.debug(f"Document embeddings queue: {de_q}")
|
||||||
|
|
||||||
|
de_pub = Publisher(
|
||||||
|
self.flow_config.pubsub, de_q,
|
||||||
|
schema=DocumentEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Starting publisher...")
|
||||||
|
|
||||||
|
await de_pub.start()
|
||||||
|
|
||||||
|
async def publish_de(de):
|
||||||
|
if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'):
|
||||||
|
de.metadata.collection = request.collection or "default"
|
||||||
|
await de_pub.send(None, de)
|
||||||
|
|
||||||
|
logger.debug("Publishing document embeddings...")
|
||||||
|
|
||||||
|
await self.table_store.get_document_embeddings(
|
||||||
|
workspace,
|
||||||
|
request.id,
|
||||||
|
publish_de,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Document embeddings core loading completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
logger.error(f"Knowledge exception: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
|
||||||
|
logger.debug("Stopping publisher...")
|
||||||
|
|
||||||
|
if de_pub: await de_pub.stop()
|
||||||
|
|
|
||||||
|
|
@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor):
|
||||||
"put-kg-core": self.knowledge.put_kg_core,
|
"put-kg-core": self.knowledge.put_kg_core,
|
||||||
"load-kg-core": self.knowledge.load_kg_core,
|
"load-kg-core": self.knowledge.load_kg_core,
|
||||||
"unload-kg-core": self.knowledge.unload_kg_core,
|
"unload-kg-core": self.knowledge.unload_kg_core,
|
||||||
|
"list-de-cores": self.knowledge.list_de_cores,
|
||||||
|
"get-de-core": self.knowledge.get_de_core,
|
||||||
|
"delete-de-core": self.knowledge.delete_de_core,
|
||||||
|
"put-de-core": self.knowledge.put_de_core,
|
||||||
|
"load-de-core": self.knowledge.load_de_core,
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.operation not in impls:
|
if v.operation not in impls:
|
||||||
|
|
|
||||||
|
|
@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core",
|
||||||
"load-kg-core", "unload-kg-core"):
|
"load-kg-core", "unload-kg-core"):
|
||||||
_register_kind_op("knowledge", _op, "knowledge:write")
|
_register_kind_op("knowledge", _op, "knowledge:write")
|
||||||
|
|
||||||
|
# knowledge: document-embeddings core service.
|
||||||
|
for _op in ("get-de-core", "list-de-cores"):
|
||||||
|
_register_kind_op("knowledge", _op, "knowledge:read")
|
||||||
|
for _op in ("put-de-core", "delete-de-core", "load-de-core"):
|
||||||
|
_register_kind_op("knowledge", _op, "knowledge:write")
|
||||||
|
|
||||||
|
|
||||||
# collection-management: workspace collection lifecycle.
|
# collection-management: workspace collection lifecycle.
|
||||||
_register_kind_op("collection-management", "list-collections", "collections:read")
|
_register_kind_op("collection-management", "list-collections", "collections:read")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
|
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
|
||||||
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
|
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
|
||||||
|
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
|
|
@ -217,6 +218,16 @@ class KnowledgeTableStore:
|
||||||
WHERE workspace = ? AND document_id = ?
|
WHERE workspace = ? AND document_id = ?
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
self.delete_document_embeddings_stmt = self.cassandra.prepare("""
|
||||||
|
DELETE FROM document_embeddings
|
||||||
|
WHERE workspace = ? AND document_id = ?
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.list_de_cores_stmt = self.cassandra.prepare("""
|
||||||
|
SELECT DISTINCT workspace, document_id FROM document_embeddings
|
||||||
|
WHERE workspace = ?
|
||||||
|
""")
|
||||||
|
|
||||||
async def add_triples(self, workspace, m):
|
async def add_triples(self, workspace, m):
|
||||||
|
|
||||||
when = int(time.time() * 1000)
|
when = int(time.time() * 1000)
|
||||||
|
|
@ -338,6 +349,50 @@ class KnowledgeTableStore:
|
||||||
logger.error("Exception occurred", exc_info=True)
|
logger.error("Exception occurred", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
await async_execute(
|
||||||
|
self.cassandra,
|
||||||
|
self.delete_document_embeddings_stmt,
|
||||||
|
(workspace, document_id),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception occurred", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_document_embeddings(self, workspace, document_id):
|
||||||
|
|
||||||
|
logger.debug("Delete document embeddings...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await async_execute(
|
||||||
|
self.cassandra,
|
||||||
|
self.delete_document_embeddings_stmt,
|
||||||
|
(workspace, document_id),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception occurred", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def list_de_cores(self, workspace):
|
||||||
|
|
||||||
|
logger.debug("List DE cores...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = await async_execute(
|
||||||
|
self.cassandra,
|
||||||
|
self.list_de_cores_stmt,
|
||||||
|
(workspace,),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception occurred", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
lst = [row[1] for row in rows]
|
||||||
|
|
||||||
|
logger.debug("Done")
|
||||||
|
|
||||||
|
return lst
|
||||||
|
|
||||||
async def get_triples(self, workspace, document_id, receiver):
|
async def get_triples(self, workspace, document_id, receiver):
|
||||||
|
|
||||||
logger.debug("Get triples...")
|
logger.debug("Get triples...")
|
||||||
|
|
@ -417,3 +472,42 @@ class KnowledgeTableStore:
|
||||||
|
|
||||||
logger.debug("Done")
|
logger.debug("Done")
|
||||||
|
|
||||||
|
async def get_document_embeddings(self, workspace, document_id, receiver):
|
||||||
|
|
||||||
|
logger.debug("Get DE...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = await async_execute(
|
||||||
|
self.cassandra,
|
||||||
|
self.get_document_embeddings_stmt,
|
||||||
|
(workspace, document_id),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception occurred", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
|
||||||
|
if row[3]:
|
||||||
|
chunks = [
|
||||||
|
ChunkEmbeddings(
|
||||||
|
chunk_id=ch[0],
|
||||||
|
vector=ch[1],
|
||||||
|
)
|
||||||
|
for ch in row[3]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
await receiver(
|
||||||
|
DocumentEmbeddings(
|
||||||
|
metadata = Metadata(
|
||||||
|
id = document_id,
|
||||||
|
collection = "default",
|
||||||
|
),
|
||||||
|
chunks = chunks
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Done")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue