mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 01:16:22 +02:00
Compare commits
38 commits
c737e8c356
...
dc72ed3cca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc72ed3cca | ||
|
|
e899370d98 | ||
|
|
c20e6540ec | ||
|
|
ddd4bd7790 | ||
|
|
2f8d6a3ffb | ||
|
|
f0c9039b76 | ||
|
|
4acd853023 | ||
|
|
d4723566cb | ||
|
|
10a931f04c | ||
|
|
ee65d90fdd | ||
|
|
d9dc4cbab5 | ||
|
|
62c30a3a50 | ||
|
|
24f0190ce7 | ||
|
|
4fb0b4d8e8 | ||
|
|
dbf8daa74a | ||
|
|
3ba6a3238f | ||
|
|
2bcf375103 | ||
|
|
153ae9ad30 | ||
|
|
89e13a756a | ||
|
|
816a8cfcf6 | ||
|
|
7b734148b3 | ||
|
|
e65ea217a2 | ||
|
|
81ca7bbc11 | ||
|
|
0781d3e6a7 | ||
|
|
849987f0e6 | ||
|
|
7af1d60db8 | ||
|
|
5a9db2da50 | ||
|
|
687a9e08fe | ||
|
|
413f917676 | ||
|
|
20204d87c3 | ||
|
|
a634520509 | ||
|
|
ea33620fb2 | ||
|
|
9c55a0a0ff | ||
|
|
1ec081f42f | ||
|
|
f02bbdb442 | ||
|
|
4164ef1c47 | ||
|
|
97f5645ea0 | ||
|
|
1f67fc2312 |
260 changed files with 16757 additions and 4051 deletions
4
Makefile
4
Makefile
|
|
@ -77,8 +77,8 @@ some-containers:
|
|||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||
|
|
|
|||
237
dev-tools/library_client.py
Normal file
237
dev-tools/library_client.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Client utility for browsing and loading documents from the TrustGraph
|
||||
public document library.
|
||||
|
||||
Usage:
|
||||
python library_client.py list
|
||||
python library_client.py search <text>
|
||||
python library_client.py load-all
|
||||
python library_client.py load-doc <id>
|
||||
python library_client.py load-match <text>
|
||||
"""
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from trustgraph.api import Api
|
||||
from trustgraph.api.types import Uri, Literal, Triple
|
||||
|
||||
BUCKET_URL = "https://storage.googleapis.com/trustgraph-library"
|
||||
INDEX_URL = f"{BUCKET_URL}/index.json"
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
default_user = "trustgraph"
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
|
||||
def fetch_index():
|
||||
with urllib.request.urlopen(INDEX_URL) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def fetch_document_metadata(doc_id):
|
||||
url = f"{BUCKET_URL}/{doc_id}.json"
|
||||
with urllib.request.urlopen(url) as resp:
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
def fetch_document_content(doc_id):
|
||||
url = f"{BUCKET_URL}/{doc_id}.epub"
|
||||
with urllib.request.urlopen(url) as resp:
|
||||
return resp.read()
|
||||
|
||||
|
||||
def search_index(index, query):
|
||||
query = query.lower()
|
||||
results = []
|
||||
for doc in index:
|
||||
title = doc.get("title", "").lower()
|
||||
comments = doc.get("comments", "").lower()
|
||||
tags = [t.lower() for t in doc.get("tags", [])]
|
||||
if (query in title or query in comments or
|
||||
any(query in t for t in tags)):
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
|
||||
def print_index(index):
|
||||
if not index:
|
||||
return
|
||||
|
||||
# Calculate column widths
|
||||
id_width = max(len(str(doc.get("id", ""))) for doc in index)
|
||||
title_width = max(len(doc.get("title", "")) for doc in index)
|
||||
|
||||
# Cap title width for readability
|
||||
title_width = min(title_width, 60)
|
||||
id_width = max(id_width, 2)
|
||||
|
||||
try:
|
||||
term_width = os.get_terminal_size().columns
|
||||
except OSError:
|
||||
term_width = 120
|
||||
|
||||
tags_width = max(term_width - id_width - title_width - 6, 20)
|
||||
|
||||
header = f"{'ID':<{id_width}} {'Title':<{title_width}} {'Tags':<{tags_width}}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for doc in index:
|
||||
eid = str(doc.get("id", ""))
|
||||
title = doc.get("title", "")
|
||||
if len(title) > title_width:
|
||||
title = title[:title_width - 3] + "..."
|
||||
tags = ", ".join(doc.get("tags", []))
|
||||
if len(tags) > tags_width:
|
||||
tags = tags[:tags_width - 3] + "..."
|
||||
print(f"{eid:<{id_width}} {title:<{title_width}} {tags}")
|
||||
|
||||
|
||||
def convert_value(v):
|
||||
"""Convert a JSON triple value to a Uri or Literal."""
|
||||
if v["type"] == "uri":
|
||||
return Uri(v["value"])
|
||||
else:
|
||||
return Literal(v["value"])
|
||||
|
||||
|
||||
def convert_metadata(metadata_json):
|
||||
"""Convert JSON metadata triples to Triple objects."""
|
||||
triples = []
|
||||
for t in metadata_json:
|
||||
triples.append(Triple(
|
||||
s=convert_value(t["s"]),
|
||||
p=convert_value(t["p"]),
|
||||
o=convert_value(t["o"]),
|
||||
))
|
||||
return triples
|
||||
|
||||
|
||||
def load_document(api, user, doc_entry):
|
||||
"""Fetch metadata and content for a document, then load into TrustGraph."""
|
||||
doc_id = doc_entry["id"]
|
||||
title = doc_entry["title"]
|
||||
|
||||
print(f" [{doc_id}] {title}")
|
||||
|
||||
print(f" fetching metadata...")
|
||||
doc_json = fetch_document_metadata(doc_id)
|
||||
doc = doc_json[0]
|
||||
|
||||
print(f" fetching content...")
|
||||
content = fetch_document_content(doc_id)
|
||||
|
||||
print(f" loading into TrustGraph ({len(content) // 1024}KB)...")
|
||||
metadata = convert_metadata(doc["metadata"])
|
||||
|
||||
api.add_document(
|
||||
id=doc["id"],
|
||||
metadata=metadata,
|
||||
user=user,
|
||||
kind=doc["kind"],
|
||||
title=doc["title"],
|
||||
comments=doc["comments"],
|
||||
tags=doc["tags"],
|
||||
document=content,
|
||||
)
|
||||
|
||||
print(f" done.")
|
||||
|
||||
|
||||
def load_documents(api, user, docs):
|
||||
"""Load a list of documents."""
|
||||
print(f"Loading {len(docs)} document(s)...\n")
|
||||
for doc in docs:
|
||||
try:
|
||||
load_document(api, user, doc)
|
||||
except Exception as e:
|
||||
print(f" FAILED: {e}", file=sys.stderr)
|
||||
print()
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Browse and load documents from the TrustGraph public document library.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--url", default=default_url,
|
||||
help=f"TrustGraph API URL (default: {default_url})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-U", "--user", default=default_user,
|
||||
help=f"User ID (default: {default_user})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--token", default=default_token,
|
||||
help="Authentication token (default: $TRUSTGRAPH_TOKEN)",
|
||||
)
|
||||
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
sub.add_parser("list", help="List all documents")
|
||||
|
||||
search_parser = sub.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("query", help="Text to search for")
|
||||
|
||||
sub.add_parser("load-all", help="Load all documents into TrustGraph")
|
||||
|
||||
load_doc_parser = sub.add_parser("load-doc", help="Load a document by ID")
|
||||
load_doc_parser.add_argument("id", help="Document ID (ebook number)")
|
||||
|
||||
load_match_parser = sub.add_parser(
|
||||
"load-match", help="Load all documents matching a search term",
|
||||
)
|
||||
load_match_parser.add_argument("query", help="Text to search for")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
index = fetch_index()
|
||||
|
||||
if args.command in ("list", "search"):
|
||||
if args.command == "list":
|
||||
print_index(index)
|
||||
else:
|
||||
results = search_index(index, args.query)
|
||||
if results:
|
||||
print_index(results)
|
||||
else:
|
||||
print("No matches found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return
|
||||
|
||||
# Load commands need the API
|
||||
api = Api(args.url, token=args.token).library()
|
||||
|
||||
if args.command == "load-all":
|
||||
load_documents(api, args.user, index)
|
||||
|
||||
elif args.command == "load-doc":
|
||||
matches = [d for d in index if str(d.get("id")) == args.id]
|
||||
if not matches:
|
||||
print(f"No document with ID '{args.id}' found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
load_documents(api, args.user, matches)
|
||||
|
||||
elif args.command == "load-match":
|
||||
results = search_index(index, args.query)
|
||||
if results:
|
||||
load_documents(api, args.user, results)
|
||||
else:
|
||||
print("No matches found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
319
dev-tools/tests/agent_dag/analyse_trace.py
Normal file
319
dev-tools/tests/agent_dag/analyse_trace.py
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyse a captured agent trace JSON file and check DAG integrity.
|
||||
|
||||
Usage:
|
||||
python analyse_trace.py react.json
|
||||
python analyse_trace.py -u http://localhost:8088/ react.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import websockets
|
||||
|
||||
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
DEFAULT_USER = "trustgraph"
|
||||
DEFAULT_COLLECTION = "default"
|
||||
DEFAULT_FLOW = "default"
|
||||
GRAPH = "urn:graph:retrieval"
|
||||
|
||||
# Namespace prefixes
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
RDF_TYPE = RDF + "type"
|
||||
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_TOOL_USE = TG + "ToolUse"
|
||||
TG_OBSERVATION_TYPE = TG + "Observation"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_QUESTION = TG + "Question"
|
||||
|
||||
|
||||
def shorten(uri):
|
||||
"""Shorten a URI for display."""
|
||||
for prefix, short in [
|
||||
(PROV, "prov:"), (RDF, "rdf:"), (RDFS, "rdfs:"), (TG, "tg:"),
|
||||
]:
|
||||
if isinstance(uri, str) and uri.startswith(prefix):
|
||||
return short + uri[len(prefix):]
|
||||
return str(uri)
|
||||
|
||||
|
||||
async def fetch_triples(ws, flow, subject, user, collection, request_counter):
|
||||
"""Query triples for a given subject URI."""
|
||||
request_counter[0] += 1
|
||||
req_id = f"q-{request_counter[0]}"
|
||||
|
||||
msg = {
|
||||
"id": req_id,
|
||||
"service": "triples",
|
||||
"flow": flow,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": subject},
|
||||
"g": GRAPH,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 100,
|
||||
},
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(msg))
|
||||
|
||||
while True:
|
||||
raw = await ws.recv()
|
||||
resp = json.loads(raw)
|
||||
if resp.get("id") == req_id:
|
||||
inner = resp.get("response", {})
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response", [])
|
||||
return inner
|
||||
|
||||
|
||||
def extract_term(term):
|
||||
"""Extract value from wire-format term."""
|
||||
if not term:
|
||||
return ""
|
||||
t = term.get("t", "")
|
||||
if t == "i":
|
||||
return term.get("i", "")
|
||||
elif t == "l":
|
||||
return term.get("v", "")
|
||||
elif t == "t":
|
||||
tr = term.get("tr", {})
|
||||
return {
|
||||
"s": extract_term(tr.get("s", {})),
|
||||
"p": extract_term(tr.get("p", {})),
|
||||
"o": extract_term(tr.get("o", {})),
|
||||
}
|
||||
return str(term)
|
||||
|
||||
|
||||
def parse_triples(wire_triples):
|
||||
"""Convert wire triples to (s, p, o) tuples."""
|
||||
result = []
|
||||
for t in wire_triples:
|
||||
s = extract_term(t.get("s", {}))
|
||||
p = extract_term(t.get("p", {}))
|
||||
o = extract_term(t.get("o", {}))
|
||||
result.append((s, p, o))
|
||||
return result
|
||||
|
||||
|
||||
def get_types(tuples):
|
||||
"""Get rdf:type values from parsed triples."""
|
||||
return {o for s, p, o in tuples if p == RDF_TYPE}
|
||||
|
||||
|
||||
def get_derived_from(tuples):
|
||||
"""Get prov:wasDerivedFrom targets from parsed triples."""
|
||||
return [o for s, p, o in tuples if p == PROV_WAS_DERIVED_FROM]
|
||||
|
||||
|
||||
async def analyse(path, url, flow, user, collection):
|
||||
with open(path) as f:
|
||||
messages = json.load(f)
|
||||
|
||||
print(f"Total messages: {len(messages)}")
|
||||
print()
|
||||
|
||||
# ---- Pass 1: collect explain IDs and check streaming chunks ----
|
||||
|
||||
explain_ids = []
|
||||
errors = []
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
resp = msg.get("response", {})
|
||||
chunk_type = resp.get("chunk_type", "?")
|
||||
|
||||
if chunk_type == "explain":
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_ids.append(explain_id)
|
||||
print(f" {i:3d} {chunk_type} {explain_id}")
|
||||
else:
|
||||
print(f" {i:3d} {chunk_type}")
|
||||
|
||||
# Rule 7: message_id on content chunks
|
||||
if chunk_type in ("thought", "observation", "answer"):
|
||||
mid = resp.get("message_id", "")
|
||||
if not mid:
|
||||
errors.append(
|
||||
f"[msg {i}] {chunk_type} chunk missing message_id"
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Explain IDs ({len(explain_ids)}):")
|
||||
for eid in explain_ids:
|
||||
print(f" {eid}")
|
||||
|
||||
# ---- Pass 2: fetch triples for each explain ID ----
|
||||
|
||||
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
|
||||
request_counter = [0]
|
||||
# entity_id -> parsed triples [(s, p, o), ...]
|
||||
entities = {}
|
||||
|
||||
print()
|
||||
print("Fetching triples...")
|
||||
print()
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as ws:
|
||||
for eid in explain_ids:
|
||||
wire = await fetch_triples(
|
||||
ws, flow, eid, user, collection, request_counter,
|
||||
)
|
||||
|
||||
tuples = parse_triples(wire) if isinstance(wire, list) else []
|
||||
entities[eid] = tuples
|
||||
|
||||
print(f" {eid}")
|
||||
for s, p, o in tuples:
|
||||
o_short = str(o)
|
||||
if len(o_short) > 80:
|
||||
o_short = o_short[:77] + "..."
|
||||
print(f" {shorten(p)} = {o_short}")
|
||||
print()
|
||||
|
||||
# ---- Pass 3: check rules ----
|
||||
|
||||
all_ids = set(entities.keys())
|
||||
|
||||
# Collect entity metadata
|
||||
roots = [] # entities with no wasDerivedFrom
|
||||
conclusions = [] # tg:Conclusion entities
|
||||
analyses = [] # tg:Analysis entities
|
||||
observations = [] # tg:Observation entities
|
||||
|
||||
for eid, tuples in entities.items():
|
||||
types = get_types(tuples)
|
||||
parents = get_derived_from(tuples)
|
||||
|
||||
if not tuples:
|
||||
errors.append(f"[{eid}] entity has no triples in store")
|
||||
|
||||
if not parents:
|
||||
roots.append(eid)
|
||||
|
||||
if TG_CONCLUSION in types:
|
||||
conclusions.append(eid)
|
||||
if TG_ANALYSIS in types:
|
||||
analyses.append(eid)
|
||||
if TG_OBSERVATION_TYPE in types:
|
||||
observations.append(eid)
|
||||
|
||||
# Rule 4: every non-root entity has wasDerivedFrom
|
||||
if parents:
|
||||
for parent in parents:
|
||||
# Rule 5: parent exists in known entities
|
||||
if parent not in all_ids:
|
||||
errors.append(
|
||||
f"[{eid}] wasDerivedFrom target not in explain set: "
|
||||
f"{parent}"
|
||||
)
|
||||
|
||||
# Rule 6: Analysis entities must have ToolUse type
|
||||
if TG_ANALYSIS in types and TG_TOOL_USE not in types:
|
||||
errors.append(
|
||||
f"[{eid}] Analysis entity missing tg:ToolUse type"
|
||||
)
|
||||
|
||||
# Rule 1: exactly one root
|
||||
if len(roots) == 0:
|
||||
errors.append("No root entity found (all have wasDerivedFrom)")
|
||||
elif len(roots) > 1:
|
||||
errors.append(
|
||||
f"Multiple roots ({len(roots)}) — expected exactly 1:"
|
||||
)
|
||||
for r in roots:
|
||||
types = get_types(entities[r])
|
||||
type_labels = ", ".join(shorten(t) for t in types)
|
||||
errors.append(f" root: {r} [{type_labels}]")
|
||||
|
||||
# Rule 2: exactly one terminal node (nothing derives from it)
|
||||
# Build set of entities that are parents of something
|
||||
has_children = set()
|
||||
for eid, tuples in entities.items():
|
||||
for parent in get_derived_from(tuples):
|
||||
has_children.add(parent)
|
||||
|
||||
terminals = [eid for eid in all_ids if eid not in has_children]
|
||||
if len(terminals) == 0:
|
||||
errors.append("No terminal entity found (cycle?)")
|
||||
elif len(terminals) > 1:
|
||||
errors.append(
|
||||
f"Multiple terminal entities ({len(terminals)}) — expected exactly 1:"
|
||||
)
|
||||
for t in terminals:
|
||||
types = get_types(entities[t])
|
||||
type_labels = ", ".join(shorten(ty) for ty in types)
|
||||
errors.append(f" terminal: {t} [{type_labels}]")
|
||||
|
||||
# Rule 8: Observation should not derive from Analysis if a sub-trace
|
||||
# exists as a sibling. Check: if an Analysis has both a Question child
|
||||
# and an Observation child, the Observation should derive from the
|
||||
# sub-trace's Synthesis, not from the Analysis.
|
||||
for obs_id in observations:
|
||||
obs_parents = get_derived_from(entities[obs_id])
|
||||
for parent in obs_parents:
|
||||
if parent in entities:
|
||||
parent_types = get_types(entities[parent])
|
||||
if TG_ANALYSIS in parent_types:
|
||||
# Check if this Analysis also has a Question child
|
||||
# (i.e. a sub-trace exists)
|
||||
has_subtrace = False
|
||||
for other_id, other_tuples in entities.items():
|
||||
if other_id == obs_id:
|
||||
continue
|
||||
other_parents = get_derived_from(other_tuples)
|
||||
other_types = get_types(other_tuples)
|
||||
if (parent in other_parents
|
||||
and TG_QUESTION in other_types):
|
||||
has_subtrace = True
|
||||
break
|
||||
if has_subtrace:
|
||||
errors.append(
|
||||
f"[{obs_id}] Observation derives from Analysis "
|
||||
f"{parent} which has a sub-trace — should derive "
|
||||
f"from the sub-trace's Synthesis instead"
|
||||
)
|
||||
|
||||
# ---- Report ----
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
if errors:
|
||||
print(f"ERRORS ({len(errors)}):")
|
||||
print()
|
||||
for err in errors:
|
||||
print(f" !! {err}")
|
||||
else:
|
||||
print("ALL CHECKS PASSED")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input", help="JSON trace file")
|
||||
parser.add_argument("-u", "--url", default=DEFAULT_URL)
|
||||
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
|
||||
parser.add_argument("-U", "--user", default=DEFAULT_USER)
|
||||
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(analyse(
|
||||
args.input, args.url, args.flow,
|
||||
args.user, args.collection,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
dev-tools/tests/agent_dag/ws_capture.py
Normal file
81
dev-tools/tests/agent_dag/ws_capture.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Connect to TrustGraph websocket, run an agent query, capture all
|
||||
response messages to a JSON file.
|
||||
|
||||
Usage:
|
||||
python ws_capture.py -q "What is the document about?" -o trace.json
|
||||
python ws_capture.py -q "..." -u http://localhost:8088/ -o out.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import websockets
|
||||
|
||||
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
DEFAULT_USER = "trustgraph"
|
||||
DEFAULT_COLLECTION = "default"
|
||||
DEFAULT_FLOW = "default"
|
||||
|
||||
|
||||
async def capture(url, flow, question, user, collection, output):
|
||||
|
||||
# Convert to ws URL
|
||||
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=120) as ws:
|
||||
|
||||
request = {
|
||||
"id": "capture",
|
||||
"service": "agent",
|
||||
"flow": flow,
|
||||
"request": {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"streaming": True,
|
||||
},
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(request))
|
||||
|
||||
messages = []
|
||||
|
||||
async for raw in ws:
|
||||
msg = json.loads(raw)
|
||||
|
||||
if msg.get("id") != "capture":
|
||||
continue
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
if msg.get("complete"):
|
||||
break
|
||||
|
||||
with open(output, "w") as f:
|
||||
json.dump(messages, f, indent=2)
|
||||
|
||||
print(f"Captured {len(messages)} messages to {output}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("-q", "--question", required=True)
|
||||
parser.add_argument("-o", "--output", default="trace.json")
|
||||
parser.add_argument("-u", "--url", default=DEFAULT_URL)
|
||||
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
|
||||
parser.add_argument("-U", "--user", default=DEFAULT_USER)
|
||||
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(capture(
|
||||
args.url, args.flow, args.question,
|
||||
args.user, args.collection, args.output,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
dev-tools/tests/librarian/simple_text_download.py
Normal file
67
dev-tools/tests/librarian/simple_text_download.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal example: download a text document in tiny chunks via websocket API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import websockets
|
||||
|
||||
async def main():
|
||||
url = "ws://localhost:8088/api/v1/socket"
|
||||
|
||||
document_id = "test-chunked-doc-001"
|
||||
chunk_size = 10 # Tiny chunks!
|
||||
|
||||
request_id = 0
|
||||
|
||||
async def send_request(ws, request):
|
||||
nonlocal request_id
|
||||
request_id += 1
|
||||
msg = {
|
||||
"id": f"req-{request_id}",
|
||||
"service": "librarian",
|
||||
"request": request
|
||||
}
|
||||
await ws.send(json.dumps(msg))
|
||||
response = json.loads(await ws.recv())
|
||||
if "error" in response:
|
||||
raise Exception(response["error"])
|
||||
return response.get("response", {})
|
||||
|
||||
async with websockets.connect(url) as ws:
|
||||
|
||||
print(f"Fetching document: {document_id}")
|
||||
print(f"Chunk size: {chunk_size} bytes")
|
||||
print()
|
||||
|
||||
chunk_index = 0
|
||||
all_content = b""
|
||||
|
||||
while True:
|
||||
resp = await send_request(ws, {
|
||||
"operation": "stream-document",
|
||||
"user": "trustgraph",
|
||||
"document-id": document_id,
|
||||
"chunk-index": chunk_index,
|
||||
"chunk-size": chunk_size,
|
||||
})
|
||||
|
||||
chunk_data = base64.b64decode(resp["content"])
|
||||
total_chunks = resp["total-chunks"]
|
||||
total_bytes = resp["total-bytes"]
|
||||
|
||||
print(f"Chunk {chunk_index}: {chunk_data}")
|
||||
|
||||
all_content += chunk_data
|
||||
chunk_index += 1
|
||||
|
||||
if chunk_index >= total_chunks:
|
||||
break
|
||||
|
||||
print()
|
||||
print(f"Complete: {all_content}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
dev-tools/tests/librarian/simple_text_upload.py
Normal file
56
dev-tools/tests/librarian/simple_text_upload.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal example: upload a small text document via websocket API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
import websockets
|
||||
|
||||
async def main():
|
||||
url = "ws://localhost:8088/api/v1/socket"
|
||||
|
||||
# Small text content
|
||||
content = b"AAAAAAAAAABBBBBBBBBBCCCCCCCCCC"
|
||||
|
||||
request_id = 0
|
||||
|
||||
async def send_request(ws, request):
|
||||
nonlocal request_id
|
||||
request_id += 1
|
||||
msg = {
|
||||
"id": f"req-{request_id}",
|
||||
"service": "librarian",
|
||||
"request": request
|
||||
}
|
||||
await ws.send(json.dumps(msg))
|
||||
response = json.loads(await ws.recv())
|
||||
if "error" in response:
|
||||
raise Exception(response["error"])
|
||||
return response.get("response", {})
|
||||
|
||||
async with websockets.connect(url) as ws:
|
||||
|
||||
print(f"Uploading {len(content)} bytes...")
|
||||
|
||||
resp = await send_request(ws, {
|
||||
"operation": "add-document",
|
||||
"document-metadata": {
|
||||
"id": "test-chunked-doc-001",
|
||||
"time": int(time.time()),
|
||||
"kind": "text/plain",
|
||||
"title": "My Test Document",
|
||||
"comments": "Small doc for chunk testing",
|
||||
"user": "trustgraph",
|
||||
"tags": ["test"],
|
||||
"metadata": [],
|
||||
},
|
||||
"content": base64.b64encode(content).decode("utf-8"),
|
||||
})
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
237
dev-tools/tests/relay/test_rev_gateway.py
Normal file
237
dev-tools/tests/relay/test_rev_gateway.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket Test Client
|
||||
|
||||
A simple client to test the reverse gateway through the relay.
|
||||
Connects to the relay's /in endpoint and allows sending test messages.
|
||||
|
||||
Usage:
|
||||
python test_client.py [--uri URI] [--interactive]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import uuid
|
||||
from aiohttp import ClientSession, WSMsgType
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("test_client")
|
||||
|
||||
class TestClient:
|
||||
"""Simple WebSocket test client"""
|
||||
|
||||
def __init__(self, uri: str):
|
||||
self.uri = uri
|
||||
self.session = None
|
||||
self.ws = None
|
||||
self.running = False
|
||||
self.message_counter = 0
|
||||
self.client_id = str(uuid.uuid4())[:8]
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the WebSocket"""
|
||||
self.session = ClientSession()
|
||||
logger.info(f"Connecting to {self.uri}")
|
||||
self.ws = await self.session.ws_connect(self.uri)
|
||||
logger.info("Connected successfully")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket"""
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
logger.info("Disconnected")
|
||||
|
||||
async def send_message(self, service: str, request_data: dict, flow: str = "default"):
|
||||
"""Send a properly formatted TrustGraph message"""
|
||||
self.message_counter += 1
|
||||
message = {
|
||||
"id": f"{self.client_id}-{self.message_counter}",
|
||||
"service": service,
|
||||
"request": request_data,
|
||||
"flow": flow
|
||||
}
|
||||
|
||||
message_json = json.dumps(message, indent=2)
|
||||
logger.info(f"Sending message:\n{message_json}")
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
|
||||
async def listen_for_responses(self):
|
||||
"""Listen for incoming messages"""
|
||||
logger.info("Listening for responses...")
|
||||
|
||||
async for msg in self.ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
try:
|
||||
response = json.loads(msg.data)
|
||||
logger.info(f"Received response:\n{json.dumps(response, indent=2)}")
|
||||
except json.JSONDecodeError:
|
||||
logger.info(f"Received text: {msg.data}")
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
logger.info(f"Received binary data: {len(msg.data)} bytes")
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {self.ws.exception()}")
|
||||
break
|
||||
else:
|
||||
logger.info(f"Connection closed: {msg.type}")
|
||||
break
|
||||
|
||||
async def interactive_mode(self):
|
||||
"""Interactive mode for manual testing"""
|
||||
print("\n=== Interactive Test Client ===")
|
||||
print("Available commands:")
|
||||
print(" text-completion - Test text completion service")
|
||||
print(" agent - Test agent service")
|
||||
print(" embeddings - Test embeddings service")
|
||||
print(" custom - Send custom message")
|
||||
print(" quit - Exit")
|
||||
print()
|
||||
|
||||
# Start response listener
|
||||
listen_task = asyncio.create_task(self.listen_for_responses())
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
command = input("Command> ").strip().lower()
|
||||
|
||||
if command == "quit":
|
||||
break
|
||||
elif command == "text-completion":
|
||||
await self.send_message("text-completion", {
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is 2+2?"
|
||||
})
|
||||
elif command == "agent":
|
||||
await self.send_message("agent", {
|
||||
"question": "What is the capital of France?"
|
||||
})
|
||||
elif command == "embeddings":
|
||||
await self.send_message("embeddings", {
|
||||
"text": "Hello world"
|
||||
})
|
||||
elif command == "custom":
|
||||
service = input("Service name> ").strip()
|
||||
request_json = input("Request JSON> ").strip()
|
||||
try:
|
||||
request_data = json.loads(request_json)
|
||||
await self.send_message(service, request_data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
elif command == "":
|
||||
continue
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in interactive mode: {e}")
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def run_predefined_tests(self):
|
||||
"""Run a series of predefined tests"""
|
||||
print("\n=== Running Predefined Tests ===")
|
||||
|
||||
# Start response listener
|
||||
listen_task = asyncio.create_task(self.listen_for_responses())
|
||||
|
||||
try:
|
||||
# Test 1: Text completion
|
||||
print("\n1. Testing text-completion service...")
|
||||
await self.send_message("text-completion", {
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is 2+2?"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 2: Agent
|
||||
print("\n2. Testing agent service...")
|
||||
await self.send_message("agent", {
|
||||
"question": "What is the capital of France?"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 3: Embeddings
|
||||
print("\n3. Testing embeddings service...")
|
||||
await self.send_message("embeddings", {
|
||||
"text": "Hello world"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 4: Invalid service
|
||||
print("\n4. Testing invalid service...")
|
||||
await self.send_message("nonexistent-service", {
|
||||
"test": "data"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
print("\nTests completed. Waiting for any remaining responses...")
|
||||
await asyncio.sleep(3)
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Test Client for Reverse Gateway"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--uri',
|
||||
default='ws://localhost:8080/in',
|
||||
help='WebSocket URI to connect to (default: ws://localhost:8080/in)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--interactive', '-i',
|
||||
action='store_true',
|
||||
help='Run in interactive mode'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
client = TestClient(args.uri)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
|
||||
if args.interactive:
|
||||
await client.interactive_mode()
|
||||
else:
|
||||
await client.run_predefined_tests()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Client error: {e}")
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
210
dev-tools/tests/relay/websocket_relay.py
Normal file
210
dev-tools/tests/relay/websocket_relay.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket Relay Test Harness
|
||||
|
||||
This script creates a relay server with two WebSocket endpoints:
|
||||
- /in - for test clients to connect to
|
||||
- /out - for reverse gateway to connect to
|
||||
|
||||
Messages are bidirectionally relayed between the two connections.
|
||||
|
||||
Usage:
|
||||
python websocket_relay.py [--port PORT] [--host HOST]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import argparse
|
||||
from aiohttp import web, WSMsgType
|
||||
import weakref
|
||||
from typing import Optional, Set
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("websocket_relay")
|
||||
|
||||
class WebSocketRelay:
|
||||
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
|
||||
|
||||
def __init__(self):
|
||||
self.in_connections: Set = weakref.WeakSet()
|
||||
self.out_connections: Set = weakref.WeakSet()
|
||||
|
||||
async def handle_in_connection(self, request):
|
||||
"""Handle incoming connections on /in endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.in_connections.add(ws)
|
||||
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {data}")
|
||||
await self._forward_to_out(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
|
||||
await self._forward_to_out(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'in' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
"""Handle outgoing connections on /out endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {data}")
|
||||
await self._forward_to_in(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
|
||||
await self._forward_to_in(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def _forward_to_out(self, data, binary=False):
|
||||
"""Forward message from 'in' to all 'out' connections"""
|
||||
if not self.out_connections:
|
||||
logger.warning("No 'out' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.out_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'out' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.out_connections:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
async def _forward_to_in(self, data, binary=False):
|
||||
"""Forward message from 'out' to all 'in' connections"""
|
||||
if not self.in_connections:
|
||||
logger.warning("No 'in' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.in_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'in' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.in_connections:
|
||||
self.in_connections.discard(ws)
|
||||
|
||||
async def create_app(relay):
|
||||
"""Create the web application with routes"""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes
|
||||
app.router.add_get('/in', relay.handle_in_connection)
|
||||
app.router.add_get('/out', relay.handle_out_connection)
|
||||
|
||||
# Add a simple status endpoint
|
||||
async def status(request):
|
||||
status_info = {
|
||||
'in_connections': len(relay.in_connections),
|
||||
'out_connections': len(relay.out_connections),
|
||||
'status': 'running'
|
||||
}
|
||||
return web.json_response(status_info)
|
||||
|
||||
app.router.add_get('/status', status)
|
||||
app.router.add_get('/', status) # Root also shows status
|
||||
|
||||
return app
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Relay Test Harness"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
default='localhost',
|
||||
help='Host to bind to (default: localhost)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Port to bind to (default: 8080)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
relay = WebSocketRelay()
|
||||
|
||||
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
|
||||
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
||||
print(f" Status: http://{args.host}:{args.port}/status")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
|
||||
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
|
||||
|
||||
web.run_app(create_app(relay), host=args.host, port=args.port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
227
dev-tools/tests/triples/load_test_triples.py
Executable file
227
dev-tools/tests/triples/load_test_triples.py
Executable file
|
|
@ -0,0 +1,227 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Load test triples into the triple store for testing tg-query-graph.
|
||||
|
||||
Tests all graph features:
|
||||
- SPO with IRI objects
|
||||
- SPO with literal objects
|
||||
- Literals with XML datatypes
|
||||
- Literals with language tags
|
||||
- Quoted triples (RDF-star)
|
||||
- Named graphs
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import websockets
|
||||
|
||||
# Configuration
|
||||
API_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
FLOW = "default"
|
||||
USER = "trustgraph"
|
||||
COLLECTION = "default"
|
||||
DOCUMENT_ID = "test-triples-001"
|
||||
|
||||
# Namespaces
|
||||
EX = "http://example.org/"
|
||||
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
|
||||
XSD = "http://www.w3.org/2001/XMLSchema#"
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
|
||||
|
||||
def iri(value):
|
||||
"""Build IRI term."""
|
||||
return {"t": "i", "i": value}
|
||||
|
||||
|
||||
def literal(value, datatype=None, language=None):
|
||||
"""Build literal term with optional datatype or language."""
|
||||
term = {"t": "l", "v": value}
|
||||
if datatype:
|
||||
term["dt"] = datatype
|
||||
if language:
|
||||
term["ln"] = language
|
||||
return term
|
||||
|
||||
|
||||
def quoted_triple(s, p, o):
|
||||
"""Build quoted triple term (RDF-star)."""
|
||||
return {
|
||||
"t": "t",
|
||||
"tr": {"s": s, "p": p, "o": o}
|
||||
}
|
||||
|
||||
|
||||
def triple(s, p, o, g=None):
|
||||
"""Build a complete triple dict."""
|
||||
t = {"s": s, "p": p, "o": o}
|
||||
if g:
|
||||
t["g"] = g
|
||||
return t
|
||||
|
||||
|
||||
# Test triples covering all features
|
||||
TEST_TRIPLES = [
|
||||
# 1. Basic SPO with IRI object
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDF}type"),
|
||||
iri(f"{EX}Scientist")
|
||||
),
|
||||
|
||||
# 2. SPO with IRI object (relationship)
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{EX}discovered"),
|
||||
iri(f"{EX}radium")
|
||||
),
|
||||
|
||||
# 3. Simple literal (no datatype/language)
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDFS}label"),
|
||||
literal("Marie Curie")
|
||||
),
|
||||
|
||||
# 4. Literal with language tag (English)
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDFS}label"),
|
||||
literal("Marie Curie", language="en")
|
||||
),
|
||||
|
||||
# 5. Literal with language tag (French)
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDFS}label"),
|
||||
literal("Marie Curie", language="fr")
|
||||
),
|
||||
|
||||
# 6. Literal with language tag (Polish)
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDFS}label"),
|
||||
literal("Maria Sk\u0142odowska-Curie", language="pl")
|
||||
),
|
||||
|
||||
# 7. Literal with xsd:integer datatype
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{EX}birthYear"),
|
||||
literal("1867", datatype=f"{XSD}integer")
|
||||
),
|
||||
|
||||
# 8. Literal with xsd:date datatype
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{EX}birthDate"),
|
||||
literal("1867-11-07", datatype=f"{XSD}date")
|
||||
),
|
||||
|
||||
# 9. Literal with xsd:boolean datatype
|
||||
triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{EX}nobelLaureate"),
|
||||
literal("true", datatype=f"{XSD}boolean")
|
||||
),
|
||||
|
||||
# 10. Quoted triple in object position (RDF 1.2 style)
|
||||
# "Wikipedia asserts that Marie Curie discovered radium"
|
||||
triple(
|
||||
iri(f"{EX}wikipedia"),
|
||||
iri(f"{TG}asserts"),
|
||||
quoted_triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{EX}discovered"),
|
||||
iri(f"{EX}radium")
|
||||
)
|
||||
),
|
||||
|
||||
# 11. Quoted triple with literal inside (object position)
|
||||
# "NLP-v1.0 extracted that Marie Curie has label Marie Curie"
|
||||
triple(
|
||||
iri(f"{EX}nlp-v1"),
|
||||
iri(f"{TG}extracted"),
|
||||
quoted_triple(
|
||||
iri(f"{EX}marie-curie"),
|
||||
iri(f"{RDFS}label"),
|
||||
literal("Marie Curie")
|
||||
)
|
||||
),
|
||||
|
||||
# 12. Triple in a named graph (g is plain string, not Term)
|
||||
triple(
|
||||
iri(f"{EX}radium"),
|
||||
iri(f"{RDF}type"),
|
||||
iri(f"{EX}Element"),
|
||||
g=f"{EX}chemistry-graph"
|
||||
),
|
||||
|
||||
# 13. Another triple in the same named graph
|
||||
triple(
|
||||
iri(f"{EX}radium"),
|
||||
iri(f"{EX}atomicNumber"),
|
||||
literal("88", datatype=f"{XSD}integer"),
|
||||
g=f"{EX}chemistry-graph"
|
||||
),
|
||||
|
||||
# 14. Triple in a different named graph
|
||||
triple(
|
||||
iri(f"{EX}pierre-curie"),
|
||||
iri(f"{EX}spouseOf"),
|
||||
iri(f"{EX}marie-curie"),
|
||||
g=f"{EX}biography-graph"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def load_triples():
|
||||
"""Load test triples via WebSocket bulk import."""
|
||||
# Convert HTTP URL to WebSocket URL
|
||||
ws_url = API_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/flow/{FLOW}/import/triples"
|
||||
if TOKEN:
|
||||
ws_url = f"{ws_url}?token={TOKEN}"
|
||||
|
||||
metadata = {
|
||||
"id": DOCUMENT_ID,
|
||||
"metadata": [],
|
||||
"user": USER,
|
||||
"collection": COLLECTION
|
||||
}
|
||||
|
||||
print(f"Connecting to {ws_url}...")
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as websocket:
|
||||
message = {
|
||||
"metadata": metadata,
|
||||
"triples": TEST_TRIPLES
|
||||
}
|
||||
print(f"Sending {len(TEST_TRIPLES)} test triples...")
|
||||
await websocket.send(json.dumps(message))
|
||||
print("Triples sent successfully!")
|
||||
|
||||
print("\nTest triples loaded:")
|
||||
print(" - 2 basic IRI triples (type, relationship)")
|
||||
print(" - 4 literal triples (plain + 3 languages: en, fr, pl)")
|
||||
print(" - 3 typed literal triples (xsd:integer, xsd:date, xsd:boolean)")
|
||||
print(" - 2 quoted triples (RDF-star provenance)")
|
||||
print(" - 3 triples in named graphs (chemistry-graph, biography-graph)")
|
||||
print(f"\nTotal: {len(TEST_TRIPLES)} triples")
|
||||
print(f"User: {USER}, Collection: {COLLECTION}")
|
||||
|
||||
|
||||
def main():
|
||||
print("Loading test triples for tg-query-graph testing\n")
|
||||
asyncio.run(load_triples())
|
||||
print("\nDone! Now test with:")
|
||||
print(" tg-query-graph -s http://example.org/marie-curie")
|
||||
print(" tg-query-graph -p http://www.w3.org/2000/01/rdf-schema#label")
|
||||
print(" tg-query-graph -o 'Marie Curie' --object-language en")
|
||||
print(" tg-query-graph --format json | jq .")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
# API Gateway Changes: v1.8 to v2.1
|
||||
|
||||
## Summary
|
||||
|
||||
The API gateway gained new WebSocket service dispatchers for embeddings
|
||||
queries, a new REST streaming endpoint for document content, and underwent
|
||||
a significant wire format change from `Value` to `Term`. The "objects"
|
||||
service was renamed to "rows".
|
||||
|
||||
---
|
||||
|
||||
## New WebSocket Service Dispatchers
|
||||
|
||||
These are new request/response services available through the WebSocket
|
||||
multiplexer at `/api/v1/socket` (flow-scoped):
|
||||
|
||||
| Service Key | Description |
|
||||
|-------------|-------------|
|
||||
| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. |
|
||||
| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. |
|
||||
|
||||
These join the existing `graph-embeddings` dispatcher (which was already
|
||||
present in v1.8 but may have been updated).
|
||||
|
||||
### Full list of WebSocket flow service dispatchers (v2.1)
|
||||
|
||||
Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or
|
||||
WebSocket mux):
|
||||
|
||||
- `agent`, `text-completion`, `prompt`, `mcp-tool`
|
||||
- `graph-rag`, `document-rag`
|
||||
- `embeddings`, `graph-embeddings`, `document-embeddings`
|
||||
- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag`
|
||||
- `row-embeddings`
|
||||
|
||||
---
|
||||
|
||||
## New REST Endpoint
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. |
|
||||
|
||||
---
|
||||
|
||||
## Renamed Service: "objects" to "rows"
|
||||
|
||||
| v1.8 | v2.1 | Notes |
|
||||
|------|------|-------|
|
||||
| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. |
|
||||
| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. |
|
||||
|
||||
The WebSocket service key changed from `"objects"` to `"rows"`, and the
|
||||
import dispatcher key similarly changed from `"objects"` to `"rows"`.
|
||||
|
||||
---
|
||||
|
||||
## Wire Format Change: Value to Term
|
||||
|
||||
The serialization layer (`serialize.py`) was rewritten to use the new `Term`
|
||||
type instead of the old `Value` type.
|
||||
|
||||
### Old format (v1.8 — `Value`)
|
||||
|
||||
```json
|
||||
{"v": "http://example.org/entity", "e": true}
|
||||
```
|
||||
|
||||
- `v`: the value (string)
|
||||
- `e`: boolean flag indicating whether the value is a URI
|
||||
|
||||
### New format (v2.1 — `Term`)
|
||||
|
||||
IRIs:
|
||||
```json
|
||||
{"t": "i", "i": "http://example.org/entity"}
|
||||
```
|
||||
|
||||
Literals:
|
||||
```json
|
||||
{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"}
|
||||
```
|
||||
|
||||
Quoted triples (RDF-star):
|
||||
```json
|
||||
{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}}
|
||||
```
|
||||
|
||||
- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node)
|
||||
- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives`
|
||||
|
||||
### Other serialization changes
|
||||
|
||||
| Field | v1.8 | v2.1 |
|
||||
|-------|------|------|
|
||||
| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) |
|
||||
| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) |
|
||||
| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) |
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format.
|
||||
- **`objects` to `rows` rename**: WebSocket service key and import key changed.
|
||||
- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value).
|
||||
- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text.
|
||||
- **New `/api/v1/document-stream` endpoint**: Additive, not breaking.
|
||||
176
docs/api.html
176
docs/api.html
File diff suppressed because one or more lines are too long
|
|
@ -1,112 +0,0 @@
|
|||
# CLI Changes: v1.8 to v2.1
|
||||
|
||||
## Summary
|
||||
|
||||
The CLI (`trustgraph-cli`) has significant additions focused on three themes:
|
||||
**explainability/provenance**, **embeddings access**, and **graph querying**.
|
||||
Two legacy tools were removed, one was renamed, and several existing tools
|
||||
gained new capabilities.
|
||||
|
||||
---
|
||||
|
||||
## New CLI Tools
|
||||
|
||||
### Explainability & Provenance
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-list-explain-traces` | Lists all explainability sessions (GraphRAG and Agent) in a collection, showing session IDs, type, question text, and timestamps. |
|
||||
| `tg-show-explain-trace` | Displays the full explainability trace for a session. For GraphRAG: Question, Exploration, Focus, Synthesis stages. For Agent: Session, Iterations (thought/action/observation), Final Answer. Auto-detects trace type. Supports `--show-provenance` to trace edges back to source documents. |
|
||||
| `tg-show-extraction-provenance` | Given a document ID, traverses the provenance chain: Document -> Pages -> Chunks -> Edges, using `prov:wasDerivedFrom` relationships. Supports `--show-content` and `--max-content` options. |
|
||||
|
||||
### Embeddings
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-invoke-embeddings` | Converts text to a vector embedding via the embeddings service. Accepts one or more text inputs, returns vectors as lists of floats. |
|
||||
| `tg-invoke-graph-embeddings` | Queries graph entities by text similarity using vector embeddings. Returns matching entities with similarity scores. |
|
||||
| `tg-invoke-document-embeddings` | Queries document chunks by text similarity using vector embeddings. Returns matching chunk IDs with similarity scores. |
|
||||
| `tg-invoke-row-embeddings` | Queries structured data rows by text similarity on indexed fields. Returns matching rows with index values and scores. Requires `--schema-name` and supports `--index-name`. |
|
||||
|
||||
### Graph Querying
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-query-graph` | Pattern-based triple store query. Unlike `tg-show-graph` (which dumps everything), this allows selective queries by any combination of subject, predicate, object, and graph. Auto-detects value types: IRIs (`http://...`, `urn:...`, `<...>`), quoted triples (`<<s p o>>`), and literals. |
|
||||
| `tg-get-document-content` | Retrieves document content from the library by document ID. Can output to file or stdout, handles both text and binary content. |
|
||||
|
||||
---
|
||||
|
||||
## Removed CLI Tools
|
||||
|
||||
| Command | Notes |
|
||||
|---------|-------|
|
||||
| `tg-load-pdf` | Removed. Document loading is now handled through the library/processing pipeline. |
|
||||
| `tg-load-text` | Removed. Document loading is now handled through the library/processing pipeline. |
|
||||
|
||||
---
|
||||
|
||||
## Renamed CLI Tools
|
||||
|
||||
| Old Name | New Name | Notes |
|
||||
|----------|----------|-------|
|
||||
| `tg-invoke-objects-query` | `tg-invoke-rows-query` | Reflects the terminology rename from "objects" to "rows" for structured data. |
|
||||
|
||||
---
|
||||
|
||||
## Significant Changes to Existing Tools
|
||||
|
||||
### `tg-invoke-graph-rag`
|
||||
|
||||
- **Explainability support**: Now supports a 4-stage explainability pipeline (Question, Grounding/Exploration, Focus, Synthesis) with inline provenance event display.
|
||||
- **Streaming**: Uses WebSocket streaming for real-time output.
|
||||
- **Provenance tracing**: Can trace selected edges back to source documents via reification and `prov:wasDerivedFrom` chains.
|
||||
- Grew from ~30 lines to ~760 lines to accommodate the full explainability pipeline.
|
||||
|
||||
### `tg-invoke-document-rag`
|
||||
|
||||
- **Explainability support**: Added `question_explainable()` mode that streams Document RAG responses with inline provenance events (Question, Grounding, Exploration, Synthesis stages).
|
||||
|
||||
### `tg-invoke-agent`
|
||||
|
||||
- **Explainability support**: Added `question_explainable()` mode showing provenance events inline during agent execution (Question, Analysis, Conclusion, AgentThought, AgentObservation, AgentAnswer).
|
||||
- Verbose mode shows thought/observation streams with emoji prefixes.
|
||||
|
||||
### `tg-show-graph`
|
||||
|
||||
- **Streaming mode**: Now uses `triples_query_stream()` with configurable batch sizes for lower time-to-first-result and reduced memory overhead.
|
||||
- **Named graph support**: New `--graph` filter option. Recognises named graphs:
|
||||
- Default graph (empty): Core knowledge facts
|
||||
- `urn:graph:source`: Extraction provenance
|
||||
- `urn:graph:retrieval`: Query-time explainability
|
||||
- **Show graph column**: New `--show-graph` flag to display the named graph for each triple.
|
||||
- **Configurable limits**: New `--limit` and `--batch-size` options.
|
||||
|
||||
### `tg-graph-to-turtle`
|
||||
|
||||
- **RDF-star support**: Now handles quoted triples (RDF-star reification).
|
||||
- **Streaming mode**: Uses streaming for lower time-to-first-processing.
|
||||
- **Wire format handling**: Updated to use the new term wire format (`{"t": "i", "i": uri}` for IRIs, `{"t": "l", "v": value}` for literals, `{"t": "r", "r": {...}}` for quoted triples).
|
||||
- **Named graph support**: New `--graph` filter option.
|
||||
|
||||
### `tg-set-tool`
|
||||
|
||||
- **New tool type**: `row-embeddings-query` for semantic search on structured data indexes.
|
||||
- **New options**: `--schema-name`, `--index-name`, `--limit` for configuring row embeddings query tools.
|
||||
|
||||
### `tg-show-tools`
|
||||
|
||||
- Displays the new `row-embeddings-query` tool type with its `schema-name`, `index-name`, and `limit` fields.
|
||||
|
||||
### `tg-load-knowledge`
|
||||
|
||||
- **Progress reporting**: Now counts and reports triples and entity contexts loaded per file and in total.
|
||||
- **Term format update**: Entity contexts now use the new Term format (`{"t": "i", "i": uri}`) instead of the old Value format (`{"v": entity, "e": True}`).
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **Terminology rename**: The `Value` schema was renamed to `Term` across the system (PR #622). This affects the wire format used by CLI tools that interact with the graph store. The new format uses `{"t": "i", "i": uri}` for IRIs and `{"t": "l", "v": value}` for literals, replacing the old `{"v": ..., "e": ...}` format.
|
||||
- **`tg-invoke-objects-query` renamed** to `tg-invoke-rows-query`.
|
||||
- **`tg-load-pdf` and `tg-load-text` removed**.
|
||||
|
|
@ -911,7 +911,7 @@ results = flow.graph_embeddings_query(
|
|||
# results contains {"entities": [{"entity": {...}, "score": 0.95}, ...]}
|
||||
```
|
||||
|
||||
### `graph_rag(self, query, user='trustgraph', collection='default', entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2)`
|
||||
### `graph_rag(self, query, user='trustgraph', collection='default', entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, edge_score_limit=30, edge_limit=25)`
|
||||
|
||||
Execute graph-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
||||
|
|
@ -927,6 +927,8 @@ traversing entity relationships, then generates a response using an LLM.
|
|||
- `triple_limit`: Maximum triples per entity (default: 30)
|
||||
- `max_subgraph_size`: Maximum total triples in subgraph (default: 150)
|
||||
- `max_path_length`: Maximum traversal depth (default: 2)
|
||||
- `edge_score_limit`: Max edges for semantic pre-filter (default: 50)
|
||||
- `edge_limit`: Max edges after LLM scoring (default: 25)
|
||||
|
||||
**Returns:** str: Generated response incorporating graph context
|
||||
|
||||
|
|
@ -1216,6 +1218,23 @@ Select matching schemas for a data sample using prompt analysis.
|
|||
|
||||
**Returns:** dict with schema_matches array and metadata
|
||||
|
||||
### `sparql_query(self, query, user='trustgraph', collection='default', limit=10000)`
|
||||
|
||||
Execute a SPARQL query against the knowledge graph.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `query`: SPARQL 1.1 query string
|
||||
- `user`: User/keyspace identifier (default: "trustgraph")
|
||||
- `collection`: Collection identifier (default: "default")
|
||||
- `limit`: Safety limit on results (default: 10000)
|
||||
|
||||
**Returns:** dict with query results. Structure depends on query type: - SELECT: {"query-type": "select", "variables": [...], "bindings": [...]} - ASK: {"query-type": "ask", "ask-result": bool} - CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]}
|
||||
|
||||
**Raises:**
|
||||
|
||||
- `ProtocolException`: If an error occurs
|
||||
|
||||
### `structured_query(self, question, user='trustgraph', collection='default')`
|
||||
|
||||
Execute a natural language question against structured data.
|
||||
|
|
@ -1937,54 +1956,24 @@ for triple in results.get("triples", []):
|
|||
from trustgraph.api import SocketClient
|
||||
```
|
||||
|
||||
Synchronous WebSocket client for streaming operations.
|
||||
Synchronous WebSocket client with persistent connection.
|
||||
|
||||
Provides a synchronous interface to WebSocket-based TrustGraph services,
|
||||
wrapping async websockets library with synchronous generators for ease of use.
|
||||
Supports streaming responses from agents, RAG queries, and text completions.
|
||||
|
||||
Note: This is a synchronous wrapper around async WebSocket operations. For
|
||||
true async support, use AsyncSocketClient instead.
|
||||
Maintains a single websocket connection and multiplexes requests
|
||||
by ID via a background reader task. Provides synchronous generators
|
||||
for streaming responses.
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, url: str, timeout: int, token: str | None) -> None`
|
||||
|
||||
Initialize synchronous WebSocket client.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `url`: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
|
||||
- `timeout`: WebSocket timeout in seconds
|
||||
- `token`: Optional bearer token for authentication
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
### `close(self) -> None`
|
||||
|
||||
Close WebSocket connections.
|
||||
|
||||
Note: Cleanup is handled automatically by context managers in async code.
|
||||
Close the persistent WebSocket connection.
|
||||
|
||||
### `flow(self, flow_id: str) -> 'SocketFlowInstance'`
|
||||
|
||||
Get a flow instance for WebSocket streaming operations.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `flow_id`: Flow identifier
|
||||
|
||||
**Returns:** SocketFlowInstance: Flow instance with streaming methods
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Stream agent responses
|
||||
for chunk in flow.agent(question="Hello", user="trustgraph", streaming=True):
|
||||
print(chunk.content, end='', flush=True)
|
||||
```
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -1997,618 +1986,82 @@ from trustgraph.api import SocketFlowInstance
|
|||
Synchronous WebSocket flow instance for streaming operations.
|
||||
|
||||
Provides the same interface as REST FlowInstance but with WebSocket-based
|
||||
streaming support for real-time responses. All methods support an optional
|
||||
`streaming` parameter to enable incremental result delivery.
|
||||
streaming support for real-time responses.
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, client: trustgraph.api.socket_client.SocketClient, flow_id: str) -> None`
|
||||
|
||||
Initialize socket flow instance.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `client`: Parent SocketClient
|
||||
- `flow_id`: Flow identifier
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
### `agent(self, question: str, user: str, state: Dict[str, Any] | None = None, group: str | None = None, history: List[Dict[str, Any]] | None = None, streaming: bool = False, **kwargs: Any) -> Dict[str, Any] | Iterator[trustgraph.api.types.StreamingChunk]`
|
||||
|
||||
Execute an agent operation with streaming support.
|
||||
|
||||
Agents can perform multi-step reasoning with tool use. This method always
|
||||
returns streaming chunks (thoughts, observations, answers) even when
|
||||
streaming=False, to show the agent's reasoning process.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `question`: User question or instruction
|
||||
- `user`: User identifier
|
||||
- `state`: Optional state dictionary for stateful conversations
|
||||
- `group`: Optional group identifier for multi-user contexts
|
||||
- `history`: Optional conversation history as list of message dicts
|
||||
- `streaming`: Enable streaming mode (default: False)
|
||||
- `**kwargs`: Additional parameters passed to the agent service
|
||||
|
||||
**Returns:** Iterator[StreamingChunk]: Stream of agent thoughts, observations, and answers
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Stream agent reasoning
|
||||
for chunk in flow.agent(
|
||||
question="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
streaming=True
|
||||
):
|
||||
if isinstance(chunk, AgentThought):
|
||||
print(f"[Thinking] {chunk.content}")
|
||||
elif isinstance(chunk, AgentObservation):
|
||||
print(f"[Observation] {chunk.content}")
|
||||
elif isinstance(chunk, AgentAnswer):
|
||||
print(f"[Answer] {chunk.content}")
|
||||
```
|
||||
|
||||
### `agent_explain(self, question: str, user: str, collection: str, state: Dict[str, Any] | None = None, group: str | None = None, history: List[Dict[str, Any]] | None = None, **kwargs: Any) -> Iterator[trustgraph.api.types.StreamingChunk | trustgraph.api.types.ProvenanceEvent]`
|
||||
|
||||
Execute an agent operation with explainability support.
|
||||
|
||||
Streams both content chunks (AgentThought, AgentObservation, AgentAnswer)
|
||||
and provenance events (ProvenanceEvent). Provenance events contain URIs
|
||||
that can be fetched using ExplainabilityClient to get detailed information
|
||||
about the agent's reasoning process.
|
||||
|
||||
Agent trace consists of:
|
||||
- Session: The initial question and session metadata
|
||||
- Iterations: Each thought/action/observation cycle
|
||||
- Conclusion: The final answer
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `question`: User question or instruction
|
||||
- `user`: User identifier
|
||||
- `collection`: Collection identifier for provenance storage
|
||||
- `state`: Optional state dictionary for stateful conversations
|
||||
- `group`: Optional group identifier for multi-user contexts
|
||||
- `history`: Optional conversation history as list of message dicts
|
||||
- `**kwargs`: Additional parameters passed to the agent service
|
||||
- `Yields`:
|
||||
- `Union[StreamingChunk, ProvenanceEvent]`: Agent chunks and provenance events
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent
|
||||
from trustgraph.api import AgentThought, AgentObservation, AgentAnswer
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
provenance_ids = []
|
||||
for item in flow.agent_explain(
|
||||
question="What is the capital of France?",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
):
|
||||
if isinstance(item, AgentThought):
|
||||
print(f"[Thought] {item.content}")
|
||||
elif isinstance(item, AgentObservation):
|
||||
print(f"[Observation] {item.content}")
|
||||
elif isinstance(item, AgentAnswer):
|
||||
print(f"[Answer] {item.content}")
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
provenance_ids.append(item.explain_id)
|
||||
|
||||
# Fetch session trace after completion
|
||||
if provenance_ids:
|
||||
trace = explain_client.fetch_agent_trace(
|
||||
provenance_ids[0], # Session URI is first
|
||||
graph="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
```
|
||||
|
||||
### `document_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Query document chunks using semantic similarity.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `text`: Query text for semantic search
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `limit`: Maximum number of results (default: 10)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: Query results with chunk_ids of matching document chunks
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
results = flow.document_embeddings_query(
|
||||
text="machine learning algorithms",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
limit=5
|
||||
)
|
||||
# results contains {"chunks": [{"chunk_id": "...", "score": 0.95}, ...]}
|
||||
```
|
||||
|
||||
### `document_rag(self, query: str, user: str, collection: str, doc_limit: int = 10, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
|
||||
|
||||
Execute document-based RAG query with optional streaming.
|
||||
|
||||
Uses vector embeddings to find relevant document chunks, then generates
|
||||
a response using an LLM. Streaming mode delivers results incrementally.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `query`: Natural language query
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `doc_limit`: Maximum document chunks to retrieve (default: 10)
|
||||
- `streaming`: Enable streaming mode (default: False)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Streaming document RAG
|
||||
for chunk in flow.document_rag(
|
||||
query="Summarize the key findings",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
doc_limit=5,
|
||||
streaming=True
|
||||
):
|
||||
print(chunk, end='', flush=True)
|
||||
```
|
||||
|
||||
### `document_rag_explain(self, query: str, user: str, collection: str, doc_limit: int = 10, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
|
||||
|
||||
Execute document-based RAG query with explainability support.
|
||||
|
||||
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
|
||||
Provenance events contain URIs that can be fetched using ExplainabilityClient
|
||||
to get detailed information about how the response was generated.
|
||||
|
||||
Document RAG trace consists of:
|
||||
- Question: The user's query
|
||||
- Exploration: Chunks retrieved from document store (chunk_count)
|
||||
- Synthesis: The generated answer
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `query`: Natural language query
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `doc_limit`: Maximum document chunks to retrieve (default: 10)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
- `Yields`:
|
||||
- `Union[RAGChunk, ProvenanceEvent]`: Content chunks and provenance events
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
for item in flow.document_rag_explain(
|
||||
query="Summarize the key findings",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
doc_limit=5
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
print(item.content, end='', flush=True)
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
# Fetch entity details
|
||||
entity = explain_client.fetch_entity(
|
||||
item.explain_id,
|
||||
graph=item.explain_graph,
|
||||
user="trustgraph",
|
||||
collection="research-papers"
|
||||
)
|
||||
print(f"Event: {entity}", file=sys.stderr)
|
||||
```
|
||||
|
||||
### `embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Generate vector embeddings for one or more texts.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `texts`: List of input texts to embed
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: Response containing vectors (one set per input text)
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
result = flow.embeddings(["quantum computing"])
|
||||
vectors = result.get("vectors", [])
|
||||
```
|
||||
|
||||
### `graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Query knowledge graph entities using semantic similarity.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `text`: Query text for semantic search
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `limit`: Maximum number of results (default: 10)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: Query results with similar entities
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
results = flow.graph_embeddings_query(
|
||||
text="physicist who discovered radioactivity",
|
||||
user="trustgraph",
|
||||
collection="scientists",
|
||||
limit=5
|
||||
)
|
||||
```
|
||||
|
||||
### `graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
|
||||
### `graph_rag(self, query: str, user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, max_subgraph_size: int = 1000, max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
|
||||
|
||||
Execute graph-based RAG query with optional streaming.
|
||||
|
||||
Uses knowledge graph structure to find relevant context, then generates
|
||||
a response using an LLM. Streaming mode delivers results incrementally.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `query`: Natural language query
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `max_subgraph_size`: Maximum total triples in subgraph (default: 1000)
|
||||
- `max_subgraph_count`: Maximum number of subgraphs (default: 5)
|
||||
- `max_entity_distance`: Maximum traversal depth (default: 3)
|
||||
- `streaming`: Enable streaming mode (default: False)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Streaming graph RAG
|
||||
for chunk in flow.graph_rag(
|
||||
query="Tell me about Marie Curie",
|
||||
user="trustgraph",
|
||||
collection="scientists",
|
||||
streaming=True
|
||||
):
|
||||
print(chunk, end='', flush=True)
|
||||
```
|
||||
|
||||
### `graph_rag_explain(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
|
||||
### `graph_rag_explain(self, query: str, user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, max_subgraph_size: int = 1000, max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
|
||||
|
||||
Execute graph-based RAG query with explainability support.
|
||||
|
||||
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
|
||||
Provenance events contain URIs that can be fetched using ExplainabilityClient
|
||||
to get detailed information about how the response was generated.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `query`: Natural language query
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `max_subgraph_size`: Maximum total triples in subgraph (default: 1000)
|
||||
- `max_subgraph_count`: Maximum number of subgraphs (default: 5)
|
||||
- `max_entity_distance`: Maximum traversal depth (default: 3)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
- `Yields`:
|
||||
- `Union[RAGChunk, ProvenanceEvent]`: Content chunks and provenance events
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
|
||||
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
explain_client = ExplainabilityClient(flow)
|
||||
|
||||
provenance_ids = []
|
||||
response_text = ""
|
||||
|
||||
for item in flow.graph_rag_explain(
|
||||
query="Tell me about Marie Curie",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
response_text += item.content
|
||||
print(item.content, end='', flush=True)
|
||||
elif isinstance(item, ProvenanceEvent):
|
||||
provenance_ids.append(item.provenance_id)
|
||||
|
||||
# Fetch explainability details
|
||||
for prov_id in provenance_ids:
|
||||
entity = explain_client.fetch_entity(
|
||||
prov_id,
|
||||
graph="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
)
|
||||
print(f"Entity: {entity}")
|
||||
```
|
||||
|
||||
### `mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Execute a Model Context Protocol (MCP) tool.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `name`: Tool name/identifier
|
||||
- `parameters`: Tool parameters dictionary
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: Tool execution result
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
result = flow.mcp_tool(
|
||||
name="search-web",
|
||||
parameters={"query": "latest AI news", "limit": 5}
|
||||
)
|
||||
```
|
||||
|
||||
### `prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
|
||||
|
||||
Execute a prompt template with optional streaming.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `id`: Prompt template identifier
|
||||
- `variables`: Dictionary of variable name to value mappings
|
||||
- `streaming`: Enable streaming mode (default: False)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Streaming prompt execution
|
||||
for chunk in flow.prompt(
|
||||
id="summarize-template",
|
||||
variables={"topic": "quantum computing", "length": "brief"},
|
||||
streaming=True
|
||||
):
|
||||
print(chunk, end='', flush=True)
|
||||
```
|
||||
|
||||
### `row_embeddings_query(self, text: str, schema_name: str, user: str = 'trustgraph', collection: str = 'default', index_name: str | None = None, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `text`: Query text for semantic search
|
||||
- `schema_name`: Schema name to search within
|
||||
- `user`: User/keyspace identifier (default: "trustgraph")
|
||||
- `collection`: Collection identifier (default: "default")
|
||||
- `index_name`: Optional index name to filter search to specific index
|
||||
- `limit`: Maximum number of results (default: 10)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: Query results with matches containing index_name, index_value, text, and score
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
|
||||
### `rows_query(self, query: str, user: str, collection: str, variables: Dict[str, Any] | None = None, operation_name: str | None = None, **kwargs: Any) -> Dict[str, Any]`
|
||||
|
||||
Execute a GraphQL query against structured rows.
|
||||
|
||||
**Arguments:**
|
||||
### `sparql_query_stream(self, query: str, user: str = 'trustgraph', collection: str = 'default', limit: int = 10000, batch_size: int = 20, **kwargs: Any) -> Iterator[Dict[str, Any]]`
|
||||
|
||||
- `query`: GraphQL query string
|
||||
- `user`: User/keyspace identifier
|
||||
- `collection`: Collection identifier
|
||||
- `variables`: Optional query variables dictionary
|
||||
- `operation_name`: Optional operation name for multi-operation documents
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** dict: GraphQL response with data, errors, and/or extensions
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
query = '''
|
||||
{
|
||||
scientists(limit: 10) {
|
||||
name
|
||||
field
|
||||
discoveries
|
||||
}
|
||||
}
|
||||
'''
|
||||
result = flow.rows_query(
|
||||
query=query,
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
)
|
||||
```
|
||||
Execute a SPARQL query with streaming batches.
|
||||
|
||||
### `text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> str | Iterator[str]`
|
||||
|
||||
Execute text completion with optional streaming.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `system`: System prompt defining the assistant's behavior
|
||||
- `prompt`: User prompt/question
|
||||
- `streaming`: Enable streaming mode (default: False)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Non-streaming
|
||||
response = flow.text_completion(
|
||||
system="You are helpful",
|
||||
prompt="Explain quantum computing",
|
||||
streaming=False
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Streaming
|
||||
for chunk in flow.text_completion(
|
||||
system="You are helpful",
|
||||
prompt="Explain quantum computing",
|
||||
streaming=True
|
||||
):
|
||||
print(chunk, end='', flush=True)
|
||||
```
|
||||
|
||||
### `triples_query(self, s: str | Dict[str, Any] | None = None, p: str | Dict[str, Any] | None = None, o: str | Dict[str, Any] | None = None, g: str | None = None, user: str | None = None, collection: str | None = None, limit: int = 100, **kwargs: Any) -> List[Dict[str, Any]]`
|
||||
|
||||
Query knowledge graph triples using pattern matching.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `s`: Subject filter - URI string, Term dict, or None for wildcard
|
||||
- `p`: Predicate filter - URI string, Term dict, or None for wildcard
|
||||
- `o`: Object filter - URI/literal string, Term dict, or None for wildcard
|
||||
- `g`: Named graph filter - URI string or None for all graphs
|
||||
- `user`: User/keyspace identifier (optional)
|
||||
- `collection`: Collection identifier (optional)
|
||||
- `limit`: Maximum results to return (default: 100)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
|
||||
**Returns:** List[Dict]: List of matching triples in wire format
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Find all triples about a specific subject
|
||||
triples = flow.triples_query(
|
||||
s="http://example.org/person/marie-curie",
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
)
|
||||
|
||||
# Query with named graph filter
|
||||
triples = flow.triples_query(
|
||||
s="urn:trustgraph:session:abc123",
|
||||
g="urn:graph:retrieval",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
```
|
||||
|
||||
### `triples_query_stream(self, s: str | Dict[str, Any] | None = None, p: str | Dict[str, Any] | None = None, o: str | Dict[str, Any] | None = None, g: str | None = None, user: str | None = None, collection: str | None = None, limit: int = 100, batch_size: int = 20, **kwargs: Any) -> Iterator[List[Dict[str, Any]]]`
|
||||
|
||||
Query knowledge graph triples with streaming batches.
|
||||
|
||||
Yields batches of triples as they arrive, reducing time-to-first-result
|
||||
and memory overhead for large result sets.
|
||||
|
||||
**Arguments:**
|
||||
|
||||
- `s`: Subject filter - URI string, Term dict, or None for wildcard
|
||||
- `p`: Predicate filter - URI string, Term dict, or None for wildcard
|
||||
- `o`: Object filter - URI/literal string, Term dict, or None for wildcard
|
||||
- `g`: Named graph filter - URI string or None for all graphs
|
||||
- `user`: User/keyspace identifier (optional)
|
||||
- `collection`: Collection identifier (optional)
|
||||
- `limit`: Maximum results to return (default: 100)
|
||||
- `batch_size`: Triples per batch (default: 20)
|
||||
- `**kwargs`: Additional parameters passed to the service
|
||||
- `Yields`:
|
||||
- `List[Dict]`: Batches of triples in wire format
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
for batch in flow.triples_query_stream(
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
):
|
||||
for triple in batch:
|
||||
print(triple["s"], triple["p"], triple["o"])
|
||||
```
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -2618,17 +2071,35 @@ for batch in flow.triples_query_stream(
|
|||
from trustgraph.api import AsyncSocketClient
|
||||
```
|
||||
|
||||
Asynchronous WebSocket client
|
||||
Asynchronous WebSocket client with persistent connection.
|
||||
|
||||
Maintains a single websocket connection and multiplexes requests
|
||||
by ID, routing responses via a background reader task.
|
||||
|
||||
Use as an async context manager for proper lifecycle management:
|
||||
|
||||
async with AsyncSocketClient(url, timeout, token) as client:
|
||||
result = await client._send_request(...)
|
||||
|
||||
Or call connect()/aclose() manually.
|
||||
|
||||
### Methods
|
||||
|
||||
### `__aenter__(self)`
|
||||
|
||||
### `__aexit__(self, exc_type, exc_val, exc_tb)`
|
||||
|
||||
### `__init__(self, url: str, timeout: int, token: str | None)`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
### `aclose(self)`
|
||||
|
||||
Close WebSocket connection
|
||||
Close the persistent WebSocket connection cleanly.
|
||||
|
||||
### `connect(self)`
|
||||
|
||||
Establish the persistent websocket connection.
|
||||
|
||||
### `flow(self, flow_id: str)`
|
||||
|
||||
|
|
@ -3151,7 +2622,10 @@ Detect whether a session is GraphRAG or Agent type.
|
|||
|
||||
Fetch the complete Agent trace starting from a session URI.
|
||||
|
||||
Follows the provenance chain: Question -> Analysis(s) -> Conclusion
|
||||
Follows the provenance chain for all patterns:
|
||||
- ReAct: Question -> Analysis(s) -> Conclusion
|
||||
- Supervisor: Question -> Decomposition -> Finding(s) -> Synthesis
|
||||
- Plan-then-Execute: Question -> Plan -> StepResult(s) -> Synthesis
|
||||
|
||||
**Arguments:**
|
||||
|
||||
|
|
@ -3162,7 +2636,7 @@ Follows the provenance chain: Question -> Analysis(s) -> Conclusion
|
|||
- `api`: TrustGraph Api instance for librarian access (optional)
|
||||
- `max_content`: Maximum content length for conclusion
|
||||
|
||||
**Returns:** Dict with question, iterations (Analysis list), conclusion entities
|
||||
**Returns:** Dict with question, steps (mixed entity list), conclusion/synthesis
|
||||
|
||||
### `fetch_docrag_trace(self, question_uri: str, graph: str | None = None, user: str | None = None, collection: str | None = None, api: Any = None, max_content: int = 10000) -> Dict[str, Any]`
|
||||
|
||||
|
|
@ -3423,7 +2897,7 @@ Initialize self. See help(type(self)) for accurate signature.
|
|||
from trustgraph.api import Analysis
|
||||
```
|
||||
|
||||
Analysis entity - one think/act/observe cycle (Agent only).
|
||||
Analysis+ToolUse entity - decision + tool call (Agent only).
|
||||
|
||||
**Fields:**
|
||||
|
||||
|
|
@ -3432,11 +2906,33 @@ Analysis entity - one think/act/observe cycle (Agent only).
|
|||
- `action`: <class 'str'>
|
||||
- `arguments`: <class 'str'>
|
||||
- `thought`: <class 'str'>
|
||||
- `observation`: <class 'str'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, uri: str, entity_type: str = '', action: str = '', arguments: str = '', thought: str = '', observation: str = '') -> None`
|
||||
### `__init__(self, uri: str, entity_type: str = '', action: str = '', arguments: str = '', thought: str = '') -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
||||
---
|
||||
|
||||
## `Observation`
|
||||
|
||||
```python
|
||||
from trustgraph.api import Observation
|
||||
```
|
||||
|
||||
Observation entity - standalone tool result (Agent only).
|
||||
|
||||
**Fields:**
|
||||
|
||||
- `uri`: <class 'str'>
|
||||
- `entity_type`: <class 'str'>
|
||||
- `document`: <class 'str'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, uri: str, entity_type: str = '', document: str = '') -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
|
@ -3761,10 +3257,11 @@ These chunks show how the agent is thinking about the problem.
|
|||
- `content`: <class 'str'>
|
||||
- `end_of_message`: <class 'bool'>
|
||||
- `chunk_type`: <class 'str'>
|
||||
- `message_id`: <class 'str'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'thought') -> None`
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'thought', message_id: str = '') -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
|
@ -3787,10 +3284,11 @@ These chunks show what the agent learned from using tools.
|
|||
- `content`: <class 'str'>
|
||||
- `end_of_message`: <class 'bool'>
|
||||
- `chunk_type`: <class 'str'>
|
||||
- `message_id`: <class 'str'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'observation') -> None`
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'observation', message_id: str = '') -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
|
@ -3818,10 +3316,11 @@ its reasoning and tool use.
|
|||
- `end_of_message`: <class 'bool'>
|
||||
- `chunk_type`: <class 'str'>
|
||||
- `end_of_dialog`: <class 'bool'>
|
||||
- `message_id`: <class 'str'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'final-answer', end_of_dialog: bool = False) -> None`
|
||||
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'final-answer', end_of_dialog: bool = False, message_id: str = '') -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
|
@ -3864,7 +3363,7 @@ from trustgraph.api import ProvenanceEvent
|
|||
|
||||
Provenance event for explainability.
|
||||
|
||||
Emitted during GraphRAG queries when explainable mode is enabled.
|
||||
Emitted during retrieval queries when explainable mode is enabled.
|
||||
Each event represents a provenance node created during query processing.
|
||||
|
||||
**Fields:**
|
||||
|
|
@ -3872,10 +3371,12 @@ Each event represents a provenance node created during query processing.
|
|||
- `explain_id`: <class 'str'>
|
||||
- `explain_graph`: <class 'str'>
|
||||
- `event_type`: <class 'str'>
|
||||
- `entity`: <class 'object'>
|
||||
- `triples`: <class 'list'>
|
||||
|
||||
### Methods
|
||||
|
||||
### `__init__(self, explain_id: str, explain_graph: str = '', event_type: str = '') -> None`
|
||||
### `__init__(self, explain_id: str, explain_graph: str = '', event_type: str = '', entity: object = None, triples: list = <factory>) -> None`
|
||||
|
||||
Initialize self. See help(type(self)) for accurate signature.
|
||||
|
||||
|
|
|
|||
|
|
@ -219,8 +219,8 @@ TG_ANSWER = TG + "answer"
|
|||
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
|
||||
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
|
||||
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
|
||||
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id and explain_graph to DocumentRagResponse |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields |
|
||||
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples |
|
||||
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
|
||||
|
|
|
|||
939
docs/tech-specs/agent-orchestration.md
Normal file
939
docs/tech-specs/agent-orchestration.md
Normal file
|
|
@ -0,0 +1,939 @@
|
|||
# TrustGraph Agent Orchestration — Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes the extension of TrustGraph's agent architecture
|
||||
from a single ReACT execution pattern to a multi-pattern orchestration
|
||||
model. The existing Pulsar-based self-queuing loop is pattern-agnostic — the
|
||||
same infrastructure supports ReACT, Plan-then-Execute, Supervisor/Subagent
|
||||
fan-out, and other execution strategies without changes to the message
|
||||
transport. The extension adds a routing layer that selects the appropriate
|
||||
pattern for each task, a set of pattern implementations that share common
|
||||
iteration infrastructure, and a fan-out/fan-in mechanism for multi-agent
|
||||
coordination.
|
||||
|
||||
The central design principle is that
|
||||
**trust and explainability are structural properties of the architecture**,
|
||||
achieved by constraining LLM decisions to
|
||||
graph-defined option sets and recording those constraints in the execution
|
||||
trace.
|
||||
|
||||
---
|
||||
|
||||
## Background
|
||||
|
||||
### Existing Architecture
|
||||
|
||||
The current agent manager is built on the ReACT pattern (Reasoning + Acting)
|
||||
with these properties:
|
||||
|
||||
- **Self-queuing loop**: Each iteration emits a new Pulsar message carrying
|
||||
the accumulated history. The agent manager picks this up and runs the next
|
||||
iteration.
|
||||
- **Stateless agent manager**: No in-process state. All state lives in the
|
||||
message payload.
|
||||
- **Natural parallelism**: Multiple independent agent requests are handled
|
||||
concurrently across Pulsar consumers.
|
||||
- **Durability**: Crash recovery is inherent — the message survives process
|
||||
failure.
|
||||
- **Real-time feedback**: Streaming thought, action, observation and answer
|
||||
chunks are emitted as iterations complete.
|
||||
- **Tool calling and MCP invocation**: Tool calls into knowledge graphs,
|
||||
external services, and MCP-connected systems.
|
||||
- **Decision traces written to the knowledge graph**: Every iteration records
|
||||
PROV-O triples — session, analysis, and conclusion entities — forming the
|
||||
basis of explainability.
|
||||
|
||||
### Current Message Flow
|
||||
|
||||
```
|
||||
AgentRequest arrives (question, history=[], state, group, session_id)
|
||||
│
|
||||
▼
|
||||
Filter tools by group/state
|
||||
│
|
||||
▼
|
||||
AgentManager.react() → LLM call → parse → Action or Final
|
||||
│ │
|
||||
│ [Action] │ [Final]
|
||||
▼ ▼
|
||||
Execute tool, capture observation Emit conclusion triples
|
||||
Emit iteration triples Send AgentResponse
|
||||
Append to history (end_of_dialog=True)
|
||||
Emit new AgentRequest → "next" topic
|
||||
│
|
||||
└── (picked up again by consumer, loop continues)
|
||||
```
|
||||
|
||||
The key insight is that this loop structure is not ReACT-specific. The
|
||||
plumbing — receive message, do work, emit next message — is the same
|
||||
regardless of what the "work" step does. The payload and the pattern logic
|
||||
define the behaviour; the infrastructure remains constant.
|
||||
|
||||
### Current Limitations
|
||||
|
||||
- Only one execution pattern (ReACT) is available regardless of task
|
||||
characteristics.
|
||||
- No mechanism for one agent to spawn and coordinate subagents.
|
||||
- Pattern selection is implicit — every task gets the same treatment.
|
||||
- The provenance model assumes a linear iteration chain (analysis N derives
|
||||
from analysis N-1), with no support for parallel branches.
|
||||
|
||||
---
|
||||
|
||||
## Design Goals
|
||||
|
||||
- **Pattern-agnostic iteration infrastructure**: The self-queuing loop, tool
|
||||
filtering, provenance emission, and streaming feedback should be shared
|
||||
across all patterns.
|
||||
- **Graph-constrained pattern selection**: The LLM selects patterns from a
|
||||
graph-defined set, not from unconstrained reasoning. This makes the
|
||||
selection auditable and explainable.
|
||||
- **Genuinely parallel fan-out**: Subagent tasks execute concurrently on the
|
||||
Pulsar queue, not sequentially in a single process.
|
||||
- **Stateless coordination**: Fan-in uses the knowledge graph as coordination
|
||||
substrate. The agent manager remains stateless.
|
||||
- **Additive change**: The existing ReACT flow continues to work
|
||||
unchanged. New patterns are added alongside it, not in place of it.
|
||||
|
||||
---
|
||||
|
||||
## Patterns
|
||||
|
||||
### ReACT as One Pattern Among Many
|
||||
|
||||
ReACT is one point in a wider space of agent execution strategies:
|
||||
|
||||
| Pattern | Structure | Strengths |
|
||||
|---|---|---|
|
||||
| **ReACT** | Interleaved reasoning and action | Adaptive, good for open-ended tasks |
|
||||
| **Plan-then-Execute** | Decompose into a step DAG, then execute | More predictable, auditable plan |
|
||||
| **Reflexion** | ReACT + self-critique after each action | Agents improve within the episode |
|
||||
| **Supervisor/Subagent** | One agent orchestrates others | Parallel decomposition, synthesis |
|
||||
| **Debate/Ensemble** | Multiple agents reason independently | Diverse perspectives, reconciliation |
|
||||
| **LLM-as-router** | No reasoning loop, pure dispatch | Fast classification and routing |
|
||||
|
||||
Not all of these need to be implemented at once. The architecture should
|
||||
support them; the initial implementation delivers ReACT (already exists),
|
||||
Plan-then-Execute, and Supervisor/Subagent.
|
||||
|
||||
### Pattern Storage
|
||||
|
||||
Patterns are stored as configuration items via the config API. They are
|
||||
finite in number, mechanically well-defined, have enumerable properties,
|
||||
and change slowly. Each pattern is a JSON object stored under the
|
||||
`agent-pattern` config type.
|
||||
|
||||
```json
|
||||
Config type: "agent-pattern"
|
||||
Config key: "react"
|
||||
Value: {
|
||||
"name": "react",
|
||||
"description": "ReACT — Reasoning + Acting",
|
||||
"when_to_use": "Adaptive, good for open-ended tasks"
|
||||
}
|
||||
```
|
||||
|
||||
These are written at deployment time and change rarely. If the architecture
|
||||
later benefits from graph-based pattern storage (e.g. for richer ontological
|
||||
relationships), the config items can be migrated to graph nodes — the
|
||||
meta-router's selection logic is the same regardless of backend.
|
||||
|
||||
---
|
||||
|
||||
## Task Types
|
||||
|
||||
### What a Task Type Represents
|
||||
|
||||
A **task type** characterises the problem domain — what the agent is being
|
||||
asked to accomplish, and how a domain expert would frame it analytically.
|
||||
|
||||
- Carries domain-specific methodology (e.g. "intelligence analysis always
|
||||
applies structured analytic techniques")
|
||||
- Pre-populates initial reasoning context via a framing prompt
|
||||
- Constrains which patterns are valid for this class of problem
|
||||
- Can define domain-specific termination criteria
|
||||
|
||||
### Identification
|
||||
|
||||
Task types are identified from plain-text task descriptions by the
|
||||
LLM. Building a formal ontology over task descriptions is premature — natural
|
||||
language is too varied and context-dependent. The LLM reads the description;
|
||||
the graph provides the structure downstream.
|
||||
|
||||
### Task Type Storage
|
||||
|
||||
Task types are stored as configuration items via the config API under the
|
||||
`agent-task-type` config type. Each task type is a JSON object that
|
||||
references valid patterns by name.
|
||||
|
||||
```json
|
||||
Config type: "agent-task-type"
|
||||
Config key: "risk-assessment"
|
||||
Value: {
|
||||
"name": "risk-assessment",
|
||||
"description": "Due Diligence / Risk Assessment",
|
||||
"framing_prompt": "Analyse across financial, reputational, legal and operational dimensions using structured analytic techniques.",
|
||||
"valid_patterns": ["supervisor", "plan-then-execute", "react"],
|
||||
"when_to_use": "Multi-dimensional analysis requiring structured assessment"
|
||||
}
|
||||
```
|
||||
|
||||
The `valid_patterns` list defines the constrained decision space — the LLM
|
||||
can only select patterns that the task type's configuration says are valid.
|
||||
This is the many-to-many relationship between task types and patterns,
|
||||
expressed as configuration rather than graph edges.
|
||||
|
||||
### Selection Flow
|
||||
|
||||
```
|
||||
Task Description (plain text, from AgentRequest.question)
|
||||
│
|
||||
│ [LLM interprets, constrained by available task types from config]
|
||||
▼
|
||||
Task Type (config item — domain framing and methodology)
|
||||
│
|
||||
│ [config lookup — valid_patterns list]
|
||||
▼
|
||||
Pattern Candidates (config items)
|
||||
│
|
||||
│ [LLM selects within constrained set,
|
||||
│ informed by task description signals:
|
||||
│ complexity, urgency, scope]
|
||||
▼
|
||||
Selected Pattern
|
||||
```
|
||||
|
||||
The task description may carry modulating signals (complexity, urgency, scope)
|
||||
that influence which pattern is selected within the constrained set. But the
|
||||
raw description never directly selects a pattern — it always passes through
|
||||
the task type layer first.
|
||||
|
||||
---
|
||||
|
||||
## Explainability Through Constrained Decision Spaces
|
||||
|
||||
A central principle of TrustGraph's explainability architecture is that
|
||||
**explainability comes from constrained decision spaces**.
|
||||
|
||||
When a decision is made from an unconstrained space — a raw LLM call with no
|
||||
guardrails — the reasoning is opaque even if the LLM produces a rationale,
|
||||
because that rationale is post-hoc and unverifiable.
|
||||
|
||||
When a decision is made from a **constrained set defined in configuration**,
|
||||
you can always answer:
|
||||
- What valid options were available
|
||||
- What criteria narrowed the set
|
||||
- What signal made the final selection within that set
|
||||
|
||||
This principle already governs the existing decision trace architecture and
|
||||
extends naturally to pattern selection. The routing decision — which task type
|
||||
and which pattern — is itself recorded as a provenance node, making the first
|
||||
decision in the execution trace auditable.
|
||||
|
||||
**Trust becomes a structural property of the architecture, not a claimed
|
||||
property of the model.**
|
||||
|
||||
---
|
||||
|
||||
## Orchestration Architecture
|
||||
|
||||
### The Meta-Router
|
||||
|
||||
The meta-router is the entry point for all agent requests. It runs as a
|
||||
pre-processing step before the pattern-specific iteration loop begins. Its
|
||||
job is to determine the task type and select the execution pattern.
|
||||
|
||||
**When it runs**: On receipt of an `AgentRequest` with empty history (i.e. a
|
||||
new task, not a continuation). Requests with non-empty history are already
|
||||
mid-iteration and bypass the meta-router.
|
||||
|
||||
**What it does**:
|
||||
|
||||
1. Lists all available task types from the config API
|
||||
(`config.list("agent-task-type")`).
|
||||
2. Presents these to the LLM alongside the task description. The LLM
|
||||
identifies which task type applies (or "general" as a fallback).
|
||||
3. Reads the selected task type's configuration to get the `valid_patterns`
|
||||
list.
|
||||
4. Loads the candidate pattern definitions from config and presents them to
|
||||
the LLM. The LLM selects one, influenced by signals in the task
|
||||
description (complexity, number of independent dimensions, urgency).
|
||||
5. Records the routing decision as a provenance node (see Provenance Model
|
||||
below).
|
||||
6. Populates the `AgentRequest` with the selected pattern, task type framing
|
||||
prompt, and any pattern-specific configuration, then emits it onto the
|
||||
queue.
|
||||
|
||||
**Where it lives**: The meta-router is a phase within the agent-orchestrator,
|
||||
not a separate service. The agent-orchestrator is a new executable that
|
||||
uses the same service identity as the existing agent-manager-react, making
|
||||
it a drop-in replacement on the same Pulsar queues. It includes the full
|
||||
ReACT implementation alongside the new orchestration patterns. The
|
||||
distinction between "route" and "iterate" is determined by whether the
|
||||
request already has a pattern set.
|
||||
|
||||
### Pattern Dispatch
|
||||
|
||||
Once the meta-router has annotated the request with a pattern, the agent
|
||||
manager dispatches to the appropriate pattern implementation. This is a
|
||||
straightforward branch on the pattern field:
|
||||
|
||||
```
|
||||
request arrives
|
||||
│
|
||||
├── history is empty → meta-router → annotate with pattern → re-emit
|
||||
│
|
||||
└── history is non-empty (or pattern is set)
|
||||
│
|
||||
├── pattern = "react" → ReACT iteration
|
||||
├── pattern = "plan-then-execute" → PtE iteration
|
||||
├── pattern = "supervisor" → Supervisor iteration
|
||||
└── (no pattern) → ReACT iteration (default)
|
||||
```
|
||||
|
||||
Each pattern implementation follows the same contract: receive a request, do
|
||||
one iteration of work, then either emit a "next" message (continue) or emit a
|
||||
response (done). The self-queuing loop doesn't change.
|
||||
|
||||
### Pattern Implementations
|
||||
|
||||
#### ReACT (Existing)
|
||||
|
||||
No changes. The existing `AgentManager.react()` path continues to work
|
||||
as-is.
|
||||
|
||||
#### Plan-then-Execute
|
||||
|
||||
Two-phase pattern:
|
||||
|
||||
**Planning phase** (first iteration):
|
||||
- LLM receives the question plus task type framing.
|
||||
- Produces a structured plan: an ordered list of steps, each with a goal,
|
||||
expected tool, and dependencies on prior steps.
|
||||
- The plan is recorded in the history as a special "plan" step.
|
||||
- Emits a "next" message to begin execution.
|
||||
|
||||
**Execution phase** (subsequent iterations):
|
||||
- Reads the plan from history.
|
||||
- Identifies the next unexecuted step.
|
||||
- Executes that step (tool call + observation), similar to a single ReACT
|
||||
action.
|
||||
- Records the result against the plan step.
|
||||
- If all steps complete, synthesises a final answer.
|
||||
- If a step fails or produces unexpected results, the LLM can revise the
|
||||
remaining plan (bounded re-planning, not a full restart).
|
||||
|
||||
The plan lives in the history, so it travels with the message. No external
|
||||
state is needed.
|
||||
|
||||
#### Supervisor/Subagent
|
||||
|
||||
The supervisor pattern introduces fan-out and fan-in. This is the most
|
||||
architecturally significant addition.
|
||||
|
||||
**Supervisor planning iteration**:
|
||||
- LLM receives the question plus task type framing.
|
||||
- Decomposes the task into independent subagent goals.
|
||||
- For each subagent, emits a new `AgentRequest` with:
|
||||
- A focused question (the subagent's specific goal)
|
||||
- A shared correlation ID tying it to the parent task
|
||||
- The subagent's own pattern (typically ReACT, but could be anything)
|
||||
- Relevant context sliced from the parent request
|
||||
|
||||
**Subagent execution**:
|
||||
- Each subagent request is picked up by an agent manager instance and runs its
|
||||
own independent iteration loop.
|
||||
- Subagents are ordinary agent executions — they self-queue, use tools, emit
|
||||
provenance, stream feedback.
|
||||
- When a subagent reaches a Final answer, it writes a completion record to the
|
||||
knowledge graph under the shared correlation ID.
|
||||
|
||||
**Fan-in and synthesis**:
|
||||
- An aggregator detects when all sibling subagents for a correlation ID have
|
||||
completed.
|
||||
- It emits a synthesis request to the supervisor carrying the correlation ID.
|
||||
- The supervisor queries the graph for subagent results, reasons across them,
|
||||
and decides whether to emit a final answer or iterate again.
|
||||
|
||||
**Supervisor re-iteration**:
|
||||
- After synthesis, the supervisor may determine that the results are
|
||||
incomplete, contradictory, or reveal gaps requiring further investigation.
|
||||
- Rather than emitting a final answer, it can fan out again with new or
|
||||
refined subagent goals under a new correlation ID. This is the same
|
||||
self-queuing loop — the supervisor emits new subagent requests and stops,
|
||||
the aggregator detects completion, and synthesis runs again.
|
||||
- The supervisor's iteration count (planning + synthesis rounds) is bounded
|
||||
to prevent unbounded looping.
|
||||
|
||||
This is detailed further in the Fan-Out / Fan-In section below.
|
||||
|
||||
---
|
||||
|
||||
## Message Schema Evolution
|
||||
|
||||
### Shared Schema Principle
|
||||
|
||||
The `AgentRequest` and `AgentResponse` schemas are the shared contract
|
||||
between the agent-manager (existing ReACT execution) and the
|
||||
agent-orchestrator (meta-routing, supervisor, plan-then-execute). Both
|
||||
services consume from the same *agent request* topic using the same
|
||||
schema. Any schema changes must be reflected in both — the schema is
|
||||
the integration point, not the service implementation.
|
||||
|
||||
This means the orchestrator does not introduce separate message types for
|
||||
its own use. Subagent requests, synthesis triggers, and meta-router
|
||||
outputs are all `AgentRequest` messages with different field values. The
|
||||
agent-manager ignores orchestration fields it doesn't use.
|
||||
|
||||
### New Fields
|
||||
|
||||
The `AgentRequest` schema needs new fields to carry orchestration
|
||||
metadata.
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentRequest:
|
||||
# Existing fields (unchanged)
|
||||
question: str = ""
|
||||
state: str = ""
|
||||
group: list[str] | None = None
|
||||
history: list[AgentStep] = field(default_factory=list)
|
||||
user: str = ""
|
||||
collection: str = "default"
|
||||
streaming: bool = False
|
||||
session_id: str = ""
|
||||
|
||||
# New orchestration fields
|
||||
conversation_id: str = "" # Optional caller-generated ID grouping related requests
|
||||
pattern: str = "" # "react", "plan-then-execute", "supervisor", ""
|
||||
task_type: str = "" # Identified task type name
|
||||
framing: str = "" # Task type framing prompt injected into LLM context
|
||||
correlation_id: str = "" # Shared ID linking subagents to parent
|
||||
parent_session_id: str = "" # Parent's session_id (for subagents)
|
||||
subagent_goal: str = "" # Focused goal for this subagent
|
||||
expected_siblings: int = 0 # How many sibling subagents exist
|
||||
```
|
||||
|
||||
The `AgentStep` schema also extends to accommodate non-ReACT iteration types:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
# Existing fields (unchanged)
|
||||
thought: str = ""
|
||||
action: str = ""
|
||||
arguments: dict[str, str] = field(default_factory=dict)
|
||||
observation: str = ""
|
||||
user: str = ""
|
||||
|
||||
# New fields
|
||||
step_type: str = "" # "react", "plan", "execute", "supervise", "synthesise"
|
||||
plan: list[PlanStep] | None = None # For plan-then-execute: the structured plan
|
||||
subagent_results: dict | None = None # For supervisor: collected subagent outputs
|
||||
```
|
||||
|
||||
The `PlanStep` structure for Plan-then-Execute:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PlanStep:
|
||||
goal: str = "" # What this step should accomplish
|
||||
tool_hint: str = "" # Suggested tool (advisory, not binding)
|
||||
depends_on: list[int] = field(default_factory=list) # Indices of prerequisite steps
|
||||
status: str = "pending" # "pending", "complete", "failed", "revised"
|
||||
result: str = "" # Observation from execution
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Fan-Out and Fan-In
|
||||
|
||||
### Why This Matters
|
||||
|
||||
Fan-out is the mechanism that makes multi-agent coordination genuinely
|
||||
parallel rather than simulated. With Pulsar, emitting multiple messages means
|
||||
multiple consumers can pick them up concurrently. This is not threading or
|
||||
async simulation — it is real distributed parallelism across agent manager
|
||||
instances.
|
||||
|
||||
### Fan-Out: Supervisor Emits Subagent Requests
|
||||
|
||||
When a supervisor iteration decides to decompose a task, it:
|
||||
|
||||
1. Generates a **correlation ID** — a UUID that groups the sibling subagents.
|
||||
2. For each subagent, constructs a new `AgentRequest`:
|
||||
- `question` = the subagent's focused goal (from `subagent_goal`)
|
||||
- `correlation_id` = the shared correlation ID
|
||||
- `parent_session_id` = the supervisor's session_id
|
||||
- `pattern` = typically "react", but the supervisor can specify any pattern
|
||||
- `session_id` = a new unique ID for this subagent's own provenance chain
|
||||
- `expected_siblings` = total number of sibling subagents
|
||||
- `history` = empty (fresh start, but framing context inherited)
|
||||
- `group`, `user`, `collection` = inherited from parent
|
||||
3. Emits each subagent request onto the agent request topic.
|
||||
4. Records the fan-out decision in the provenance graph (see below).
|
||||
|
||||
The supervisor then **stops**. It does not wait. It does not poll. It has
|
||||
emitted its messages and its iteration is complete. The graph and the
|
||||
aggregator handle the rest.
|
||||
|
||||
### Fan-In: Graph-Based Completion Detection
|
||||
|
||||
When a subagent reaches its Final answer, it writes a **completion node** to
|
||||
the knowledge graph:
|
||||
|
||||
```
|
||||
Completion node:
|
||||
rdf:type tg:SubagentCompletion
|
||||
tg:correlationId <shared correlation ID>
|
||||
tg:subagentSessionId <this subagent's session_id>
|
||||
tg:parentSessionId <supervisor's session_id>
|
||||
tg:subagentGoal <what this subagent was asked to do>
|
||||
tg:result → <document URI in librarian>
|
||||
prov:wasGeneratedBy → <this subagent's conclusion entity>
|
||||
```
|
||||
|
||||
The **aggregator** is a component that watches for completion nodes. When it
|
||||
detects that all expected siblings for a correlation ID have written
|
||||
completion nodes, it:
|
||||
|
||||
1. Collects all sibling results from the graph and librarian.
|
||||
2. Constructs a **synthesis request** — a new `AgentRequest` addressed to the supervisor flow:
|
||||
- `session_id` = the original supervisor's session_id
|
||||
- `pattern` = "supervisor"
|
||||
- `step_type` = "synthesise" (carried in history)
|
||||
- `subagent_results` = the collected findings
|
||||
- `history` = the supervisor's history up to the fan-out point, plus the synthesis step
|
||||
3. Emits this onto the agent request topic.
|
||||
|
||||
The supervisor picks this up, reasons across the aggregated findings, and
|
||||
produces its final answer.
|
||||
|
||||
### Aggregator Design
|
||||
|
||||
The aggregator is event-driven, consistent with TrustGraph's Pulsar-based
|
||||
architecture. Polling would be an anti-pattern in a system where all
|
||||
coordination is message-driven.
|
||||
|
||||
**Mechanism**: The aggregator is a Pulsar consumer on the explainability
|
||||
topic. Subagent completion nodes are emitted as triples on this topic as
|
||||
part of the existing provenance flow. When the aggregator receives a
|
||||
`tg:SubagentCompletion` triple, it:
|
||||
|
||||
1. Extracts the `tg:correlationId` from the completion node.
|
||||
2. Queries the graph to count how many siblings for that correlation ID
|
||||
have completed.
|
||||
3. If all `expected_siblings` are present, triggers fan-in immediately —
|
||||
collects results and emits the synthesis request.
|
||||
|
||||
**State**: The aggregator is stateless in the same sense as the agent
|
||||
manager — it holds no essential in-memory state. The graph is the source
|
||||
of truth for completion counts. If the aggregator restarts, it can
|
||||
re-process unacknowledged completion messages from Pulsar and re-check the
|
||||
graph. No coordination state is lost.
|
||||
|
||||
**Consistency**: Because the completion check queries the graph rather than
|
||||
relying on an in-memory counter, the aggregator is tolerant of duplicate
|
||||
messages, out-of-order delivery, and restarts. The graph query is
|
||||
idempotent — asking "are all siblings complete?" gives the same answer
|
||||
regardless of how many times or in what order the events arrive.
|
||||
|
||||
### Timeout and Failure
|
||||
|
||||
- **Subagent timeout**: The aggregator records the timestamp of the first
|
||||
sibling completion (from the graph). A periodic timeout check (the one
|
||||
concession to polling — but over local state, not the graph) detects
|
||||
stalled correlation IDs. If `expected_siblings` completions are not
|
||||
reached within a configurable timeout, the aggregator emits a partial
|
||||
synthesis request with whatever results are available, flagging the
|
||||
incomplete subagents.
|
||||
- **Subagent failure**: If a subagent errors out, it writes an error
|
||||
completion node (with `tg:status = "error"` and an error message). The
|
||||
aggregator treats this as a completion — the supervisor receives the error
|
||||
in its synthesis input and can reason about partial results.
|
||||
- **Supervisor iteration limit**: The supervisor's own iteration count
|
||||
(planning + synthesis) is bounded by `max_iterations` just like any other
|
||||
pattern.
|
||||
|
||||
---
|
||||
|
||||
## Provenance Model Extensions
|
||||
|
||||
### Routing Decision
|
||||
|
||||
The meta-router's task type and pattern selection is recorded as the first
|
||||
provenance node in the session:
|
||||
|
||||
```
|
||||
Routing node:
|
||||
rdf:type prov:Entity, tg:RoutingDecision
|
||||
prov:wasGeneratedBy → session (Question) activity
|
||||
tg:taskType → TaskType node URI
|
||||
tg:selectedPattern → Pattern node URI
|
||||
tg:candidatePatterns → [Pattern node URIs] (what was available)
|
||||
tg:routingRationale → document URI in librarian (LLM's reasoning)
|
||||
```
|
||||
|
||||
This captures the constrained decision space: what candidates existed, which
|
||||
was selected, and why. The candidates are graph-derived; the rationale is
|
||||
LLM-generated but verifiable against the candidates.
|
||||
|
||||
### Fan-Out Provenance
|
||||
|
||||
When a supervisor fans out, the provenance records the decomposition:
|
||||
|
||||
```
|
||||
FanOut node:
|
||||
rdf:type prov:Entity, tg:FanOut
|
||||
prov:wasDerivedFrom → supervisor's routing or planning iteration
|
||||
tg:correlationId <correlation ID>
|
||||
tg:subagentGoals → [document URIs for each subagent goal]
|
||||
tg:expectedSiblings <count>
|
||||
```
|
||||
|
||||
Each subagent's provenance chain is independent (its own session, iterations,
|
||||
conclusion) but linked back to the parent via:
|
||||
|
||||
```
|
||||
Subagent session:
|
||||
rdf:type prov:Activity, tg:Question, tg:AgentQuestion
|
||||
tg:parentCorrelationId <correlation ID>
|
||||
tg:parentSessionId <supervisor session URI>
|
||||
```
|
||||
|
||||
### Fan-In Provenance
|
||||
|
||||
The synthesis step links back to all subagent conclusions:
|
||||
|
||||
```
|
||||
Synthesis node:
|
||||
rdf:type prov:Entity, tg:Synthesis
|
||||
prov:wasDerivedFrom → [all subagent Conclusion entities]
|
||||
tg:correlationId <correlation ID>
|
||||
```
|
||||
|
||||
This creates a DAG in the provenance graph: the supervisor's routing fans out
|
||||
to N parallel subagent chains, which fan back in to a synthesis node. The
|
||||
entire multi-agent execution is traceable from a single correlation ID.
|
||||
|
||||
### URI Scheme
|
||||
|
||||
Extending the existing `urn:trustgraph:agent:{session_id}` pattern:
|
||||
|
||||
| Entity | URI Pattern |
|
||||
|---|---|
|
||||
| Session (existing) | `urn:trustgraph:agent:{session_id}` |
|
||||
| Iteration (existing) | `urn:trustgraph:agent:{session_id}/i{n}` |
|
||||
| Conclusion (existing) | `urn:trustgraph:agent:{session_id}/answer` |
|
||||
| Routing decision | `urn:trustgraph:agent:{session_id}/routing` |
|
||||
| Fan-out record | `urn:trustgraph:agent:{session_id}/fanout/{correlation_id}` |
|
||||
| Subagent completion | `urn:trustgraph:agent:{session_id}/completion` |
|
||||
|
||||
---
|
||||
|
||||
## Storage Responsibilities
|
||||
|
||||
Pattern and task type definitions live in the config API. Runtime state and
|
||||
provenance live in the knowledge graph. The division is:
|
||||
|
||||
| Role | Storage | When Written | Content |
|
||||
|---|---|---|---|
|
||||
| Pattern definitions | Config API | At design time | Pattern properties, descriptions |
|
||||
| Task type definitions | Config API | At design time | Domain framing, valid pattern lists |
|
||||
| Routing decision trace | Knowledge graph | At request arrival | Why this task type and pattern were selected |
|
||||
| Iteration decision trace | Knowledge graph | During execution | Each think/act/observe cycle, per existing model |
|
||||
| Fan-out coordination | Knowledge graph | During fan-out | Subagent goals, correlation ID, expected count |
|
||||
| Subagent completion | Knowledge graph | During fan-in | Per-subagent results under shared correlation ID |
|
||||
| Execution audit trail | Knowledge graph | Post-execution | Full multi-agent reasoning trace as a DAG |
|
||||
|
||||
The config API holds the definitions that constrain decisions. The knowledge
|
||||
graph holds the runtime decisions and their provenance. The fan-in
|
||||
coordination state is part of the provenance automatically — subagent
|
||||
completion nodes are both coordination signals and audit trail entries.
|
||||
|
||||
---
|
||||
|
||||
## Worked Example: Partner Risk Assessment
|
||||
|
||||
**Request**: "Assess the risk profile of Company X as a potential partner"
|
||||
|
||||
**1. Request arrives** on the *agent request* topic with empty history.
|
||||
The agent manager picks it up.
|
||||
|
||||
**2. Meta-router**:
|
||||
- Queries config API, finds task types: *Risk Assessment*, *Research*,
|
||||
*Summarisation*, *General*.
|
||||
- LLM identifies *Risk Assessment*. Framing prompt loaded: "analyse across
|
||||
financial, reputational, legal and operational dimensions using structured
|
||||
analytic techniques."
|
||||
- Valid patterns for *Risk Assessment*: [*Supervisor/Subagent*,
|
||||
*Plan-then-Execute*, *ReACT*].
|
||||
- LLM selects *Supervisor/Subagent* — task has four independent investigative
|
||||
dimensions, well-suited to parallel decomposition.
|
||||
- Routing decision written to graph. Request re-emitted on the
|
||||
*agent request* topic with `pattern="supervisor"`, framing populated.
|
||||
|
||||
**3. Supervisor iteration** (picked up from *agent request* topic):
|
||||
- LLM receives question + framing. Reasons that four independent investigative
|
||||
threads are required.
|
||||
- Generates correlation ID `corr-abc123`.
|
||||
- Emits four subagent requests on the *agent request* topic:
|
||||
- Financial analysis (pattern="react", subagent_goal="Analyse financial
|
||||
health and stability of Company X")
|
||||
- Legal analysis (pattern="react", subagent_goal="Review regulatory filings,
|
||||
sanctions, and legal exposure for Company X")
|
||||
- Reputational analysis (pattern="react", subagent_goal="Analyse news
|
||||
sentiment and public reputation of Company X")
|
||||
- Operational analysis (pattern="react", subagent_goal="Assess supply chain
|
||||
dependencies and operational risks for Company X")
|
||||
- Fan-out node written to graph.
|
||||
|
||||
**4. Four subagents run in parallel** (each picked up from the *agent
|
||||
request* topic by agent manager instances), each as an independent ReACT
|
||||
loop:
|
||||
- Financial — queries financial data services and knowledge graph
|
||||
relationships
|
||||
- Legal — searches regulatory filings and sanctions lists
|
||||
- Reputational — searches news, analyses sentiment
|
||||
- Operational — queries supply chain databases
|
||||
|
||||
Each self-queues its iterations on the *agent request* topic. Each writes
|
||||
its own decision trace to the graph as it progresses. Each completes
|
||||
independently.
|
||||
|
||||
**5. Fan-in**:
|
||||
- Each subagent writes a `tg:SubagentCompletion` node to the graph on
|
||||
completion, emitted on the *explainability* topic. The completion node
|
||||
references the subagent's result document in the librarian.
|
||||
- Aggregator (consuming the *explainability* topic) sees each completion
|
||||
event. It queries the graph for the fan-out node to get the expected
|
||||
sibling count, then checks how many completions exist for
|
||||
`corr-abc123`.
|
||||
- When all four siblings are complete, the aggregator emits a synthesis
|
||||
request on the *agent request* topic with the correlation ID. It does
|
||||
not fetch or bundle subagent results — the supervisor will query the
|
||||
graph for those.
|
||||
|
||||
**6. Supervisor synthesis** (picked up from *agent request* topic):
|
||||
- Receives the synthesis trigger carrying the correlation ID.
|
||||
- Queries the graph for `tg:SubagentCompletion` nodes under
|
||||
`corr-abc123`, retrieving each subagent's goal and result document
|
||||
reference.
|
||||
- Fetches the result documents from the librarian.
|
||||
- Reasons across all four dimensions, produces a structured risk
|
||||
assessment with confidence scores.
|
||||
- Emits final answer on the *agent response* topic and writes conclusion
|
||||
provenance to the graph.
|
||||
|
||||
**7. Response delivered** — the supervisor's synthesis streams on the
|
||||
*agent response* topic as the LLM generates it, with `end_of_dialog`
|
||||
on the final chunk. The collated answer is saved to the librarian and
|
||||
referenced from conclusion provenance in the graph. The graph now holds
|
||||
a complete, human-readable trace of the entire multi-agent execution —
|
||||
from pattern selection through four parallel investigations to final
|
||||
synthesis.
|
||||
|
||||
---
|
||||
|
||||
## Class Hierarchy
|
||||
|
||||
The agent-orchestrator executable (`agent-orchestrator`) uses the same
|
||||
service identity as agent-manager-react, making it a drop-in replacement.
|
||||
The pattern dispatch model suggests a class hierarchy where shared iteration
|
||||
infrastructure lives in a base class and pattern-specific logic is in
|
||||
subclasses:
|
||||
|
||||
```
|
||||
AgentService (base — Pulsar consumer/producer specs, request handling)
|
||||
│
|
||||
└── Processor (agent-orchestrator service)
|
||||
│
|
||||
├── MetaRouter — task type identification, pattern selection
|
||||
│
|
||||
├── PatternBase — shared: tool filtering, provenance, streaming, history
|
||||
│ ├── ReactPattern — existing ReACT logic (extract from current AgentManager)
|
||||
│ ├── PlanThenExecutePattern — plan phase + execute phase
|
||||
│ └── SupervisorPattern — fan-out, synthesis
|
||||
│
|
||||
└── Aggregator — fan-in completion detection
|
||||
```
|
||||
|
||||
`PatternBase` captures what is currently spread across `Processor` and
|
||||
`AgentManager`: tool filtering, LLM invocation, provenance triple emission,
|
||||
streaming callbacks, history management. The pattern subclasses implement only
|
||||
the decision logic specific to their execution strategy — what to do with the
|
||||
LLM output, when to terminate, whether to fan out.
|
||||
|
||||
This refactoring is not strictly necessary for the first iteration — the
|
||||
meta-router and pattern dispatch could be added as branches within the
|
||||
existing `Processor.agent_request()` method. But the class hierarchy clarifies
|
||||
where shared vs. pattern-specific logic lives and will prevent duplication as
|
||||
more patterns are added.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Config API Seeding
|
||||
|
||||
Pattern and task type definitions are stored via the config API and need to
|
||||
be seeded at deployment time. This is analogous to how flow blueprints and
|
||||
parameter types are loaded — a bootstrap step that writes the initial
|
||||
configuration.
|
||||
|
||||
The initial seed includes:
|
||||
|
||||
**Patterns** (config type `agent-pattern`):
|
||||
- `react` — interleaved reasoning and action
|
||||
- `plan-then-execute` — structured plan followed by step execution
|
||||
- `supervisor` — decomposition, fan-out to subagents, synthesis
|
||||
|
||||
**Task types** (config type `agent-task-type`, initial set, expected to grow):
|
||||
- `general` — no specific domain framing, all patterns valid
|
||||
- `research` — open-ended investigation, valid patterns: react, plan-then-execute
|
||||
- `risk-assessment` — multi-dimensional analysis, valid patterns: supervisor,
|
||||
plan-then-execute, react
|
||||
- `summarisation` — condense information, valid patterns: react
|
||||
|
||||
The seed data is configuration, not code. It can be extended via the config
|
||||
API (or the configuration UI) without redeploying the agent manager.
|
||||
|
||||
### Migration Path
|
||||
|
||||
The config API provides a practical starting point. If richer ontological
|
||||
relationships between patterns, task types, and domain knowledge become
|
||||
valuable, the definitions can be migrated to graph storage. The meta-router's
|
||||
selection logic queries an abstract set of task types and patterns — the
|
||||
storage backend is an implementation detail.
|
||||
|
||||
### Fallback Behaviour
|
||||
|
||||
If the config contains no patterns or task types:
|
||||
- Task type defaults to `general`.
|
||||
- Pattern defaults to `react`.
|
||||
- The system degrades gracefully to existing behaviour.
|
||||
|
||||
---
|
||||
|
||||
## Design Decisions
|
||||
|
||||
| Decision | Resolution | Rationale |
|
||||
|---|---|---|
|
||||
| Task type identification | LLM interprets from plain text | Natural language too varied to formalise prematurely |
|
||||
| Pattern/task type storage | Config API initially, graph later if needed | Avoids graph model complexity upfront; config API already has UI support; migration path is straightforward |
|
||||
| Meta-router location | Phase within agent manager, not separate service | Avoids an extra network hop; routing is fast |
|
||||
| Fan-in mechanism | Event-driven via explainability topic | Consistent with Pulsar-based architecture; graph query for completion count is idempotent and restart-safe |
|
||||
| Aggregator deployment | Separate lightweight process | Decoupled from agent manager lifecycle |
|
||||
| Subagent pattern selection | Supervisor specifies per-subagent | Supervisor has task context to make this choice |
|
||||
| Plan storage | In message history | No external state needed; plan travels with message |
|
||||
| Default pattern | Empty pattern field → ReACT | Sensible default when meta-router is not configured |
|
||||
|
||||
---
|
||||
|
||||
## Streaming Protocol
|
||||
|
||||
### Current Model
|
||||
|
||||
The existing agent response schema has two levels:
|
||||
|
||||
- **`end_of_message`** — marks the end of a complete thought, observation,
|
||||
or answer. Chunks belonging to the same message arrive sequentially.
|
||||
- **`end_of_dialog`** — marks the end of the entire agent execution. No
|
||||
more messages will follow.
|
||||
|
||||
This works because the current system produces messages serially — one
|
||||
thought at a time, one agent at a time.
|
||||
|
||||
### Fan-Out Breaks Serial Assumptions
|
||||
|
||||
With supervisor/subagent fan-out, multiple subagents stream chunks
|
||||
concurrently on the same *agent response* topic. The caller receives
|
||||
interleaved chunks from different sources and needs to demultiplex them.
|
||||
|
||||
### Resolution: Message ID
|
||||
|
||||
Each chunk carries a `message_id` — a per-message UUID generated when
|
||||
the agent begins streaming a new thought, observation, or answer. The
|
||||
caller groups chunks by `message_id` and assembles each message
|
||||
independently.
|
||||
|
||||
```
|
||||
Response chunk fields:
|
||||
message_id UUID for this message (groups chunks)
|
||||
session_id Which agent session produced this chunk
|
||||
chunk_type "thought" | "observation" | "answer" | ...
|
||||
content The chunk text
|
||||
end_of_message True on the final chunk of this message
|
||||
end_of_dialog True on the final message of the entire execution
|
||||
```
|
||||
|
||||
A single subagent emits multiple messages (thought, observation, thought,
|
||||
answer), each with a distinct `message_id`. The `session_id` identifies
|
||||
which subagent the message belongs to. The caller can display, group, or
|
||||
filter by either.
|
||||
|
||||
### Provenance Trigger
|
||||
|
||||
`end_of_message` is the trigger for provenance storage. When a complete
|
||||
message has been assembled from its chunks:
|
||||
|
||||
1. The collated text is saved to the librarian as a single document.
|
||||
2. A provenance node is written to the graph referencing the document URI.
|
||||
|
||||
This follows the pattern established by GraphRAG, where streaming synthesis
|
||||
chunks are delivered live but the stored provenance references the collated
|
||||
answer text. Streaming is for the caller; provenance needs complete messages.
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **Re-planning depth** (resolved): Runtime parameter on the
|
||||
agent-orchestrator executable, default 2. Bounds how many times
|
||||
Plan-then-Execute can revise its plan before forcing termination.
|
||||
- **Nested fan-out** (phase B): A subagent can itself be a supervisor
|
||||
that fans out further. The architecture supports this — correlation IDs
|
||||
are independent and the aggregator is stateless. The protocols and
|
||||
message schema should not preclude nested fan-out, but implementation
|
||||
is deferred. Depth limits will need to be enforced to prevent runaway
|
||||
decomposition.
|
||||
- **Task type evolution** (resolved): Manually curated for now. See
|
||||
Future Directions below for automated discovery.
|
||||
- **Cost attribution** (deferred): Costs are measured at the
|
||||
text-completion queue level as they are today. Per-request attribution
|
||||
across subagents is not yet implemented and is not a blocker for
|
||||
orchestration.
|
||||
- **Conversation ID** (resolved): An optional `conversation_id` field on
|
||||
`AgentRequest`, generated by the caller. When present, all objects
|
||||
created during the execution (provenance nodes, librarian documents,
|
||||
subagent completion records) are tagged with the conversation ID. This
|
||||
enables querying all interactions in a conversation with a single
|
||||
lookup, and provides the foundation for conversation-scoped memory.
|
||||
No explicit open/close — the first request with a new conversation ID
|
||||
implicitly starts the conversation. Omit for one-shot queries.
|
||||
- **Tool scoping per subagent** (resolved): Subagents inherit the
|
||||
parent's tool group by default. The supervisor can optionally override
|
||||
the group per subagent to constrain capabilities (e.g. financial
|
||||
subagent gets only financial tools). The `group` field on
|
||||
`AgentRequest` already supports this — the supervisor just sets it
|
||||
when constructing subagent requests.
|
||||
|
||||
---
|
||||
|
||||
## Future Directions
|
||||
|
||||
### Automated Task Type Discovery
|
||||
|
||||
Task types are manually curated in the initial implementation. However,
|
||||
the architecture is well-suited to automated discovery because all agent
|
||||
requests and their execution traces flow through Pulsar topics. A
|
||||
learning service could consume these messages and analyse patterns in
|
||||
how tasks are framed, which patterns are selected, and how successfully
|
||||
they execute. Over time, it could propose new task types based on
|
||||
clusters of similar requests that don't map well to existing types, or
|
||||
suggest refinements to framing prompts based on which framings lead to
|
||||
better outcomes. This service would write proposed task types to the
|
||||
config API for human review — automated discovery, manual approval. The
|
||||
agent-orchestrator does not need to change; it always reads task types
|
||||
from config regardless of how they got there.
|
||||
282
docs/tech-specs/config-push-poke.md
Normal file
282
docs/tech-specs/config-push-poke.md
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
# Config Push "Notify" Pattern Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
Replace the current config push mechanism — which broadcasts the full config
|
||||
blob on a `state` class queue — with a lightweight "notify" notification
|
||||
containing only the version number and affected types. Processors that care
|
||||
about those types fetch the full config via the existing request/response
|
||||
interface.
|
||||
|
||||
This solves the RabbitMQ late-subscriber problem: when a process restarts,
|
||||
its fresh queue has no historical messages, so it never receives the current
|
||||
config state. With the notify pattern, the push queue is only a signal — the
|
||||
source of truth is the config service's request/response API, which is
|
||||
always available.
|
||||
|
||||
## Problem
|
||||
|
||||
On Pulsar, `state` class queues are persistent topics. A new subscriber
|
||||
with `InitialPosition.Earliest` reads from message 0 and receives the
|
||||
last config push. On RabbitMQ, each subscriber gets a fresh per-subscriber
|
||||
queue (named with a new UUID). Messages published before the queue existed
|
||||
are gone. A restarting processor never gets the current config.
|
||||
|
||||
## Design
|
||||
|
||||
### The Notify Message
|
||||
|
||||
The `ConfigPush` schema changes from carrying the full config to carrying
|
||||
just a version number and the list of affected config types:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ConfigPush:
|
||||
version: int = 0
|
||||
types: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
When the config service handles a `put` or `delete`, it knows which types
|
||||
were affected (from the request's `values[].type` or `keys[].type`). It
|
||||
includes those in the notify. On startup, the config service sends a notify
|
||||
with an empty types list (meaning "everything").
|
||||
|
||||
### Subscribe-then-Fetch Startup (No Race Condition)
|
||||
|
||||
The critical ordering to avoid missing an update:
|
||||
|
||||
1. **Subscribe** to the config push queue. Buffer incoming notify messages.
|
||||
2. **Fetch** the full config via request/response (`operation: "config"`).
|
||||
This returns the config dict and a version number.
|
||||
3. **Apply** the fetched config to all registered handlers.
|
||||
4. **Process** buffered notifys. For any notify with `version > fetched_version`,
|
||||
re-fetch and re-apply. Discard notifys with `version <= fetched_version`.
|
||||
5. **Enter steady state**. Process future notifys as they arrive.
|
||||
|
||||
This is safe because:
|
||||
- If an update happens before the subscription, the fetch picks it up.
|
||||
- If an update happens between subscribe and fetch, it's in the buffer.
|
||||
- If an update happens after the fetch, it arrives on the queue normally.
|
||||
- Version comparison ensures no duplicate processing.
|
||||
|
||||
### Processor API
|
||||
|
||||
The current API requires processors to understand the full config dict
|
||||
structure. The new API should be cleaner — processors declare which config
|
||||
types they care about and provide a handler that receives only the relevant
|
||||
config subset.
|
||||
|
||||
#### Current API
|
||||
|
||||
```python
|
||||
# In processor __init__:
|
||||
self.register_config_handler(self.on_configure_flows)
|
||||
|
||||
# Handler receives the entire config dict:
|
||||
async def on_configure_flows(self, config, version):
|
||||
if "active-flow" not in config:
|
||||
return
|
||||
if self.id in config["active-flow"]:
|
||||
flow_config = json.loads(config["active-flow"][self.id])
|
||||
# ...
|
||||
```
|
||||
|
||||
#### New API
|
||||
|
||||
```python
|
||||
# In processor __init__:
|
||||
self.register_config_handler(
|
||||
handler=self.on_configure_flows,
|
||||
types=["active-flow"],
|
||||
)
|
||||
|
||||
# Handler receives only the relevant config subset, same signature:
|
||||
async def on_configure_flows(self, config, version):
|
||||
# config still contains the full dict, but handler is only called
|
||||
# when "active-flow" type changes (or on startup)
|
||||
if "active-flow" not in config:
|
||||
return
|
||||
# ...
|
||||
```
|
||||
|
||||
The `types` parameter is optional. If omitted, the handler is called for
|
||||
every config change (backward compatible). If specified, the handler is
|
||||
only invoked when the notify's `types` list intersects with the handler's
|
||||
types, or on startup (empty types list = everything).
|
||||
|
||||
#### Internal Registration Structure
|
||||
|
||||
```python
|
||||
# In AsyncProcessor:
|
||||
def register_config_handler(self, handler, types=None):
|
||||
self.config_handlers.append({
|
||||
"handler": handler,
|
||||
"types": set(types) if types else None, # None = all types
|
||||
})
|
||||
```
|
||||
|
||||
#### Notify Processing Logic
|
||||
|
||||
```python
|
||||
async def on_config_notify(self, message, consumer, flow):
|
||||
notify_version = message.value().version
|
||||
notify_types = set(message.value().types)
|
||||
|
||||
# Skip if we already have this version or newer
|
||||
if notify_version <= self.config_version:
|
||||
return
|
||||
|
||||
# Fetch full config from config service
|
||||
config, version = await self.config_client.config()
|
||||
self.config_version = version
|
||||
|
||||
# Determine which handlers to invoke
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types is None:
|
||||
# Handler cares about everything
|
||||
await entry["handler"](config, version)
|
||||
elif not notify_types or notify_types & handler_types:
|
||||
# notify_types empty = startup (invoke all),
|
||||
# or intersection with handler's types
|
||||
await entry["handler"](config, version)
|
||||
```
|
||||
|
||||
### Config Service Changes
|
||||
|
||||
#### Push Method
|
||||
|
||||
The `push()` method changes to send only version + types:
|
||||
|
||||
```python
|
||||
async def push(self, types=None):
|
||||
version = await self.config.get_version()
|
||||
resp = ConfigPush(
|
||||
version=version,
|
||||
types=types or [],
|
||||
)
|
||||
await self.config_push_producer.send(resp)
|
||||
```
|
||||
|
||||
#### Put/Delete Handlers
|
||||
|
||||
Extract affected types and pass to push:
|
||||
|
||||
```python
|
||||
async def handle_put(self, v):
|
||||
types = list(set(k.type for k in v.values))
|
||||
for k in v.values:
|
||||
await self.table_store.put_config(k.type, k.key, k.value)
|
||||
await self.inc_version()
|
||||
await self.push(types=types)
|
||||
|
||||
async def handle_delete(self, v):
|
||||
types = list(set(k.type for k in v.keys))
|
||||
for k in v.keys:
|
||||
await self.table_store.delete_key(k.type, k.key)
|
||||
await self.inc_version()
|
||||
await self.push(types=types)
|
||||
```
|
||||
|
||||
#### Queue Class Change
|
||||
|
||||
The config push queue changes from `state` class to `flow` class. The push
|
||||
is now a transient signal — the source of truth is the config service's
|
||||
request/response API, not the queue. `flow` class is persistent (survives
|
||||
broker restarts) but doesn't require last-message retention, which was the
|
||||
root cause of the RabbitMQ problem.
|
||||
|
||||
```python
|
||||
config_push_queue = queue('config', cls='flow') # was cls='state'
|
||||
```
|
||||
|
||||
#### Startup Push
|
||||
|
||||
On startup, the config service sends a notify with empty types list
|
||||
(signalling "everything changed"):
|
||||
|
||||
```python
|
||||
async def start(self):
|
||||
await self.push(types=[]) # Empty = all types
|
||||
await self.config_request_consumer.start()
|
||||
```
|
||||
|
||||
### AsyncProcessor Changes
|
||||
|
||||
The `AsyncProcessor` needs a config request/response client alongside the
|
||||
push consumer. The startup sequence becomes:
|
||||
|
||||
```python
|
||||
async def start(self):
|
||||
# 1. Start the push consumer (begins buffering notifys)
|
||||
await self.config_sub_task.start()
|
||||
|
||||
# 2. Fetch current config via request/response
|
||||
config, version = await self.config_client.config()
|
||||
self.config_version = version
|
||||
|
||||
# 3. Apply to all handlers (startup = all handlers invoked)
|
||||
for entry in self.config_handlers:
|
||||
await entry["handler"](config, version)
|
||||
|
||||
# 4. Buffered notifys are now processed by on_config_notify,
|
||||
# which skips versions <= self.config_version
|
||||
```
|
||||
|
||||
The config client needs to be created in `__init__` using the existing
|
||||
request/response queue infrastructure. The `ConfigClient` from
|
||||
`trustgraph.clients.config_client` already exists but uses a synchronous
|
||||
blocking pattern. An async variant or integration with the processor's
|
||||
pub/sub backend is needed.
|
||||
|
||||
### Existing Config Handler Types
|
||||
|
||||
For reference, the config types currently used by handlers:
|
||||
|
||||
| Handler | Type(s) | Used By |
|
||||
|---------|---------|---------|
|
||||
| `on_configure_flows` | `active-flow` | All FlowProcessor subclasses |
|
||||
| `on_collection_config` | `collection` | Storage services (triples, embeddings, rows) |
|
||||
| `on_prompt_config` | `prompt` | Prompt template service, agent extract |
|
||||
| `on_schema_config` | `schema` | Rows storage, row embeddings, NLP query, structured diag |
|
||||
| `on_cost_config` | `token-costs` | Metering service |
|
||||
| `on_ontology_config` | `ontology` | Ontology extraction |
|
||||
| `on_librarian_config` | `librarian` | Librarian service |
|
||||
| `on_mcp_config` | `mcp-tool` | MCP tool service |
|
||||
| `on_knowledge_config` | `kg-core` | Cores service |
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. **Update ConfigPush schema** — change `config` field to `types` field.
|
||||
|
||||
2. **Update config service** — modify `push()` to send version + types.
|
||||
Modify `handle_put`/`handle_delete` to extract affected types.
|
||||
|
||||
3. **Add async config query to AsyncProcessor** — create a
|
||||
request/response client for config queries within the processor's
|
||||
event loop.
|
||||
|
||||
4. **Implement subscribe-then-fetch startup** — reorder
|
||||
`AsyncProcessor.start()` to subscribe first, then fetch, then
|
||||
process buffered notifys with version comparison.
|
||||
|
||||
5. **Update register_config_handler** — add optional `types` parameter.
|
||||
Update `on_config_notify` to filter by type intersection.
|
||||
|
||||
6. **Update existing handlers** — add `types` parameter to all
|
||||
`register_config_handler` calls across the codebase.
|
||||
|
||||
7. **Backward compatibility** — handlers without `types` parameter
|
||||
continue to work (invoked for all changes).
|
||||
|
||||
## Risks
|
||||
|
||||
- **Thundering herd**: if many processors restart simultaneously, they
|
||||
all hit the config service API at once. Mitigated by the config service
|
||||
already being designed for request/response load, and the number of
|
||||
processors being small (tens, not thousands).
|
||||
|
||||
- **Config service availability**: processors now depend on the config
|
||||
service being up at startup, not just having received a push. This is
|
||||
already the case in practice — without config, processors can't do
|
||||
anything useful.
|
||||
551
docs/tech-specs/pubsub-abstraction.md
Normal file
551
docs/tech-specs/pubsub-abstraction.md
Normal file
|
|
@ -0,0 +1,551 @@
|
|||
# Pub/Sub Abstraction: Broker-Independent Messaging
|
||||
|
||||
## Problem
|
||||
|
||||
TrustGraph's messaging infrastructure is deeply coupled to Apache Pulsar in ways that go beyond the transport layer. This coupling creates several concrete problems.
|
||||
|
||||
### 1. Schema system is Pulsar-native
|
||||
|
||||
Every message type in the system is defined as a `pulsar.schema.Record` subclass using Pulsar field types (`String()`, `Integer()`, `Boolean()`, etc.). This means:
|
||||
|
||||
- The `pulsar` Python package is a build dependency for `trustgraph-base`, even though `trustgraph-base` contains no transport logic
|
||||
- Any code that imports a message schema transitively depends on Pulsar
|
||||
- The schema definitions cannot be reused with a different broker without the Pulsar library installed
|
||||
- What's actually happening on the wire is JSON serialisation — the Pulsar schema machinery adds complexity without adding value over plain JSON encode/decode
|
||||
|
||||
### 2. Translators are named after the broker
|
||||
|
||||
The translator layer that converts between internal Python objects and wire format uses methods called `to_pulsar()` and `from_pulsar()`. These are really just JSON encode/decode operations — they have nothing to do with Pulsar specifically. The naming creates a false impression that the translation is broker-specific, when in reality any broker that carries JSON payloads would use identical logic.
|
||||
|
||||
### 3. Queue names use Pulsar URI format
|
||||
|
||||
Queue identifiers throughout the codebase use Pulsar's `persistent://tenant/namespace/topic` or `non-persistent://tenant/namespace/topic` URI format. These are hardcoded in schema definitions and referenced across services. RabbitMQ, Redis Streams, or any other broker would use completely different naming conventions. There is no abstraction between the logical identity of a queue and its broker-specific address.
|
||||
|
||||
### 4. Broker selection is not configurable
|
||||
|
||||
There is no mechanism to select a different pub/sub backend at deployment time. The Pulsar client is instantiated directly in the gateway and via `PulsarClient` in the base processor. Switching to a different broker would require code changes across multiple packages, not a configuration change.
|
||||
|
||||
### 5. Architectural requirements are implicit
|
||||
|
||||
TrustGraph relies on specific pub/sub behaviours — shared subscriptions for load balancing, message acknowledgement for reliability, message properties for correlation — but these requirements are not documented. This makes it difficult to evaluate whether a candidate broker (RabbitMQ, Redis Streams, NATS, etc.) actually satisfies the system's needs, or where the gaps would be.
|
||||
|
||||
## Design Goals
|
||||
|
||||
### Goal 1: Remove the link between Pulsar schemas and application code
|
||||
|
||||
Message types should be plain Python objects (dataclasses) that know how to serialise to and from JSON. The `pulsar.schema.Record` base class and Pulsar field types should not appear in schema definitions. The pub/sub transport layer sends and receives JSON bytes; the schema layer handles the mapping between JSON and typed Python objects independently.
|
||||
|
||||
### Goal 2: Remove `to_pulsar` / `from_pulsar` naming
|
||||
|
||||
The translator methods should reflect what they actually do: encode a Python object to a JSON-compatible dict, and decode a JSON-compatible dict back to a Python object. The naming should be broker-neutral (e.g. `encode` / `decode`, or `to_dict` / `from_dict`).
|
||||
|
||||
### Goal 3: Schema objects provide encode/decode
|
||||
|
||||
Each message type should be a Python dataclass (or similar) with a well-defined mapping to and from JSON. For example:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class TextCompletionRequest:
|
||||
system: str
|
||||
prompt: str
|
||||
streaming: bool = False
|
||||
```
|
||||
|
||||
Given `{"system": "You are helpful", "prompt": "Hello", "streaming": false}` on the wire, decoding produces an object where `request.system` is `"You are helpful"`, `request.prompt` is `"Hello"`, and `request.streaming` is `False`. Encoding does the reverse. This is the schema's concern, not the broker's.
|
||||
|
||||
### Goal 4: Abstract queue naming
|
||||
|
||||
Queue identifiers should not use Pulsar URI format (`persistent://tg/flow/topic`). A broker-neutral naming scheme is needed so that each backend can map logical queue names to its native format. The right approach here is not yet clear and needs to be worked through — considerations include how to express quality-of-service, multi-tenancy, and namespace separation without leaking broker concepts.
|
||||
|
||||
### Goal 5: Document pub/sub architectural requirements
|
||||
|
||||
TrustGraph's actual requirements from the pub/sub layer need to be formally specified. This includes:
|
||||
|
||||
- **Delivery semantics**: Which queues need at-least-once delivery? Are any fire-and-forget?
|
||||
- **Consumer patterns**: Shared subscriptions (competing consumers for load balancing), exclusive subscriptions, fan-out/broadcast
|
||||
- **Message acknowledgement**: Positive ack, negative ack (redelivery), timeout-based redelivery
|
||||
- **Message properties**: Key-value metadata on messages used for correlation (e.g. request IDs, flow routing)
|
||||
- **Ordering guarantees**: Per-topic ordering, per-key ordering, or no ordering required
|
||||
- **Message size**: Typical and maximum message sizes (some payloads include base64-encoded documents)
|
||||
- **Persistence**: Which messages must survive broker restarts
|
||||
- **Consumer positioning**: Ability to consume from earliest (replay) vs latest (live tail)
|
||||
- **Connection model**: Long-lived connections with reconnection, or transient
|
||||
|
||||
Documenting these requirements makes it possible to evaluate RabbitMQ or any other candidate against concrete criteria rather than discovering gaps during implementation.
|
||||
|
||||
## Pub/Sub Architectural Requirements (As-Is)
|
||||
|
||||
This section documents what TrustGraph currently needs from its pub/sub layer. These are the as-is requirements — some may be revisited or relaxed in a future design if it makes broker portability easier.
|
||||
|
||||
### Consumer model
|
||||
|
||||
All consumers use **shared subscriptions** (competing consumers). Multiple instances of the same processor read from the same subscription, and each message is delivered to exactly one instance. This is the load-balancing mechanism.
|
||||
|
||||
No exclusive or failover subscriptions are used anywhere in the codebase, despite infrastructure support for them.
|
||||
|
||||
Consumers support configurable concurrency — multiple async tasks within a single process can independently call `receive()` on the same subscription.
|
||||
|
||||
### Delivery semantics
|
||||
|
||||
Almost all queues are **non-persistent / best-effort (q0)**. The only persistent queue is `config_push_queue` (q2, exactly-once), which pushes full configuration state to processors. Since config pushes are idempotent (full state, not deltas), the persistence requirement here is about surviving broker restarts, not about exactly-once semantics per se.
|
||||
|
||||
Flow processing queues (request/response pairs for LLM, RAG, agent, etc.) are all non-persistent. Messages in flight are lost on broker restart. This is acceptable because:
|
||||
|
||||
- Requests originate from a client that will time out and retry
|
||||
- There is no durable work-in-progress that would be corrupted by message loss
|
||||
- The system is designed for real-time query processing, not batch pipelines
|
||||
|
||||
### Message acknowledgement
|
||||
|
||||
**Positive acknowledgement**: After successful handler execution, the message is acknowledged. This removes it from the subscription.
|
||||
|
||||
**Negative acknowledgement**: On handler failure (unhandled exception or rate-limit timeout), the message is negatively acknowledged, which triggers redelivery by the broker. Rate-limited messages retry for up to 7200 seconds before giving up and negatively acknowledging.
|
||||
|
||||
**Orphaned messages**: In the request-response subscriber pattern, messages that arrive with no matching waiter (e.g. the requester timed out) are positively acknowledged and discarded. This prevents redelivery storms.
|
||||
|
||||
### Message properties
|
||||
|
||||
Messages carry a small set of key-value string properties as metadata, separate from the payload. The primary use is a `"id"` property for request-response correlation — the requester generates a unique ID, attaches it as a property, and the responder echoes it back so the subscriber can match responses to waiters.
|
||||
|
||||
Agent orchestration correlation (`correlation_id`, `parent_session_id`) is carried in the message payload, not in properties.
|
||||
|
||||
### Consumer positioning
|
||||
|
||||
Two modes are used:
|
||||
|
||||
- **Earliest**: The configuration consumer starts from the beginning of the topic to receive full configuration history on startup. This is the only use of earliest positioning.
|
||||
- **Latest** (default): All flow consumers start from the current position, processing only new messages.
|
||||
|
||||
### Message ordering
|
||||
|
||||
**Not required.** The codebase explicitly does not depend on message ordering:
|
||||
|
||||
- Shared subscriptions distribute messages across consumers without ordering guarantees
|
||||
- Concurrent handler tasks within a consumer process messages in arbitrary order
|
||||
- Request-response correlation uses IDs, not positional ordering
|
||||
- The supervisor fan-out/fan-in pattern collects results in a dictionary, order-independent
|
||||
- Configuration pushes are full state snapshots, not ordered deltas
|
||||
|
||||
### Message sizes
|
||||
|
||||
Most messages are small JSON payloads (< 10KB). The exceptions:
|
||||
|
||||
- **Document content**: Large documents (PDFs, text files) can be sent through the chunking service with base64 encoding. Pulsar's chunking feature (`chunking_enabled`) handles automatic splitting of oversized messages.
|
||||
- **Agent observations**: LLM-generated text can be several KB but rarely exceeds typical message size limits.
|
||||
|
||||
A replacement broker needs to either support large messages natively or provide a chunking/streaming mechanism. Alternatively, the large-document path could be refactored to use a side-channel (e.g. object store reference) instead of inline payload.
|
||||
|
||||
### Fan-out patterns
|
||||
|
||||
**Supervisor fan-out**: One supervisor request decomposes into N independent sub-agent requests, each emitted as a separate message on the agent request queue. Different agent instances pick them up via the shared subscription. A correlation ID links the completions back to the original decomposition. This is not pub/sub fan-out (one message to many consumers) — it's application-level fan-out (many messages to one queue).
|
||||
|
||||
**Request-response isolation**: Each client creates a unique subscription name on response queues so it only receives its own responses. This means the response queue effectively has many independent subscribers, each seeing a filtered subset of messages based on the `"id"` property match.
|
||||
|
||||
### Reconnection and resilience
|
||||
|
||||
Reconnection logic lives in the Consumer/Producer/Publisher/Subscriber classes, not in the broker client. These classes handle:
|
||||
|
||||
- Automatic reconnection on connection loss
|
||||
- Retry loops with backoff
|
||||
- Graceful shutdown (unsubscribe, close)
|
||||
|
||||
The broker client itself is expected to provide a basic connection that can fail, and the wrapper classes handle recovery. This is important for the abstraction — the backend interface can be simple because resilience is handled above it.
|
||||
|
||||
### Queue inventory
|
||||
|
||||
| Queue | Persistence | Purpose |
|
||||
|-------|-------------|---------|
|
||||
| config push | Persistent (q2) | Full configuration state broadcast |
|
||||
| config request/response | Non-persistent | Configuration queries |
|
||||
| flow request/response | Non-persistent | Flow management |
|
||||
| knowledge request/response | Non-persistent | Knowledge graph operations |
|
||||
| librarian request/response | Non-persistent | Document storage operations |
|
||||
| document embeddings request/response | Non-persistent | Document vector queries |
|
||||
| row embeddings request/response | Non-persistent | Row vector queries |
|
||||
| collection request/response | Non-persistent | Collection management |
|
||||
|
||||
Additionally, each processing service (LLM, RAG, agent, prompt, embeddings, etc.) has dynamically defined request/response queue pairs configured at deployment time.
|
||||
|
||||
### Summary of hard requirements for a replacement broker
|
||||
|
||||
1. **Shared subscription / competing consumers** — multiple consumers on one queue, each message delivered to exactly one
|
||||
2. **Message acknowledgement** — positive ack (remove from queue) and negative ack (trigger redelivery)
|
||||
3. **Message properties** — key-value metadata on messages, at minimum a string `"id"` field
|
||||
4. **Two consumer start positions** — from beginning of topic and from current position
|
||||
5. **Persistence for at least one queue** — config state must survive broker restart
|
||||
6. **Messages up to several MB** — or a chunking mechanism for large payloads
|
||||
7. **No ordering requirement** — simplifies broker selection significantly
|
||||
|
||||
## Candidate Brokers
|
||||
|
||||
A quick assessment of alternatives against the hard requirements above.
|
||||
|
||||
### RabbitMQ
|
||||
|
||||
The primary candidate. Mature, widely deployed, well understood.
|
||||
|
||||
- **Competing consumers**: Yes — multiple consumers on a queue, round-robin delivery. This is RabbitMQ's native model.
|
||||
- **Acknowledgement**: Yes — `basic.ack` and `basic.nack` with requeue flag.
|
||||
- **Message properties**: Yes — headers and properties on every message. The `correlation_id` and `message_id` fields are first-class concepts.
|
||||
- **Consumer positioning**: Yes, via RabbitMQ Streams (3.9+). Streams are append-only logs that support reading from any offset — beginning, end, or timestamp. Classic queues are consumed destructively (no replay), but streams solve this cleanly. The `state` queue class maps to a RabbitMQ stream. Additionally, the Last Value Cache Exchange plugin can retain the most recent message per routing key for new consumers.
|
||||
- **Persistence**: Yes — durable queues and persistent messages survive broker restart.
|
||||
- **Large messages**: No hard limit but not designed for very large payloads. Practical limit around 128MB with default config. Adequate for current use.
|
||||
- **Ordering**: FIFO per queue (stronger than required).
|
||||
- **Operational complexity**: Low. Single binary, no ZooKeeper/BookKeeper dependencies. Significantly simpler to operate than Pulsar.
|
||||
- **Ecosystem**: Excellent client libraries, management UI, mature tooling.
|
||||
|
||||
**Gaps**: None significant. RabbitMQ Streams cover the replay/earliest positioning requirement.
|
||||
|
||||
### Apache Kafka
|
||||
|
||||
High-throughput distributed log. More infrastructure than TrustGraph likely needs.
|
||||
|
||||
- **Competing consumers**: Yes — consumer groups with partition assignment.
|
||||
- **Acknowledgement**: Yes — offset commits. No per-message negative ack; failed messages require application-level retry or dead-letter handling.
|
||||
- **Message properties**: Yes — message headers (key-value byte arrays).
|
||||
- **Consumer positioning**: Yes — seek to earliest or latest offset. Supports full replay.
|
||||
- **Persistence**: Yes — all messages are persisted to the log by default.
|
||||
- **Large messages**: Configurable (`max.message.bytes`), default 1MB, can be increased. Large payloads are discouraged by design.
|
||||
- **Ordering**: Per-partition ordering (stronger than required).
|
||||
- **Operational complexity**: High. Requires ZooKeeper (or KRaft), partition management, replication config. Overkill for typical TrustGraph deployments.
|
||||
- **Ecosystem**: Excellent client libraries, schema registry, Connect framework.
|
||||
|
||||
**Gaps**: No native negative acknowledgement. Operational complexity is high for small-to-medium deployments. Partition count must be planned upfront for parallelism.
|
||||
|
||||
### Redis Streams
|
||||
|
||||
Lightweight option using Redis as a message broker.
|
||||
|
||||
- **Competing consumers**: Yes — consumer groups with `XREADGROUP`.
|
||||
- **Acknowledgement**: Yes — `XACK`. Pending entries list tracks unacknowledged messages. No explicit negative ack but unacknowledged messages can be claimed after timeout via `XAUTOCLAIM`.
|
||||
- **Message properties**: No native separation between properties and payload. Would need to encode properties as fields within the stream entry or in the payload.
|
||||
- **Consumer positioning**: Yes — `0` (earliest) or `$` (latest) on group creation.
|
||||
- **Persistence**: Yes — Redis persistence (RDB/AOF), though Redis is primarily an in-memory system.
|
||||
- **Large messages**: Practical limit tied to Redis memory. Not suited for large payloads.
|
||||
- **Ordering**: Per-stream ordering (stronger than required).
|
||||
- **Operational complexity**: Low if Redis is already in the stack. No additional infrastructure.
|
||||
|
||||
**Gaps**: No native message properties. Memory-bound. Persistence depends on Redis configuration. Not a natural fit for message broker patterns.
|
||||
|
||||
### NATS / NATS JetStream
|
||||
|
||||
Lightweight, high-performance messaging. JetStream adds persistence.
|
||||
|
||||
- **Competing consumers**: Yes — queue groups in core NATS; consumer groups in JetStream.
|
||||
- **Acknowledgement**: JetStream only — `Ack`, `Nak` (with redelivery), `InProgress` (extend timeout).
|
||||
- **Message properties**: Yes — message headers (key-value).
|
||||
- **Consumer positioning**: JetStream — deliver all, deliver last, deliver new, deliver by sequence/time.
|
||||
- **Persistence**: JetStream only. Core NATS is fire-and-forget.
|
||||
- **Large messages**: Default 1MB, configurable up to 64MB.
|
||||
- **Ordering**: Per-subject ordering.
|
||||
- **Operational complexity**: Very low. Single binary, no dependencies. Clustering is straightforward.
|
||||
|
||||
**Gaps**: Requires JetStream for persistence and acknowledgement. Smaller ecosystem than RabbitMQ/Kafka.
|
||||
|
||||
### Assessment Summary
|
||||
|
||||
| Requirement | RabbitMQ | Kafka | Redis Streams | NATS JetStream |
|
||||
|---|---|---|---|---|
|
||||
| Competing consumers | Yes | Yes | Yes | Yes |
|
||||
| Positive/negative ack | Yes | Partial | Partial | Yes |
|
||||
| Message properties | Yes | Yes | No | Yes |
|
||||
| Earliest positioning | Yes (Streams) | Yes | Yes | Yes |
|
||||
| Persistence | Yes | Yes | Partial | Yes |
|
||||
| Large messages | Yes | Configurable | No | Configurable |
|
||||
| Operational simplicity | Good | Poor | Good | Good |
|
||||
|
||||
**RabbitMQ** is the strongest candidate given TrustGraph's requirements and deployment profile. The only gap (earliest consumer positioning for config) has known workarounds. Operational simplicity is a significant advantage over Pulsar.
|
||||
|
||||
## Approach
|
||||
|
||||
### Current state
|
||||
|
||||
The codebase has already undergone a partial abstraction. The picture is better than the problem statement might suggest:
|
||||
|
||||
- **Backend abstraction exists**: `backend.py` defines Protocol-based interfaces (`PubSubBackend`, `BackendProducer`, `BackendConsumer`, `Message`). The Pulsar implementation lives in `pulsar_backend.py`.
|
||||
- **Schemas are already dataclasses**: Message types in `schema/services/*.py` are plain Python dataclasses with type hints, not Pulsar `Record` subclasses. This was the hardest part of the old spec and it's done.
|
||||
- **Serialization is JSON-based**: `pulsar_backend.py` contains `dataclass_to_dict()` and `dict_to_dataclass()` helpers that handle the round-trip. The wire format is JSON.
|
||||
- **Factory pattern exists**: `pubsub.py` has `get_pubsub()` which creates a backend from configuration. Currently only Pulsar is implemented.
|
||||
- **Consumer/Producer/Publisher/Subscriber are backend-agnostic**: These classes accept a `backend` parameter and delegate transport operations to it. They own retry, reconnection, metrics, and concurrency.
|
||||
|
||||
What remains is cleanup, not a rewrite.
|
||||
|
||||
### What needs to change
|
||||
|
||||
#### 1. Rename translator methods
|
||||
|
||||
The translator base class (`messaging/translators/base.py`) defines `to_pulsar()` and `from_pulsar()` as abstract methods. Every translator implements these. The methods convert between external API dicts and internal dataclass objects — nothing Pulsar-specific happens in them.
|
||||
|
||||
**Change**: Rename to `decode()` (external dict → dataclass) and `encode()` (dataclass → external dict). Update all translator subclasses and all call sites.
|
||||
|
||||
This is a mechanical rename. The method bodies don't change.
|
||||
|
||||
#### 2. Rename translator base classes
|
||||
|
||||
The base classes `Translator`, `MessageTranslator`, and `SendTranslator` reference "pulsar" in docstrings and parameter names. Clean these up so the naming reflects what the layer actually does: translating between the external API representation (JSON dicts from HTTP/WebSocket) and the internal schema (dataclasses).
|
||||
|
||||
#### 3. Move serialization out of the Pulsar backend
|
||||
|
||||
`dataclass_to_dict()` and `dict_to_dataclass()` currently live in `pulsar_backend.py` but are not Pulsar-specific. They handle the conversion between dataclasses and JSON-compatible dicts, which every backend needs.
|
||||
|
||||
**Change**: Move these to a shared location (e.g. `trustgraph/base/serialization.py` or alongside the schema definitions). The backend interface sends and receives dicts; serialization to/from dataclasses happens at a layer above.
|
||||
|
||||
This means the backend Protocol simplifies: `send()` accepts a dict and properties, `value()` returns a dict. The Consumer/Producer layer handles dataclass ↔ dict conversion using the shared serializers.
|
||||
|
||||
#### 4. Abstract queue naming
|
||||
|
||||
Queue names currently use the format `q0/tg/flow/queue-name` or `q2/tg/config/queue-name`, which the Pulsar backend maps to `non-persistent://tg/flow/queue-name` or `persistent://tg/config/queue-name`.
|
||||
|
||||
This is an open design question. Options:
|
||||
|
||||
**Option A: Simple string names.** Queues are just strings like `"text-completion-request"`. The backend is responsible for mapping to its native format (Pulsar adds `persistent://tg/flow/` prefix, RabbitMQ uses the string as-is or adds a vhost prefix). Persistence and namespace are configuration concerns, not embedded in the name.
|
||||
|
||||
**Option B: Structured queue descriptor.** A small object that carries the logical name plus metadata:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class QueueDescriptor:
|
||||
name: str # e.g. "text-completion-request"
|
||||
namespace: str = "flow" # logical grouping
|
||||
persistent: bool = False # must survive broker restart
|
||||
```
|
||||
|
||||
The backend maps this to its native format.
|
||||
|
||||
**Option C: Keep the current format** (`q0/tg/flow/name`) but document it as a TrustGraph convention, not a Pulsar convention. Backends parse it.
|
||||
|
||||
Option B is the most explicit. Option A is the simplest. Either is workable. The key constraint is that persistence is a property of the queue definition, not a runtime choice — the config push queue is persistent, everything else is not.
|
||||
|
||||
#### 5. Implement RabbitMQ backend
|
||||
|
||||
Write `rabbitmq_backend.py` implementing the `PubSubBackend` Protocol:
|
||||
|
||||
- **`create_producer()`**: Creates a channel and declares the target queue. `send()` publishes to the default exchange with the queue name as routing key. Properties map to AMQP basic properties (specifically `message_id` for the `"id"` property).
|
||||
- **`create_consumer()`**: Declares the queue and starts consuming with `basic_consume`. Shared subscription is the default RabbitMQ model — multiple consumers on one queue get round-robin delivery. `acknowledge()` maps to `basic_ack`, `negative_acknowledge()` maps to `basic_nack` with `requeue=True`.
|
||||
- **Persistence**: For persistent queues, declare as durable with `delivery_mode=2` on messages. For non-persistent queues, declare as non-durable.
|
||||
- **Consumer positioning**: RabbitMQ queues are consumed destructively, so "earliest" doesn't apply in the Pulsar sense. For the config push use case, use a **fanout exchange with per-consumer exclusive queues** — each new processor gets its own queue that receives all config publishes, plus the last-value can be handled by having the config service re-publish on startup.
|
||||
- **Large messages**: RabbitMQ handles messages up to `rabbit.max_message_size` (default 128MB). No chunking needed.
|
||||
|
||||
The factory in `pubsub.py` gets a new branch:
|
||||
|
||||
```python
|
||||
if backend_type == 'rabbitmq':
|
||||
return RabbitMQBackend(
|
||||
host=config.get('rabbitmq_host'),
|
||||
port=config.get('rabbitmq_port'),
|
||||
username=config.get('rabbitmq_username'),
|
||||
password=config.get('rabbitmq_password'),
|
||||
vhost=config.get('rabbitmq_vhost', '/'),
|
||||
)
|
||||
```
|
||||
|
||||
Backend selection via `PUBSUB_BACKEND=rabbitmq` environment variable or `--pubsub-backend rabbitmq` CLI flag.
|
||||
|
||||
#### 6. Clean up remaining Pulsar references
|
||||
|
||||
After the above changes, Pulsar-specific code should be confined to:
|
||||
|
||||
- `pulsar_backend.py` — the Pulsar implementation
|
||||
- `pubsub.py` — the factory that imports it
|
||||
|
||||
Audit and remove any remaining Pulsar imports, Pulsar exception handling, or Pulsar-specific concepts from:
|
||||
|
||||
- `async_processor.py` (currently catches `_pulsar.Interrupted`)
|
||||
- `consumer.py`, `subscriber.py` (if any Pulsar exceptions leak through)
|
||||
- Schema files (should be clean already, but verify)
|
||||
- Gateway service (currently instantiates Pulsar client directly)
|
||||
|
||||
The gateway is a special case — it currently bypasses the abstraction layer and creates a Pulsar client directly for dispatching API requests. It should use the same `get_pubsub()` factory as everything else.
|
||||
|
||||
### What stays the same
|
||||
|
||||
- **Schema definitions**: Already dataclasses. No changes needed.
|
||||
- **Consumer/Producer/Publisher/Subscriber**: Already backend-agnostic. No changes to their core logic.
|
||||
- **FlowProcessor and spec wiring**: Already uses `processor.pubsub` to create backend instances. No changes.
|
||||
- **Backend Protocol**: The interface in `backend.py` is sound. Minor refinement possible (dict vs dataclass at the boundary) but the shape is right.
|
||||
|
||||
### Concrete cleanups
|
||||
|
||||
The following files have Pulsar-specific imports that should not be there after the abstraction is complete. Pulsar imports should be confined to `pulsar_backend.py` and the factory in `pubsub.py`.
|
||||
|
||||
**Dead imports (unused, can just be removed):**
|
||||
|
||||
- `trustgraph-base/trustgraph/base/pubsub.py` — `from pulsar.schema import JsonSchema`, `import pulsar`, `import _pulsar`. The `JsonSchema` import is unused since the switch to `BytesSchema`. The `pulsar`/`_pulsar` imports are only used by the legacy `PulsarClient` class which should be removed (superseded by `PulsarBackend`).
|
||||
- `trustgraph-base/trustgraph/base/flow_processor.py` — `from pulsar.schema import JsonSchema`. Unused.
|
||||
|
||||
**Legacy `PulsarClient` class:**
|
||||
|
||||
- `trustgraph-base/trustgraph/base/pubsub.py` — The `PulsarClient` class is a leftover from before the backend abstraction. `get_pubsub()` still references `PulsarClient.default_pulsar_host` for defaults. Move the defaults to `PulsarBackend` or to environment variable reads in the factory, then delete `PulsarClient`.
|
||||
|
||||
**Client libraries using Pulsar directly:**
|
||||
|
||||
- `trustgraph-base/trustgraph/clients/base.py` — `import pulsar`, `import _pulsar`, `from pulsar.schema import JsonSchema`. This is the base class for the old synchronous client library. These clients predate the backend abstraction and use Pulsar directly.
|
||||
- `trustgraph-base/trustgraph/clients/embeddings_client.py` — `from pulsar.schema import JsonSchema`, `import _pulsar`.
|
||||
- `trustgraph-base/trustgraph/clients/*.py` (agent, config, document_embeddings, document_rag, graph_embeddings, graph_rag, llm, prompt, row_embeddings, triples_query) — all import `_pulsar` for exception handling.
|
||||
|
||||
These clients are the internal request-response clients used by processors. They need to be migrated to use the backend abstraction or their Pulsar exception handling needs to be wrapped behind a backend-agnostic exception type.
|
||||
|
||||
**Translator base class:**
|
||||
|
||||
- `trustgraph-base/trustgraph/messaging/translators/base.py` — `from pulsar.schema import Record`. Used in type hints. Should be removed when `to_pulsar`/`from_pulsar` are renamed.
|
||||
|
||||
**Gateway service (bypasses abstraction):**
|
||||
|
||||
- `trustgraph-flow/trustgraph/gateway/service.py` — `import pulsar`. Creates a Pulsar client directly.
|
||||
- `trustgraph-flow/trustgraph/gateway/config/receiver.py` — `import pulsar`. Direct Pulsar usage.
|
||||
|
||||
The gateway should use `get_pubsub()` like everything else.
|
||||
|
||||
**Storage writers:**
|
||||
|
||||
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py` — `import pulsar`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py` — `import pulsar`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py` — `import pulsar`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` — `import pulsar`
|
||||
|
||||
These need investigation — likely Pulsar exception handling or direct client usage that should go through the abstraction.
|
||||
|
||||
**Log level:**
|
||||
|
||||
- `trustgraph-base/trustgraph/log_level.py` — `import _pulsar`. Used to set Pulsar's log level. Should be moved into `pulsar_backend.py`.
|
||||
|
||||
### Queue naming
|
||||
|
||||
The current scheme encodes QoS, tenant, namespace, and queue name into a slash-separated string (`q0/tg/request/config`) which the Pulsar backend parses and maps to a Pulsar URI (`non-persistent://tg/request/config`). This was an attempt at abstraction but it has problems:
|
||||
|
||||
- QoS in the name was a mistake — it's a property of the queue definition, not something that belongs in the name. A queue is either persistent or it isn't; that's decided once when the queue is defined.
|
||||
- The tenant/namespace structure mirrors Pulsar's model. RabbitMQ doesn't use this — it has vhosts and exchange/queue names. Pretending the naming isn't TrustGraph-specific just leaks Pulsar concepts.
|
||||
- The `topic()` helper generates these strings, and the backend parses them apart. This is unnecessary indirection.
|
||||
|
||||
There are two categories of queue in TrustGraph:
|
||||
|
||||
**Infrastructure queues** — defined in code, used for system services. These are fixed and well-known:
|
||||
|
||||
| Queue | Persistent | Purpose |
|
||||
|-------|------------|---------|
|
||||
| `config-request` | No | Config queries |
|
||||
| `config-response` | No | Config query responses |
|
||||
| `config-push` | Yes | Config state broadcast |
|
||||
| `flow-request` | No | Flow management queries |
|
||||
| `flow-response` | No | Flow management responses |
|
||||
| `librarian-request` | No | Document storage operations |
|
||||
| `librarian-response` | No | Document storage responses |
|
||||
| `knowledge-request` | No | Knowledge graph operations |
|
||||
| `knowledge-response` | No | Knowledge graph responses |
|
||||
| `document-embeddings-request` | No | Document vector queries |
|
||||
| `document-embeddings-response` | No | Document vector responses |
|
||||
| `row-embeddings-request` | No | Row vector queries |
|
||||
| `row-embeddings-response` | No | Row vector responses |
|
||||
| `collection-request` | No | Collection management |
|
||||
| `collection-response` | No | Collection management responses |
|
||||
|
||||
**Flow queues** — defined in configuration, created dynamically per flow. The queue names come from the config service (e.g. `text-completion-request`, `graph-rag-request`, `agent-request`). Each flow instance has its own set of these queues.
|
||||
|
||||
For infrastructure queues, the name is just a string. Persistence is a property of the queue definition, not encoded in the name. The backend maps the name to whatever its native format requires.
|
||||
|
||||
For flow queues, the name comes from configuration. The config service already distributes queue names as strings — the backend just needs to be able to use them.
|
||||
|
||||
#### Proposed scheme: CLASS:TOPICSPACE:TOPIC
|
||||
|
||||
A queue name has three parts separated by colons:
|
||||
|
||||
- **CLASS** — a small enum that defines the queue's operational characteristics. The backend knows what each class means in terms of persistence, TTL, memory limits, etc. There are only four classes:
|
||||
|
||||
| Class | Persistent | TTL | Behaviour |
|
||||
|-------|------------|-----|-----------|
|
||||
| `flow` | Yes | Long | Processing pipeline queues. Messages survive broker restart. |
|
||||
| `request` | No | Short | Transient request-response. Low TTL, no persistence needed — clients retry on failure. |
|
||||
| `response` | No | Short | Same as request, for the response side. |
|
||||
| `state` | Yes | Retained | Last-value state broadcast. Consumers need the most recent value on startup, plus any future updates. Config push is the primary example. |
|
||||
|
||||
- **TOPICSPACE** — deployment isolation. Keeps different TrustGraph deployments separate when sharing the same pub/sub infrastructure. Most deployments just use `tg`. Avoids the overloaded terms "tenant" and "namespace".
|
||||
|
||||
- **TOPIC** — the logical queue identity. What the queue is for.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```
|
||||
flow:tg:text-completion-request
|
||||
flow:tg:graph-rag-request
|
||||
flow:tg:agent-request
|
||||
request:tg:librarian
|
||||
response:tg:librarian
|
||||
request:tg:config
|
||||
response:tg:config
|
||||
state:tg:config
|
||||
request:tg:flow
|
||||
response:tg:flow
|
||||
```
|
||||
|
||||
**Backend mapping:**
|
||||
|
||||
Each backend parses the three parts and maps them to its native concepts:
|
||||
|
||||
- **Pulsar**: `flow:tg:text-completion-request` → `persistent://tg/flow/text-completion-request`. Class maps to persistent/non-persistent and namespace. State class uses persistent topic with earliest consumer positioning.
|
||||
- **RabbitMQ**: Topicspace maps to vhost. Class determines queue durability and TTL policy. State class uses a last-value queue (via plugin) or a fanout exchange pattern where each consumer gets the retained state on connect.
|
||||
- **Kafka**: `flow.tg.text-completion-request` as topic name. Class determines retention and compaction policy. State class maps to a compacted topic (last value per key).
|
||||
|
||||
**Why this works:**
|
||||
|
||||
- The class enum is small and stable — adding a new class is rare and deliberate
|
||||
- Queue properties (persistence, TTL) are implied by class, not encoded in the name
|
||||
- Dynamic registration works naturally — the config service publishes `flow:tg:text-completion-request` and the backend knows how to declare it from the `flow` class
|
||||
- The colon separator is unambiguous, easy to split, doesn't conflict with URIs or path separators that backends use internally
|
||||
- No pretence of being generic — this is a TrustGraph convention, and that's fine
|
||||
|
||||
### Serialization boundary
|
||||
|
||||
**Decision: the backend owns the wire format.**
|
||||
|
||||
The contract between the Consumer/Producer layer and the backend is dataclass objects in, dataclass objects out:
|
||||
|
||||
- `send()` accepts a dataclass instance and properties dict
|
||||
- `receive()` returns a message whose `value()` is a dataclass instance
|
||||
|
||||
What happens on the wire is the backend's concern. The Pulsar backend uses JSON (via `dataclass_to_dict` / `dict_to_dataclass`). A RabbitMQ backend would likely also use JSON. A future backend could use Protobuf, MessagePack, or Avro if the broker benefits from it.
|
||||
|
||||
The serialization helpers stay inside the backend that uses them — they are not shared infrastructure. Each backend brings its own serialization strategy. The Consumer/Producer layer never thinks about wire format.
|
||||
|
||||
### Gateway service
|
||||
|
||||
**Decision: the gateway uses the backend abstraction like any other component.**
|
||||
|
||||
The gateway currently bridges WebSocket/REST to Pulsar directly, bypassing the abstraction layer. It translates incoming API JSON to Pulsar schema objects, sends them, receives responses as Pulsar schema objects, and translates back to API JSON. Since the wire format is JSON in both directions, this is effectively a no-op round trip through the schema machinery.
|
||||
|
||||
With the backend abstraction, the gateway follows the same pattern as every other component:
|
||||
|
||||
1. Incoming API JSON → translator `decode()` → dataclass
|
||||
2. Dataclass → backend `send()` (backend handles wire format)
|
||||
3. Backend `receive()` → dataclass
|
||||
4. Dataclass → translator `encode()` → API JSON → WebSocket/REST client
|
||||
|
||||
This is architecturally simple — one code path, no special cases. The gateway depends on the schema dataclasses and the translator layer, which it already does. The overhead of deserialize-then-reserialize is negligible for the message sizes involved. And it keeps all options open — if a future backend uses a non-JSON wire format, the gateway still works without changes.
|
||||
|
||||
## Implementation Order
|
||||
|
||||
### Phase 1: Rename translators
|
||||
|
||||
Rename `to_pulsar()` → `decode()`, `from_pulsar()` → `encode()` across all translator classes and call sites. Remove `from pulsar.schema import Record` from the translator base class. Mechanical find-and-replace, no behavioural changes.
|
||||
|
||||
### Phase 2: Queue naming
|
||||
|
||||
Replace the `topic()` helper with the CLASS:TOPICSPACE:TOPIC scheme. Update all queue definitions in `schema/services/*.py` and `schema/knowledge/*.py`. Update `PulsarBackend.map_topic()` to parse the new format. Verify all existing functionality still works with Pulsar.
|
||||
|
||||
### Phase 3: Clean up Pulsar leaks
|
||||
|
||||
Work through the concrete cleanups list: remove dead imports, delete the legacy `PulsarClient` class, migrate the client libraries and gateway to use the backend abstraction. After this phase, `pulsar` imports exist only in `pulsar_backend.py`.
|
||||
|
||||
### Phase 4: RabbitMQ backend
|
||||
|
||||
Implement `rabbitmq_backend.py` against the existing `PubSubBackend` Protocol. Map queue classes to RabbitMQ concepts: `flow` → durable queues, `request`/`response` → non-durable queues with TTL, `state` → RabbitMQ streams. Add `rabbitmq` as a backend option in the factory. Test end-to-end with `PUBSUB_BACKEND=rabbitmq`.
|
||||
|
||||
Phases 1-3 are safe to do on main — they don't change behaviour, just clean up. Phase 4 is additive — it adds a new backend without touching the existing one.
|
||||
|
||||
### Config distribution on RabbitMQ
|
||||
|
||||
The `state` queue class needs "start from earliest" semantics — a newly started processor must receive the current configuration state.
|
||||
|
||||
RabbitMQ Streams (available since 3.9) solve this directly. Streams are persistent, append-only logs that support consumer offset positioning. The RabbitMQ backend maps the `state` class to a stream, and consumers attach with offset `first` to read from the beginning, or `last` to read the most recent entry plus future updates.
|
||||
|
||||
Since config pushes are full state snapshots (not deltas), a consumer only needs the most recent entry. The RabbitMQ backend can use `last` offset positioning for `state` class consumers, which delivers the last message in the stream followed by any new messages. This matches the current behaviour where processors read config on startup and then react to updates.
|
||||
|
||||
|
|
@ -63,7 +63,11 @@ Explainability events stream to client as the query executes:
|
|||
3. Edges selected with reasoning → event emitted
|
||||
4. Answer synthesized → event emitted
|
||||
|
||||
Client receives `explain_id` and `explain_collection` to fetch full details.
|
||||
Client receives `explain_id`, `explain_graph`, and `explain_triples` inline
|
||||
in each explain message. The triples contain the full provenance data for
|
||||
that step — no follow-up graph query needed. The `explain_id` serves as
|
||||
the root entity URI within the triples. Data is also written to the
|
||||
knowledge graph for later audit/analysis.
|
||||
|
||||
## URI Structure
|
||||
|
||||
|
|
@ -144,7 +148,8 @@ class GraphRagResponse:
|
|||
response: str = ""
|
||||
end_of_stream: bool = False
|
||||
explain_id: str | None = None
|
||||
explain_collection: str | None = None
|
||||
explain_graph: str | None = None
|
||||
explain_triples: list[Triple] = field(default_factory=list)
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False
|
||||
```
|
||||
|
|
@ -154,7 +159,7 @@ class GraphRagResponse:
|
|||
| message_type | Purpose |
|
||||
|--------------|---------|
|
||||
| `chunk` | Response text (streaming or final) |
|
||||
| `explain` | Explainability event with IRI reference |
|
||||
| `explain` | Explainability event with inline provenance triples |
|
||||
|
||||
### Session Lifecycle
|
||||
|
||||
|
|
|
|||
268
docs/tech-specs/sparql-query.md
Normal file
268
docs/tech-specs/sparql-query.md
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
# SPARQL Query Service Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
A pub/sub-hosted SPARQL query service that accepts SPARQL queries, decomposes
|
||||
them into triple pattern lookups via the existing triples query pub/sub
|
||||
interface, performs in-memory joins/filters/projections, and returns SPARQL
|
||||
result bindings.
|
||||
|
||||
This makes the triple store queryable using a standard graph query language
|
||||
without coupling to any specific backend (Neo4j, Cassandra, FalkorDB, etc.).
|
||||
|
||||
## Goals
|
||||
|
||||
- **SPARQL 1.1 support**: SELECT, ASK, CONSTRUCT, DESCRIBE queries
|
||||
- **Backend-agnostic**: query via the pub/sub triples interface, not direct
|
||||
database access
|
||||
- **Standard service pattern**: FlowProcessor with ConsumerSpec/ProducerSpec,
|
||||
using TriplesClientSpec to call the triples query service
|
||||
- **Correct SPARQL semantics**: proper BGP evaluation, joins, OPTIONAL, UNION,
|
||||
FILTER, BIND, aggregation, solution modifiers (ORDER BY, LIMIT, OFFSET,
|
||||
DISTINCT)
|
||||
|
||||
## Background
|
||||
|
||||
The triples query service provides a single-pattern lookup: given optional
|
||||
(s, p, o) values, return matching triples. This is the equivalent of one
|
||||
triple pattern in a SPARQL Basic Graph Pattern.
|
||||
|
||||
To evaluate a full SPARQL query, we need to:
|
||||
1. Parse the SPARQL string into an algebra tree
|
||||
2. Walk the algebra tree, issuing triple pattern lookups for each BGP pattern
|
||||
3. Join results across patterns (nested-loop or hash join)
|
||||
4. Apply filters, optionals, unions, and aggregations in-memory
|
||||
5. Project and return the requested variables
|
||||
|
||||
rdflib (already a dependency) provides a SPARQL 1.1 parser and algebra
|
||||
compiler. We use rdflib to parse queries into algebra trees, then evaluate
|
||||
the algebra ourselves using the triples query client as the data source.
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
pub/sub
|
||||
[Client] ──request──> [SPARQL Query Service] ──triples-request──> [Triples Query Service]
|
||||
[Client] <─response── [SPARQL Query Service] <─triples-response── [Triples Query Service]
|
||||
```
|
||||
|
||||
The service is a FlowProcessor that:
|
||||
- Consumes SPARQL query requests
|
||||
- Uses TriplesClientSpec to issue triple pattern lookups
|
||||
- Evaluates the SPARQL algebra in-memory
|
||||
- Produces result responses
|
||||
|
||||
### Components
|
||||
|
||||
1. **SPARQL Query Service (FlowProcessor)**
|
||||
- ConsumerSpec for incoming SPARQL requests
|
||||
- ProducerSpec for outgoing results
|
||||
- TriplesClientSpec for calling the triples query service
|
||||
- Delegates parsing and evaluation to the components below
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/sparql/service.py`
|
||||
|
||||
2. **SPARQL Parser (rdflib wrapper)**
|
||||
- Uses `rdflib.plugins.sparql.prepareQuery` / `parseQuery` and
|
||||
`rdflib.plugins.sparql.algebra.translateQuery` to produce an algebra tree
|
||||
- Extracts PREFIX declarations, query type (SELECT/ASK/CONSTRUCT/DESCRIBE),
|
||||
and the algebra root
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/sparql/parser.py`
|
||||
|
||||
3. **Algebra Evaluator**
|
||||
- Recursive evaluator over the rdflib algebra tree
|
||||
- Each algebra node type maps to an evaluation function
|
||||
- BGP nodes issue triple pattern queries via TriplesClient
|
||||
- Join/Filter/Optional/Union etc. operate on in-memory solution sequences
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/sparql/algebra.py`
|
||||
|
||||
4. **Solution Sequence**
|
||||
- A solution is a dict mapping variable names to Term values
|
||||
- Solution sequences are lists of solutions
|
||||
- Join: hash join on shared variables
|
||||
- LeftJoin (OPTIONAL): hash join preserving unmatched left rows
|
||||
- Union: concatenation
|
||||
- Filter: evaluate SPARQL expressions against each solution
|
||||
- Projection/Distinct/Order/Slice: standard post-processing
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/sparql/solutions.py`
|
||||
|
||||
### Data Models
|
||||
|
||||
#### Request
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SparqlQueryRequest:
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
query: str = "" # SPARQL query string
|
||||
limit: int = 10000 # Safety limit on results
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SparqlQueryResponse:
|
||||
error: Error | None = None
|
||||
query_type: str = "" # "select", "ask", "construct", "describe"
|
||||
|
||||
# For SELECT queries
|
||||
variables: list[str] = field(default_factory=list)
|
||||
bindings: list[SparqlBinding] = field(default_factory=list)
|
||||
|
||||
# For ASK queries
|
||||
ask_result: bool = False
|
||||
|
||||
# For CONSTRUCT/DESCRIBE queries
|
||||
triples: list[Triple] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class SparqlBinding:
|
||||
values: list[Term | None] = field(default_factory=list)
|
||||
```
|
||||
|
||||
### BGP Evaluation Strategy
|
||||
|
||||
For each triple pattern in a BGP:
|
||||
- Extract bound terms (concrete IRIs/literals) and variables
|
||||
- Call `TriplesClient.query_stream(s, p, o)` with bound terms, None for
|
||||
variables
|
||||
- Map returned triples back to variable bindings
|
||||
|
||||
For multi-pattern BGPs, join solutions incrementally:
|
||||
- Order patterns by selectivity (patterns with more bound terms first)
|
||||
- For each subsequent pattern, substitute bound variables from the current
|
||||
solution sequence before querying
|
||||
- This avoids full cross-products and reduces the number of triples queries
|
||||
|
||||
### Streaming and Early Termination
|
||||
|
||||
The triples query service supports streaming responses (batched delivery via
|
||||
`TriplesClient.query_stream`). The SPARQL evaluator should use streaming
|
||||
from the start, not as an optimisation. This is important because:
|
||||
|
||||
- **Early termination**: when the SPARQL query has a LIMIT, or when only one
|
||||
solution is needed (ASK queries), we can stop consuming triples as soon as
|
||||
we have enough results. Without streaming, a wildcard pattern like
|
||||
`?s ?p ?o` would fetch the entire graph before we could apply the limit.
|
||||
- **Memory efficiency**: results are processed batch-by-batch rather than
|
||||
materialising the full result set in memory before joining.
|
||||
|
||||
The batch callback in `query_stream` returns a boolean to signal completion.
|
||||
The evaluator should signal completion (return True) as soon as sufficient
|
||||
solutions have been produced, allowing the underlying pub/sub consumer to
|
||||
stop pulling batches.
|
||||
|
||||
### Parallel BGP Execution (Phase 2 Optimisation)
|
||||
|
||||
Within a BGP, patterns that share variables benefit from sequential
|
||||
evaluation with bound-variable substitution (query results from earlier
|
||||
patterns narrow later queries). However, patterns with no shared variables
|
||||
are independent and could be issued concurrently via `asyncio.gather`.
|
||||
|
||||
A practical approach for a future optimisation pass:
|
||||
- Analyse BGP patterns and identify connected components (groups of
|
||||
patterns linked by shared variables)
|
||||
- Execute independent components in parallel
|
||||
- Within each component, evaluate patterns sequentially with substitution
|
||||
|
||||
This is not needed for correctness -- the sequential approach works for all
|
||||
cases -- but could significantly reduce latency for queries with independent
|
||||
pattern groups. Flagged as a phase 2 optimisation.
|
||||
|
||||
### FILTER Expression Evaluation
|
||||
|
||||
rdflib's algebra represents FILTER expressions as expression trees. We
|
||||
evaluate these against each solution row, supporting:
|
||||
- Comparison operators (=, !=, <, >, <=, >=)
|
||||
- Logical operators (&&, ||, !)
|
||||
- SPARQL built-in functions (isIRI, isLiteral, isBlank, str, lang,
|
||||
datatype, bound, regex, etc.)
|
||||
- Arithmetic operators (+, -, *, /)
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. **Schema and service skeleton** -- define SparqlQueryRequest/Response
|
||||
dataclasses, create the FlowProcessor subclass with ConsumerSpec,
|
||||
ProducerSpec, and TriplesClientSpec wired up. Verify it starts and
|
||||
connects.
|
||||
|
||||
2. **SPARQL parsing** -- wrap rdflib's parser to produce algebra trees from
|
||||
SPARQL strings. Handle parse errors gracefully. Unit test with a range of
|
||||
query shapes.
|
||||
|
||||
3. **BGP evaluation** -- implement single-pattern and multi-pattern BGP
|
||||
evaluation using TriplesClient. This is the core building block. Test
|
||||
with simple SELECT WHERE { ?s ?p ?o } queries.
|
||||
|
||||
4. **Joins and solution sequences** -- implement hash join, left join (for
|
||||
OPTIONAL), and union. Test with multi-pattern queries.
|
||||
|
||||
5. **FILTER evaluation** -- implement the expression evaluator for FILTER
|
||||
clauses. Start with comparisons and logical operators, then add built-in
|
||||
functions incrementally.
|
||||
|
||||
6. **Solution modifiers** -- DISTINCT, ORDER BY, LIMIT, OFFSET, projection.
|
||||
|
||||
7. **ASK / CONSTRUCT / DESCRIBE** -- extend beyond SELECT. ASK is trivial
|
||||
(non-empty result = true). CONSTRUCT builds triples from a template.
|
||||
DESCRIBE fetches all triples for matched resources.
|
||||
|
||||
8. **Aggregation** -- GROUP BY, HAVING, COUNT, SUM, AVG, MIN, MAX,
|
||||
GROUP_CONCAT, SAMPLE.
|
||||
|
||||
9. **BIND, VALUES, subqueries** -- remaining SPARQL 1.1 features.
|
||||
|
||||
10. **API gateway integration** -- add SparqlQueryRequestor dispatcher,
|
||||
request/response translators, and API endpoint so that the SPARQL
|
||||
service is accessible via the HTTP gateway.
|
||||
|
||||
11. **SDK support** -- add `sparql_query()` method to FlowInstance in the
|
||||
Python API SDK, following the same pattern as `triples_query()`.
|
||||
|
||||
12. **CLI command** -- add a `tg-sparql-query` CLI command that takes a
|
||||
SPARQL query string (or reads from a file/stdin), submits it via the
|
||||
SDK, and prints results in a readable format (table for SELECT,
|
||||
true/false for ASK, Turtle for CONSTRUCT/DESCRIBE).
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
In-memory join over pub/sub round-trips will be slower than native SPARQL on
|
||||
a graph database. Key mitigations:
|
||||
|
||||
- **Streaming with early termination**: use `query_stream` so that
|
||||
limit-bound queries don't fetch entire result sets. A `SELECT ... LIMIT 1`
|
||||
against a wildcard pattern fetches one batch, not the whole graph.
|
||||
- **Bound-variable substitution**: when evaluating BGP patterns sequentially,
|
||||
substitute known bindings into subsequent patterns to issue narrow queries
|
||||
rather than broad ones followed by in-memory filtering.
|
||||
- **Parallel independent patterns** (phase 2): patterns with no shared
|
||||
variables can be issued concurrently.
|
||||
- **Query complexity limits**: may need a cap on the number of triple pattern
|
||||
queries issued per SPARQL query to prevent runaway evaluation.
|
||||
|
||||
### Named Graph Mapping
|
||||
|
||||
SPARQL's `GRAPH ?g { ... }` and `GRAPH <uri> { ... }` clauses map to the
|
||||
triples query service's graph filter parameter:
|
||||
|
||||
- `GRAPH <uri> { ?s ?p ?o }` — pass `g=uri` to the triples query
|
||||
- Patterns outside any GRAPH clause — pass `g=""` (default graph only)
|
||||
- `GRAPH ?g { ?s ?p ?o }` — pass `g="*"` (all graphs), then bind `?g` from
|
||||
the returned triple's graph field
|
||||
|
||||
The triples query interface does not support a wildcard graph natively in
|
||||
the SPARQL sense, but `g="*"` (all graphs) combined with client-side
|
||||
filtering on the returned graph values achieves the same effect.
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **SPARQL 1.2**: rdflib's parser support for 1.2 features (property paths
|
||||
are already in 1.1; 1.2 adds lateral joins, ADJUST, etc.). Start with
|
||||
1.1 and extend as rdflib support matures.
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -10,6 +10,7 @@ properties:
|
|||
- observation
|
||||
- answer
|
||||
- final-answer
|
||||
- explain
|
||||
- error
|
||||
example: answer
|
||||
content:
|
||||
|
|
@ -29,6 +30,11 @@ properties:
|
|||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
explain_triples:
|
||||
type: array
|
||||
description: Provenance triples for this explain event (inline, no follow-up query needed)
|
||||
items:
|
||||
$ref: '../common/Triple.yaml'
|
||||
end-of-message:
|
||||
type: boolean
|
||||
description: Current chunk type is complete (streaming mode)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ description: |
|
|||
Librarian service request for document library management.
|
||||
|
||||
Operations: add-document, remove-document, list-documents,
|
||||
get-document-metadata, stream-document, add-child-document,
|
||||
list-children, begin-upload, upload-chunk, complete-upload,
|
||||
abort-upload, get-upload-status, list-uploads,
|
||||
start-processing, stop-processing, list-processing
|
||||
required:
|
||||
- operation
|
||||
|
|
@ -13,6 +16,17 @@ properties:
|
|||
- add-document
|
||||
- remove-document
|
||||
- list-documents
|
||||
- get-document-metadata
|
||||
- get-document-content
|
||||
- stream-document
|
||||
- add-child-document
|
||||
- list-children
|
||||
- begin-upload
|
||||
- upload-chunk
|
||||
- complete-upload
|
||||
- abort-upload
|
||||
- get-upload-status
|
||||
- list-uploads
|
||||
- start-processing
|
||||
- stop-processing
|
||||
- list-processing
|
||||
|
|
@ -21,6 +35,21 @@ properties:
|
|||
- `add-document`: Add document to library
|
||||
- `remove-document`: Remove document from library
|
||||
- `list-documents`: List documents in library
|
||||
- `get-document-metadata`: Get document metadata
|
||||
- `get-document-content`: Get full document content in a single response.
|
||||
**Deprecated** — use `stream-document` instead. Fails for documents
|
||||
exceeding the broker's max message size.
|
||||
- `stream-document`: Stream document content in chunks. Each response
|
||||
includes `chunk_index` and `is_final`. Preferred over `get-document-content`
|
||||
for all document sizes.
|
||||
- `add-child-document`: Add a child document (e.g. page, chunk)
|
||||
- `list-children`: List child documents of a parent
|
||||
- `begin-upload`: Start a chunked upload session
|
||||
- `upload-chunk`: Upload a chunk of data
|
||||
- `complete-upload`: Finalize a chunked upload
|
||||
- `abort-upload`: Cancel a chunked upload
|
||||
- `get-upload-status`: Check upload progress
|
||||
- `list-uploads`: List active upload sessions
|
||||
- `start-processing`: Start processing library documents
|
||||
- `stop-processing`: Stop library processing
|
||||
- `list-processing`: List processing status
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ required:
|
|||
properties:
|
||||
text:
|
||||
type: string
|
||||
description: Text content (base64 encoded)
|
||||
format: byte
|
||||
description: Text content, either raw text or base64 encoded for compatibility with older clients
|
||||
example: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==
|
||||
id:
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ properties:
|
|||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
explain_triples:
|
||||
type: array
|
||||
description: Provenance triples for this explain event (inline, no follow-up query needed)
|
||||
items:
|
||||
$ref: '../common/Triple.yaml'
|
||||
end-of-stream:
|
||||
type: boolean
|
||||
description: Indicates LLM response stream is complete
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ properties:
|
|||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
explain_triples:
|
||||
type: array
|
||||
description: Provenance triples for this explain event (inline, no follow-up query needed)
|
||||
items:
|
||||
$ref: '../common/Triple.yaml'
|
||||
end_of_stream:
|
||||
type: boolean
|
||||
description: Indicates LLM response stream is complete
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ openapi: 3.1.0
|
|||
|
||||
info:
|
||||
title: TrustGraph API Gateway
|
||||
version: "2.1"
|
||||
version: "2.2"
|
||||
description: |
|
||||
REST API for TrustGraph - an AI-powered knowledge graph and RAG system.
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ info:
|
|||
Require running flow instance, accessed via `/api/v1/flow/{flow}/service/{kind}`:
|
||||
- AI services: agent, text-completion, prompt, RAG (document/graph)
|
||||
- Embeddings: embeddings, graph-embeddings, document-embeddings
|
||||
- Query: triples, rows, nlp-query, structured-query, row-embeddings
|
||||
- Query: triples, rows, nlp-query, structured-query, sparql-query, row-embeddings
|
||||
- Data loading: text-load, document-load
|
||||
- Utilities: mcp-tool, structured-diag
|
||||
|
||||
|
|
@ -139,6 +139,8 @@ paths:
|
|||
$ref: './paths/flow/text-load.yaml'
|
||||
/api/v1/flow/{flow}/service/document-load:
|
||||
$ref: './paths/flow/document-load.yaml'
|
||||
/api/v1/flow/{flow}/service/sparql-query:
|
||||
$ref: './paths/flow/sparql-query.yaml'
|
||||
|
||||
# Document streaming
|
||||
/api/v1/document-stream:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ post:
|
|||
- `action`: Action being taken
|
||||
- `observation`: Result from action
|
||||
- `answer`: Final response to user
|
||||
- `explain`: Provenance event with inline triples (`explain_triples`)
|
||||
- `error`: Error occurred
|
||||
|
||||
Each chunk may have multiple messages. Check flags:
|
||||
|
|
@ -116,6 +117,22 @@ post:
|
|||
content: ""
|
||||
end-of-message: true
|
||||
end-of-dialog: true
|
||||
explainEvent:
|
||||
summary: Explain event with inline provenance triples
|
||||
value:
|
||||
chunk-type: explain
|
||||
content: ""
|
||||
explain_id: urn:trustgraph:agent:abc123
|
||||
explain_graph: urn:graph:retrieval
|
||||
explain_triples:
|
||||
- s: {t: i, i: "urn:trustgraph:agent:abc123"}
|
||||
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
|
||||
o: {t: i, i: "https://trustgraph.ai/ns/AgentSession"}
|
||||
- s: {t: i, i: "urn:trustgraph:agent:abc123"}
|
||||
p: {t: i, i: "https://trustgraph.ai/ns/query"}
|
||||
o: {t: l, v: "Explain quantum computing"}
|
||||
end-of-message: true
|
||||
end-of-dialog: false
|
||||
legacyResponse:
|
||||
summary: Legacy non-streaming response
|
||||
value:
|
||||
|
|
|
|||
|
|
@ -24,8 +24,13 @@ post:
|
|||
## Streaming
|
||||
|
||||
Enable `streaming: true` to receive the answer as it's generated:
|
||||
- Multiple messages with `response` content
|
||||
- Multiple `chunk` messages with `response` content
|
||||
- `explain` messages with inline provenance triples (`explain_triples`)
|
||||
- Final message with `end-of-stream: true`
|
||||
- Session ends with `end_of_session: true`
|
||||
|
||||
Explain events carry `explain_id`, `explain_graph`, and `explain_triples`
|
||||
inline in the stream, so no follow-up knowledge graph query is needed.
|
||||
|
||||
Without streaming, returns complete answer in single response.
|
||||
|
||||
|
|
@ -96,6 +101,21 @@ post:
|
|||
value:
|
||||
response: "The research papers present three"
|
||||
end-of-stream: false
|
||||
explainEvent:
|
||||
summary: Explain event with inline provenance triples
|
||||
value:
|
||||
message_type: explain
|
||||
explain_id: urn:trustgraph:question:abc123
|
||||
explain_graph: urn:graph:retrieval
|
||||
explain_triples:
|
||||
- s: {t: i, i: "urn:trustgraph:question:abc123"}
|
||||
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
|
||||
o: {t: i, i: "https://trustgraph.ai/ns/DocumentRagQuestion"}
|
||||
- s: {t: i, i: "urn:trustgraph:question:abc123"}
|
||||
p: {t: i, i: "https://trustgraph.ai/ns/query"}
|
||||
o: {t: l, v: "What are the key findings in the research papers?"}
|
||||
end-of-stream: false
|
||||
end_of_session: false
|
||||
streamingComplete:
|
||||
summary: Streaming complete marker
|
||||
value:
|
||||
|
|
|
|||
|
|
@ -25,8 +25,13 @@ post:
|
|||
## Streaming
|
||||
|
||||
Enable `streaming: true` to receive the answer as it's generated:
|
||||
- Multiple messages with `response` content
|
||||
- Multiple `chunk` messages with `response` content
|
||||
- `explain` messages with inline provenance triples (`explain_triples`)
|
||||
- Final message with `end-of-stream: true`
|
||||
- Session ends with `end_of_session: true`
|
||||
|
||||
Explain events carry `explain_id`, `explain_graph`, and `explain_triples`
|
||||
inline in the stream, so no follow-up knowledge graph query is needed.
|
||||
|
||||
Without streaming, returns complete answer in single response.
|
||||
|
||||
|
|
@ -116,6 +121,21 @@ post:
|
|||
value:
|
||||
response: "Quantum physics and computer science intersect"
|
||||
end-of-stream: false
|
||||
explainEvent:
|
||||
summary: Explain event with inline provenance triples
|
||||
value:
|
||||
message_type: explain
|
||||
explain_id: urn:trustgraph:question:abc123
|
||||
explain_graph: urn:graph:retrieval
|
||||
explain_triples:
|
||||
- s: {t: i, i: "urn:trustgraph:question:abc123"}
|
||||
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
|
||||
o: {t: i, i: "https://trustgraph.ai/ns/GraphRagQuestion"}
|
||||
- s: {t: i, i: "urn:trustgraph:question:abc123"}
|
||||
p: {t: i, i: "https://trustgraph.ai/ns/query"}
|
||||
o: {t: l, v: "What connections exist between quantum physics and computer science?"}
|
||||
end_of_stream: false
|
||||
end_of_session: false
|
||||
streamingComplete:
|
||||
summary: Streaming complete marker
|
||||
value:
|
||||
|
|
|
|||
145
specs/api/paths/flow/sparql-query.yaml
Normal file
145
specs/api/paths/flow/sparql-query.yaml
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
post:
|
||||
tags:
|
||||
- Flow Services
|
||||
summary: SPARQL query - execute SPARQL 1.1 queries against the knowledge graph
|
||||
description: |
|
||||
Execute a SPARQL 1.1 query against the knowledge graph.
|
||||
|
||||
## Supported Query Types
|
||||
|
||||
- **SELECT**: Returns variable bindings as a table of results
|
||||
- **ASK**: Returns true/false for existence checks
|
||||
- **CONSTRUCT**: Returns a set of triples built from a template
|
||||
- **DESCRIBE**: Returns triples describing matched resources
|
||||
|
||||
## SPARQL Features
|
||||
|
||||
Supports standard SPARQL 1.1 features including:
|
||||
- Basic Graph Patterns (BGPs) with triple pattern matching
|
||||
- OPTIONAL, UNION, FILTER
|
||||
- BIND, VALUES
|
||||
- ORDER BY, LIMIT, OFFSET, DISTINCT
|
||||
- GROUP BY with aggregates (COUNT, SUM, AVG, MIN, MAX, GROUP_CONCAT)
|
||||
- Built-in functions (isIRI, STR, REGEX, CONTAINS, etc.)
|
||||
|
||||
## Query Examples
|
||||
|
||||
Find all entities of a type:
|
||||
```sparql
|
||||
SELECT ?s ?label WHERE {
|
||||
?s <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://example.com/Person> .
|
||||
?s <http://www.w3.org/2000/01/rdf-schema#label> ?label .
|
||||
}
|
||||
LIMIT 10
|
||||
```
|
||||
|
||||
Check if an entity exists:
|
||||
```sparql
|
||||
ASK { <http://example.com/alice> ?p ?o }
|
||||
```
|
||||
|
||||
operationId: sparqlQueryService
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: flow
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: Flow instance ID
|
||||
example: my-flow
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- query
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: SPARQL 1.1 query string
|
||||
user:
|
||||
type: string
|
||||
default: trustgraph
|
||||
description: User/keyspace identifier
|
||||
collection:
|
||||
type: string
|
||||
default: default
|
||||
description: Collection identifier
|
||||
limit:
|
||||
type: integer
|
||||
default: 10000
|
||||
description: Safety limit on number of results
|
||||
examples:
|
||||
selectQuery:
|
||||
summary: SELECT query
|
||||
value:
|
||||
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
|
||||
user: trustgraph
|
||||
collection: default
|
||||
askQuery:
|
||||
summary: ASK query
|
||||
value:
|
||||
query: "ASK { <http://example.com/alice> ?p ?o }"
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
query-type:
|
||||
type: string
|
||||
enum: [select, ask, construct, describe]
|
||||
variables:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Variable names (SELECT only)
|
||||
bindings:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
values:
|
||||
type: array
|
||||
items:
|
||||
$ref: '../../components/schemas/common/RdfValue.yaml'
|
||||
description: Result rows (SELECT only)
|
||||
ask-result:
|
||||
type: boolean
|
||||
description: Boolean result (ASK only)
|
||||
triples:
|
||||
type: array
|
||||
description: Result triples (CONSTRUCT/DESCRIBE only)
|
||||
error:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
examples:
|
||||
selectResult:
|
||||
summary: SELECT result
|
||||
value:
|
||||
query-type: select
|
||||
variables: [s, p, o]
|
||||
bindings:
|
||||
- values:
|
||||
- {t: i, i: "http://example.com/alice"}
|
||||
- {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
|
||||
- {t: i, i: "http://example.com/Person"}
|
||||
askResult:
|
||||
summary: ASK result
|
||||
value:
|
||||
query-type: ask
|
||||
ask-result: true
|
||||
'401':
|
||||
$ref: '../../components/responses/Unauthorized.yaml'
|
||||
'500':
|
||||
$ref: '../../components/responses/Error.yaml'
|
||||
|
|
@ -8,7 +8,7 @@ post:
|
|||
## Text Load Overview
|
||||
|
||||
Fire-and-forget document loading:
|
||||
- **Input**: Text content (base64 encoded)
|
||||
- **Input**: Text content (raw UTF-8 or base64 encoded)
|
||||
- **Process**: Chunk, embed, store
|
||||
- **Output**: None (202 Accepted)
|
||||
|
||||
|
|
@ -26,7 +26,14 @@ post:
|
|||
|
||||
## Text Format
|
||||
|
||||
Text must be base64 encoded:
|
||||
Text may be sent as raw UTF-8 text:
|
||||
```
|
||||
{
|
||||
"text": "Cancer survival: 2.74× higher hazard ratio"
|
||||
}
|
||||
```
|
||||
|
||||
Older clients may still send base64 encoded text:
|
||||
```
|
||||
text_content = "This is the document..."
|
||||
encoded = base64.b64encode(text_content.encode('utf-8'))
|
||||
|
|
@ -78,12 +85,12 @@ post:
|
|||
simpleLoad:
|
||||
summary: Load text document
|
||||
value:
|
||||
text: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==
|
||||
text: This is the document text...
|
||||
id: doc-123
|
||||
user: alice
|
||||
collection: research
|
||||
withMetadata:
|
||||
summary: Load with RDF metadata
|
||||
summary: Load with RDF metadata using base64 text
|
||||
value:
|
||||
text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u
|
||||
id: doc-456
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ asyncapi: 3.0.0
|
|||
|
||||
info:
|
||||
title: TrustGraph WebSocket API
|
||||
version: "2.1"
|
||||
version: "2.2"
|
||||
description: |
|
||||
WebSocket API for TrustGraph - providing multiplexed, asynchronous access to all services.
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ info:
|
|||
**Flow-Hosted Services** (require `flow` parameter):
|
||||
- agent, text-completion, prompt, document-rag, graph-rag
|
||||
- embeddings, graph-embeddings, document-embeddings
|
||||
- triples, rows, nlp-query, structured-query, structured-diag, row-embeddings
|
||||
- triples, rows, nlp-query, structured-query, sparql-query, structured-diag, row-embeddings
|
||||
- text-load, document-load, mcp-tool
|
||||
|
||||
## Schema Reuse
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ payload:
|
|||
- $ref: './requests/RowEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/TextLoadRequest.yaml'
|
||||
- $ref: './requests/DocumentLoadRequest.yaml'
|
||||
- $ref: './requests/SparqlQueryRequest.yaml'
|
||||
|
||||
examples:
|
||||
- name: Config service request
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
type: object
|
||||
description: WebSocket request for sparql-query service (flow-hosted service)
|
||||
required:
|
||||
- id
|
||||
- service
|
||||
- flow
|
||||
- request
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique request identifier
|
||||
service:
|
||||
type: string
|
||||
const: sparql-query
|
||||
description: Service identifier for sparql-query service
|
||||
flow:
|
||||
type: string
|
||||
description: Flow ID
|
||||
request:
|
||||
type: object
|
||||
required:
|
||||
- query
|
||||
properties:
|
||||
query:
|
||||
type: string
|
||||
description: SPARQL 1.1 query string
|
||||
user:
|
||||
type: string
|
||||
default: trustgraph
|
||||
description: User/keyspace identifier
|
||||
collection:
|
||||
type: string
|
||||
default: default
|
||||
description: Collection identifier
|
||||
limit:
|
||||
type: integer
|
||||
default: 10000
|
||||
description: Safety limit on number of results
|
||||
examples:
|
||||
- id: req-1
|
||||
service: sparql-query
|
||||
flow: my-flow
|
||||
request:
|
||||
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
|
||||
user: trustgraph
|
||||
collection: default
|
||||
|
|
@ -87,10 +87,11 @@ def sample_message_data():
|
|||
"history": []
|
||||
},
|
||||
"AgentResponse": {
|
||||
"answer": "Machine learning is a subset of AI.",
|
||||
"chunk_type": "answer",
|
||||
"content": "Machine learning is a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
"error": None,
|
||||
"thought": "I need to provide information about machine learning.",
|
||||
"observation": None
|
||||
},
|
||||
"Metadata": {
|
||||
"id": "test-doc-123",
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert request.user == "test_user"
|
||||
assert request.collection == "test_collection"
|
||||
|
||||
def test_request_translator_to_pulsar(self):
|
||||
def test_request_translator_decode(self):
|
||||
"""Test request translator converts dict to Pulsar schema"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
"collection": "custom_collection"
|
||||
}
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
result = translator.decode(data)
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||
|
|
@ -57,7 +57,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert result.user == "custom_user"
|
||||
assert result.collection == "custom_collection"
|
||||
|
||||
def test_request_translator_to_pulsar_with_defaults(self):
|
||||
def test_request_translator_decode_with_defaults(self):
|
||||
"""Test request translator uses correct defaults"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
# No limit, user, or collection provided
|
||||
}
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
result = translator.decode(data)
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vector == [0.1, 0.2]
|
||||
|
|
@ -74,7 +74,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
assert result.user == "trustgraph" # Default
|
||||
assert result.collection == "default" # Default
|
||||
|
||||
def test_request_translator_from_pulsar(self):
|
||||
def test_request_translator_encode(self):
|
||||
"""Test request translator converts Pulsar schema to dict"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(request)
|
||||
result = translator.encode(request)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["vector"] == [0.5, 0.6]
|
||||
|
|
@ -134,7 +134,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert response.error == error
|
||||
assert response.chunks == []
|
||||
|
||||
def test_response_translator_from_pulsar_with_chunks(self):
|
||||
def test_response_translator_encode_with_chunks(self):
|
||||
"""Test response translator converts Pulsar schema with chunks to dict"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
|
|
@ -155,7 +155,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
|
||||
assert result["chunks"][0]["score"] == 0.95
|
||||
|
||||
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
||||
def test_response_translator_encode_with_empty_chunks(self):
|
||||
"""Test response translator handles empty chunks list"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -164,25 +164,25 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
chunks=[]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == []
|
||||
|
||||
def test_response_translator_from_pulsar_with_none_chunks(self):
|
||||
def test_response_translator_encode_with_none_chunks(self):
|
||||
"""Test response translator handles None chunks"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
response = MagicMock()
|
||||
response.chunks = None
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" not in result or result.get("chunks") is None
|
||||
|
||||
def test_response_translator_from_response_with_completion(self):
|
||||
def test_response_translator_encode_with_completion(self):
|
||||
"""Test response translator with completion flag"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
]
|
||||
)
|
||||
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
|
|
@ -202,12 +202,12 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
assert result["chunks"][0]["chunk_id"] == "chunk1"
|
||||
assert is_final is True # Document embeddings responses are always final
|
||||
|
||||
def test_response_translator_to_pulsar_not_implemented(self):
|
||||
"""Test that to_pulsar raises NotImplementedError for responses"""
|
||||
def test_response_translator_decode_not_implemented(self):
|
||||
"""Test that decode raises NotImplementedError for responses"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||
translator.decode({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsMessageCompatibility:
|
||||
|
|
@ -225,7 +225,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert to Pulsar request
|
||||
req_translator = DocumentEmbeddingsRequestTranslator()
|
||||
pulsar_request = req_translator.to_pulsar(request_data)
|
||||
pulsar_request = req_translator.decode(request_data)
|
||||
|
||||
# Simulate service processing and creating response
|
||||
response = DocumentEmbeddingsResponse(
|
||||
|
|
@ -238,7 +238,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert response back to dict
|
||||
resp_translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = resp_translator.from_pulsar(response)
|
||||
response_data = resp_translator.encode(response)
|
||||
|
||||
# Verify data integrity
|
||||
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||
|
|
@ -261,7 +261,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
|
||||
# Convert response to dict
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = translator.from_pulsar(response)
|
||||
response_data = translator.encode(response)
|
||||
|
||||
# Verify error handling
|
||||
assert isinstance(response_data, dict)
|
||||
|
|
|
|||
|
|
@ -212,10 +212,11 @@ class TestAgentMessageContracts:
|
|||
|
||||
# Test required fields
|
||||
response = AgentResponse(**response_data)
|
||||
assert hasattr(response, 'answer')
|
||||
assert hasattr(response, 'chunk_type')
|
||||
assert hasattr(response, 'content')
|
||||
assert hasattr(response, 'end_of_message')
|
||||
assert hasattr(response, 'end_of_dialog')
|
||||
assert hasattr(response, 'error')
|
||||
assert hasattr(response, 'thought')
|
||||
assert hasattr(response, 'observation')
|
||||
|
||||
def test_agent_step_schema_contract(self):
|
||||
"""Test AgentStep schema contract"""
|
||||
|
|
|
|||
177
tests/contract/test_orchestrator_contracts.py
Normal file
177
tests/contract/test_orchestrator_contracts.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Contract tests for orchestrator message schemas.
|
||||
|
||||
Verifies that AgentRequest/AgentStep with orchestration fields
|
||||
serialise and deserialise correctly through the Pulsar schema layer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestOrchestrationFieldContracts:
|
||||
"""Contract tests for orchestration fields on AgentRequest."""
|
||||
|
||||
def test_agent_request_orchestration_fields_roundtrip(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
pattern="react",
|
||||
task_type="research",
|
||||
framing="Focus on accuracy",
|
||||
conversation_id="conv-456",
|
||||
)
|
||||
|
||||
assert req.correlation_id == "corr-123"
|
||||
assert req.parent_session_id == "parent-sess"
|
||||
assert req.subagent_goal == "What is X?"
|
||||
assert req.expected_siblings == 4
|
||||
assert req.pattern == "react"
|
||||
assert req.task_type == "research"
|
||||
assert req.framing == "Focus on accuracy"
|
||||
assert req.conversation_id == "conv-456"
|
||||
|
||||
def test_agent_request_orchestration_fields_default_empty(self):
|
||||
req = AgentRequest(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
)
|
||||
|
||||
assert req.correlation_id == ""
|
||||
assert req.parent_session_id == ""
|
||||
assert req.subagent_goal == ""
|
||||
assert req.expected_siblings == 0
|
||||
assert req.pattern == ""
|
||||
assert req.task_type == ""
|
||||
assert req.framing == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSubagentCompletionStepContract:
|
||||
"""Contract tests for subagent-completion step type."""
|
||||
|
||||
def test_subagent_completion_step_fields(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="The answer text",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
|
||||
assert step.step_type == "subagent-completion"
|
||||
assert step.observation == "The answer text"
|
||||
assert step.thought == "Subagent completed"
|
||||
assert step.action == "complete"
|
||||
|
||||
def test_subagent_completion_in_request_history(self):
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation="answer",
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
req = AgentRequest(
|
||||
question="goal",
|
||||
user="testuser",
|
||||
correlation_id="corr-123",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
assert len(req.history) == 1
|
||||
assert req.history[0].step_type == "subagent-completion"
|
||||
assert req.history[0].observation == "answer"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisStepContract:
|
||||
"""Contract tests for synthesis step type with subagent_results."""
|
||||
|
||||
def test_synthesis_step_with_results(self):
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
assert step.step_type == "synthesise"
|
||||
assert step.subagent_results == results
|
||||
assert json.loads(step.observation) == results
|
||||
|
||||
def test_synthesis_request_matches_supervisor_expectations(self):
|
||||
"""The synthesis request built by the aggregator must be
|
||||
recognisable by SupervisorPattern._synthesise()."""
|
||||
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
|
||||
step = AgentStep(
|
||||
thought="All subagents completed",
|
||||
action="aggregate",
|
||||
arguments={},
|
||||
observation=json.dumps(results),
|
||||
step_type="synthesise",
|
||||
subagent_results=results,
|
||||
)
|
||||
|
||||
req = AgentRequest(
|
||||
question="Original question",
|
||||
user="testuser",
|
||||
pattern="supervisor",
|
||||
correlation_id="",
|
||||
session_id="parent-sess",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
# SupervisorPattern checks for step_type='synthesise' with
|
||||
# subagent_results
|
||||
has_results = bool(
|
||||
req.history
|
||||
and any(
|
||||
getattr(h, 'step_type', '') == 'synthesise'
|
||||
and getattr(h, 'subagent_results', None)
|
||||
for h in req.history
|
||||
)
|
||||
)
|
||||
assert has_results
|
||||
|
||||
# Pattern must be supervisor
|
||||
assert req.pattern == "supervisor"
|
||||
|
||||
# Correlation ID must be empty (not re-intercepted)
|
||||
assert req.correlation_id == ""
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanStepContract:
|
||||
"""Contract tests for plan steps in history."""
|
||||
|
||||
def test_plan_step_in_history(self):
|
||||
plan = [
|
||||
PlanStep(goal="Step 1", tool_hint="knowledge-query",
|
||||
depends_on=[], status="completed", result="done"),
|
||||
PlanStep(goal="Step 2", tool_hint="",
|
||||
depends_on=[0], status="pending", result=""),
|
||||
]
|
||||
step = AgentStep(
|
||||
thought="Created plan",
|
||||
action="plan",
|
||||
step_type="plan",
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
assert step.step_type == "plan"
|
||||
assert len(step.plan) == 2
|
||||
assert step.plan[0].goal == "Step 1"
|
||||
assert step.plan[0].status == "completed"
|
||||
assert step.plan[1].depends_on == [0]
|
||||
129
tests/contract/test_provenance_wire_format.py
Normal file
129
tests/contract/test_provenance_wire_format.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
Contract tests for provenance triple wire format — verifies that triples
|
||||
built by the provenance library can be parsed by the explainability API
|
||||
through the wire format conversion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
wire_triples_to_tuples,
|
||||
)
|
||||
|
||||
|
||||
def _triples_to_wire(triples):
|
||||
"""Convert provenance Triple objects to the wire format dicts
|
||||
that the gateway/socket client would produce."""
|
||||
wire = []
|
||||
for t in triples:
|
||||
entry = {
|
||||
"s": _term_to_wire(t.s),
|
||||
"p": _term_to_wire(t.p),
|
||||
"o": _term_to_wire(t.o),
|
||||
}
|
||||
wire.append(entry)
|
||||
return wire
|
||||
|
||||
|
||||
def _term_to_wire(term):
|
||||
"""Convert a Term to wire format dict."""
|
||||
if term.type == IRI:
|
||||
return {"t": "i", "i": term.iri}
|
||||
elif term.type == LITERAL:
|
||||
return {"t": "l", "v": term.value}
|
||||
return {"t": "l", "v": str(term)}
|
||||
|
||||
|
||||
def _roundtrip(triples, uri):
|
||||
"""Convert triples through wire format and parse via from_triples."""
|
||||
wire = _triples_to_wire(triples)
|
||||
tuples = wire_triples_to_tuples(wire)
|
||||
return ExplainEntity.from_triples(uri, tuples)
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestDecompositionWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session",
|
||||
["What is X?", "What is Y?"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:decompose")
|
||||
|
||||
assert isinstance(entity, Decomposition)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestFindingWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
document_id="urn:doc/finding",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:finding")
|
||||
|
||||
assert isinstance(entity, Finding)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestPlanWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session",
|
||||
["Step 1", "Step 2", "Step 3"],
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:plan")
|
||||
|
||||
assert isinstance(entity, Plan)
|
||||
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestStepResultWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:step")
|
||||
|
||||
assert isinstance(entity, StepResult)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSynthesisWireFormat:
|
||||
|
||||
def test_roundtrip(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
entity = _roundtrip(triples, "urn:synthesis")
|
||||
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document == "urn:doc/synthesis"
|
||||
|
|
@ -33,7 +33,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
|
|
@ -57,7 +57,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_session=False"
|
||||
|
|
@ -80,7 +80,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
|
|
@ -103,7 +103,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
|
|
@ -125,7 +125,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
|
|
@ -147,7 +147,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
|
|
@ -168,7 +168,7 @@ class TestRAGTranslatorCompletionFlags:
|
|||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
|
|
@ -188,20 +188,18 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
chunk_type="answer",
|
||||
content="4",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_dialog=True"
|
||||
assert response_dict["answer"] == "4"
|
||||
assert response_dict["content"] == "4"
|
||||
assert response_dict["end_of_dialog"] is True
|
||||
|
||||
def test_agent_translator_is_final_with_end_of_dialog_false(self):
|
||||
|
|
@ -212,44 +210,20 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="I need to solve this.",
|
||||
observation=None,
|
||||
chunk_type="thought",
|
||||
content="I need to solve this.",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_dialog=False"
|
||||
assert response_dict["thought"] == "I need to solve this."
|
||||
assert response_dict["content"] == "I need to solve this."
|
||||
assert response_dict["end_of_dialog"] is False
|
||||
|
||||
def test_agent_translator_is_final_fallback_with_answer(self):
|
||||
"""
|
||||
Test that AgentResponseTranslator returns is_final=True
|
||||
when answer is present (fallback for legacy responses).
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
# Legacy response without end_of_dialog flag
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when answer is present (legacy fallback)"
|
||||
assert response_dict["answer"] == "4"
|
||||
|
||||
def test_agent_translator_intermediate_message_is_not_final(self):
|
||||
"""
|
||||
Test that intermediate messages (thought/observation) return is_final=False.
|
||||
|
|
@ -259,32 +233,28 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test thought message
|
||||
thought_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="Processing...",
|
||||
observation=None,
|
||||
chunk_type="thought",
|
||||
content="Processing...",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
thought_dict, thought_is_final = translator.from_response_with_completion(thought_response)
|
||||
thought_dict, thought_is_final = translator.encode_with_completion(thought_response)
|
||||
|
||||
# Assert
|
||||
assert thought_is_final is False, "Thought message must not be final"
|
||||
|
||||
# Test observation message
|
||||
observation_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation="Result found",
|
||||
chunk_type="observation",
|
||||
content="Result found",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
# Act
|
||||
obs_dict, obs_is_final = translator.from_response_with_completion(observation_response)
|
||||
obs_dict, obs_is_final = translator.encode_with_completion(observation_response)
|
||||
|
||||
# Assert
|
||||
assert obs_is_final is False, "Observation message must not be final"
|
||||
|
|
@ -302,14 +272,10 @@ class TestAgentTranslatorCompletionFlags:
|
|||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
response_dict, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "Streaming format must use end_of_dialog for is_final"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
|
|||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, ANY, patch
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
|
|
@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
|
||||
# Verify tool was executed
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
|
|
@ -272,7 +272,7 @@ Args: {{
|
|||
|
||||
# Verify correct service was called
|
||||
if tool_name == "knowledge_query":
|
||||
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
|
||||
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
elif tool_name == "text_completion":
|
||||
mock_flow_context("prompt-request").question.assert_called()
|
||||
|
||||
|
|
@ -726,7 +726,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
|
||||
|
|
@ -739,7 +739,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
|
||||
|
|
@ -752,7 +752,7 @@ Final Answer: {
|
|||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
|
||||
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
|
||||
|
|
@ -810,7 +810,7 @@ Args: {
|
|||
|
||||
# Verify the custom collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
|
||||
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers", explain_callback=ANY, parent_uri=ANY)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
|
||||
|
|
@ -840,4 +840,4 @@ Args: {
|
|||
|
||||
# Verify correct collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)
|
||||
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection, explain_callback=ANY, parent_uri=ANY)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class TestAgentServiceNonStreaming:
|
|||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
await think("I need to solve this.", is_final=True)
|
||||
await observe("The answer is 4.", is_final=True)
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
|
@ -76,22 +76,33 @@ class TestAgentServiceNonStreaming:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 3 responses (thought, observation, answer)
|
||||
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session, iteration, observation, and final
|
||||
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
|
||||
|
||||
# Should have 3 content responses (thought, observation, answer)
|
||||
assert len(content_responses) == 3, f"Expected 3 content responses, got {len(content_responses)}"
|
||||
|
||||
# Check thought message
|
||||
thought_response = sent_responses[0]
|
||||
thought_response = content_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.thought == "I need to solve this."
|
||||
assert thought_response.answer is None
|
||||
assert thought_response.chunk_type == "thought"
|
||||
assert thought_response.content == "I need to solve this."
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
||||
# Check observation message
|
||||
observation_response = sent_responses[1]
|
||||
observation_response = content_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.observation == "The answer is 4."
|
||||
assert observation_response.answer is None
|
||||
assert observation_response.chunk_type == "observation"
|
||||
assert observation_response.content == "The answer is 4."
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
||||
|
|
@ -120,7 +131,7 @@ class TestAgentServiceNonStreaming:
|
|||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
||||
mock_agent_instance.react = mock_react
|
||||
|
|
@ -155,15 +166,25 @@ class TestAgentServiceNonStreaming:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 1 response (final answer)
|
||||
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session and final
|
||||
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
|
||||
|
||||
# Should have 1 content response (final answer)
|
||||
assert len(content_responses) == 1, f"Expected 1 content response, got {len(content_responses)}"
|
||||
|
||||
# Check final answer message
|
||||
answer_response = sent_responses[0]
|
||||
answer_response = content_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.answer == "4"
|
||||
assert answer_response.thought is None
|
||||
assert answer_response.observation is None
|
||||
assert answer_response.chunk_type == "answer"
|
||||
assert answer_response.content == "4"
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
||||
|
|
|
|||
216
tests/unit/test_agent/test_aggregator.py
Normal file
216
tests/unit/test_agent/test_aggregator.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
Unit tests for the Aggregator — tracks fan-out correlations and triggers
|
||||
synthesis when all subagents complete.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(question="Test question", user="testuser",
|
||||
collection="default", streaming=False,
|
||||
session_id="parent-session", task_type="research",
|
||||
framing="test framing", conversation_id="conv-1"):
|
||||
return AgentRequest(
|
||||
question=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
framing=framing,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
class TestRegisterFanout:
|
||||
|
||||
def test_stores_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert "corr-1" in agg.correlations
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["parent_session_id"] == "parent-1"
|
||||
assert entry["expected"] == 3
|
||||
assert entry["results"] == {}
|
||||
|
||||
def test_stores_request_template(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
entry = agg.correlations["corr-1"]
|
||||
assert entry["request_template"] is template
|
||||
|
||||
def test_records_creation_time(self):
|
||||
agg = Aggregator()
|
||||
before = time.time()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
after = time.time()
|
||||
|
||||
created = agg.correlations["corr-1"]["created_at"]
|
||||
assert before <= created <= after
|
||||
|
||||
|
||||
class TestRecordCompletion:
|
||||
|
||||
def test_returns_false_until_all_done(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 3)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
|
||||
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
|
||||
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
|
||||
|
||||
def test_returns_none_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
result = agg.record_completion("unknown", "goal", "answer")
|
||||
assert result is None
|
||||
|
||||
def test_stores_results_by_goal(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
results = agg.correlations["corr-1"]["results"]
|
||||
assert results["goal-a"] == "answer-a"
|
||||
assert results["goal-b"] == "answer-b"
|
||||
|
||||
def test_single_subagent(self):
|
||||
agg = Aggregator()
|
||||
agg.register_fanout("corr-1", "parent-1", 1)
|
||||
|
||||
assert agg.record_completion("corr-1", "goal-a", "answer") is True
|
||||
|
||||
|
||||
class TestGetOriginalRequest:
|
||||
|
||||
def test_peeks_without_consuming(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
|
||||
result = agg.get_original_request("corr-1")
|
||||
assert result is template
|
||||
# Entry still exists
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_returns_none_for_unknown(self):
|
||||
agg = Aggregator()
|
||||
assert agg.get_original_request("unknown") is None
|
||||
|
||||
|
||||
class TestBuildSynthesisRequest:
|
||||
|
||||
def test_builds_correct_request(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
)
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
assert req.question == "Original question"
|
||||
assert req.pattern == "supervisor"
|
||||
assert req.session_id == "parent-1"
|
||||
assert req.correlation_id == "" # Must be empty
|
||||
assert req.streaming == True
|
||||
assert req.task_type == "risk-assessment"
|
||||
assert req.framing == "Assess risks"
|
||||
|
||||
def test_synthesis_step_in_history(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 2,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
agg.record_completion("corr-1", "goal-b", "answer-b")
|
||||
|
||||
req = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Last history step should be the synthesis step
|
||||
assert len(req.history) >= 1
|
||||
synth_step = req.history[-1]
|
||||
assert synth_step.step_type == "synthesise"
|
||||
assert synth_step.subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_consumes_correlation_entry(self):
|
||||
agg = Aggregator()
|
||||
template = _make_request()
|
||||
agg.register_fanout("corr-1", "parent-1", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal-a", "answer-a")
|
||||
|
||||
agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# Entry should be removed
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_raises_for_unknown_correlation(self):
|
||||
agg = Aggregator()
|
||||
with pytest.raises(RuntimeError, match="No results"):
|
||||
agg.build_synthesis_request(
|
||||
"unknown", "question", "user", "default",
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupStale:
|
||||
|
||||
def test_removes_entries_older_than_timeout(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
# Backdate the creation time
|
||||
agg.correlations["corr-1"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "corr-1" in stale
|
||||
assert "corr-1" not in agg.correlations
|
||||
|
||||
def test_keeps_recent_entries(self):
|
||||
agg = Aggregator(timeout=300)
|
||||
agg.register_fanout("corr-1", "parent-1", 2)
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert stale == []
|
||||
assert "corr-1" in agg.correlations
|
||||
|
||||
def test_mixed_stale_and_fresh(self):
|
||||
agg = Aggregator(timeout=1)
|
||||
agg.register_fanout("stale", "parent-1", 2)
|
||||
agg.register_fanout("fresh", "parent-2", 2)
|
||||
|
||||
agg.correlations["stale"]["created_at"] = time.time() - 2
|
||||
|
||||
stale = agg.cleanup_stale()
|
||||
assert "stale" in stale
|
||||
assert "stale" not in agg.correlations
|
||||
assert "fresh" in agg.correlations
|
||||
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Tests that streaming callbacks set message_id on AgentResponse.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
from trustgraph.schema import AgentResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pattern():
|
||||
processor = MagicMock()
|
||||
return PatternBase(processor)
|
||||
|
||||
|
||||
class TestThinkCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
|
||||
await think("hello", is_final=False)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
|
||||
await think("hello")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_message is True
|
||||
|
||||
|
||||
class TestObserveCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_observe_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/observation"
|
||||
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
|
||||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_answer_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
answer = pattern.make_answer_callback(capture, streaming=True)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == ""
|
||||
|
||||
|
||||
class TestSendFinalResponseMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=True, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
# Should get content chunk + end-of-dialog marker
|
||||
assert all(r.message_id == msg_id for r in responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=False, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_dialog is True
|
||||
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
174
tests/unit/test_agent/test_completion_dispatch.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Unit tests for completion dispatch — verifies that agent_request() in the
|
||||
orchestrator service correctly intercepts subagent completion messages and
|
||||
routes them to _handle_subagent_completion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentStep
|
||||
|
||||
from trustgraph.agent.orchestrator.aggregator import Aggregator
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_completion_request(correlation_id, goal, answer):
|
||||
"""Build a completion request as emit_subagent_completion would."""
|
||||
step = AgentStep(
|
||||
thought="Subagent completed",
|
||||
action="complete",
|
||||
arguments={},
|
||||
observation=answer,
|
||||
step_type="subagent-completion",
|
||||
)
|
||||
return _make_request(
|
||||
correlation_id=correlation_id,
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal=goal,
|
||||
expected_siblings=2,
|
||||
history=[step],
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionDetection:
|
||||
"""Test that completion messages are correctly identified."""
|
||||
|
||||
def test_is_completion_when_correlation_id_and_step_type(self):
|
||||
req = _make_completion_request("corr-1", "goal-a", "answer-a")
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
|
||||
assert has_correlation
|
||||
assert is_completion
|
||||
|
||||
def test_not_completion_without_correlation_id(self):
|
||||
step = AgentStep(
|
||||
step_type="subagent-completion",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
has_correlation = bool(getattr(req, 'correlation_id', ''))
|
||||
assert not has_correlation
|
||||
|
||||
def test_not_completion_without_step_type(self):
|
||||
step = AgentStep(
|
||||
step_type="react",
|
||||
observation="answer",
|
||||
)
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[step],
|
||||
)
|
||||
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in req.history
|
||||
)
|
||||
assert not is_completion
|
||||
|
||||
def test_not_completion_with_empty_history(self):
|
||||
req = _make_request(
|
||||
correlation_id="corr-1",
|
||||
history=[],
|
||||
)
|
||||
assert not req.history
|
||||
|
||||
|
||||
class TestAggregatorIntegration:
|
||||
"""Test the aggregator flow as used by _handle_subagent_completion."""
|
||||
|
||||
def test_full_completion_flow(self):
|
||||
"""Simulates the flow: register, record completions, build synthesis."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(
|
||||
question="Original question",
|
||||
streaming=True,
|
||||
task_type="risk-assessment",
|
||||
framing="Assess risks",
|
||||
session_id="parent-sess",
|
||||
)
|
||||
|
||||
# Register fan-out
|
||||
agg.register_fanout("corr-1", "parent-sess", 2,
|
||||
request_template=template)
|
||||
|
||||
# First completion — not all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-a", "answer-a",
|
||||
)
|
||||
assert all_done is False
|
||||
|
||||
# Second completion — all done
|
||||
all_done = agg.record_completion(
|
||||
"corr-1", "goal-b", "answer-b",
|
||||
)
|
||||
assert all_done is True
|
||||
|
||||
# Peek at template
|
||||
peeked = agg.get_original_request("corr-1")
|
||||
assert peeked.question == "Original question"
|
||||
|
||||
# Build synthesis request
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1",
|
||||
original_question="Original question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
|
||||
# Verify synthesis request
|
||||
assert synth.pattern == "supervisor"
|
||||
assert synth.correlation_id == ""
|
||||
assert synth.session_id == "parent-sess"
|
||||
assert synth.streaming is True
|
||||
|
||||
# Verify synthesis history has results
|
||||
synth_steps = [
|
||||
s for s in synth.history
|
||||
if getattr(s, 'step_type', '') == 'synthesise'
|
||||
]
|
||||
assert len(synth_steps) == 1
|
||||
assert synth_steps[0].subagent_results == {
|
||||
"goal-a": "answer-a",
|
||||
"goal-b": "answer-b",
|
||||
}
|
||||
|
||||
def test_synthesis_request_not_detected_as_completion(self):
|
||||
"""The synthesis request must not be intercepted as a completion."""
|
||||
agg = Aggregator()
|
||||
template = _make_request(session_id="parent-sess")
|
||||
agg.register_fanout("corr-1", "parent-sess", 1,
|
||||
request_template=template)
|
||||
agg.record_completion("corr-1", "goal", "answer")
|
||||
|
||||
synth = agg.build_synthesis_request(
|
||||
"corr-1", "question", "user", "default",
|
||||
)
|
||||
|
||||
# correlation_id must be empty so it's not intercepted
|
||||
assert synth.correlation_id == ""
|
||||
|
||||
# Even if we check for completion step, shouldn't match
|
||||
is_completion = any(
|
||||
getattr(h, 'step_type', '') == 'subagent-completion'
|
||||
for h in synth.history
|
||||
)
|
||||
assert not is_completion
|
||||
177
tests/unit/test_agent/test_explainability_parsing.py
Normal file
177
tests/unit/test_agent/test_explainability_parsing.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Unit tests for explainability API parsing — verifies that from_triples()
|
||||
correctly dispatches and parses the new orchestrator entity types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
ExplainEntity,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
Synthesis,
|
||||
Analysis,
|
||||
Observation,
|
||||
Conclusion,
|
||||
TG_DECOMPOSITION,
|
||||
TG_FINDING,
|
||||
TG_PLAN_TYPE,
|
||||
TG_STEP_RESULT,
|
||||
TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE,
|
||||
TG_ANALYSIS,
|
||||
TG_CONCLUSION,
|
||||
TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL,
|
||||
TG_PLAN_STEP,
|
||||
RDF_TYPE,
|
||||
)
|
||||
|
||||
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
|
||||
|
||||
|
||||
def _make_triples(uri, types, extras=None):
|
||||
"""Build a list of (s, p, o) tuples for testing."""
|
||||
triples = [(uri, RDF_TYPE, t) for t in types]
|
||||
if extras:
|
||||
triples.extend((uri, p, o) for p, o in extras)
|
||||
return triples
|
||||
|
||||
|
||||
class TestFromTriplesDispatch:
|
||||
|
||||
def test_dispatches_decomposition(self):
|
||||
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
|
||||
entity = ExplainEntity.from_triples("urn:d", triples)
|
||||
assert isinstance(entity, Decomposition)
|
||||
|
||||
def test_dispatches_finding(self):
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
|
||||
def test_dispatches_plan(self):
|
||||
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:p", triples)
|
||||
assert isinstance(entity, Plan)
|
||||
|
||||
def test_dispatches_step_result(self):
|
||||
triples = _make_triples("urn:sr",
|
||||
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:sr", triples)
|
||||
assert isinstance(entity, StepResult)
|
||||
|
||||
def test_dispatches_synthesis(self):
|
||||
triples = _make_triples("urn:s",
|
||||
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:s", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
|
||||
def test_dispatches_analysis_unchanged(self):
|
||||
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_analysis_with_tooluse(self):
|
||||
"""Analysis+ToolUse mixin still dispatches to Analysis."""
|
||||
triples = _make_triples("urn:a",
|
||||
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_observation(self):
|
||||
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:o", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
|
||||
def test_dispatches_conclusion_unchanged(self):
|
||||
triples = _make_triples("urn:c",
|
||||
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:c", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
|
||||
def test_finding_takes_precedence_over_synthesis(self):
|
||||
"""Finding has Answer mixin but should dispatch to Finding, not
|
||||
Synthesis, because Finding is checked first."""
|
||||
triples = _make_triples("urn:f",
|
||||
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:f", triples)
|
||||
assert isinstance(entity, Finding)
|
||||
assert not isinstance(entity, Synthesis)
|
||||
|
||||
|
||||
class TestDecompositionParsing:
|
||||
|
||||
def test_parses_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_SUBAGENT_GOAL, "What is Y?"),
|
||||
])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert set(entity.goals) == {"What is X?", "What is Y?"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.entity_type == "decomposition"
|
||||
|
||||
def test_empty_goals(self):
|
||||
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
|
||||
entity = Decomposition.from_triples("urn:d", triples)
|
||||
assert entity.goals == []
|
||||
|
||||
|
||||
class TestFindingParsing:
|
||||
|
||||
def test_parses_goal_and_document(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
|
||||
(TG_SUBAGENT_GOAL, "What is X?"),
|
||||
(TG_DOCUMENT, "urn:doc/finding"),
|
||||
])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.goal == "What is X?"
|
||||
assert entity.document == "urn:doc/finding"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:f", [TG_FINDING])
|
||||
entity = Finding.from_triples("urn:f", triples)
|
||||
assert entity.entity_type == "finding"
|
||||
|
||||
|
||||
class TestPlanParsing:
|
||||
|
||||
def test_parses_steps(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_PLAN_STEP, "Research Y"),
|
||||
(TG_PLAN_STEP, "Analyse Z"),
|
||||
])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
|
||||
entity = Plan.from_triples("urn:p", triples)
|
||||
assert entity.entity_type == "plan"
|
||||
|
||||
|
||||
class TestStepResultParsing:
|
||||
|
||||
def test_parses_step_and_document(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
|
||||
(TG_PLAN_STEP, "Define X"),
|
||||
(TG_DOCUMENT, "urn:doc/step"),
|
||||
])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.step == "Define X"
|
||||
assert entity.document == "urn:doc/step"
|
||||
|
||||
def test_entity_type_field(self):
|
||||
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
|
||||
entity = StepResult.from_triples("urn:sr", triples)
|
||||
assert entity.entity_type == "step-result"
|
||||
289
tests/unit/test_agent/test_meta_router.py
Normal file
289
tests/unit/test_agent/test_meta_router.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Unit tests for the MetaRouter — task type identification and pattern selection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
"""Build a config dict as the config service would provide."""
|
||||
config = {}
|
||||
if patterns:
|
||||
config["agent-pattern"] = {
|
||||
pid: json.dumps(pdata) for pid, pdata in patterns.items()
|
||||
}
|
||||
if task_types:
|
||||
config["agent-task-type"] = {
|
||||
tid: json.dumps(tdata) for tid, tdata in task_types.items()
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
||||
return context
|
||||
|
||||
|
||||
SAMPLE_PATTERNS = {
|
||||
"react": {"name": "react", "description": "ReAct pattern"},
|
||||
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
|
||||
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
|
||||
}
|
||||
|
||||
SAMPLE_TASK_TYPES = {
|
||||
"general": {
|
||||
"name": "general",
|
||||
"description": "General queries",
|
||||
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
|
||||
"framing": "",
|
||||
},
|
||||
"research": {
|
||||
"name": "research",
|
||||
"description": "Research queries",
|
||||
"valid_patterns": ["react", "plan-then-execute"],
|
||||
"framing": "Focus on gathering information.",
|
||||
},
|
||||
"summarisation": {
|
||||
"name": "summarisation",
|
||||
"description": "Summarisation queries",
|
||||
"valid_patterns": ["react"],
|
||||
"framing": "Focus on concise synthesis.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestMetaRouterInit:
|
||||
|
||||
def test_defaults_when_no_config(self):
|
||||
router = MetaRouter()
|
||||
assert "react" in router.patterns
|
||||
assert "general" in router.task_types
|
||||
|
||||
def test_loads_patterns_from_config(self):
|
||||
config = _make_config(patterns=SAMPLE_PATTERNS)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
|
||||
|
||||
def test_loads_task_types_from_config(self):
|
||||
config = _make_config(task_types=SAMPLE_TASK_TYPES)
|
||||
router = MetaRouter(config=config)
|
||||
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
|
||||
|
||||
def test_handles_invalid_json_in_config(self):
|
||||
config = {
|
||||
"agent-pattern": {"react": "not valid json"},
|
||||
}
|
||||
router = MetaRouter(config=config)
|
||||
assert "react" in router.patterns
|
||||
assert router.patterns["react"]["name"] == "react"
|
||||
|
||||
|
||||
class TestIdentifyTaskType:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_task_type(self):
|
||||
router = MetaRouter() # Only "general"
|
||||
context = _make_context("should not be called")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == "general"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_task_types(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("research")
|
||||
|
||||
task_type, framing = await router.identify_task_type(
|
||||
"Research the topic", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert framing == "Focus on gathering information."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_llm_returning_quoted_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context('"summarisation"')
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"Summarise this", context,
|
||||
)
|
||||
|
||||
assert task_type == "summarisation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("nonexistent-type")
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
task_type, _ = await router.identify_task_type(
|
||||
"test question", context,
|
||||
)
|
||||
|
||||
assert task_type == DEFAULT_TASK_TYPE
|
||||
|
||||
|
||||
class TestSelectPattern:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_llm_when_single_valid_pattern(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("should not be called")
|
||||
|
||||
# summarisation only has ["react"]
|
||||
pattern = await router.select_pattern(
|
||||
"Summarise this", "summarisation", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_llm_when_multiple_valid_patterns(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("plan-then-execute")
|
||||
|
||||
# research has ["react", "plan-then-execute"]
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
assert pattern == "plan-then-execute"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_valid_patterns_constraint(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
# LLM returns supervisor, but research doesn't allow it
|
||||
context = _make_context("supervisor")
|
||||
|
||||
pattern = await router.select_pattern(
|
||||
"Research this", "research", context,
|
||||
)
|
||||
|
||||
# Should fall back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_llm_error(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
context = lambda name: client
|
||||
|
||||
# general has ["react", "plan-then-execute", "supervisor"]
|
||||
pattern = await router.select_pattern(
|
||||
"test", "general", context,
|
||||
)
|
||||
|
||||
# Falls back to first valid pattern
|
||||
assert pattern == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_default_for_unknown_task_type(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
context = _make_context("react")
|
||||
|
||||
# Unknown task type — valid_patterns falls back to all patterns
|
||||
pattern = await router.select_pattern(
|
||||
"test", "unknown-type", context,
|
||||
)
|
||||
|
||||
assert pattern == "react"
|
||||
|
||||
|
||||
class TestRoute:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_routing_pipeline(self):
|
||||
config = _make_config(
|
||||
patterns=SAMPLE_PATTERNS,
|
||||
task_types=SAMPLE_TASK_TYPES,
|
||||
)
|
||||
router = MetaRouter(config=config)
|
||||
|
||||
# Mock context where prompt returns different values per call
|
||||
client = AsyncMock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
||||
pattern, task_type, framing = await router.route(
|
||||
"Research the relationships", context,
|
||||
)
|
||||
|
||||
assert task_type == "research"
|
||||
assert pattern == "plan-then-execute"
|
||||
assert framing == "Focus on gathering information."
|
||||
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Tests for the on_action callback in react() — verifies that it fires
|
||||
after action selection but before tool execution.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
|
||||
|
||||
class TestOnActionCallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_called_for_tool_use(self):
|
||||
"""on_action fires when react() selects a tool (not Final)."""
|
||||
call_log = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
call_log.append(("on_action", act.name))
|
||||
|
||||
# Tool that records when it's invoked
|
||||
async def tool_invoke(**kwargs):
|
||||
call_log.append(("tool_invoke",))
|
||||
return "tool result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[Argument(name="query", type="string", description="q")],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
|
||||
# Mock reason() to return an Action
|
||||
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
|
||||
agent.reason = AsyncMock(return_value=action)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
# on_action should fire before tool_invoke
|
||||
assert len(call_log) == 2
|
||||
assert call_log[0] == ("on_action", "search")
|
||||
assert call_log[1] == ("tool_invoke",)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_not_called_for_final(self):
|
||||
"""on_action does not fire when react() returns Final."""
|
||||
called = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
called.append(act)
|
||||
|
||||
agent = AgentManager(tools={})
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Final(thought="done", final="answer")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
assert isinstance(result, Final)
|
||||
assert len(called) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_none_accepted(self):
|
||||
"""react() works fine when on_action is None (default)."""
|
||||
async def tool_invoke(**kwargs):
|
||||
return "result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Action(thought="t", name="search", arguments={}, observation="")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
# on_action not passed — defaults to None
|
||||
)
|
||||
|
||||
assert isinstance(result, Action)
|
||||
assert result.observation == "result"
|
||||
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Tests that _parse_chunk propagates message_id from wire format
|
||||
to AgentThought, AgentObservation, and AgentAnswer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.socket_client import SocketClient
|
||||
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# We only need _parse_chunk — don't connect
|
||||
c = object.__new__(SocketClient)
|
||||
return c
|
||||
|
||||
|
||||
class TestParseChunkMessageId:
|
||||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
|
||||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentObservation)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
|
||||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/final",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
|
||||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == ""
|
||||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == ""
|
||||
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
144
tests/unit/test_agent/test_pattern_base_subagent.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
Unit tests for PatternBase subagent helpers — is_subagent() and
|
||||
emit_subagent_completion().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.schema import AgentRequest
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockProcessor:
|
||||
"""Minimal processor mock for PatternBase."""
|
||||
pass
|
||||
|
||||
|
||||
def _make_request(**kwargs):
|
||||
defaults = dict(
|
||||
question="Test question",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return AgentRequest(**defaults)
|
||||
|
||||
|
||||
def _make_pattern():
|
||||
return PatternBase(MockProcessor())
|
||||
|
||||
|
||||
class TestIsSubagent:
|
||||
|
||||
def test_returns_true_when_correlation_id_set(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="corr-123")
|
||||
assert pattern.is_subagent(request) is True
|
||||
|
||||
def test_returns_false_when_correlation_id_empty(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(correlation_id="")
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
def test_returns_false_when_correlation_id_missing(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request()
|
||||
assert pattern.is_subagent(request) is False
|
||||
|
||||
|
||||
class TestEmitSubagentCompletion:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_next_with_completion_request(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
next_fn.assert_called_once()
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert isinstance(completion_req, AgentRequest)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_correct_step_type(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer text",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert len(completion_req.history) == 1
|
||||
step = completion_req.history[0]
|
||||
assert step.step_type == "subagent-completion"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_carries_answer_in_observation(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="What is X?",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "The answer is Y",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
step = completion_req.history[0]
|
||||
assert step.observation == "The answer is Y"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_preserves_correlation_fields(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
parent_session_id="parent-sess",
|
||||
subagent_goal="What is X?",
|
||||
expected_siblings=4,
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.correlation_id == "corr-123"
|
||||
assert completion_req.parent_session_id == "parent-sess"
|
||||
assert completion_req.subagent_goal == "What is X?"
|
||||
assert completion_req.expected_siblings == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_has_empty_pattern(self):
|
||||
pattern = _make_pattern()
|
||||
request = _make_request(
|
||||
correlation_id="corr-123",
|
||||
subagent_goal="goal",
|
||||
)
|
||||
next_fn = AsyncMock()
|
||||
|
||||
await pattern.emit_subagent_completion(
|
||||
request, next_fn, "answer",
|
||||
)
|
||||
|
||||
completion_req = next_fn.call_args[0][0]
|
||||
assert completion_req.pattern == ""
|
||||
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
226
tests/unit/test_agent/test_provenance_triples.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Unit tests for orchestrator provenance triple builders.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.provenance import (
|
||||
agent_decomposition_triples,
|
||||
agent_finding_triples,
|
||||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
)
|
||||
|
||||
|
||||
def _triple_set(triples):
|
||||
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
|
||||
result = set()
|
||||
for t in triples:
|
||||
s = t.s.iri
|
||||
p = t.p.iri
|
||||
o = t.o.iri if t.o.iri else t.o.value
|
||||
result.add((s, p, o))
|
||||
return result
|
||||
|
||||
|
||||
def _has_type(triples, uri, rdf_type):
|
||||
"""Check if a URI has a given rdf:type in the triples."""
|
||||
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
|
||||
|
||||
|
||||
def _get_values(triples, uri, predicate):
|
||||
"""Get all object values for a given subject + predicate."""
|
||||
ts = _triple_set(triples)
|
||||
return [o for s, p, o in ts if s == uri and p == predicate]
|
||||
|
||||
|
||||
class TestDecompositionTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
|
||||
)
|
||||
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["goal-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:decompose", PROV_WAS_DERIVED_FROM, "urn:session") in ts
|
||||
|
||||
def test_includes_goals(self):
|
||||
goals = ["What is X?", "What is Y?", "What is Z?"]
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", goals,
|
||||
)
|
||||
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
|
||||
assert set(values) == set(goals)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_decomposition_triples(
|
||||
"urn:decompose", "urn:session", ["a", "b", "c"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
|
||||
assert any("3" in label for label in labels)
|
||||
|
||||
|
||||
class TestFindingTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
assert _has_type(triples, "urn:finding", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:finding", TG_FINDING)
|
||||
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_decomposition(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "What is X?",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
|
||||
assert "What is X?" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
document_id="urn:doc/1",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert "urn:doc/1" in values
|
||||
|
||||
def test_no_document_when_none(self):
|
||||
triples = agent_finding_triples(
|
||||
"urn:finding", "urn:decompose", "goal",
|
||||
)
|
||||
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
|
||||
assert values == []
|
||||
|
||||
|
||||
class TestPlanTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert _has_type(triples, "urn:plan", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
|
||||
|
||||
def test_not_answer_type(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_session(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["step-a"],
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:plan", PROV_WAS_DERIVED_FROM, "urn:session") in ts
|
||||
|
||||
def test_includes_steps(self):
|
||||
steps = ["Define X", "Research Y", "Analyse Z"]
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", steps,
|
||||
)
|
||||
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
|
||||
assert set(values) == set(steps)
|
||||
|
||||
def test_label_includes_count(self):
|
||||
triples = agent_plan_triples(
|
||||
"urn:plan", "urn:session", ["a", "b"],
|
||||
)
|
||||
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
|
||||
assert any("2" in label for label in labels)
|
||||
|
||||
|
||||
class TestStepResultTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
assert _has_type(triples, "urn:step", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
|
||||
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_plan(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
|
||||
|
||||
def test_includes_goal(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "Define X",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
|
||||
assert "Define X" in values
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_step_result_triples(
|
||||
"urn:step", "urn:plan", "goal",
|
||||
document_id="urn:doc/step",
|
||||
)
|
||||
values = _get_values(triples, "urn:step", TG_DOCUMENT)
|
||||
assert "urn:doc/step" in values
|
||||
|
||||
|
||||
class TestSynthesisTriples:
|
||||
|
||||
def test_has_correct_types(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
|
||||
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
|
||||
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
|
||||
|
||||
def test_links_to_previous(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:last-finding",
|
||||
)
|
||||
ts = _triple_set(triples)
|
||||
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
|
||||
"urn:last-finding") in ts
|
||||
|
||||
def test_includes_document_when_provided(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
document_id="urn:doc/synthesis",
|
||||
)
|
||||
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
|
||||
assert "urn:doc/synthesis" in values
|
||||
|
||||
def test_label_is_synthesis(self):
|
||||
triples = agent_synthesis_triples(
|
||||
"urn:synthesis", "urn:previous",
|
||||
)
|
||||
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
|
||||
assert "Synthesis" in labels
|
||||
323
tests/unit/test_base/test_async_processor_config.py
Normal file
323
tests/unit/test_base/test_async_processor_config.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""
|
||||
Tests for AsyncProcessor config notify pattern:
|
||||
- register_config_handler with types filtering
|
||||
- on_config_notify version comparison and type matching
|
||||
- fetch_config with short-lived client
|
||||
- fetch_and_apply_config retry logic
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, Mock
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# Patch heavy dependencies before importing AsyncProcessor
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
"""Create an AsyncProcessor with mocked dependencies."""
|
||||
with patch('trustgraph.base.async_processor.get_pubsub') as mock_pubsub, \
|
||||
patch('trustgraph.base.async_processor.Consumer') as mock_consumer, \
|
||||
patch('trustgraph.base.async_processor.ProcessorMetrics') as mock_pm, \
|
||||
patch('trustgraph.base.async_processor.ConsumerMetrics') as mock_cm:
|
||||
|
||||
mock_pubsub.return_value = MagicMock()
|
||||
mock_consumer.return_value = MagicMock()
|
||||
mock_pm.return_value = MagicMock()
|
||||
mock_cm.return_value = MagicMock()
|
||||
|
||||
from trustgraph.base.async_processor import AsyncProcessor
|
||||
p = AsyncProcessor(
|
||||
id="test-processor",
|
||||
taskgroup=AsyncMock(),
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
class TestRegisterConfigHandler:
|
||||
|
||||
def test_register_without_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
|
||||
assert len(processor.config_handlers) == 1
|
||||
assert processor.config_handlers[0]["handler"] is handler
|
||||
assert processor.config_handlers[0]["types"] is None
|
||||
|
||||
def test_register_with_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
assert processor.config_handlers[0]["types"] == {"prompt"}
|
||||
|
||||
def test_register_multiple_types(self, processor):
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(
|
||||
handler, types=["schema", "collection"]
|
||||
)
|
||||
|
||||
assert processor.config_handlers[0]["types"] == {
|
||||
"schema", "collection"
|
||||
}
|
||||
|
||||
def test_register_multiple_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
assert len(processor.config_handlers) == 2
|
||||
|
||||
|
||||
class TestOnConfigNotify:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_old_version(self, processor):
|
||||
processor.config_version = 5
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=3, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_same_version(self, processor):
|
||||
processor.config_version = 5
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=5, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_irrelevant_types(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["schema"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
# Version should still be updated
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_on_relevant_type(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler, types=["prompt"])
|
||||
|
||||
# Mock fetch_config
|
||||
mock_config = {"prompt": {"key": "value"}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
assert processor.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_types_always_called(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler) # No types = all
|
||||
|
||||
mock_config = {"anything": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["whatever"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_called_once_with(mock_config, 2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_type_filtering(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
prompt_handler = AsyncMock()
|
||||
schema_handler = AsyncMock()
|
||||
all_handler = AsyncMock()
|
||||
|
||||
processor.register_config_handler(prompt_handler, types=["prompt"])
|
||||
processor.register_config_handler(schema_handler, types=["schema"])
|
||||
processor.register_config_handler(all_handler)
|
||||
|
||||
mock_config = {"prompt": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
prompt_handler.assert_called_once()
|
||||
schema_handler.assert_not_called()
|
||||
all_handler.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_types_invokes_all(self, processor):
|
||||
"""Empty types list (startup signal) should invoke all handlers."""
|
||||
processor.config_version = 1
|
||||
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
mock_config = {}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 2)
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=[])
|
||||
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
h1.assert_called_once()
|
||||
h2.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_handled(self, processor):
|
||||
processor.config_version = 1
|
||||
|
||||
handler = AsyncMock()
|
||||
processor.register_config_handler(handler)
|
||||
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("Connection failed")
|
||||
):
|
||||
msg = Mock()
|
||||
msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
# Should not raise
|
||||
await processor.on_config_notify(msg, None, None)
|
||||
|
||||
handler.assert_not_called()
|
||||
|
||||
|
||||
class TestFetchConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_config_and_version(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.config = {"prompt": {"key": "val"}}
|
||||
mock_resp.version = 42
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
config, version = await processor.fetch_config()
|
||||
|
||||
assert config == {"prompt": {"key": "val"}}
|
||||
assert version == 42
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_raises_on_error_response(self, processor):
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = Mock(message="not found")
|
||||
mock_resp.config = {}
|
||||
mock_resp.version = 0
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="Config error"):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stops_client_on_exception(self, processor):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = TimeoutError("timeout")
|
||||
|
||||
with patch.object(
|
||||
processor, '_create_config_client', return_value=mock_client
|
||||
):
|
||||
with pytest.raises(TimeoutError):
|
||||
await processor.fetch_config()
|
||||
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestFetchAndApplyConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_applies_config_to_all_handlers(self, processor):
|
||||
h1 = AsyncMock()
|
||||
h2 = AsyncMock()
|
||||
processor.register_config_handler(h1, types=["prompt"])
|
||||
processor.register_config_handler(h2, types=["schema"])
|
||||
|
||||
mock_config = {"prompt": {}, "schema": {}}
|
||||
with patch.object(
|
||||
processor, 'fetch_config',
|
||||
new_callable=AsyncMock,
|
||||
return_value=(mock_config, 10)
|
||||
):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
# On startup, all handlers are invoked regardless of type
|
||||
h1.assert_called_once_with(mock_config, 10)
|
||||
h2.assert_called_once_with(mock_config, 10)
|
||||
assert processor.config_version == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_failure(self, processor):
|
||||
call_count = 0
|
||||
mock_config = {"prompt": {}}
|
||||
|
||||
async def mock_fetch():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise RuntimeError("not ready")
|
||||
return mock_config, 5
|
||||
|
||||
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
|
||||
patch('asyncio.sleep', new_callable=AsyncMock):
|
||||
await processor.fetch_and_apply_config()
|
||||
|
||||
assert call_count == 3
|
||||
assert processor.config_version == 5
|
||||
|
|
@ -35,7 +35,9 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_async_init.assert_called_once()
|
||||
|
||||
# Verify register_config_handler was called with the correct handler
|
||||
mock_register_config.assert_called_once_with(processor.on_configure_flows)
|
||||
mock_register_config.assert_called_once_with(
|
||||
processor.on_configure_flows, types=["active-flow"]
|
||||
)
|
||||
|
||||
# Verify FlowProcessor-specific initialization
|
||||
assert hasattr(processor, 'flows')
|
||||
|
|
|
|||
|
|
@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success():
|
|||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Create queue for subscription
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
||||
|
||||
# Create mock message with matching queue name
|
||||
msg = create_mock_message("test-queue", {"data": "test"})
|
||||
|
||||
|
||||
# Process message
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
|
||||
# Should acknowledge successful delivery
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
|
||||
# Message should be in queue
|
||||
assert not queue.empty()
|
||||
received_msg = await queue.get()
|
||||
|
|
@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks():
|
|||
max_size=1, # Very small queue
|
||||
backpressure_strategy="drop_new"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Create queue and fill it
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
|
@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks():
|
|||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
# Don't create any queues - message will be orphaned
|
||||
# This simulates a response arriving after the waiter has unsubscribed
|
||||
|
|
@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies():
|
|||
max_size=2,
|
||||
backpressure_strategy="drop_oldest"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
subscriber.consumer = mock_consumer
|
||||
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
|||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Recursive chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
|
|
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
|
|
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 2000 # Should use overridden value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
|
|
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 200 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
|
|
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1500 # Should use overridden value
|
||||
assert chunk_overlap == 150 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
|
|
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
}.get(name)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
|
|
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
|||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Token chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
|
|
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
|
|
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 400 # Should use overridden value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
|
|
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 25 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
|
|
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
|
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 350 # Should use overridden value
|
||||
assert chunk_overlap == 30 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
|
|
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid librarian producer interactions
|
||||
processor.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
}.get(name)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
|
|
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
|
||||
"""Test that token chunker has different defaults than recursive chunker"""
|
||||
|
|
|
|||
|
|
@ -21,17 +21,15 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
|
||||
# Act
|
||||
client = DocumentEmbeddingsClient(
|
||||
log_level=1,
|
||||
subscriber="test-subscriber",
|
||||
input_queue="test-input",
|
||||
output_queue="test-output",
|
||||
pulsar_host="pulsar://test:6650",
|
||||
pulsar_api_key="test-key"
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_base_init.assert_called_once_with(
|
||||
log_level=1,
|
||||
subscriber="test-subscriber",
|
||||
input_queue="test-input",
|
||||
output_queue="test-output",
|
||||
|
|
|
|||
|
|
@ -81,9 +81,8 @@ class TestTaskGroupConcurrency:
|
|||
|
||||
# Track how many consume_from_queue calls are made
|
||||
call_count = 0
|
||||
original_running = True
|
||||
|
||||
async def mock_consume():
|
||||
async def mock_consume(backend_consumer, executor=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Wait a bit to let all tasks start, then signal stop
|
||||
|
|
@ -107,7 +106,7 @@ class TestTaskGroupConcurrency:
|
|||
consumer = _make_consumer(concurrency=1)
|
||||
call_count = 0
|
||||
|
||||
async def mock_consume():
|
||||
async def mock_consume(backend_consumer, executor=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -147,7 +146,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
assert call_count == 2
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
|
@ -166,7 +165,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
|
||||
consumer.consumer.acknowledge.assert_not_called()
|
||||
|
|
@ -185,7 +184,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
assert call_count == 1
|
||||
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
|
||||
|
|
@ -197,7 +196,7 @@ class TestRateLimitRetry:
|
|||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
||||
|
|
@ -219,7 +218,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics.record_time.return_value.__exit__ = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("success")
|
||||
|
||||
|
|
@ -235,7 +234,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("error")
|
||||
|
||||
|
|
@ -261,7 +260,7 @@ class TestMetricsIntegration:
|
|||
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||
|
||||
mock_metrics.rate_limit.assert_called_once()
|
||||
|
||||
|
|
@ -294,9 +293,8 @@ class TestPollTimeout:
|
|||
raise type('Timeout', (Exception,), {})("timeout")
|
||||
|
||||
mock_pulsar_consumer.receive = capture_receive
|
||||
consumer.consumer = mock_pulsar_consumer
|
||||
|
||||
await consumer.consume_from_queue()
|
||||
await consumer.consume_from_queue(mock_pulsar_consumer)
|
||||
|
||||
assert received_kwargs.get("timeout_millis") == 100
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ class MockAsyncProcessor:
|
|||
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral OCR processor functionality"""
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_with_api_key(
|
||||
|
|
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization_without_api_key(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
||||
Processor(**config)
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_ocr_single_chunk(
|
||||
|
|
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
)
|
||||
mock_mistral.ocr.process.assert_called_once()
|
||||
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(
|
||||
|
|
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
|||
]
|
||||
|
||||
# Mock save_child_document
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
|
|
|||
|
|
@ -24,12 +24,10 @@ class MockAsyncProcessor:
|
|||
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test PDF decoder processor functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_processor_initialization(self, mock_producer, mock_consumer):
|
||||
"""Test PDF decoder processor initialization"""
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
|
|
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test successful PDF processing"""
|
||||
# Mock PDF content
|
||||
pdf_content = b"fake pdf content"
|
||||
|
|
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
|
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
# Verify triples were sent for each page (provenance)
|
||||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test handling of empty PDF"""
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
|
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_output_flow.send.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||
"""Test handling of unicode content in PDF"""
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
|
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
|
|
|||
|
|
@ -142,8 +142,8 @@ class TestPageBasedFormats:
|
|||
class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test universal decoder processor."""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_custom_strategy(
|
||||
self, mock_producer, mock_consumer
|
||||
|
|
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert processor.partition_strategy == "hi_res"
|
||||
assert processor.section_strategy_name == "heading"
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_group_by_page(self, mock_producer, mock_consumer):
|
||||
"""Test page grouping of elements."""
|
||||
|
|
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert result[1][0] == 2
|
||||
assert len(result[1][1]) == 1
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_inline_non_page(
|
||||
|
|
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
}.get(name))
|
||||
|
||||
# Mock save_child_document and magic
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "text/markdown"
|
||||
|
|
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert call_args.document_id.startswith("urn:section:")
|
||||
assert call_args.text == b""
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_page_based(
|
||||
|
|
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
|
|
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
||||
assert call_args.document_id.startswith("urn:page:")
|
||||
|
||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
||||
@patch('trustgraph.base.librarian_client.Consumer')
|
||||
@patch('trustgraph.base.librarian_client.Producer')
|
||||
@patch('trustgraph.decoding.universal.processor.partition')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_images_stored_not_emitted(
|
||||
|
|
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||
|
||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||
mock_magic.from_buffer.return_value = "application/pdf"
|
||||
|
|
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
|||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
# save_child_document called twice (page + image)
|
||||
assert processor.save_child_document.call_count == 2
|
||||
assert processor.librarian.save_child_document.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
def test_add_args(self, mock_parent_add_args):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Tests for Gateway Config Receiver
|
|||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, patch, Mock, MagicMock
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.config.receiver import ConfigReceiver
|
||||
|
|
@ -23,174 +23,237 @@ class TestConfigReceiver:
|
|||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_backend = Mock()
|
||||
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
assert config_receiver.backend == mock_backend
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 0
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
assert len(config_receiver.flow_handlers) == 2
|
||||
assert handler1 in config_receiver.flow_handlers
|
||||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}',
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows were added
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
|
||||
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
|
||||
|
||||
# Verify start_flow was called for each new flow
|
||||
assert len(start_flow_calls) == 2
|
||||
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
|
||||
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message with only flow1 (flow2 removed)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flow2 was removed
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
|
||||
# Verify stop_flow was called for removed flow
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
async def mock_stop_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message without flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify no flows were added
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
# Since no flows were in the config, the flow methods shouldn't be called
|
||||
# (We can't easily assert this with simple async functions, but the test
|
||||
# passes if no exceptions are thrown)
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow", "active-flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows remain empty
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock config_client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver.config_client = mock_client
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert len(start_flow_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver.config_client = mock_client
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 1
|
||||
mock_resp.config = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver.config_client = mock_client
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -199,19 +262,17 @@ class TestConfigReceiver:
|
|||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,21 +280,19 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -242,167 +301,77 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_backend = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
# Mock Consumer class
|
||||
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_consumer = Mock()
|
||||
async def mock_start():
|
||||
pass
|
||||
mock_consumer.start = mock_start
|
||||
mock_consumer_class.return_value = mock_consumer
|
||||
|
||||
# Create a task that will complete quickly
|
||||
async def quick_task():
|
||||
await config_receiver.config_loader()
|
||||
|
||||
# Run the task with a timeout to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(quick_task(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
# This is expected since the method runs indefinitely
|
||||
pass
|
||||
|
||||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['backend'] == mock_backend
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
|
||||
mock_task = Mock()
|
||||
mock_create_task.return_value = mock_task
|
||||
|
||||
|
||||
await config_receiver.start()
|
||||
|
||||
# Verify task was created
|
||||
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify the argument passed to create_task is a coroutine
|
||||
call_args = mock_create_task.call_args[0]
|
||||
assert len(call_args) == 1 # Should have one argument (the coroutine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using Mock
|
||||
start_flow_calls = []
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
# Directly assign to avoid patch.object detecting async methods
|
||||
original_start_flow = config_receiver.start_flow
|
||||
original_stop_flow = config_receiver.stop_flow
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver.config_client = mock_client
|
||||
|
||||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
try:
|
||||
|
||||
# Create mock message with flow1 removed and flow3 added
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}',
|
||||
"flow3": '{"name": "test_flow_3", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify final state
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
|
||||
# Verify operations
|
||||
assert len(start_flow_calls) == 1
|
||||
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
config_receiver.start_flow = original_start_flow
|
||||
config_receiver.stop_flow = original_stop_flow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with invalid JSON
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"invalid": json}', # Invalid JSON
|
||||
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# This should handle the exception gracefully
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# The entire operation should fail due to JSON parsing error
|
||||
# So no flows should be added
|
||||
assert config_receiver.flows == {}
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = Mock()
|
||||
|
||||
# Setup translator response
|
||||
mock_request_translator.to_pulsar.return_value = "translated_request"
|
||||
mock_request_translator.decode.return_value = "translated_request"
|
||||
|
||||
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
|
||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
|
|
@ -64,7 +64,7 @@ class TestConfigRequestor:
|
|||
result = requestor.to_request({"test": "body"})
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
|
||||
mock_request_translator.decode.assert_called_once_with({"test": "body"})
|
||||
assert result == "translated_request"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
|
|
@ -76,7 +76,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Setup translator response
|
||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
mock_response_translator.encode_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
backend=Mock(),
|
||||
|
|
@ -89,5 +89,5 @@ class TestConfigRequestor:
|
|||
result = requestor.from_response(mock_message)
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
|
||||
mock_response_translator.encode_with_completion.assert_called_once_with(mock_message)
|
||||
assert result == "translated_response"
|
||||
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""
|
||||
Tests for inline explainability triples in response translators
|
||||
and ProvenanceEvent parsing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse, DocumentRagResponse, AgentResponse,
|
||||
Term, Triple, IRI, LITERAL, Error,
|
||||
)
|
||||
from trustgraph.messaging.translators.retrieval import (
|
||||
GraphRagResponseTranslator,
|
||||
DocumentRagResponseTranslator,
|
||||
)
|
||||
from trustgraph.messaging.translators.agent import (
|
||||
AgentResponseTranslator,
|
||||
)
|
||||
from trustgraph.api.types import ProvenanceEvent
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=LITERAL):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_triples():
|
||||
"""A few provenance triples for a question entity."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
o_type=IRI,
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is the internet?",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/ns/prov#startedAtTime",
|
||||
"2026-04-07T09:00:00Z",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# --- GraphRag Translator ---
|
||||
|
||||
class TestGraphRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
triples = sample_triples()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=triples,
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
# Check first triple is properly encoded
|
||||
t = result["explain_triples"][0]
|
||||
assert t["s"]["t"] == "i"
|
||||
assert t["s"]["i"] == "urn:trustgraph:question:abc123"
|
||||
assert t["p"]["t"] == "i"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Some answer text",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_returns_not_final(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_session=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_id_and_graph_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert result["explain_id"] == "urn:trustgraph:question:abc123"
|
||||
assert result["explain_graph"] == "urn:graph:retrieval"
|
||||
|
||||
|
||||
# --- DocumentRag Translator ---
|
||||
|
||||
class TestDocumentRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response=None,
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:docrag:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response="Answer text",
|
||||
message_type="chunk",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
|
||||
# --- Agent Translator ---
|
||||
|
||||
class TestAgentExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
t = result["explain_triples"][1]
|
||||
assert t["p"]["i"] == "https://trustgraph.ai/ns/query"
|
||||
assert t["o"]["t"] == "l"
|
||||
assert t["o"]["v"] == "What is the internet?"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_not_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_with_completion_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is True
|
||||
|
||||
|
||||
# --- ProvenanceEvent ---
|
||||
|
||||
class TestProvenanceEvent:
|
||||
|
||||
def test_question_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.event_type == "question"
|
||||
|
||||
def test_exploration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:exploration:abc123",
|
||||
)
|
||||
assert event.event_type == "exploration"
|
||||
|
||||
def test_focus_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:focus:abc123",
|
||||
)
|
||||
assert event.event_type == "focus"
|
||||
|
||||
def test_synthesis_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:synthesis:abc123",
|
||||
)
|
||||
assert event.event_type == "synthesis"
|
||||
|
||||
def test_grounding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:grounding:abc123",
|
||||
)
|
||||
assert event.event_type == "grounding"
|
||||
|
||||
def test_session_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
)
|
||||
assert event.event_type == "session"
|
||||
|
||||
def test_iteration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:iteration:abc123:1",
|
||||
)
|
||||
assert event.event_type == "iteration"
|
||||
|
||||
def test_observation_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:observation:abc123:1",
|
||||
)
|
||||
assert event.event_type == "observation"
|
||||
|
||||
def test_conclusion_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:conclusion:abc123",
|
||||
)
|
||||
assert event.event_type == "conclusion"
|
||||
|
||||
def test_decomposition_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:decomposition:abc123",
|
||||
)
|
||||
assert event.event_type == "decomposition"
|
||||
|
||||
def test_finding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:finding:abc123:0",
|
||||
)
|
||||
assert event.event_type == "finding"
|
||||
|
||||
def test_plan_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:plan:abc123",
|
||||
)
|
||||
assert event.event_type == "plan"
|
||||
|
||||
def test_step_result_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:step-result:abc123:0",
|
||||
)
|
||||
assert event.event_type == "step-result"
|
||||
|
||||
def test_defaults(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.entity is None
|
||||
assert event.triples == []
|
||||
assert event.explain_graph == ""
|
||||
|
||||
def test_with_triples(self):
|
||||
raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}]
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
triples=raw,
|
||||
)
|
||||
assert len(event.triples) == 1
|
||||
|
||||
|
||||
# --- Build ProvenanceEvent with entity parsing ---
|
||||
|
||||
class TestBuildProvenanceEvent:
|
||||
|
||||
def _make_client(self):
|
||||
"""Create a minimal WebSocketClient-like object with _build_provenance_event."""
|
||||
from trustgraph.api.socket_client import WebSocketClient
|
||||
# We can't instantiate WebSocketClient easily, so test the method logic directly
|
||||
return None
|
||||
|
||||
def test_entity_parsed_from_wire_triples(self):
|
||||
"""Test that wire-format triples are parsed into an ExplainEntity."""
|
||||
from trustgraph.api.explainability import ExplainEntity
|
||||
|
||||
wire_triples = [
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"},
|
||||
"o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"},
|
||||
},
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/query"},
|
||||
"o": {"t": "l", "v": "What is the internet?"},
|
||||
},
|
||||
]
|
||||
|
||||
# Parse triples the same way _build_provenance_event does
|
||||
parsed = []
|
||||
for t in wire_triples:
|
||||
s = t.get("s", {}).get("i", "")
|
||||
p = t.get("p", {}).get("i", "")
|
||||
o_term = t.get("o", {})
|
||||
if o_term.get("t") == "i":
|
||||
o = o_term.get("i", "")
|
||||
else:
|
||||
o = o_term.get("v", "")
|
||||
parsed.append((s, p, o))
|
||||
|
||||
entity = ExplainEntity.from_triples(
|
||||
"urn:trustgraph:question:abc123", parsed
|
||||
)
|
||||
|
||||
assert entity.entity_type == "question"
|
||||
assert entity.query == "What is the internet?"
|
||||
assert entity.question_type == "graph-rag"
|
||||
|
|
@ -25,7 +25,7 @@ from trustgraph.schema import (
|
|||
class TestGraphRagResponseTranslator:
|
||||
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Empty string should be included in result
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Some text"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_response(self):
|
||||
def test_encode_with_none_response(self):
|
||||
"""Test that None response is handled correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - None should not be included
|
||||
assert "response" not in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_response_with_completion_returns_correct_flag(self):
|
||||
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||
def test_encode_with_completion_returns_correct_flag(self):
|
||||
"""Test that encode_with_completion returns correct is_final flag"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||
result, is_final = translator.encode_with_completion(response_chunk)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
|
|
@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
result, is_final = translator.encode_with_completion(final_response)
|
||||
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
|
|
@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator:
|
|||
class TestDocumentRagResponseTranslator:
|
||||
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Document content"
|
||||
|
|
@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator:
|
|||
class TestPromptResponseTranslator:
|
||||
"""Test PromptResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_text(self):
|
||||
def test_encode_with_empty_text(self):
|
||||
"""Test that empty text strings are preserved"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -167,14 +167,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" in result
|
||||
assert result["text"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_text(self):
|
||||
def test_encode_with_non_empty_text(self):
|
||||
"""Test that non-empty text works correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -186,13 +186,13 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["text"] == "Some prompt response"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_text(self):
|
||||
def test_encode_with_none_text(self):
|
||||
"""Test that None text is handled correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -204,14 +204,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" not in result
|
||||
assert "object" in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_includes_end_of_stream(self):
|
||||
def test_encode_includes_end_of_stream(self):
|
||||
"""Test that end_of_stream flag is always included"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -225,7 +225,7 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result
|
||||
|
|
@ -235,7 +235,7 @@ class TestPromptResponseTranslator:
|
|||
class TestTextCompletionResponseTranslator:
|
||||
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_always_includes_response(self):
|
||||
def test_encode_always_includes_response(self):
|
||||
"""Test that response field is always included, even if empty"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Response should always be present
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
|
||||
def test_from_response_with_completion_with_empty_final(self):
|
||||
def test_encode_with_completion_with_empty_final(self):
|
||||
"""Test that empty final response is handled correctly"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
|
|
@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||
|
|
@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||
|
|
|
|||
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Unit tests for text document gateway translation compatibility.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from trustgraph.messaging.translators.document_loading import TextDocumentTranslator
|
||||
|
||||
|
||||
class TestTextDocumentTranslator:
|
||||
def test_decode_decodes_base64_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "research",
|
||||
"charset": "utf-8",
|
||||
"text": base64.b64encode(payload.encode("utf-8")).decode("ascii"),
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_accepts_raw_utf8_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_falls_back_to_raw_non_base64_ascii(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "plain-text payload"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
|
@ -10,16 +10,19 @@ from trustgraph.schema import Triple, Term, IRI, LITERAL
|
|||
from trustgraph.provenance.agent import (
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_observation_triples,
|
||||
agent_final_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE, TG_SYNTHESIS,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -63,7 +66,7 @@ class TestAgentSessionTriples:
|
|||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.SESSION_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
|
||||
|
||||
|
|
@ -103,6 +106,25 @@ class TestAgentSessionTriples:
|
|||
)
|
||||
assert len(triples) == 6
|
||||
|
||||
def test_session_parent_uri(self):
|
||||
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
|
||||
parent = "urn:trustgraph:agent:parent/decompose"
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
|
||||
parent_uri=parent,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == parent
|
||||
|
||||
def test_session_no_parent_uri(self):
|
||||
"""Top-level sessions have no wasDerivedFrom."""
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_iteration_triples
|
||||
|
|
@ -121,19 +143,17 @@ class TestAgentIterationTriples:
|
|||
)
|
||||
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
|
||||
assert has_type(triples, self.ITER_URI, TG_TOOL_USE)
|
||||
|
||||
def test_first_iteration_generated_by_question(self):
|
||||
"""First iteration uses wasGeneratedBy to link to question activity."""
|
||||
def test_first_iteration_derived_from_question(self):
|
||||
"""First iteration uses wasDerivedFrom to link to question entity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
# Should NOT have wasDerivedFrom
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is None
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.SESSION_URI
|
||||
|
||||
def test_subsequent_iteration_derived_from_previous(self):
|
||||
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
|
||||
|
|
@ -144,9 +164,6 @@ class TestAgentIterationTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
# Should NOT have wasGeneratedBy
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_iteration_label_includes_action(self):
|
||||
triples = agent_iteration_triples(
|
||||
|
|
@ -174,40 +191,24 @@ class TestAgentIterationTriples:
|
|||
# Thought has correct types
|
||||
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
|
||||
# Thought was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Thought was derived from iteration
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, thought_uri)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.ITER_URI
|
||||
# Thought has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == thought_doc
|
||||
|
||||
def test_iteration_observation_sub_entity(self):
|
||||
"""Observation is a sub-entity with Reflection and Observation types."""
|
||||
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
obs_doc = "urn:doc:obs-1"
|
||||
def test_iteration_no_observation_sub_entity(self):
|
||||
"""Iteration no longer embeds observation — it's a separate entity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
observation_uri=obs_uri,
|
||||
observation_document_id=obs_doc,
|
||||
)
|
||||
# Iteration links to observation sub-entity
|
||||
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_link is not None
|
||||
assert obs_link.o.iri == obs_uri
|
||||
# Observation has correct types
|
||||
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
|
||||
# Observation was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Observation has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == obs_doc
|
||||
# No TG_OBSERVATION predicate on the iteration
|
||||
for t in triples:
|
||||
assert "observation" not in t.p.iri.lower() or "Observation" not in t.p.iri
|
||||
|
||||
def test_iteration_action_recorded(self):
|
||||
triples = agent_iteration_triples(
|
||||
|
|
@ -240,19 +241,17 @@ class TestAgentIterationTriples:
|
|||
parsed = json.loads(arguments.o.value)
|
||||
assert parsed == {}
|
||||
|
||||
def test_iteration_no_thought_or_observation(self):
|
||||
"""Minimal iteration with just action — no thought or observation triples."""
|
||||
def test_iteration_no_thought(self):
|
||||
"""Minimal iteration with just action — no thought triples."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="noop",
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert thought is None
|
||||
assert obs is None
|
||||
|
||||
def test_iteration_chaining(self):
|
||||
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
|
||||
"""Both first and second iterations use wasDerivedFrom."""
|
||||
iter1_uri = "urn:trustgraph:agent:sess/i1"
|
||||
iter2_uri = "urn:trustgraph:agent:sess/i2"
|
||||
|
||||
|
|
@ -263,13 +262,62 @@ class TestAgentIterationTriples:
|
|||
iter2_uri, previous_uri=iter1_uri, action="step2",
|
||||
)
|
||||
|
||||
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
|
||||
assert gen1.o.iri == self.SESSION_URI
|
||||
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
|
||||
assert derived1.o.iri == self.SESSION_URI
|
||||
|
||||
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
|
||||
assert derived2.o.iri == iter1_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_observation_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentObservationTriples:
|
||||
|
||||
OBS_URI = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
ITER_URI = "urn:trustgraph:agent:test-session/i1"
|
||||
|
||||
def test_observation_types(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
assert has_type(triples, self.OBS_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.OBS_URI, TG_OBSERVATION_TYPE)
|
||||
|
||||
def test_observation_derived_from_iteration(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.OBS_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.ITER_URI
|
||||
|
||||
def test_observation_label(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.OBS_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Observation"
|
||||
|
||||
def test_observation_document(self):
|
||||
doc_id = "urn:doc:obs-1"
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI, document_id=doc_id,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == doc_id
|
||||
|
||||
def test_observation_no_document(self):
|
||||
triples = agent_observation_triples(
|
||||
self.OBS_URI, self.ITER_URI,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_final_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -296,19 +344,15 @@ class TestAgentFinalTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_final_generated_by_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasGeneratedBy."""
|
||||
def test_final_derived_from_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasDerivedFrom to question."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, question_uri=self.SESSION_URI,
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is None
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.SESSION_URI
|
||||
|
||||
def test_final_label(self):
|
||||
triples = agent_final_triples(
|
||||
|
|
@ -334,3 +378,59 @@ class TestAgentFinalTriples:
|
|||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_synthesis_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentSynthesisTriples:
|
||||
|
||||
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
|
||||
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
|
||||
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
|
||||
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
|
||||
|
||||
def test_synthesis_types(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_single_parent_string(self):
|
||||
"""Single parent passed as string."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_multiple_parents(self):
|
||||
"""Multiple parents for supervisor fan-in."""
|
||||
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 3
|
||||
derived_uris = {t.o.iri for t in derived}
|
||||
assert derived_uris == set(parents)
|
||||
|
||||
def test_synthesis_single_parent_as_list(self):
|
||||
"""Single parent passed as list."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_document(self):
|
||||
triples = agent_synthesis_triples(
|
||||
self.SYNTH_URI, self.FINDING_0,
|
||||
document_id="urn:doc:synth",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == "urn:doc:synth"
|
||||
|
||||
def test_synthesis_label(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Synthesis"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.api.explainability import (
|
|||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Observation,
|
||||
Conclusion,
|
||||
parse_edge_selection_triples,
|
||||
extract_term_value,
|
||||
|
|
@ -23,12 +24,12 @@ from trustgraph.api.explainability import (
|
|||
ExplainabilityClient,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM,
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
)
|
||||
|
||||
|
|
@ -180,14 +181,30 @@ class TestExplainEntityFromTriples:
|
|||
("urn:ana:1", TG_ACTION, "graph-rag-query"),
|
||||
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
|
||||
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
|
||||
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:1", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.action == "graph-rag-query"
|
||||
assert entity.arguments == '{"query": "test"}'
|
||||
assert entity.thought == "urn:ref:thought-1"
|
||||
assert entity.observation == "urn:ref:obs-1"
|
||||
|
||||
def test_observation(self):
|
||||
triples = [
|
||||
("urn:obs:1", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:obs:1", TG_DOCUMENT, "urn:doc:obs-content"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:obs:1", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
assert entity.document == "urn:doc:obs-content"
|
||||
assert entity.entity_type == "observation"
|
||||
|
||||
def test_observation_no_document(self):
|
||||
triples = [
|
||||
("urn:obs:2", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:obs:2", triples)
|
||||
assert isinstance(entity, Observation)
|
||||
assert entity.document == ""
|
||||
|
||||
def test_conclusion_with_document(self):
|
||||
triples = [
|
||||
|
|
@ -541,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
|
|||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
|
||||
|
||||
|
||||
class TestChainWalkerFollowsSubTraceTerminal:
|
||||
"""Test that _follow_provenance_chain continues from a sub-trace's
|
||||
Synthesis to find downstream entities like Observation."""
|
||||
|
||||
def test_observation_found_via_subtrace_synthesis(self):
|
||||
"""
|
||||
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
|
||||
The walker should find Analysis, the sub-trace, then follow from
|
||||
Synthesis to discover Observation.
|
||||
"""
|
||||
# Entity triples (s, p, o)
|
||||
entity_data = {
|
||||
"urn:agent:q": [
|
||||
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
|
||||
("urn:agent:q", TG_QUERY, "test"),
|
||||
],
|
||||
"urn:agent:analysis": [
|
||||
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
|
||||
],
|
||||
"urn:graphrag:q": [
|
||||
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
|
||||
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:graphrag:q", TG_QUERY, "test"),
|
||||
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
|
||||
],
|
||||
"urn:graphrag:synth": [
|
||||
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
|
||||
],
|
||||
"urn:agent:obs": [
|
||||
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
|
||||
],
|
||||
"urn:agent:conclusion": [
|
||||
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
|
||||
],
|
||||
}
|
||||
|
||||
# Build a mock flow that answers triples queries
|
||||
# Query by s= returns that entity's triples
|
||||
# Query by p=wasDerivedFrom, o=X returns entities derived from X
|
||||
def mock_triples_query(s=None, p=None, o=None, **kwargs):
|
||||
if s and not p:
|
||||
# Fetch entity triples
|
||||
tuples = entity_data.get(s, [])
|
||||
return _make_wire_triples(tuples)
|
||||
elif p == PROV_WAS_DERIVED_FROM and o:
|
||||
# Find entities derived from o
|
||||
results = []
|
||||
for uri, tuples in entity_data.items():
|
||||
for _, pred, obj in tuples:
|
||||
if pred == PROV_WAS_DERIVED_FROM and obj == o:
|
||||
results.append((uri, pred, obj))
|
||||
return _make_wire_triples(results)
|
||||
return []
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.side_effect = mock_triples_query
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
|
||||
|
||||
# Mock fetch_graphrag_trace to return a trace with a synthesis
|
||||
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
|
||||
client.fetch_graphrag_trace = MagicMock(return_value={
|
||||
"question": Question(uri="urn:graphrag:q", entity_type="question",
|
||||
question_type="graph-rag"),
|
||||
"synthesis": synth_entity,
|
||||
})
|
||||
|
||||
trace = client.fetch_agent_trace(
|
||||
"urn:agent:q",
|
||||
graph="urn:graph:retrieval",
|
||||
)
|
||||
|
||||
# Should have found all steps
|
||||
step_types = [
|
||||
type(s).__name__ if not isinstance(s, dict) else s.get("type")
|
||||
for s in trace["steps"]
|
||||
]
|
||||
|
||||
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
|
||||
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
|
||||
assert "Observation" in step_types, f"Missing Observation in {step_types}"
|
||||
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
|
||||
|
||||
# Observation should come after the sub-trace
|
||||
subtrace_idx = step_types.index("sub-trace")
|
||||
obs_idx = step_types.index("Observation")
|
||||
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ class TestQuestionTriples:
|
|||
|
||||
def test_question_types(self):
|
||||
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
|
||||
|
||||
|
|
@ -543,11 +543,11 @@ class TestGroundingTriples:
|
|||
assert has_type(triples, self.GND_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.GND_URI, TG_GROUNDING)
|
||||
|
||||
def test_grounding_generated_by_question(self):
|
||||
def test_grounding_derived_from_question(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.Q_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.GND_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.Q_URI
|
||||
|
||||
def test_grounding_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
|
||||
|
|
@ -730,7 +730,7 @@ class TestDocRagQuestionTriples:
|
|||
|
||||
def test_docrag_question_types(self):
|
||||
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
|
||||
|
||||
|
|
|
|||
164
tests/unit/test_pubsub/test_queue_naming.py
Normal file
164
tests/unit/test_pubsub/test_queue_naming.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Tests for queue naming and topic mapping.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
from trustgraph.schema.core.topic import queue
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
from trustgraph.base.pulsar_backend import PulsarBackend
|
||||
|
||||
|
||||
class TestQueueFunction:
|
||||
|
||||
def test_flow_default(self):
|
||||
assert queue('text-completion-request') == 'flow:tg:text-completion-request'
|
||||
|
||||
def test_request_class(self):
|
||||
assert queue('config', cls='request') == 'request:tg:config'
|
||||
|
||||
def test_response_class(self):
|
||||
assert queue('config', cls='response') == 'response:tg:config'
|
||||
|
||||
def test_state_class(self):
|
||||
assert queue('config', cls='state') == 'state:tg:config'
|
||||
|
||||
def test_custom_topicspace(self):
|
||||
assert queue('config', cls='request', topicspace='prod') == 'request:prod:config'
|
||||
|
||||
def test_default_class_is_flow(self):
|
||||
result = queue('something')
|
||||
assert result.startswith('flow:')
|
||||
|
||||
|
||||
class TestPulsarMapTopic:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
"""Create a PulsarBackend without connecting."""
|
||||
b = object.__new__(PulsarBackend)
|
||||
return b
|
||||
|
||||
def test_flow_maps_to_persistent(self, backend):
|
||||
assert backend.map_topic('flow:tg:text-completion-request') == \
|
||||
'persistent://tg/flow/text-completion-request'
|
||||
|
||||
def test_state_maps_to_persistent(self, backend):
|
||||
assert backend.map_topic('state:tg:config') == \
|
||||
'persistent://tg/state/config'
|
||||
|
||||
def test_request_maps_to_non_persistent(self, backend):
|
||||
assert backend.map_topic('request:tg:config') == \
|
||||
'non-persistent://tg/request/config'
|
||||
|
||||
def test_response_maps_to_non_persistent(self, backend):
|
||||
assert backend.map_topic('response:tg:librarian') == \
|
||||
'non-persistent://tg/response/librarian'
|
||||
|
||||
def test_passthrough_pulsar_uri(self, backend):
|
||||
uri = 'persistent://tg/flow/something'
|
||||
assert backend.map_topic(uri) == uri
|
||||
|
||||
def test_invalid_format_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue format"):
|
||||
backend.map_topic('bad-format')
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||
backend.map_topic('unknown:tg:topic')
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
assert backend.map_topic('flow:prod:my-queue') == \
|
||||
'persistent://prod/flow/my-queue'
|
||||
|
||||
|
||||
class TestGetPubsubDispatch:
|
||||
|
||||
def test_unknown_backend_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown pub/sub backend"):
|
||||
get_pubsub(pubsub_backend='redis')
|
||||
|
||||
|
||||
class TestAddPubsubArgs:
|
||||
|
||||
def test_standalone_defaults_to_localhost(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args([])
|
||||
assert args.pulsar_host == 'pulsar://localhost:6650'
|
||||
assert args.pulsar_listener == 'localhost'
|
||||
|
||||
def test_non_standalone_defaults_to_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=False)
|
||||
args = parser.parse_args([])
|
||||
assert 'pulsar:6650' in args.pulsar_host
|
||||
assert args.pulsar_listener is None
|
||||
|
||||
def test_cli_override_respected(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args(['--pulsar-host', 'pulsar://custom:6650'])
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
|
||||
def test_pubsub_backend_default(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.pubsub_backend == 'pulsar'
|
||||
|
||||
|
||||
class TestAddPubsubArgsRabbitMQ:
|
||||
|
||||
def test_rabbitmq_args_present(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([
|
||||
'--pubsub-backend', 'rabbitmq',
|
||||
'--rabbitmq-host', 'myhost',
|
||||
'--rabbitmq-port', '5673',
|
||||
])
|
||||
assert args.pubsub_backend == 'rabbitmq'
|
||||
assert args.rabbitmq_host == 'myhost'
|
||||
assert args.rabbitmq_port == 5673
|
||||
|
||||
def test_rabbitmq_defaults_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'rabbitmq'
|
||||
assert args.rabbitmq_port == 5672
|
||||
assert args.rabbitmq_username == 'guest'
|
||||
assert args.rabbitmq_password == 'guest'
|
||||
assert args.rabbitmq_vhost == '/'
|
||||
|
||||
def test_rabbitmq_standalone_defaults_to_localhost(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser, standalone=True)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'localhost'
|
||||
|
||||
|
||||
class TestQueueDefinitions:
|
||||
"""Verify the actual queue constants produce correct names."""
|
||||
|
||||
def test_config_request(self):
|
||||
from trustgraph.schema.services.config import config_request_queue
|
||||
assert config_request_queue == 'request:tg:config'
|
||||
|
||||
def test_config_response(self):
|
||||
from trustgraph.schema.services.config import config_response_queue
|
||||
assert config_response_queue == 'response:tg:config'
|
||||
|
||||
def test_config_push(self):
|
||||
from trustgraph.schema.services.config import config_push_queue
|
||||
assert config_push_queue == 'flow:tg:config'
|
||||
|
||||
def test_librarian_request(self):
|
||||
from trustgraph.schema.services.library import librarian_request_queue
|
||||
assert librarian_request_queue == 'request:tg:librarian'
|
||||
|
||||
def test_knowledge_request(self):
|
||||
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
|
||||
assert knowledge_request_queue == 'request:tg:knowledge'
|
||||
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""
|
||||
Unit tests for RabbitMQ backend — queue name mapping and factory dispatch.
|
||||
Does not require a running RabbitMQ instance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import argparse
|
||||
|
||||
pika = pytest.importorskip("pika", reason="pika not installed")
|
||||
|
||||
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
|
||||
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||
|
||||
|
||||
class TestRabbitMQMapQueueName:
|
||||
|
||||
@pytest.fixture
|
||||
def backend(self):
|
||||
b = object.__new__(RabbitMQBackend)
|
||||
return b
|
||||
|
||||
def test_flow_is_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
|
||||
assert durable is True
|
||||
assert name == 'tg.flow.text-completion-request'
|
||||
|
||||
def test_state_is_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('state:tg:config')
|
||||
assert durable is True
|
||||
assert name == 'tg.state.config'
|
||||
|
||||
def test_request_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('request:tg:config')
|
||||
assert durable is False
|
||||
assert name == 'tg.request.config'
|
||||
|
||||
def test_response_is_not_durable(self, backend):
|
||||
name, durable = backend.map_queue_name('response:tg:librarian')
|
||||
assert durable is False
|
||||
assert name == 'tg.response.librarian'
|
||||
|
||||
def test_custom_topicspace(self, backend):
|
||||
name, durable = backend.map_queue_name('flow:prod:my-queue')
|
||||
assert name == 'prod.flow.my-queue'
|
||||
assert durable is True
|
||||
|
||||
def test_no_colon_defaults_to_flow(self, backend):
|
||||
name, durable = backend.map_queue_name('simple-queue')
|
||||
assert name == 'tg.simple-queue'
|
||||
assert durable is False
|
||||
|
||||
def test_invalid_class_raises(self, backend):
|
||||
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||
backend.map_queue_name('unknown:tg:topic')
|
||||
|
||||
def test_flow_with_flow_suffix(self, backend):
|
||||
"""Queue names with flow suffix (e.g. :default) are preserved."""
|
||||
name, durable = backend.map_queue_name('request:tg:prompt:default')
|
||||
assert name == 'tg.request.prompt:default'
|
||||
|
||||
|
||||
class TestGetPubsubRabbitMQ:
|
||||
|
||||
def test_factory_creates_rabbitmq_backend(self):
|
||||
backend = get_pubsub(pubsub_backend='rabbitmq')
|
||||
assert isinstance(backend, RabbitMQBackend)
|
||||
|
||||
def test_factory_passes_config(self):
|
||||
backend = get_pubsub(
|
||||
pubsub_backend='rabbitmq',
|
||||
rabbitmq_host='myhost',
|
||||
rabbitmq_port=5673,
|
||||
rabbitmq_username='user',
|
||||
rabbitmq_password='pass',
|
||||
rabbitmq_vhost='/test',
|
||||
)
|
||||
assert isinstance(backend, RabbitMQBackend)
|
||||
# Verify connection params were set
|
||||
params = backend._connection_params
|
||||
assert params.host == 'myhost'
|
||||
assert params.port == 5673
|
||||
assert params.virtual_host == '/test'
|
||||
|
||||
|
||||
class TestAddPubsubArgsRabbitMQ:
|
||||
|
||||
def test_rabbitmq_args_present(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([
|
||||
'--pubsub-backend', 'rabbitmq',
|
||||
'--rabbitmq-host', 'myhost',
|
||||
'--rabbitmq-port', '5673',
|
||||
])
|
||||
assert args.pubsub_backend == 'rabbitmq'
|
||||
assert args.rabbitmq_host == 'myhost'
|
||||
assert args.rabbitmq_port == 5673
|
||||
|
||||
def test_rabbitmq_defaults_container(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
add_pubsub_args(parser)
|
||||
args = parser.parse_args([])
|
||||
assert args.rabbitmq_host == 'rabbitmq'
|
||||
assert args.rabbitmq_port == 5672
|
||||
assert args.rabbitmq_username == 'guest'
|
||||
assert args.rabbitmq_password == 'guest'
|
||||
assert args.rabbitmq_vhost == '/'
|
||||
424
tests/unit/test_query/test_sparql_expressions.py
Normal file
424
tests/unit/test_query/test_sparql_expressions.py
Normal file
|
|
@ -0,0 +1,424 @@
|
|||
"""
|
||||
Tests for SPARQL FILTER expression evaluator.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import Term, IRI, LITERAL, BLANK
|
||||
from trustgraph.query.sparql.expressions import (
|
||||
evaluate_expression, _effective_boolean, _to_string, _to_numeric,
|
||||
_comparable_value,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
def lit(v, datatype="", language=""):
|
||||
return Term(type=LITERAL, value=v, datatype=datatype, language=language)
|
||||
|
||||
def blank(v):
|
||||
return Term(type=BLANK, id=v)
|
||||
|
||||
XSD = "http://www.w3.org/2001/XMLSchema#"
|
||||
|
||||
|
||||
class TestEvaluateExpression:
|
||||
"""Test expression evaluation with rdflib algebra nodes."""
|
||||
|
||||
def test_variable_bound(self):
|
||||
from rdflib.term import Variable
|
||||
result = evaluate_expression(Variable("x"), {"x": lit("hello")})
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_variable_unbound(self):
|
||||
from rdflib.term import Variable
|
||||
result = evaluate_expression(Variable("x"), {})
|
||||
assert result is None
|
||||
|
||||
def test_uriref_constant(self):
|
||||
from rdflib import URIRef
|
||||
result = evaluate_expression(
|
||||
URIRef("http://example.com/a"), {}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == "http://example.com/a"
|
||||
|
||||
def test_literal_constant(self):
|
||||
from rdflib import Literal
|
||||
result = evaluate_expression(Literal("hello"), {})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_boolean_constant(self):
|
||||
assert evaluate_expression(True, {}) is True
|
||||
assert evaluate_expression(False, {}) is False
|
||||
|
||||
def test_numeric_constant(self):
|
||||
assert evaluate_expression(42, {}) == 42
|
||||
assert evaluate_expression(3.14, {}) == 3.14
|
||||
|
||||
def test_none_returns_true(self):
|
||||
assert evaluate_expression(None, {}) is True
|
||||
|
||||
|
||||
class TestRelationalExpressions:
|
||||
"""Test comparison operators via CompValue nodes."""
|
||||
|
||||
def _make_relational(self, left, op, right):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("RelationalExpression",
|
||||
expr=left, op=op, other=right)
|
||||
|
||||
def test_equal_literals(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "=", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_not_equal_literals(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "!=", Literal("b"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_less_than(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("a"), "<", Literal("b"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_greater_than(self):
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Literal("b"), ">", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is True
|
||||
|
||||
def test_equal_with_variables(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_relational(Variable("x"), "=", Variable("y"))
|
||||
sol = {"x": lit("same"), "y": lit("same")}
|
||||
assert evaluate_expression(expr, sol) is True
|
||||
|
||||
def test_unequal_with_variables(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_relational(Variable("x"), "=", Variable("y"))
|
||||
sol = {"x": lit("one"), "y": lit("two")}
|
||||
assert evaluate_expression(expr, sol) is False
|
||||
|
||||
def test_none_operand_returns_false(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_relational(Variable("x"), "=", Literal("a"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
|
||||
class TestLogicalExpressions:
|
||||
|
||||
def _make_and(self, exprs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("ConditionalAndExpression",
|
||||
expr=exprs[0], other=exprs[1:])
|
||||
|
||||
def _make_or(self, exprs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("ConditionalOrExpression",
|
||||
expr=exprs[0], other=exprs[1:])
|
||||
|
||||
def _make_not(self, expr):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue("UnaryNot", expr=expr)
|
||||
|
||||
def test_and_true_true(self):
|
||||
result = evaluate_expression(self._make_and([True, True]), {})
|
||||
assert result is True
|
||||
|
||||
def test_and_true_false(self):
|
||||
result = evaluate_expression(self._make_and([True, False]), {})
|
||||
assert result is False
|
||||
|
||||
def test_or_false_true(self):
|
||||
result = evaluate_expression(self._make_or([False, True]), {})
|
||||
assert result is True
|
||||
|
||||
def test_or_false_false(self):
|
||||
result = evaluate_expression(self._make_or([False, False]), {})
|
||||
assert result is False
|
||||
|
||||
def test_not_true(self):
|
||||
result = evaluate_expression(self._make_not(True), {})
|
||||
assert result is False
|
||||
|
||||
def test_not_false(self):
|
||||
result = evaluate_expression(self._make_not(False), {})
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestBuiltinFunctions:
|
||||
|
||||
def _make_builtin(self, name, **kwargs):
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
return CompValue(f"Builtin_{name}", **kwargs)
|
||||
|
||||
def test_bound_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("BOUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hi")}) is True
|
||||
|
||||
def test_bound_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("BOUND", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {}) is False
|
||||
|
||||
def test_isiri_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isIRI", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is True
|
||||
|
||||
def test_isiri_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isIRI", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
def test_isliteral_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_isliteral_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
|
||||
|
||||
def test_isblank_true(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isBLANK", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": blank("b1")}) is True
|
||||
|
||||
def test_isblank_false(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("isBLANK", arg=Variable("x"))
|
||||
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
|
||||
|
||||
def test_str(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("STR", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": iri("http://example.com/a")})
|
||||
assert result.type == LITERAL
|
||||
assert result.value == "http://example.com/a"
|
||||
|
||||
def test_lang(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LANG", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("hello", language="en")}
|
||||
)
|
||||
assert result.value == "en"
|
||||
|
||||
def test_lang_no_tag(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LANG", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == ""
|
||||
|
||||
def test_datatype(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("DATATYPE", arg=Variable("x"))
|
||||
result = evaluate_expression(
|
||||
expr, {"x": lit("42", datatype=XSD + "integer")}
|
||||
)
|
||||
assert result.type == IRI
|
||||
assert result.iri == XSD + "integer"
|
||||
|
||||
def test_strlen(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("STRLEN", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result == 5
|
||||
|
||||
def test_ucase(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("UCASE", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("hello")})
|
||||
assert result.value == "HELLO"
|
||||
|
||||
def test_lcase(self):
|
||||
from rdflib.term import Variable
|
||||
expr = self._make_builtin("LCASE", arg=Variable("x"))
|
||||
result = evaluate_expression(expr, {"x": lit("HELLO")})
|
||||
assert result.value == "hello"
|
||||
|
||||
def test_contains_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("CONTAINS",
|
||||
arg1=Variable("x"), arg2=Literal("ell"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_contains_false(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("CONTAINS",
|
||||
arg1=Variable("x"), arg2=Literal("xyz"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
def test_strstarts_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRSTARTS",
|
||||
arg1=Variable("x"), arg2=Literal("hel"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_strends_true(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("STRENDS",
|
||||
arg1=Variable("x"), arg2=Literal("llo"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_match(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("^hel"),
|
||||
flags=None)
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_case_insensitive(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("HELLO"),
|
||||
flags=Literal("i"))
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is True
|
||||
|
||||
def test_regex_no_match(self):
|
||||
from rdflib.term import Variable
|
||||
from rdflib import Literal
|
||||
expr = self._make_builtin("REGEX",
|
||||
text=Variable("x"),
|
||||
pattern=Literal("^world"),
|
||||
flags=None)
|
||||
assert evaluate_expression(expr, {"x": lit("hello")}) is False
|
||||
|
||||
|
||||
class TestEffectiveBoolean:
|
||||
|
||||
def test_true(self):
|
||||
assert _effective_boolean(True) is True
|
||||
|
||||
def test_false(self):
|
||||
assert _effective_boolean(False) is False
|
||||
|
||||
def test_none(self):
|
||||
assert _effective_boolean(None) is False
|
||||
|
||||
def test_nonzero_int(self):
|
||||
assert _effective_boolean(42) is True
|
||||
|
||||
def test_zero_int(self):
|
||||
assert _effective_boolean(0) is False
|
||||
|
||||
def test_nonempty_string(self):
|
||||
assert _effective_boolean("hello") is True
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _effective_boolean("") is False
|
||||
|
||||
def test_iri_term(self):
|
||||
assert _effective_boolean(iri("http://x")) is True
|
||||
|
||||
def test_nonempty_literal(self):
|
||||
assert _effective_boolean(lit("hello")) is True
|
||||
|
||||
def test_empty_literal(self):
|
||||
assert _effective_boolean(lit("")) is False
|
||||
|
||||
def test_boolean_literal_true(self):
|
||||
assert _effective_boolean(
|
||||
lit("true", datatype=XSD + "boolean")
|
||||
) is True
|
||||
|
||||
def test_boolean_literal_false(self):
|
||||
assert _effective_boolean(
|
||||
lit("false", datatype=XSD + "boolean")
|
||||
) is False
|
||||
|
||||
def test_numeric_literal_nonzero(self):
|
||||
assert _effective_boolean(
|
||||
lit("42", datatype=XSD + "integer")
|
||||
) is True
|
||||
|
||||
def test_numeric_literal_zero(self):
|
||||
assert _effective_boolean(
|
||||
lit("0", datatype=XSD + "integer")
|
||||
) is False
|
||||
|
||||
|
||||
class TestToString:
|
||||
|
||||
def test_none(self):
|
||||
assert _to_string(None) == ""
|
||||
|
||||
def test_string(self):
|
||||
assert _to_string("hello") == "hello"
|
||||
|
||||
def test_iri_term(self):
|
||||
assert _to_string(iri("http://example.com")) == "http://example.com"
|
||||
|
||||
def test_literal_term(self):
|
||||
assert _to_string(lit("hello")) == "hello"
|
||||
|
||||
def test_blank_term(self):
|
||||
assert _to_string(blank("b1")) == "b1"
|
||||
|
||||
|
||||
class TestToNumeric:
|
||||
|
||||
def test_none(self):
|
||||
assert _to_numeric(None) is None
|
||||
|
||||
def test_int(self):
|
||||
assert _to_numeric(42) == 42
|
||||
|
||||
def test_float(self):
|
||||
assert _to_numeric(3.14) == 3.14
|
||||
|
||||
def test_integer_literal(self):
|
||||
assert _to_numeric(lit("42")) == 42
|
||||
|
||||
def test_decimal_literal(self):
|
||||
assert _to_numeric(lit("3.14")) == 3.14
|
||||
|
||||
def test_non_numeric_literal(self):
|
||||
assert _to_numeric(lit("hello")) is None
|
||||
|
||||
def test_numeric_string(self):
|
||||
assert _to_numeric("42") == 42
|
||||
|
||||
def test_non_numeric_string(self):
|
||||
assert _to_numeric("abc") is None
|
||||
|
||||
|
||||
class TestComparableValue:
|
||||
|
||||
def test_none(self):
|
||||
assert _comparable_value(None) == (0, "")
|
||||
|
||||
def test_int(self):
|
||||
assert _comparable_value(42) == (2, 42)
|
||||
|
||||
def test_iri(self):
|
||||
assert _comparable_value(iri("http://x")) == (4, "http://x")
|
||||
|
||||
def test_literal(self):
|
||||
assert _comparable_value(lit("hello")) == (3, "hello")
|
||||
|
||||
def test_numeric_literal(self):
|
||||
assert _comparable_value(lit("42")) == (2, 42)
|
||||
|
||||
def test_ordering(self):
|
||||
vals = [lit("b"), lit("a"), lit("c")]
|
||||
sorted_vals = sorted(vals, key=_comparable_value)
|
||||
assert sorted_vals[0].value == "a"
|
||||
assert sorted_vals[1].value == "b"
|
||||
assert sorted_vals[2].value == "c"
|
||||
205
tests/unit/test_query/test_sparql_parser.py
Normal file
205
tests/unit/test_query/test_sparql_parser.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Tests for the SPARQL parser module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.query.sparql.parser import (
|
||||
parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib,
|
||||
)
|
||||
from trustgraph.schema import Term, IRI, LITERAL, BLANK
|
||||
|
||||
|
||||
class TestParseSparql:
|
||||
"""Tests for parse_sparql function."""
|
||||
|
||||
def test_select_query_type(self):
|
||||
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_variables(self):
|
||||
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.variables == ["s", "p", "o"]
|
||||
|
||||
def test_select_subset_variables(self):
|
||||
parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }")
|
||||
assert parsed.variables == ["s", "o"]
|
||||
|
||||
def test_ask_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"ASK { <http://example.com/a> ?p ?o }"
|
||||
)
|
||||
assert parsed.query_type == "ask"
|
||||
|
||||
def test_ask_no_variables(self):
|
||||
parsed = parse_sparql(
|
||||
"ASK { <http://example.com/a> ?p ?o }"
|
||||
)
|
||||
assert parsed.variables == []
|
||||
|
||||
def test_construct_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"CONSTRUCT { ?s <http://example.com/knows> ?o } "
|
||||
"WHERE { ?s <http://example.com/friendOf> ?o }"
|
||||
)
|
||||
assert parsed.query_type == "construct"
|
||||
|
||||
def test_describe_query_type(self):
|
||||
parsed = parse_sparql(
|
||||
"DESCRIBE <http://example.com/alice>"
|
||||
)
|
||||
assert parsed.query_type == "describe"
|
||||
|
||||
def test_select_with_limit(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s WHERE { ?s ?p ?o } LIMIT 10"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s"]
|
||||
|
||||
def test_select_with_distinct(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT DISTINCT ?s WHERE { ?s ?p ?o }"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s"]
|
||||
|
||||
def test_select_with_filter(self):
|
||||
parsed = parse_sparql(
|
||||
'SELECT ?s ?label WHERE { '
|
||||
' ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label . '
|
||||
' FILTER(CONTAINS(STR(?label), "test")) '
|
||||
'}'
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s", "label"]
|
||||
|
||||
def test_select_with_optional(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?p ?o ?label WHERE { "
|
||||
" ?s ?p ?o . "
|
||||
" OPTIONAL { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"}"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert set(parsed.variables) == {"s", "p", "o", "label"}
|
||||
|
||||
def test_select_with_union(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?label WHERE { "
|
||||
" { ?s <http://example.com/name> ?label } "
|
||||
" UNION "
|
||||
" { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"}"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_order_by(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?s ?label WHERE { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
|
||||
"ORDER BY ?label"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_group_by(self):
|
||||
parsed = parse_sparql(
|
||||
"SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } "
|
||||
"GROUP BY ?p ORDER BY DESC(?count)"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
|
||||
def test_select_with_prefixes(self):
|
||||
parsed = parse_sparql(
|
||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||
"SELECT ?s ?label WHERE { ?s rdfs:label ?label }"
|
||||
)
|
||||
assert parsed.query_type == "select"
|
||||
assert parsed.variables == ["s", "label"]
|
||||
|
||||
def test_algebra_not_none(self):
|
||||
parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }")
|
||||
assert parsed.algebra is not None
|
||||
|
||||
def test_parse_error_invalid_sparql(self):
|
||||
with pytest.raises(ParseError):
|
||||
parse_sparql("NOT VALID SPARQL AT ALL")
|
||||
|
||||
def test_parse_error_incomplete_query(self):
|
||||
with pytest.raises(ParseError):
|
||||
parse_sparql("SELECT ?s WHERE {")
|
||||
|
||||
def test_parse_error_message(self):
|
||||
with pytest.raises(ParseError, match="SPARQL parse error"):
|
||||
parse_sparql("GIBBERISH")
|
||||
|
||||
|
||||
class TestRdflibTermToTerm:
|
||||
"""Tests for rdflib-to-Term conversion."""
|
||||
|
||||
def test_uriref_to_term(self):
|
||||
from rdflib import URIRef
|
||||
term = rdflib_term_to_term(URIRef("http://example.com/alice"))
|
||||
assert term.type == IRI
|
||||
assert term.iri == "http://example.com/alice"
|
||||
|
||||
def test_literal_to_term(self):
|
||||
from rdflib import Literal
|
||||
term = rdflib_term_to_term(Literal("hello"))
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
|
||||
def test_typed_literal_to_term(self):
|
||||
from rdflib import Literal, URIRef
|
||||
term = rdflib_term_to_term(
|
||||
Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer"))
|
||||
)
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "42"
|
||||
assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer"
|
||||
|
||||
def test_lang_literal_to_term(self):
|
||||
from rdflib import Literal
|
||||
term = rdflib_term_to_term(Literal("hello", lang="en"))
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
assert term.language == "en"
|
||||
|
||||
def test_bnode_to_term(self):
|
||||
from rdflib import BNode
|
||||
term = rdflib_term_to_term(BNode("b1"))
|
||||
assert term.type == BLANK
|
||||
assert term.id == "b1"
|
||||
|
||||
|
||||
class TestTermToRdflib:
|
||||
"""Tests for Term-to-rdflib conversion."""
|
||||
|
||||
def test_iri_term_to_uriref(self):
|
||||
from rdflib import URIRef
|
||||
result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x"))
|
||||
assert isinstance(result, URIRef)
|
||||
assert str(result) == "http://example.com/x"
|
||||
|
||||
def test_literal_term_to_literal(self):
|
||||
from rdflib import Literal
|
||||
result = term_to_rdflib(Term(type=LITERAL, value="hello"))
|
||||
assert isinstance(result, Literal)
|
||||
assert str(result) == "hello"
|
||||
|
||||
def test_typed_literal_roundtrip(self):
|
||||
from rdflib import URIRef
|
||||
original = Term(
|
||||
type=LITERAL, value="42",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#integer"
|
||||
)
|
||||
rdflib_term = term_to_rdflib(original)
|
||||
assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer")
|
||||
|
||||
def test_lang_literal_roundtrip(self):
|
||||
original = Term(type=LITERAL, value="bonjour", language="fr")
|
||||
rdflib_term = term_to_rdflib(original)
|
||||
assert rdflib_term.language == "fr"
|
||||
|
||||
def test_blank_term_to_bnode(self):
|
||||
from rdflib import BNode
|
||||
result = term_to_rdflib(Term(type=BLANK, id="b1"))
|
||||
assert isinstance(result, BNode)
|
||||
345
tests/unit/test_query/test_sparql_solutions.py
Normal file
345
tests/unit/test_query/test_sparql_solutions.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""
|
||||
Tests for SPARQL solution sequence operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.solutions import (
|
||||
hash_join, left_join, union, project, distinct,
|
||||
order_by, slice_solutions, _terms_equal, _compatible,
|
||||
)
|
||||
|
||||
|
||||
# --- Test helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
def lit(v):
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
@pytest.fixture
|
||||
def alice():
|
||||
return iri("http://example.com/alice")
|
||||
|
||||
@pytest.fixture
|
||||
def bob():
|
||||
return iri("http://example.com/bob")
|
||||
|
||||
@pytest.fixture
|
||||
def carol():
|
||||
return iri("http://example.com/carol")
|
||||
|
||||
@pytest.fixture
|
||||
def knows():
|
||||
return iri("http://example.com/knows")
|
||||
|
||||
@pytest.fixture
|
||||
def name_alice():
|
||||
return lit("Alice")
|
||||
|
||||
@pytest.fixture
|
||||
def name_bob():
|
||||
return lit("Bob")
|
||||
|
||||
|
||||
class TestTermsEqual:
|
||||
|
||||
def test_equal_iris(self):
|
||||
assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a"))
|
||||
|
||||
def test_unequal_iris(self):
|
||||
assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b"))
|
||||
|
||||
def test_equal_literals(self):
|
||||
assert _terms_equal(lit("hello"), lit("hello"))
|
||||
|
||||
def test_unequal_literals(self):
|
||||
assert not _terms_equal(lit("hello"), lit("world"))
|
||||
|
||||
def test_iri_vs_literal(self):
|
||||
assert not _terms_equal(iri("hello"), lit("hello"))
|
||||
|
||||
def test_none_none(self):
|
||||
assert _terms_equal(None, None)
|
||||
|
||||
def test_none_vs_term(self):
|
||||
assert not _terms_equal(None, iri("http://x.com/a"))
|
||||
|
||||
|
||||
class TestCompatible:
|
||||
|
||||
def test_no_shared_variables(self):
|
||||
assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")})
|
||||
|
||||
def test_shared_variable_same_value(self, alice):
|
||||
assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")})
|
||||
|
||||
def test_shared_variable_different_value(self, alice, bob):
|
||||
assert not _compatible({"s": alice}, {"s": bob})
|
||||
|
||||
def test_empty_solutions(self):
|
||||
assert _compatible({}, {})
|
||||
|
||||
def test_empty_vs_nonempty(self, alice):
|
||||
assert _compatible({}, {"s": alice})
|
||||
|
||||
|
||||
class TestHashJoin:
|
||||
|
||||
def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob):
|
||||
left = [
|
||||
{"s": alice, "p": iri("http://example.com/knows"), "o": bob},
|
||||
{"s": bob, "p": iri("http://example.com/knows"), "o": alice},
|
||||
]
|
||||
right = [
|
||||
{"s": alice, "label": name_alice},
|
||||
{"s": bob, "label": name_bob},
|
||||
]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
# Check that joined solutions have all variables
|
||||
for sol in result:
|
||||
assert "s" in sol
|
||||
assert "p" in sol
|
||||
assert "o" in sol
|
||||
assert "label" in sol
|
||||
|
||||
def test_join_no_shared_variables_cross_product(self, alice, bob):
|
||||
left = [{"a": alice}]
|
||||
right = [{"b": bob}, {"b": alice}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_join_no_matches(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_empty_left(self, alice):
|
||||
result = hash_join([], [{"s": alice}])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_empty_right(self, alice):
|
||||
result = hash_join([{"s": alice}], [])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_join_multiple_matches(self, alice, name_alice):
|
||||
left = [
|
||||
{"s": alice, "p": iri("http://e.com/a")},
|
||||
{"s": alice, "p": iri("http://e.com/b")},
|
||||
]
|
||||
right = [{"s": alice, "label": name_alice}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_join_preserves_values(self, alice, name_alice):
|
||||
left = [{"s": alice, "x": lit("1")}]
|
||||
right = [{"s": alice, "y": lit("2")}]
|
||||
result = hash_join(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["x"].value == "1"
|
||||
assert result[0]["y"].value == "2"
|
||||
|
||||
|
||||
class TestLeftJoin:
|
||||
|
||||
def test_left_join_with_matches(self, alice, bob, name_alice):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
right = [{"s": alice, "label": name_alice}]
|
||||
result = left_join(left, right)
|
||||
assert len(result) == 2
|
||||
# Alice has label
|
||||
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
|
||||
assert len(alice_sols) == 1
|
||||
assert "label" in alice_sols[0]
|
||||
# Bob preserved without label
|
||||
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
|
||||
assert len(bob_sols) == 1
|
||||
assert "label" not in bob_sols[0]
|
||||
|
||||
def test_left_join_no_matches(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob, "label": lit("Bob")}]
|
||||
result = left_join(left, right)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/alice"
|
||||
assert "label" not in result[0]
|
||||
|
||||
def test_left_join_empty_right(self, alice):
|
||||
left = [{"s": alice}]
|
||||
result = left_join(left, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_left_join_empty_left(self):
|
||||
result = left_join([], [{"s": iri("http://x")}])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_left_join_with_filter(self, alice, bob):
|
||||
left = [{"s": alice}, {"s": bob}]
|
||||
right = [
|
||||
{"s": alice, "val": lit("yes")},
|
||||
{"s": bob, "val": lit("no")},
|
||||
]
|
||||
# Filter: only keep joins where val == "yes"
|
||||
result = left_join(
|
||||
left, right,
|
||||
filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes"
|
||||
)
|
||||
assert len(result) == 2
|
||||
# Alice matches filter
|
||||
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
|
||||
assert "val" in alice_sols[0]
|
||||
assert alice_sols[0]["val"].value == "yes"
|
||||
# Bob doesn't match filter, preserved without val
|
||||
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
|
||||
assert "val" not in bob_sols[0]
|
||||
|
||||
|
||||
class TestUnion:
|
||||
|
||||
def test_union_concatenates(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = union(left, right)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_union_preserves_order(self, alice, bob):
|
||||
left = [{"s": alice}]
|
||||
right = [{"s": bob}]
|
||||
result = union(left, right)
|
||||
assert result[0]["s"].iri == "http://example.com/alice"
|
||||
assert result[1]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_union_empty_left(self, alice):
|
||||
result = union([], [{"s": alice}])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_union_both_empty(self):
|
||||
result = union([], [])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_union_allows_duplicates(self, alice):
|
||||
result = union([{"s": alice}], [{"s": alice}])
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestProject:
|
||||
|
||||
def test_project_keeps_selected(self, alice, name_alice):
|
||||
solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}]
|
||||
result = project(solutions, ["s", "label"])
|
||||
assert len(result) == 1
|
||||
assert "s" in result[0]
|
||||
assert "label" in result[0]
|
||||
assert "extra" not in result[0]
|
||||
|
||||
def test_project_missing_variable(self, alice):
|
||||
solutions = [{"s": alice}]
|
||||
result = project(solutions, ["s", "missing"])
|
||||
assert len(result) == 1
|
||||
assert "s" in result[0]
|
||||
assert "missing" not in result[0]
|
||||
|
||||
def test_project_empty(self):
|
||||
result = project([], ["s"])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestDistinct:
|
||||
|
||||
def test_removes_duplicates(self, alice):
|
||||
solutions = [{"s": alice}, {"s": alice}, {"s": alice}]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_keeps_different(self, alice, bob):
|
||||
solutions = [{"s": alice}, {"s": bob}]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_empty(self):
|
||||
result = distinct([])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_multi_variable_distinct(self, alice, bob):
|
||||
solutions = [
|
||||
{"s": alice, "o": bob},
|
||||
{"s": alice, "o": bob},
|
||||
{"s": alice, "o": alice},
|
||||
]
|
||||
result = distinct(solutions)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestOrderBy:
|
||||
|
||||
def test_order_by_ascending(self):
|
||||
solutions = [
|
||||
{"label": lit("Charlie")},
|
||||
{"label": lit("Alice")},
|
||||
{"label": lit("Bob")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("label"), True)]
|
||||
result = order_by(solutions, key_fns)
|
||||
assert result[0]["label"].value == "Alice"
|
||||
assert result[1]["label"].value == "Bob"
|
||||
assert result[2]["label"].value == "Charlie"
|
||||
|
||||
def test_order_by_descending(self):
|
||||
solutions = [
|
||||
{"label": lit("Alice")},
|
||||
{"label": lit("Charlie")},
|
||||
{"label": lit("Bob")},
|
||||
]
|
||||
key_fns = [(lambda sol: sol.get("label"), False)]
|
||||
result = order_by(solutions, key_fns)
|
||||
assert result[0]["label"].value == "Charlie"
|
||||
assert result[1]["label"].value == "Bob"
|
||||
assert result[2]["label"].value == "Alice"
|
||||
|
||||
def test_order_by_empty(self):
|
||||
result = order_by([], [(lambda sol: sol.get("x"), True)])
|
||||
assert len(result) == 0
|
||||
|
||||
def test_order_by_no_keys(self, alice):
|
||||
solutions = [{"s": alice}]
|
||||
result = order_by(solutions, [])
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestSlice:
|
||||
|
||||
def test_limit(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, limit=2)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_offset(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, offset=1)
|
||||
assert len(result) == 2
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_offset_and_limit(self, alice, bob, carol):
|
||||
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
|
||||
result = slice_solutions(solutions, offset=1, limit=1)
|
||||
assert len(result) == 1
|
||||
assert result[0]["s"].iri == "http://example.com/bob"
|
||||
|
||||
def test_limit_zero(self, alice):
|
||||
result = slice_solutions([{"s": alice}], limit=0)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_offset_beyond_length(self, alice):
|
||||
result = slice_solutions([{"s": alice}], offset=10)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_no_slice(self, alice, bob):
|
||||
solutions = [{"s": alice}, {"s": bob}]
|
||||
result = slice_solutions(solutions)
|
||||
assert len(result) == 2
|
||||
|
|
@ -28,21 +28,21 @@ def triple_tx():
|
|||
|
||||
class TestTermTranslatorIri:
|
||||
|
||||
def test_iri_to_pulsar(self, term_tx):
|
||||
def test_iri_decode(self, term_tx):
|
||||
data = {"t": "i", "i": "http://example.org/Alice"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == IRI
|
||||
assert term.iri == "http://example.org/Alice"
|
||||
|
||||
def test_iri_from_pulsar(self, term_tx):
|
||||
def test_iri_encode(self, term_tx):
|
||||
term = Term(type=IRI, iri="http://example.org/Bob")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "i", "i": "http://example.org/Bob"}
|
||||
|
||||
def test_iri_round_trip(self, term_tx):
|
||||
original = Term(type=IRI, iri="http://example.org/round")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -52,21 +52,21 @@ class TestTermTranslatorIri:
|
|||
|
||||
class TestTermTranslatorBlank:
|
||||
|
||||
def test_blank_to_pulsar(self, term_tx):
|
||||
def test_blank_decode(self, term_tx):
|
||||
data = {"t": "b", "d": "_:b42"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == BLANK
|
||||
assert term.id == "_:b42"
|
||||
|
||||
def test_blank_from_pulsar(self, term_tx):
|
||||
def test_blank_encode(self, term_tx):
|
||||
term = Term(type=BLANK, id="_:node1")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "b", "d": "_:node1"}
|
||||
|
||||
def test_blank_round_trip(self, term_tx):
|
||||
original = Term(type=BLANK, id="_:x")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -76,29 +76,29 @@ class TestTermTranslatorBlank:
|
|||
|
||||
class TestTermTranslatorTypedLiteral:
|
||||
|
||||
def test_plain_literal_to_pulsar(self, term_tx):
|
||||
def test_plain_literal_decode(self, term_tx):
|
||||
data = {"t": "l", "v": "hello"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == LITERAL
|
||||
assert term.value == "hello"
|
||||
assert term.datatype == ""
|
||||
assert term.language == ""
|
||||
|
||||
def test_xsd_integer_to_pulsar(self, term_tx):
|
||||
def test_xsd_integer_decode(self, term_tx):
|
||||
data = {
|
||||
"t": "l", "v": "42",
|
||||
"dt": "http://www.w3.org/2001/XMLSchema#integer",
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == "42"
|
||||
assert term.datatype.endswith("#integer")
|
||||
|
||||
def test_typed_literal_from_pulsar(self, term_tx):
|
||||
def test_typed_literal_encode(self, term_tx):
|
||||
term = Term(
|
||||
type=LITERAL, value="3.14",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#double",
|
||||
)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["t"] == "l"
|
||||
assert wire["v"] == "3.14"
|
||||
assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
|
||||
|
|
@ -109,13 +109,13 @@ class TestTermTranslatorTypedLiteral:
|
|||
type=LITERAL, value="true",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#boolean",
|
||||
)
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
def test_plain_literal_omits_dt_and_ln(self, term_tx):
|
||||
term = Term(type=LITERAL, value="x")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert "dt" not in wire
|
||||
assert "ln" not in wire
|
||||
|
||||
|
|
@ -126,22 +126,22 @@ class TestTermTranslatorTypedLiteral:
|
|||
|
||||
class TestTermTranslatorLangLiteral:
|
||||
|
||||
def test_language_tag_to_pulsar(self, term_tx):
|
||||
def test_language_tag_decode(self, term_tx):
|
||||
data = {"t": "l", "v": "bonjour", "ln": "fr"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == "bonjour"
|
||||
assert term.language == "fr"
|
||||
|
||||
def test_language_tag_from_pulsar(self, term_tx):
|
||||
def test_language_tag_encode(self, term_tx):
|
||||
term = Term(type=LITERAL, value="colour", language="en-GB")
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["ln"] == "en-GB"
|
||||
assert "dt" not in wire # No datatype
|
||||
|
||||
def test_language_tag_round_trip(self, term_tx):
|
||||
original = Term(type=LITERAL, value="hola", language="es")
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ class TestTermTranslatorLangLiteral:
|
|||
|
||||
class TestTermTranslatorQuotedTriple:
|
||||
|
||||
def test_quoted_triple_to_pulsar(self, term_tx):
|
||||
def test_quoted_triple_decode(self, term_tx):
|
||||
data = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
|
|
@ -160,20 +160,20 @@ class TestTermTranslatorQuotedTriple:
|
|||
"o": {"t": "i", "i": "http://example.org/Bob"},
|
||||
},
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == TRIPLE
|
||||
assert term.triple is not None
|
||||
assert term.triple.s.iri == "http://example.org/Alice"
|
||||
assert term.triple.o.iri == "http://example.org/Bob"
|
||||
|
||||
def test_quoted_triple_from_pulsar(self, term_tx):
|
||||
def test_quoted_triple_encode(self, term_tx):
|
||||
inner = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="val"),
|
||||
)
|
||||
term = Term(type=TRIPLE, triple=inner)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire["t"] == "t"
|
||||
assert "tr" in wire
|
||||
assert wire["tr"]["s"]["i"] == "http://example.org/s"
|
||||
|
|
@ -186,18 +186,18 @@ class TestTermTranslatorQuotedTriple:
|
|||
o=Term(type=LITERAL, value="C", language="en"),
|
||||
)
|
||||
original = Term(type=TRIPLE, triple=inner)
|
||||
wire = term_tx.from_pulsar(original)
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
wire = term_tx.encode(original)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored.type == TRIPLE
|
||||
assert restored.triple.s == original.triple.s
|
||||
assert restored.triple.o == original.triple.o
|
||||
|
||||
def test_quoted_triple_none_triple(self, term_tx):
|
||||
term = Term(type=TRIPLE, triple=None)
|
||||
wire = term_tx.from_pulsar(term)
|
||||
wire = term_tx.encode(term)
|
||||
assert wire == {"t": "t"}
|
||||
# And back
|
||||
restored = term_tx.to_pulsar(wire)
|
||||
restored = term_tx.decode(wire)
|
||||
assert restored.type == TRIPLE
|
||||
assert restored.triple is None
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ class TestTermTranslatorQuotedTriple:
|
|||
"o": {"t": "l", "v": "A feeling of expectation"},
|
||||
},
|
||||
}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.triple.o.type == LITERAL
|
||||
assert term.triple.o.value == "A feeling of expectation"
|
||||
|
||||
|
|
@ -223,22 +223,22 @@ class TestTermTranslatorEdgeCases:
|
|||
|
||||
def test_unknown_type(self, term_tx):
|
||||
data = {"t": "z"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == "z"
|
||||
|
||||
def test_empty_type(self, term_tx):
|
||||
data = {}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.type == ""
|
||||
|
||||
def test_missing_iri_field(self, term_tx):
|
||||
data = {"t": "i"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.iri == ""
|
||||
|
||||
def test_missing_literal_fields(self, term_tx):
|
||||
data = {"t": "l"}
|
||||
term = term_tx.to_pulsar(data)
|
||||
term = term_tx.decode(data)
|
||||
assert term.value == ""
|
||||
assert term.datatype == ""
|
||||
assert term.language == ""
|
||||
|
|
@ -250,24 +250,24 @@ class TestTermTranslatorEdgeCases:
|
|||
|
||||
class TestTripleTranslator:
|
||||
|
||||
def test_triple_to_pulsar(self, triple_tx):
|
||||
def test_triple_decode(self, triple_tx):
|
||||
data = {
|
||||
"s": {"t": "i", "i": "http://example.org/s"},
|
||||
"p": {"t": "i", "i": "http://example.org/p"},
|
||||
"o": {"t": "l", "v": "object"},
|
||||
}
|
||||
triple = triple_tx.to_pulsar(data)
|
||||
triple = triple_tx.decode(data)
|
||||
assert triple.s.iri == "http://example.org/s"
|
||||
assert triple.o.value == "object"
|
||||
assert triple.g is None
|
||||
|
||||
def test_triple_from_pulsar(self, triple_tx):
|
||||
def test_triple_encode(self, triple_tx):
|
||||
triple = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/A"),
|
||||
p=Term(type=IRI, iri="http://example.org/B"),
|
||||
o=Term(type=LITERAL, value="C"),
|
||||
)
|
||||
wire = triple_tx.from_pulsar(triple)
|
||||
wire = triple_tx.encode(triple)
|
||||
assert wire["s"]["t"] == "i"
|
||||
assert wire["o"]["v"] == "C"
|
||||
assert "g" not in wire
|
||||
|
|
@ -279,17 +279,17 @@ class TestTripleTranslator:
|
|||
"o": {"t": "l", "v": "val"},
|
||||
"g": "urn:graph:source",
|
||||
}
|
||||
quad = triple_tx.to_pulsar(data)
|
||||
quad = triple_tx.decode(data)
|
||||
assert quad.g == "urn:graph:source"
|
||||
|
||||
def test_quad_from_pulsar_includes_graph(self, triple_tx):
|
||||
def test_quad_encode_includes_graph(self, triple_tx):
|
||||
quad = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="v"),
|
||||
g="urn:graph:retrieval",
|
||||
)
|
||||
wire = triple_tx.from_pulsar(quad)
|
||||
wire = triple_tx.encode(quad)
|
||||
assert wire["g"] == "urn:graph:retrieval"
|
||||
|
||||
def test_quad_round_trip(self, triple_tx):
|
||||
|
|
@ -299,8 +299,8 @@ class TestTripleTranslator:
|
|||
o=Term(type=LITERAL, value="v"),
|
||||
g="urn:graph:source",
|
||||
)
|
||||
wire = triple_tx.from_pulsar(original)
|
||||
restored = triple_tx.to_pulsar(wire)
|
||||
wire = triple_tx.encode(original)
|
||||
restored = triple_tx.decode(wire)
|
||||
assert restored == original
|
||||
|
||||
def test_none_graph_omitted_from_wire(self, triple_tx):
|
||||
|
|
@ -310,12 +310,12 @@ class TestTripleTranslator:
|
|||
o=Term(type=LITERAL, value="v"),
|
||||
g=None,
|
||||
)
|
||||
wire = triple_tx.from_pulsar(triple)
|
||||
wire = triple_tx.encode(triple)
|
||||
assert "g" not in wire
|
||||
|
||||
def test_missing_terms_handled(self, triple_tx):
|
||||
data = {}
|
||||
triple = triple_tx.to_pulsar(data)
|
||||
triple = triple_tx.decode(data)
|
||||
assert triple.s is None
|
||||
assert triple.p is None
|
||||
assert triple.o is None
|
||||
|
|
@ -342,16 +342,16 @@ class TestSubgraphTranslator:
|
|||
g="urn:graph:source",
|
||||
),
|
||||
]
|
||||
wire_list = tx.from_pulsar(triples)
|
||||
wire_list = tx.encode(triples)
|
||||
assert len(wire_list) == 2
|
||||
assert wire_list[1]["g"] == "urn:graph:source"
|
||||
|
||||
restored = tx.to_pulsar(wire_list)
|
||||
restored = tx.decode(wire_list)
|
||||
assert len(restored) == 2
|
||||
assert restored[0] == triples[0]
|
||||
assert restored[1] == triples[1]
|
||||
|
||||
def test_empty_subgraph(self):
|
||||
tx = SubgraphTranslator()
|
||||
assert tx.to_pulsar([]) == []
|
||||
assert tx.from_pulsar([]) == []
|
||||
assert tx.decode([]) == []
|
||||
assert tx.encode([]) == []
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class TestDocumentMetadataTranslator:
|
|||
"parent-id": "doc-100",
|
||||
"document-type": "page",
|
||||
}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.id == "doc-123"
|
||||
assert obj.time == 1710000000
|
||||
assert obj.kind == "application/pdf"
|
||||
|
|
@ -45,14 +45,14 @@ class TestDocumentMetadataTranslator:
|
|||
assert obj.parent_id == "doc-100"
|
||||
assert obj.document_type == "page"
|
||||
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["parent-id"] == "doc-100"
|
||||
assert wire["document-type"] == "page"
|
||||
|
||||
def test_defaults_for_missing_fields(self):
|
||||
obj = self.tx.to_pulsar({})
|
||||
obj = self.tx.decode({})
|
||||
assert obj.parent_id == ""
|
||||
assert obj.document_type == "source"
|
||||
|
||||
|
|
@ -63,25 +63,25 @@ class TestDocumentMetadataTranslator:
|
|||
"o": {"t": "i", "i": "http://example.org/o"},
|
||||
}]
|
||||
data = {"metadata": triple_wire}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert len(obj.metadata) == 1
|
||||
assert obj.metadata[0].s.iri == "http://example.org/s"
|
||||
|
||||
def test_none_metadata_handled(self):
|
||||
data = {"metadata": None}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.metadata == []
|
||||
|
||||
def test_empty_tags_preserved(self):
|
||||
data = {"tags": []}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_falsy_fields_omitted_from_wire(self):
|
||||
"""Empty string fields should be omitted from wire format."""
|
||||
obj = DocumentMetadata(id="", time=0, user="")
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert "id" not in wire
|
||||
assert "user" not in wire
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class TestProcessingMetadataTranslator:
|
|||
"collection": "my-collection",
|
||||
"tags": ["tag1"],
|
||||
}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
obj = self.tx.decode(data)
|
||||
assert obj.id == "proc-1"
|
||||
assert obj.document_id == "doc-123"
|
||||
assert obj.flow == "default"
|
||||
|
|
@ -113,32 +113,32 @@ class TestProcessingMetadataTranslator:
|
|||
assert obj.collection == "my-collection"
|
||||
assert obj.tags == ["tag1"]
|
||||
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["id"] == "proc-1"
|
||||
assert wire["document-id"] == "doc-123"
|
||||
assert wire["user"] == "alice"
|
||||
assert wire["collection"] == "my-collection"
|
||||
|
||||
def test_missing_fields_use_defaults(self):
|
||||
obj = self.tx.to_pulsar({})
|
||||
obj = self.tx.decode({})
|
||||
assert obj.id is None
|
||||
assert obj.user is None
|
||||
assert obj.collection is None
|
||||
|
||||
def test_tags_none_omitted(self):
|
||||
obj = ProcessingMetadata(tags=None)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert "tags" not in wire
|
||||
|
||||
def test_tags_empty_list_preserved(self):
|
||||
obj = ProcessingMetadata(tags=[])
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["tags"] == []
|
||||
|
||||
def test_user_and_collection_preserved(self):
|
||||
"""Core pipeline routing fields must survive round-trip."""
|
||||
data = {"user": "bob", "collection": "research"}
|
||||
obj = self.tx.to_pulsar(data)
|
||||
wire = self.tx.from_pulsar(obj)
|
||||
obj = self.tx.decode(data)
|
||||
wire = self.tx.encode(obj)
|
||||
assert wire["user"] == "bob"
|
||||
assert wire["collection"] == "research"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class TestRequestTranslation:
|
|||
}
|
||||
|
||||
# Translate to Pulsar
|
||||
pulsar_msg = translator.to_pulsar(api_data)
|
||||
pulsar_msg = translator.decode(api_data)
|
||||
|
||||
assert pulsar_msg.operation == "schema-selection"
|
||||
assert pulsar_msg.sample == "test data sample"
|
||||
|
|
@ -46,7 +46,7 @@ class TestRequestTranslation:
|
|||
"options": {"delimiter": ","}
|
||||
}
|
||||
|
||||
pulsar_msg = translator.to_pulsar(api_data)
|
||||
pulsar_msg = translator.decode(api_data)
|
||||
|
||||
assert pulsar_msg.operation == "generate-descriptor"
|
||||
assert pulsar_msg.sample == "csv data"
|
||||
|
|
@ -70,7 +70,7 @@ class TestResponseTranslation:
|
|||
)
|
||||
|
||||
# Translate to API format
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
assert api_data["schema-matches"] == ["products", "inventory", "catalog"]
|
||||
|
|
@ -86,7 +86,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
assert api_data["schema-matches"] == []
|
||||
|
|
@ -103,7 +103,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "detect-type"
|
||||
assert api_data["detected-type"] == "xml"
|
||||
|
|
@ -123,7 +123,7 @@ class TestResponseTranslation:
|
|||
)
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
# Error objects are typically handled separately by the gateway
|
||||
|
|
@ -146,7 +146,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data = translator.from_pulsar(pulsar_response)
|
||||
api_data = translator.encode(pulsar_response)
|
||||
|
||||
assert api_data["operation"] == "diagnose"
|
||||
assert api_data["detected-type"] == "csv"
|
||||
|
|
@ -165,7 +165,7 @@ class TestResponseTranslation:
|
|||
error=None
|
||||
)
|
||||
|
||||
api_data, is_final = translator.from_response_with_completion(pulsar_response)
|
||||
api_data, is_final = translator.encode_with_completion(pulsar_response)
|
||||
|
||||
assert is_final is True # Structured-diag responses are always final
|
||||
assert api_data["operation"] == "schema-selection"
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ dependencies = [
|
|||
"prometheus-client",
|
||||
"requests",
|
||||
"python-logging-loki",
|
||||
"pika",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
|||
|
|
@ -81,7 +81,12 @@ from .explainability import (
|
|||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Observation,
|
||||
Conclusion,
|
||||
Decomposition,
|
||||
Finding,
|
||||
Plan,
|
||||
StepResult,
|
||||
EdgeSelection,
|
||||
wire_triples_to_tuples,
|
||||
extract_term_value,
|
||||
|
|
@ -160,6 +165,7 @@ __all__ = [
|
|||
"Focus",
|
||||
"Synthesis",
|
||||
"Analysis",
|
||||
"Observation",
|
||||
"Conclusion",
|
||||
"EdgeSelection",
|
||||
"wire_triples_to_tuples",
|
||||
|
|
|
|||
|
|
@ -40,15 +40,25 @@ TG_ANSWER_TYPE = TG + "Answer"
|
|||
TG_REFLECTION_TYPE = TG + "Reflection"
|
||||
TG_THOUGHT_TYPE = TG + "Thought"
|
||||
TG_OBSERVATION_TYPE = TG + "Observation"
|
||||
TG_TOOL_USE = TG + "ToolUse"
|
||||
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
|
||||
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
|
||||
TG_AGENT_QUESTION = TG + "AgentQuestion"
|
||||
|
||||
# Orchestrator entity types
|
||||
TG_DECOMPOSITION = TG + "Decomposition"
|
||||
TG_FINDING = TG + "Finding"
|
||||
TG_PLAN_TYPE = TG + "Plan"
|
||||
TG_STEP_RESULT = TG + "StepResult"
|
||||
|
||||
# Orchestrator predicates
|
||||
TG_SUBAGENT_GOAL = TG + "subagentGoal"
|
||||
TG_PLAN_STEP = TG + "planStep"
|
||||
|
||||
# PROV-O predicates
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
|
||||
|
||||
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
|
@ -82,8 +92,18 @@ class ExplainEntity:
|
|||
return Exploration.from_triples(uri, triples)
|
||||
elif TG_FOCUS in types:
|
||||
return Focus.from_triples(uri, triples)
|
||||
elif TG_DECOMPOSITION in types:
|
||||
return Decomposition.from_triples(uri, triples)
|
||||
elif TG_FINDING in types:
|
||||
return Finding.from_triples(uri, triples)
|
||||
elif TG_PLAN_TYPE in types:
|
||||
return Plan.from_triples(uri, triples)
|
||||
elif TG_STEP_RESULT in types:
|
||||
return StepResult.from_triples(uri, triples)
|
||||
elif TG_SYNTHESIS in types:
|
||||
return Synthesis.from_triples(uri, triples)
|
||||
elif TG_OBSERVATION_TYPE in types and TG_REFLECTION_TYPE not in types:
|
||||
return Observation.from_triples(uri, triples)
|
||||
elif TG_REFLECTION_TYPE in types:
|
||||
return Reflection.from_triples(uri, triples)
|
||||
elif TG_ANALYSIS in types:
|
||||
|
|
@ -261,18 +281,16 @@ class Reflection(ExplainEntity):
|
|||
|
||||
@dataclass
|
||||
class Analysis(ExplainEntity):
|
||||
"""Analysis entity - one think/act/observe cycle (Agent only)."""
|
||||
"""Analysis+ToolUse entity - decision + tool call (Agent only)."""
|
||||
action: str = ""
|
||||
arguments: str = "" # JSON string
|
||||
thought: str = ""
|
||||
observation: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis":
|
||||
action = ""
|
||||
arguments = ""
|
||||
thought = ""
|
||||
observation = ""
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_ACTION:
|
||||
|
|
@ -281,8 +299,6 @@ class Analysis(ExplainEntity):
|
|||
arguments = o
|
||||
elif p == TG_THOUGHT:
|
||||
thought = o
|
||||
elif p == TG_OBSERVATION:
|
||||
observation = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
|
|
@ -290,7 +306,26 @@ class Analysis(ExplainEntity):
|
|||
action=action,
|
||||
arguments=arguments,
|
||||
thought=thought,
|
||||
observation=observation
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Observation(ExplainEntity):
|
||||
"""Observation entity - standalone tool result (Agent only)."""
|
||||
document: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Observation":
|
||||
document = ""
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_DOCUMENT:
|
||||
document = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="observation",
|
||||
document=document,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -314,6 +349,70 @@ class Conclusion(ExplainEntity):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Decomposition(ExplainEntity):
|
||||
"""Decomposition entity - supervisor broke question into sub-goals."""
|
||||
goals: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Decomposition":
|
||||
goals = []
|
||||
for s, p, o in triples:
|
||||
if p == TG_SUBAGENT_GOAL:
|
||||
goals.append(o)
|
||||
return cls(uri=uri, entity_type="decomposition", goals=goals)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finding(ExplainEntity):
|
||||
"""Finding entity - a subagent's result."""
|
||||
goal: str = ""
|
||||
document: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Finding":
|
||||
goal = ""
|
||||
document = ""
|
||||
for s, p, o in triples:
|
||||
if p == TG_SUBAGENT_GOAL:
|
||||
goal = o
|
||||
elif p == TG_DOCUMENT:
|
||||
document = o
|
||||
return cls(uri=uri, entity_type="finding", goal=goal, document=document)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plan(ExplainEntity):
|
||||
"""Plan entity - a structured plan of steps."""
|
||||
steps: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Plan":
|
||||
steps = []
|
||||
for s, p, o in triples:
|
||||
if p == TG_PLAN_STEP:
|
||||
steps.append(o)
|
||||
return cls(uri=uri, entity_type="plan", steps=steps)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult(ExplainEntity):
|
||||
"""StepResult entity - a plan step's result."""
|
||||
step: str = ""
|
||||
document: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "StepResult":
|
||||
step = ""
|
||||
document = ""
|
||||
for s, p, o in triples:
|
||||
if p == TG_PLAN_STEP:
|
||||
step = o
|
||||
elif p == TG_DOCUMENT:
|
||||
document = o
|
||||
return cls(uri=uri, entity_type="step-result", step=step, document=document)
|
||||
|
||||
|
||||
def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSelection:
|
||||
"""Parse triples for an edge selection entity."""
|
||||
uri = triples[0][0] if triples else ""
|
||||
|
|
@ -675,9 +774,9 @@ class ExplainabilityClient:
|
|||
return trace
|
||||
trace["question"] = question
|
||||
|
||||
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
|
||||
# Find grounding: ?grounding prov:wasDerivedFrom question_uri
|
||||
grounding_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=question_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
|
|
@ -812,9 +911,9 @@ class ExplainabilityClient:
|
|||
return trace
|
||||
trace["question"] = question
|
||||
|
||||
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
|
||||
# Find grounding: ?grounding prov:wasDerivedFrom question_uri
|
||||
grounding_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=question_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
|
|
@ -895,7 +994,10 @@ class ExplainabilityClient:
|
|||
"""
|
||||
Fetch the complete Agent trace starting from a session URI.
|
||||
|
||||
Follows the provenance chain: Question -> Analysis(s) -> Conclusion
|
||||
Follows the provenance chain for all patterns:
|
||||
- ReAct: Question -> Analysis(s) -> Conclusion
|
||||
- Supervisor: Question -> Decomposition -> Finding(s) -> Synthesis
|
||||
- Plan-then-Execute: Question -> Plan -> StepResult(s) -> Synthesis
|
||||
|
||||
Args:
|
||||
session_uri: The agent session/question URI
|
||||
|
|
@ -906,15 +1008,14 @@ class ExplainabilityClient:
|
|||
max_content: Maximum content length for conclusion
|
||||
|
||||
Returns:
|
||||
Dict with question, iterations (Analysis list), conclusion entities
|
||||
Dict with question, steps (mixed entity list), conclusion/synthesis
|
||||
"""
|
||||
if graph is None:
|
||||
graph = "urn:graph:retrieval"
|
||||
|
||||
trace = {
|
||||
"question": None,
|
||||
"iterations": [],
|
||||
"conclusion": None,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Fetch question/session
|
||||
|
|
@ -923,65 +1024,89 @@ class ExplainabilityClient:
|
|||
return trace
|
||||
trace["question"] = question
|
||||
|
||||
# Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after
|
||||
current_uri = session_uri
|
||||
is_first = True
|
||||
max_iterations = 50 # Safety limit
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# First hop uses wasGeneratedBy (entity←activity),
|
||||
# subsequent hops use wasDerivedFrom (entity←entity)
|
||||
if is_first:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
# Fall back to wasDerivedFrom for backwards compatibility
|
||||
if not derived_triples:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
is_first = False
|
||||
else:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
|
||||
if not derived_triples:
|
||||
break
|
||||
|
||||
derived_uri = extract_term_value(derived_triples[0].get("s", {}))
|
||||
if not derived_uri:
|
||||
break
|
||||
|
||||
entity = self.fetch_entity(derived_uri, graph, user, collection)
|
||||
|
||||
if isinstance(entity, Analysis):
|
||||
trace["iterations"].append(entity)
|
||||
current_uri = derived_uri
|
||||
elif isinstance(entity, Conclusion):
|
||||
trace["conclusion"] = entity
|
||||
break
|
||||
else:
|
||||
# Unknown entity type, stop
|
||||
break
|
||||
# Follow the provenance chain from the question
|
||||
self._follow_provenance_chain(
|
||||
session_uri, trace, graph, user, collection,
|
||||
max_depth=50,
|
||||
)
|
||||
|
||||
return trace
|
||||
|
||||
def _follow_provenance_chain(
|
||||
self, current_uri, trace, graph, user, collection,
|
||||
max_depth=50,
|
||||
):
|
||||
"""Recursively follow the provenance chain, handling branches."""
|
||||
if max_depth <= 0:
|
||||
return
|
||||
|
||||
# Find entities derived from current_uri
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph, user=user, collection=collection,
|
||||
limit=20
|
||||
)
|
||||
|
||||
if not derived_triples:
|
||||
return
|
||||
|
||||
derived_uris = [
|
||||
extract_term_value(t.get("s", {}))
|
||||
for t in derived_triples
|
||||
]
|
||||
|
||||
for derived_uri in derived_uris:
|
||||
if not derived_uri:
|
||||
continue
|
||||
|
||||
entity = self.fetch_entity(derived_uri, graph, user, collection)
|
||||
if entity is None:
|
||||
continue
|
||||
|
||||
if isinstance(entity, (Analysis, Observation, Decomposition,
|
||||
Finding, Plan, StepResult)):
|
||||
trace["steps"].append(entity)
|
||||
|
||||
# Continue following from this entity
|
||||
self._follow_provenance_chain(
|
||||
derived_uri, trace, graph, user, collection,
|
||||
max_depth=max_depth - 1,
|
||||
)
|
||||
|
||||
elif isinstance(entity, Question):
|
||||
# Sub-trace: a RAG session linked to this agent step.
|
||||
# Fetch the full sub-trace and embed it.
|
||||
if entity.question_type == "graph-rag":
|
||||
sub_trace = self.fetch_graphrag_trace(
|
||||
derived_uri, graph, user, collection,
|
||||
)
|
||||
elif entity.question_type == "document-rag":
|
||||
sub_trace = self.fetch_docrag_trace(
|
||||
derived_uri, graph, user, collection,
|
||||
)
|
||||
else:
|
||||
sub_trace = None
|
||||
|
||||
if sub_trace:
|
||||
trace["steps"].append({
|
||||
"type": "sub-trace",
|
||||
"question": entity,
|
||||
"trace": sub_trace,
|
||||
})
|
||||
|
||||
# Continue from the sub-trace's terminal entity
|
||||
# (Observation may derive from Synthesis)
|
||||
terminal = sub_trace.get("synthesis")
|
||||
if terminal:
|
||||
self._follow_provenance_chain(
|
||||
terminal.uri, trace, graph, user, collection,
|
||||
max_depth=max_depth - 1,
|
||||
)
|
||||
|
||||
elif isinstance(entity, (Conclusion, Synthesis)):
|
||||
trace["steps"].append(entity)
|
||||
|
||||
def list_sessions(
|
||||
self,
|
||||
graph: Optional[str] = None,
|
||||
|
|
@ -1021,10 +1146,25 @@ class ExplainabilityClient:
|
|||
if isinstance(entity, Question):
|
||||
questions.append(entity)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
questions.sort(key=lambda q: q.timestamp or "", reverse=True)
|
||||
# Filter out sub-traces: sessions that have a wasDerivedFrom link
|
||||
# (they are child sessions linked to a parent agent iteration)
|
||||
top_level = []
|
||||
for q in questions:
|
||||
parent_triples = self.flow.triples_query(
|
||||
s=q.uri,
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=1
|
||||
)
|
||||
if not parent_triples:
|
||||
top_level.append(q)
|
||||
|
||||
return questions
|
||||
# Sort by timestamp (newest first)
|
||||
top_level.sort(key=lambda q: q.timestamp or "", reverse=True)
|
||||
|
||||
return top_level
|
||||
|
||||
def detect_session_type(
|
||||
self,
|
||||
|
|
@ -1066,23 +1206,14 @@ class ExplainabilityClient:
|
|||
limit=5
|
||||
)
|
||||
|
||||
generated_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
o=session_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=5
|
||||
)
|
||||
|
||||
all_child_uris = [
|
||||
extract_term_value(t.get("s", {}))
|
||||
for t in (derived_triples + generated_triples)
|
||||
for t in derived_triples
|
||||
]
|
||||
|
||||
for child_uri in all_child_uris:
|
||||
entity = self.fetch_entity(child_uri, graph, user, collection)
|
||||
if isinstance(entity, Analysis):
|
||||
if isinstance(entity, (Analysis, Decomposition, Plan)):
|
||||
return "agent"
|
||||
if isinstance(entity, Exploration):
|
||||
return "graphrag"
|
||||
|
|
|
|||
|
|
@ -1122,6 +1122,45 @@ class FlowInstance:
|
|||
|
||||
return result
|
||||
|
||||
def sparql_query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
limit=10000
|
||||
):
|
||||
"""
|
||||
Execute a SPARQL query against the knowledge graph.
|
||||
|
||||
Args:
|
||||
query: SPARQL 1.1 query string
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
limit: Safety limit on results (default: 10000)
|
||||
|
||||
Returns:
|
||||
dict with query results. Structure depends on query type:
|
||||
- SELECT: {"query-type": "select", "variables": [...], "bindings": [...]}
|
||||
- ASK: {"query-type": "ask", "ask-result": bool}
|
||||
- CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]}
|
||||
|
||||
Raises:
|
||||
ProtocolException: If an error occurs
|
||||
"""
|
||||
|
||||
input = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
response = self.request("service/sparql", input)
|
||||
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
def nlp_query(self, question, max_results=100):
|
||||
"""
|
||||
Convert a natural language question to a GraphQL query.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ logger = logging.getLogger(__name__)
|
|||
# Lower threshold provides progress feedback and resumability on slower connections
|
||||
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
|
||||
|
||||
# Default chunk size (5MB - S3 multipart minimum)
|
||||
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
|
||||
# Default chunk size (3MB - stays under broker message size limits
|
||||
# after base64 encoding ~4MB)
|
||||
DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024
|
||||
|
||||
|
||||
def to_value(x):
|
||||
|
|
|
|||
|
|
@ -366,59 +366,39 @@ class SocketClient:
|
|||
# Handle GraphRAG/DocRAG message format with message_type
|
||||
if message_type == "explain":
|
||||
if include_provenance:
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
# Handle Agent message format with chunk_type="explain"
|
||||
if chunk_type == "explain":
|
||||
if include_provenance:
|
||||
return ProvenanceEvent(
|
||||
explain_id=resp.get("explain_id", ""),
|
||||
explain_graph=resp.get("explain_graph", "")
|
||||
)
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
)
|
||||
elif chunk_type == "observation":
|
||||
return AgentObservation(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
)
|
||||
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
end_of_dialog=resp.get("end_of_dialog", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
# Non-streaming agent format: chunk_type is empty but has thought/observation/answer fields
|
||||
elif resp.get("thought"):
|
||||
return AgentThought(
|
||||
content=resp.get("thought", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif resp.get("observation"):
|
||||
return AgentObservation(
|
||||
content=resp.get("observation", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif resp.get("answer"):
|
||||
return AgentAnswer(
|
||||
content=resp.get("answer", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
)
|
||||
else:
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
|
|
@ -427,6 +407,42 @@ class SocketClient:
|
|||
error=None
|
||||
)
|
||||
|
||||
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
|
||||
"""Build a ProvenanceEvent from a response dict, parsing inline triples
|
||||
into an ExplainEntity if available."""
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_graph = resp.get("explain_graph", "")
|
||||
raw_triples = resp.get("explain_triples", [])
|
||||
|
||||
entity = None
|
||||
if raw_triples:
|
||||
try:
|
||||
from .explainability import ExplainEntity
|
||||
# Convert wire-format triple dicts to (s, p, o) tuples
|
||||
parsed = []
|
||||
for t in raw_triples:
|
||||
s = t.get("s", {}).get("i", "") if t.get("s") else ""
|
||||
p = t.get("p", {}).get("i", "") if t.get("p") else ""
|
||||
o_term = t.get("o", {})
|
||||
if o_term:
|
||||
if o_term.get("t") == "i":
|
||||
o = o_term.get("i", "")
|
||||
else:
|
||||
o = o_term.get("v", "")
|
||||
else:
|
||||
o = ""
|
||||
parsed.append((s, p, o))
|
||||
entity = ExplainEntity.from_triples(explain_id, parsed)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ProvenanceEvent(
|
||||
explain_id=explain_id,
|
||||
explain_graph=explain_graph,
|
||||
entity=entity,
|
||||
triples=raw_triples,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the persistent WebSocket connection."""
|
||||
if self._loop and not self._loop.is_closed():
|
||||
|
|
@ -826,6 +842,31 @@ class SocketFlowInstance:
|
|||
else:
|
||||
yield response
|
||||
|
||||
def sparql_query_stream(
|
||||
self,
|
||||
query: str,
|
||||
user: str = "trustgraph",
|
||||
collection: str = "default",
|
||||
limit: int = 10000,
|
||||
batch_size: int = 20,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Execute a SPARQL query with streaming batches."""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit,
|
||||
"streaming": True,
|
||||
"batch-size": batch_size,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
for response in self.client._send_request_sync(
|
||||
"sparql", self.flow_id, request, streaming_raw=True
|
||||
):
|
||||
yield response
|
||||
|
||||
def rows_query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -150,8 +150,10 @@ class AgentThought(StreamingChunk):
|
|||
content: Agent's thought text
|
||||
end_of_message: True if this completes the current thought
|
||||
chunk_type: Always "thought"
|
||||
message_id: Provenance URI of the entity being built
|
||||
"""
|
||||
chunk_type: str = "thought"
|
||||
message_id: str = ""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentObservation(StreamingChunk):
|
||||
|
|
@ -165,8 +167,10 @@ class AgentObservation(StreamingChunk):
|
|||
content: Observation text describing tool results
|
||||
end_of_message: True if this completes the current observation
|
||||
chunk_type: Always "observation"
|
||||
message_id: Provenance URI of the entity being built
|
||||
"""
|
||||
chunk_type: str = "observation"
|
||||
message_id: str = ""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentAnswer(StreamingChunk):
|
||||
|
|
@ -184,6 +188,7 @@ class AgentAnswer(StreamingChunk):
|
|||
"""
|
||||
chunk_type: str = "final-answer"
|
||||
end_of_dialog: bool = False
|
||||
message_id: str = ""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RAGChunk(StreamingChunk):
|
||||
|
|
@ -208,25 +213,47 @@ class ProvenanceEvent:
|
|||
"""
|
||||
Provenance event for explainability.
|
||||
|
||||
Emitted during GraphRAG queries when explainable mode is enabled.
|
||||
Emitted during retrieval queries when explainable mode is enabled.
|
||||
Each event represents a provenance node created during query processing.
|
||||
|
||||
Attributes:
|
||||
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
|
||||
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
|
||||
event_type: Type of provenance event (question, exploration, focus, synthesis)
|
||||
event_type: Type of provenance event (question, exploration, focus, synthesis, etc.)
|
||||
entity: Parsed ExplainEntity from inline triples (if available)
|
||||
triples: Raw triples from the response (wire format dicts)
|
||||
"""
|
||||
explain_id: str
|
||||
explain_graph: str = ""
|
||||
event_type: str = "" # Derived from explain_id
|
||||
entity: object = None # ExplainEntity (parsed from triples)
|
||||
triples: list = dataclasses.field(default_factory=list) # Raw wire-format triple dicts
|
||||
|
||||
def __post_init__(self):
|
||||
# Extract event type from explain_id
|
||||
if "question" in self.explain_id:
|
||||
self.event_type = "question"
|
||||
elif "grounding" in self.explain_id:
|
||||
self.event_type = "grounding"
|
||||
elif "exploration" in self.explain_id:
|
||||
self.event_type = "exploration"
|
||||
elif "focus" in self.explain_id:
|
||||
self.event_type = "focus"
|
||||
elif "synthesis" in self.explain_id:
|
||||
self.event_type = "synthesis"
|
||||
elif "iteration" in self.explain_id:
|
||||
self.event_type = "iteration"
|
||||
elif "observation" in self.explain_id:
|
||||
self.event_type = "observation"
|
||||
elif "conclusion" in self.explain_id:
|
||||
self.event_type = "conclusion"
|
||||
elif "decomposition" in self.explain_id:
|
||||
self.event_type = "decomposition"
|
||||
elif "finding" in self.explain_id:
|
||||
self.event_type = "finding"
|
||||
elif "plan" in self.explain_id:
|
||||
self.event_type = "plan"
|
||||
elif "step-result" in self.explain_id:
|
||||
self.event_type = "step-result"
|
||||
elif "session" in self.explain_id:
|
||||
self.event_type = "session"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from . pubsub import PulsarClient, get_pubsub
|
||||
from . pubsub import get_pubsub, add_pubsub_args
|
||||
from . async_processor import AsyncProcessor
|
||||
from . consumer import Consumer
|
||||
from . producer import Producer
|
||||
|
|
@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
|
|||
from . subscriber_spec import SubscriberSpec
|
||||
from . request_response_spec import RequestResponseSpec
|
||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||
from . librarian_client import LibrarianClient
|
||||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
|
|
|
|||
|
|
@ -57,8 +57,7 @@ class AgentClient(RequestResponse):
|
|||
await self.request(
|
||||
AgentRequest(
|
||||
question = question,
|
||||
plan = plan,
|
||||
state = state,
|
||||
state = state or "",
|
||||
history = history,
|
||||
),
|
||||
recipient=recipient,
|
||||
|
|
|
|||
|
|
@ -90,9 +90,6 @@ class AgentService(FlowProcessor):
|
|||
type = "agent-error",
|
||||
message = str(e),
|
||||
),
|
||||
thought = None,
|
||||
observation = None,
|
||||
answer = None,
|
||||
end_of_message = True,
|
||||
end_of_dialog = True,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -1,24 +1,29 @@
|
|||
|
||||
# Base class for processors. Implements:
|
||||
# - Pulsar client, subscribe and consume basic
|
||||
# - Pub/sub client, subscribe and consume basic
|
||||
# - the async startup logic
|
||||
# - Config notify handling with subscribe-then-fetch pattern
|
||||
# - Initialising metrics
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import _pulsar
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
from prometheus_client import start_http_server, Info
|
||||
|
||||
from .. schema import ConfigPush, config_push_queue
|
||||
from .. schema import ConfigPush, ConfigRequest, ConfigResponse
|
||||
from .. schema import config_push_queue, config_request_queue
|
||||
from .. schema import config_response_queue
|
||||
from .. log_level import LogLevel
|
||||
from . pubsub import PulsarClient, get_pubsub
|
||||
from . pubsub import get_pubsub, add_pubsub_args
|
||||
from . producer import Producer
|
||||
from . consumer import Consumer
|
||||
from . metrics import ProcessorMetrics, ConsumerMetrics
|
||||
from . subscriber import Subscriber
|
||||
from . request_response_spec import RequestResponse
|
||||
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||
from . metrics import SubscriberMetrics
|
||||
from . logging import add_logging_args, setup_logging
|
||||
|
||||
default_config_queue = config_push_queue
|
||||
|
|
@ -58,9 +63,13 @@ class AsyncProcessor:
|
|||
"config_push_queue", default_config_queue
|
||||
)
|
||||
|
||||
# This records registered configuration handlers
|
||||
# This records registered configuration handlers, each entry is:
|
||||
# { "handler": async_fn, "types": set_or_none }
|
||||
self.config_handlers = []
|
||||
|
||||
# Track the current config version for dedup
|
||||
self.config_version = 0
|
||||
|
||||
# Create a random ID for this subscription to the configuration
|
||||
# service
|
||||
config_subscriber_id = str(uuid.uuid4())
|
||||
|
|
@ -69,33 +78,104 @@ class AsyncProcessor:
|
|||
processor = self.id, flow = None, name = "config",
|
||||
)
|
||||
|
||||
# Subscribe to config queue
|
||||
# Subscribe to config notify queue
|
||||
self.config_sub_task = Consumer(
|
||||
|
||||
taskgroup = self.taskgroup,
|
||||
backend = self.pubsub_backend, # Changed from client to backend
|
||||
backend = self.pubsub_backend,
|
||||
subscriber = config_subscriber_id,
|
||||
flow = None,
|
||||
|
||||
topic = self.config_push_queue,
|
||||
schema = ConfigPush,
|
||||
|
||||
handler = self.on_config_change,
|
||||
handler = self.on_config_notify,
|
||||
|
||||
metrics = config_consumer_metrics,
|
||||
|
||||
# This causes new subscriptions to view the entire history of
|
||||
# configuration
|
||||
start_of_messages = True
|
||||
start_of_messages = False,
|
||||
)
|
||||
|
||||
self.running = True
|
||||
|
||||
# This is called to start dynamic behaviour. An over-ride point for
|
||||
# extra functionality
|
||||
def _create_config_client(self):
|
||||
"""Create a short-lived config request/response client."""
|
||||
config_rr_id = str(uuid.uuid4())
|
||||
|
||||
config_req_metrics = ProducerMetrics(
|
||||
processor = self.id, flow = None, name = "config-request",
|
||||
)
|
||||
config_resp_metrics = SubscriberMetrics(
|
||||
processor = self.id, flow = None, name = "config-response",
|
||||
)
|
||||
|
||||
return RequestResponse(
|
||||
backend = self.pubsub_backend,
|
||||
subscription = f"{self.id}--config--{config_rr_id}",
|
||||
consumer_name = self.id,
|
||||
request_topic = config_request_queue,
|
||||
request_schema = ConfigRequest,
|
||||
request_metrics = config_req_metrics,
|
||||
response_topic = config_response_queue,
|
||||
response_schema = ConfigResponse,
|
||||
response_metrics = config_resp_metrics,
|
||||
)
|
||||
|
||||
async def fetch_config(self):
|
||||
"""Fetch full config from config service using a short-lived
|
||||
request/response client. Returns (config, version) or raises."""
|
||||
client = self._create_config_client()
|
||||
try:
|
||||
await client.start()
|
||||
resp = await client.request(
|
||||
ConfigRequest(operation="config"),
|
||||
timeout=10,
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(f"Config error: {resp.error.message}")
|
||||
return resp.config, resp.version
|
||||
finally:
|
||||
await client.stop()
|
||||
|
||||
# This is called to start dynamic behaviour.
|
||||
# Implements the subscribe-then-fetch pattern to avoid race conditions.
|
||||
async def start(self):
|
||||
|
||||
# 1. Start the notify consumer (begins buffering incoming notifys)
|
||||
await self.config_sub_task.start()
|
||||
|
||||
# 2. Fetch current config via request/response
|
||||
await self.fetch_and_apply_config()
|
||||
|
||||
# 3. Any buffered notifys with version > fetched version will be
|
||||
# processed by on_config_notify, which does the version check
|
||||
|
||||
async def fetch_and_apply_config(self):
|
||||
"""Fetch full config from config service and apply to all handlers.
|
||||
Retries until successful — config service may not be ready yet."""
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
config, version = await self.fetch_config()
|
||||
|
||||
logger.info(f"Fetched config version {version}")
|
||||
|
||||
self.config_version = version
|
||||
|
||||
# Apply to all handlers (startup = invoke all)
|
||||
for entry in self.config_handlers:
|
||||
await entry["handler"](config, version)
|
||||
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Config fetch failed: {e}, retrying in 2s...",
|
||||
exc_info=True
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# This is called to stop all threads. An over-ride point for extra
|
||||
# functionality
|
||||
def stop(self):
|
||||
|
|
@ -111,20 +191,66 @@ class AsyncProcessor:
|
|||
def pulsar_host(self): return self._pulsar_host
|
||||
|
||||
# Register a new event handler for configuration change
|
||||
def register_config_handler(self, handler):
|
||||
self.config_handlers.append(handler)
|
||||
def register_config_handler(self, handler, types=None):
|
||||
self.config_handlers.append({
|
||||
"handler": handler,
|
||||
"types": set(types) if types else None,
|
||||
})
|
||||
|
||||
# Called when a new configuration message push occurs
|
||||
async def on_config_change(self, message, consumer, flow):
|
||||
# Called when a config notify message arrives
|
||||
async def on_config_notify(self, message, consumer, flow):
|
||||
|
||||
# Get configuration data and version number
|
||||
config = message.value().config
|
||||
version = message.value().version
|
||||
notify_version = message.value().version
|
||||
notify_types = set(message.value().types)
|
||||
|
||||
# Invoke message handlers
|
||||
logger.info(f"Config change event: version={version}")
|
||||
for ch in self.config_handlers:
|
||||
await ch(config, version)
|
||||
# Skip if we already have this version or newer
|
||||
if notify_version <= self.config_version:
|
||||
logger.debug(
|
||||
f"Ignoring config notify v{notify_version}, "
|
||||
f"already at v{self.config_version}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if any handler cares about the affected types
|
||||
if notify_types:
|
||||
any_interested = False
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types is None or notify_types & handler_types:
|
||||
any_interested = True
|
||||
break
|
||||
|
||||
if not any_interested:
|
||||
logger.debug(
|
||||
f"Ignoring config notify v{notify_version}, "
|
||||
f"no handlers for types {notify_types}"
|
||||
)
|
||||
self.config_version = notify_version
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Config notify v{notify_version} types={list(notify_types)}, "
|
||||
f"fetching config..."
|
||||
)
|
||||
|
||||
# Fetch full config using short-lived client
|
||||
try:
|
||||
config, version = await self.fetch_config()
|
||||
|
||||
self.config_version = version
|
||||
|
||||
# Invoke handlers that care about the affected types
|
||||
for entry in self.config_handlers:
|
||||
handler_types = entry["types"]
|
||||
if handler_types is None:
|
||||
await entry["handler"](config, version)
|
||||
elif not notify_types or notify_types & handler_types:
|
||||
await entry["handler"](config, version)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch config on notify: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# This is the 'main' body of the handler. It is a point to override
|
||||
# if needed. By default does nothing. Processors are implemented
|
||||
|
|
@ -182,7 +308,7 @@ class AsyncProcessor:
|
|||
prog=ident,
|
||||
description=doc
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--id',
|
||||
default=ident,
|
||||
|
|
@ -223,8 +349,8 @@ class AsyncProcessor:
|
|||
logger.info("Keyboard interrupt.")
|
||||
return
|
||||
|
||||
except _pulsar.Interrupted:
|
||||
logger.info("Pulsar Interrupted.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted.")
|
||||
return
|
||||
|
||||
# Exceptions from a taskgroup come in as an exception group
|
||||
|
|
@ -250,15 +376,7 @@ class AsyncProcessor:
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
# Pub/sub backend selection
|
||||
parser.add_argument(
|
||||
'--pubsub-backend',
|
||||
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||
choices=['pulsar', 'mqtt'],
|
||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||
)
|
||||
|
||||
PulsarClient.add_args(parser)
|
||||
add_pubsub_args(parser)
|
||||
add_logging_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -280,4 +398,3 @@ class AsyncProcessor:
|
|||
default=8000,
|
||||
help=f'Pulsar host (default: 8000)',
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,23 +7,14 @@ fetching large document content.
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .flow_processor import FlowProcessor
|
||||
from .parameter_spec import ParameterSpec
|
||||
from .consumer import Consumer
|
||||
from .producer import Producer
|
||||
from .metrics import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ..schema import librarian_request_queue, librarian_response_queue
|
||||
from .librarian_client import LibrarianClient
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
|
||||
class ChunkingService(FlowProcessor):
|
||||
"""Base service for chunking processors with parameter specification support"""
|
||||
|
|
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
|
|||
ParameterSpec(name="chunk-overlap")
|
||||
)
|
||||
|
||||
# Librarian client for fetching document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
# Librarian client
|
||||
self.librarian = LibrarianClient(
|
||||
id=id,
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
logger.debug("ChunkingService initialized with parameter specifications")
|
||||
|
||||
async def start(self):
|
||||
await super(ChunkingService, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document content from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="chunk", title=None, timeout=120):
|
||||
"""
|
||||
Save a child document (chunk) to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the new child document
|
||||
parent_id: ID of the parent document
|
||||
user: User ID
|
||||
content: Document content (bytes or str)
|
||||
document_type: Type of document ("chunk", etc.)
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving chunk {doc_id}")
|
||||
await self.librarian.start()
|
||||
|
||||
async def get_document_text(self, doc):
|
||||
"""
|
||||
|
|
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
|
|||
"""
|
||||
if doc.document_id and not doc.text:
|
||||
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
||||
content = await self.fetch_document_content(
|
||||
text = await self.librarian.fetch_document_text(
|
||||
document_id=doc.document_id,
|
||||
user=doc.metadata.user,
|
||||
)
|
||||
# Content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
text = base64.b64decode(content).decode("utf-8")
|
||||
logger.info(f"Fetched {len(text)} characters from librarian")
|
||||
return text
|
||||
else:
|
||||
|
|
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
|
|||
Extract chunk parameters from flow and return effective values
|
||||
|
||||
Args:
|
||||
msg: The message containing the document to chunk
|
||||
consumer: The consumer spec
|
||||
flow: The flow context
|
||||
default_chunk_size: Default chunk size from processor config
|
||||
default_chunk_overlap: Default chunk overlap from processor config
|
||||
msg: The message being processed
|
||||
consumer: The consumer instance
|
||||
flow: The flow object containing parameters
|
||||
default_chunk_size: Default chunk size if not configured
|
||||
default_chunk_overlap: Default chunk overlap if not configured
|
||||
|
||||
Returns:
|
||||
tuple: (chunk_size, chunk_overlap) - effective values to use
|
||||
tuple: (chunk_size, chunk_overlap) effective values
|
||||
"""
|
||||
# Extract parameters from flow (flow-configurable parameters)
|
||||
chunk_size = flow("chunk-size")
|
||||
chunk_overlap = flow("chunk-overlap")
|
||||
|
||||
# Use provided values or fall back to defaults
|
||||
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
|
||||
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
|
||||
chunk_size = default_chunk_size
|
||||
chunk_overlap = default_chunk_overlap
|
||||
|
||||
logger.debug(f"Using chunk-size: {effective_chunk_size}")
|
||||
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
|
||||
try:
|
||||
cs = flow.parameters.get("chunk-size")
|
||||
if cs is not None:
|
||||
chunk_size = int(cs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse chunk-size parameter: {e}")
|
||||
|
||||
return effective_chunk_size, effective_chunk_overlap
|
||||
try:
|
||||
co = flow.parameters.get("chunk-overlap")
|
||||
if co is not None:
|
||||
chunk_overlap = int(co)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add chunking service arguments to parser"""
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
||||
)
|
||||
return chunk_size, chunk_overlap
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .. exceptions import TooManyRequests
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ class Consumer:
|
|||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||
reconnect_time = 5,
|
||||
concurrency = 1, # Number of concurrent requests to handle
|
||||
consumer_type = 'shared',
|
||||
):
|
||||
|
||||
self.taskgroup = taskgroup
|
||||
|
|
@ -42,6 +44,8 @@ class Consumer:
|
|||
self.schema = schema
|
||||
self.handler = handler
|
||||
|
||||
self.consumer_type = consumer_type
|
||||
|
||||
self.rate_limit_retry_time = rate_limit_retry_time
|
||||
self.rate_limit_timeout = rate_limit_timeout
|
||||
|
||||
|
|
@ -93,33 +97,11 @@ class Consumer:
|
|||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
||||
try:
|
||||
|
||||
logger.info(f"Subscribing to topic: {self.topic}")
|
||||
|
||||
# Determine initial position
|
||||
if self.start_of_messages:
|
||||
initial_pos = 'earliest'
|
||||
else:
|
||||
initial_pos = 'latest'
|
||||
|
||||
# Create consumer via backend
|
||||
self.consumer = await asyncio.to_thread(
|
||||
self.backend.create_consumer,
|
||||
topic = self.topic,
|
||||
subscription = self.subscriber,
|
||||
schema = self.schema,
|
||||
initial_position = initial_pos,
|
||||
consumer_type = 'shared',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
|
||||
await asyncio.sleep(self.reconnect_time)
|
||||
continue
|
||||
|
||||
logger.info(f"Successfully subscribed to topic: {self.topic}")
|
||||
# Determine initial position
|
||||
if self.start_of_messages:
|
||||
initial_pos = 'earliest'
|
||||
else:
|
||||
initial_pos = 'latest'
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("running")
|
||||
|
|
@ -128,14 +110,38 @@ class Consumer:
|
|||
|
||||
logger.info(f"Starting {self.concurrency} receiver threads")
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
tasks = []
|
||||
|
||||
for i in range(0, self.concurrency):
|
||||
tasks.append(
|
||||
tg.create_task(self.consume_from_queue())
|
||||
# Create one backend consumer per concurrent task.
|
||||
# Each gets its own connection and dedicated thread —
|
||||
# required for backends like RabbitMQ where connections
|
||||
# are not thread-safe (pika BlockingConnection must be
|
||||
# used from a single thread).
|
||||
consumers = []
|
||||
executors = []
|
||||
for i in range(self.concurrency):
|
||||
try:
|
||||
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
loop = asyncio.get_event_loop()
|
||||
c = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.backend.create_consumer(
|
||||
topic = self.topic,
|
||||
subscription = self.subscriber,
|
||||
schema = self.schema,
|
||||
initial_position = initial_pos,
|
||||
consumer_type = self.consumer_type,
|
||||
),
|
||||
)
|
||||
consumers.append(c)
|
||||
executors.append(executor)
|
||||
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for c, ex in zip(consumers, executors):
|
||||
tg.create_task(self.consume_from_queue(c, ex))
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
|
@ -143,24 +149,38 @@ class Consumer:
|
|||
except Exception as e:
|
||||
|
||||
logger.error(f"Consumer loop exception: {e}", exc_info=True)
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
self.consumer = None
|
||||
for c in consumers:
|
||||
try:
|
||||
c.unsubscribe()
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
for ex in executors:
|
||||
ex.shutdown(wait=False)
|
||||
consumers = []
|
||||
executors = []
|
||||
await asyncio.sleep(self.reconnect_time)
|
||||
continue
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
finally:
|
||||
for c in consumers:
|
||||
try:
|
||||
c.unsubscribe()
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
for ex in executors:
|
||||
ex.shutdown(wait=False)
|
||||
|
||||
async def consume_from_queue(self):
|
||||
async def consume_from_queue(self, consumer, executor=None):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
msg = await asyncio.to_thread(
|
||||
self.consumer.receive,
|
||||
timeout_millis=100
|
||||
msg = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: consumer.receive(timeout_millis=100),
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle timeout from any backend
|
||||
|
|
@ -168,10 +188,11 @@ class Consumer:
|
|||
continue
|
||||
raise e
|
||||
|
||||
await self.handle_one_from_queue(msg)
|
||||
await self.handle_one_from_queue(msg, consumer, executor)
|
||||
|
||||
async def handle_one_from_queue(self, msg):
|
||||
async def handle_one_from_queue(self, msg, consumer, executor=None):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
expiry = time.time() + self.rate_limit_timeout
|
||||
|
||||
# This loop is for retry on rate-limit / resource limits
|
||||
|
|
@ -182,8 +203,11 @@ class Consumer:
|
|||
logger.warning("Gave up waiting for rate-limit retry")
|
||||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
# be retried. Ack on the consumer's dedicated thread
|
||||
# (pika is not thread-safe).
|
||||
await loop.run_in_executor(
|
||||
executor, lambda: consumer.negative_acknowledge(msg)
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
|
@ -205,8 +229,11 @@ class Consumer:
|
|||
|
||||
logger.debug("Message processed successfully")
|
||||
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
# Acknowledge on the consumer's dedicated thread
|
||||
# (pika is not thread-safe)
|
||||
await loop.run_in_executor(
|
||||
executor, lambda: consumer.acknowledge(msg)
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("success")
|
||||
|
|
@ -232,8 +259,10 @@ class Consumer:
|
|||
logger.error(f"Message processing exception: {e}", exc_info=True)
|
||||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
# be retried. Ack on the consumer's dedicated thread.
|
||||
await loop.run_in_executor(
|
||||
executor, lambda: consumer.negative_acknowledge(msg)
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .. schema import Error
|
||||
from .. schema import config_request_queue, config_response_queue
|
||||
from .. schema import config_push_queue
|
||||
|
|
@ -28,7 +26,9 @@ class FlowProcessor(AsyncProcessor):
|
|||
super(FlowProcessor, self).__init__(**params)
|
||||
|
||||
# Register configuration handler
|
||||
self.register_config_handler(self.on_configure_flows)
|
||||
self.register_config_handler(
|
||||
self.on_configure_flows, types=["active-flow"]
|
||||
)
|
||||
|
||||
# Initialise flow information state
|
||||
self.flows = {}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .. schema import GraphRagQuery, GraphRagResponse
|
|||
class GraphRagClient(RequestResponse):
|
||||
async def rag(self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None, explain_callback=None,
|
||||
parent_uri="",
|
||||
timeout=600):
|
||||
"""
|
||||
Execute a graph RAG query with optional streaming callbacks.
|
||||
|
|
@ -50,6 +51,7 @@ class GraphRagClient(RequestResponse):
|
|||
query = query,
|
||||
user = user,
|
||||
collection = collection,
|
||||
parent_uri = parent_uri,
|
||||
),
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
|
|
|
|||
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
Shared librarian client for services that need to communicate
|
||||
with the librarian via pub/sub.
|
||||
|
||||
Provides request-response and streaming operations over the message
|
||||
broker, with proper support for large documents via stream-document.
|
||||
|
||||
Usage:
|
||||
self.librarian = LibrarianClient(
|
||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
||||
)
|
||||
await self.librarian.start()
|
||||
content = await self.librarian.fetch_document_content(doc_id, user)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .consumer import Consumer
|
||||
from .producer import Producer
|
||||
from .metrics import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ..schema import librarian_request_queue, librarian_response_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LibrarianClient:
|
||||
"""Client for librarian request-response over the message broker."""
|
||||
|
||||
def __init__(self, id, backend, taskgroup, **params):
|
||||
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", librarian_request_queue,
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", librarian_response_queue,
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request",
|
||||
)
|
||||
|
||||
self._producer = Producer(
|
||||
backend=backend,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response",
|
||||
)
|
||||
|
||||
self._consumer = Consumer(
|
||||
taskgroup=taskgroup,
|
||||
backend=backend,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self._on_response,
|
||||
metrics=librarian_response_metrics,
|
||||
consumer_type='exclusive',
|
||||
)
|
||||
|
||||
# Single-response requests: request_id -> asyncio.Future
|
||||
self._pending = {}
|
||||
# Streaming requests: request_id -> asyncio.Queue
|
||||
self._streams = {}
|
||||
|
||||
async def start(self):
|
||||
"""Start the librarian producer and consumer."""
|
||||
await self._producer.start()
|
||||
await self._consumer.start()
|
||||
|
||||
async def _on_response(self, msg, consumer, flow):
|
||||
"""Route librarian responses to the right waiter."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if not request_id:
|
||||
return
|
||||
|
||||
if request_id in self._pending:
|
||||
future = self._pending.pop(request_id)
|
||||
future.set_result(response)
|
||||
elif request_id in self._streams:
|
||||
await self._streams[request_id].put(response)
|
||||
|
||||
async def request(self, request, timeout=120):
|
||||
"""Send a request to the librarian and wait for a single response."""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._pending[request_id] = future
|
||||
|
||||
try:
|
||||
await self._producer.send(
|
||||
request, properties={"id": request_id},
|
||||
)
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: "
|
||||
f"{response.error.message}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending.pop(request_id, None)
|
||||
raise RuntimeError("Timeout waiting for librarian response")
|
||||
|
||||
async def stream(self, request, timeout=120):
|
||||
"""Send a request and collect streamed response chunks."""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
q = asyncio.Queue()
|
||||
self._streams[request_id] = q
|
||||
|
||||
try:
|
||||
await self._producer.send(
|
||||
request, properties={"id": request_id},
|
||||
)
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
response = await asyncio.wait_for(q.get(), timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: "
|
||||
f"{response.error.message}"
|
||||
)
|
||||
|
||||
chunks.append(response)
|
||||
|
||||
if response.is_final:
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._streams.pop(request_id, None)
|
||||
raise RuntimeError("Timeout waiting for librarian stream")
|
||||
finally:
|
||||
self._streams.pop(request_id, None)
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""Fetch document content using streaming.
|
||||
|
||||
Returns base64-encoded content. Caller is responsible for decoding.
|
||||
"""
|
||||
req = LibrarianRequest(
|
||||
operation="stream-document",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
chunks = await self.stream(req, timeout=timeout)
|
||||
|
||||
# Decode each chunk's base64 to raw bytes, concatenate,
|
||||
# re-encode for the caller.
|
||||
raw = b""
|
||||
for chunk in chunks:
|
||||
if chunk.content:
|
||||
if isinstance(chunk.content, bytes):
|
||||
raw += base64.b64decode(chunk.content)
|
||||
else:
|
||||
raw += base64.b64decode(
|
||||
chunk.content.encode("utf-8")
|
||||
)
|
||||
|
||||
return base64.b64encode(raw)
|
||||
|
||||
async def fetch_document_text(self, document_id, user, timeout=120):
|
||||
"""Fetch document content and decode as UTF-8 text."""
|
||||
content = await self.fetch_document_content(
|
||||
document_id, user, timeout=timeout,
|
||||
)
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
"""Fetch document metadata from the librarian."""
|
||||
req = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
response = await self.request(req, timeout=timeout)
|
||||
return response.document_metadata
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="chunk", title=None,
|
||||
kind="text/plain", timeout=120):
|
||||
"""Save a child document to the librarian."""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
req = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
await self.request(req, timeout=timeout)
|
||||
return doc_id
|
||||
|
||||
async def save_document(self, doc_id, user, content, title=None,
|
||||
document_type="answer", kind="text/plain",
|
||||
timeout=120):
|
||||
"""Save a document to the librarian."""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
req = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
await self.request(req, timeout=timeout)
|
||||
return doc_id
|
||||
|
|
@ -1,21 +1,16 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
logger.info(f"DEBUG prompt_client: prompt called, id={id}, streaming={streaming}, chunk_callback={chunk_callback is not None}")
|
||||
|
||||
if not streaming:
|
||||
logger.info("DEBUG prompt_client: Non-streaming path")
|
||||
# Non-streaming path
|
||||
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
|
|
@ -36,39 +31,30 @@ class PromptClient(RequestResponse):
|
|||
return json.loads(resp.object)
|
||||
|
||||
else:
|
||||
logger.info("DEBUG prompt_client: Streaming path")
|
||||
# Streaming path - just forward chunks, don't accumulate
|
||||
|
||||
last_text = ""
|
||||
last_object = None
|
||||
|
||||
async def forward_chunks(resp):
|
||||
nonlocal last_text, last_object
|
||||
logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||
|
||||
if resp.error:
|
||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
# Always call callback if there's text OR if it's the final message
|
||||
if resp.text is not None:
|
||||
last_text = resp.text
|
||||
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||
if chunk_callback:
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
logger.info(f"DEBUG prompt_client: Got object response")
|
||||
last_object = resp.object
|
||||
|
||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||
return end_stream
|
||||
|
||||
logger.info("DEBUG prompt_client: Creating PromptRequest")
|
||||
req = PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
|
|
@ -77,19 +63,16 @@ class PromptClient(RequestResponse):
|
|||
},
|
||||
streaming = True
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
||||
|
||||
await self.request(
|
||||
req,
|
||||
recipient=forward_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
|
||||
|
||||
if last_text:
|
||||
logger.info("DEBUG prompt_client: Returning last_text")
|
||||
return last_text
|
||||
|
||||
logger.info("DEBUG prompt_client: Returning parsed last_object")
|
||||
return json.loads(last_object) if last_object else None
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
|
|
|
|||
|
|
@ -1,110 +1,121 @@
|
|||
|
||||
import os
|
||||
import pulsar
|
||||
import _pulsar
|
||||
import uuid
|
||||
from pulsar.schema import JsonSchema
|
||||
import logging
|
||||
|
||||
from .. log_level import LogLevel
|
||||
from .pulsar_backend import PulsarBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default connection settings from environment
|
||||
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
|
||||
|
||||
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
|
||||
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
|
||||
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
|
||||
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
|
||||
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
|
||||
|
||||
|
||||
def get_pubsub(**config):
|
||||
"""
|
||||
Factory function to create a pub/sub backend based on configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary from command-line args
|
||||
Must include 'pubsub_backend' key
|
||||
config: Configuration dictionary from command-line args.
|
||||
Key 'pubsub_backend' selects the backend (default: 'pulsar').
|
||||
|
||||
Returns:
|
||||
Backend instance (PulsarBackend, MQTTBackend, etc.)
|
||||
|
||||
Example:
|
||||
backend = get_pubsub(
|
||||
pubsub_backend='pulsar',
|
||||
pulsar_host='pulsar://localhost:6650'
|
||||
)
|
||||
Backend instance implementing the PubSubBackend protocol.
|
||||
"""
|
||||
backend_type = config.get('pubsub_backend', 'pulsar')
|
||||
|
||||
if backend_type == 'pulsar':
|
||||
from .pulsar_backend import PulsarBackend
|
||||
return PulsarBackend(
|
||||
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
|
||||
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
|
||||
host=config.get('pulsar_host', DEFAULT_PULSAR_HOST),
|
||||
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
|
||||
listener=config.get('pulsar_listener'),
|
||||
)
|
||||
elif backend_type == 'mqtt':
|
||||
# TODO: Implement MQTT backend
|
||||
raise NotImplementedError("MQTT backend not yet implemented")
|
||||
elif backend_type == 'rabbitmq':
|
||||
from .rabbitmq_backend import RabbitMQBackend
|
||||
return RabbitMQBackend(
|
||||
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
|
||||
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
|
||||
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
|
||||
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
|
||||
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||
|
||||
|
||||
class PulsarClient:
|
||||
STANDALONE_PULSAR_HOST = 'pulsar://localhost:6650'
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
|
||||
|
||||
def __init__(self, **params):
|
||||
def add_pubsub_args(parser, standalone=False):
|
||||
"""Add pub/sub CLI arguments to an argument parser.
|
||||
|
||||
self.client = None
|
||||
Args:
|
||||
parser: argparse.ArgumentParser
|
||||
standalone: If True, default host is localhost (for CLI tools
|
||||
that run outside containers)
|
||||
"""
|
||||
pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
|
||||
pulsar_listener = 'localhost' if standalone else None
|
||||
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
|
||||
|
||||
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
|
||||
pulsar_listener = params.get("pulsar_listener", None)
|
||||
pulsar_api_key = params.get(
|
||||
"pulsar_api_key",
|
||||
self.default_pulsar_api_key
|
||||
)
|
||||
# Hard-code Pulsar logging to ERROR level to minimize noise
|
||||
parser.add_argument(
|
||||
'--pubsub-backend',
|
||||
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||
)
|
||||
|
||||
self.pulsar_host = pulsar_host
|
||||
self.pulsar_api_key = pulsar_api_key
|
||||
# Pulsar options
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=pulsar_host,
|
||||
help=f'Pulsar host (default: {pulsar_host})',
|
||||
)
|
||||
|
||||
if pulsar_api_key:
|
||||
auth = pulsar.AuthenticationToken(pulsar_api_key)
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
authentication=auth,
|
||||
logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error)
|
||||
)
|
||||
else:
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
listener_name=pulsar_listener,
|
||||
logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error)
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=DEFAULT_PULSAR_API_KEY,
|
||||
help='Pulsar API key',
|
||||
)
|
||||
|
||||
self.pulsar_listener = pulsar_listener
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=pulsar_listener,
|
||||
help=f'Pulsar listener (default: {pulsar_listener or "none"})',
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
# RabbitMQ options
|
||||
parser.add_argument(
|
||||
'--rabbitmq-host',
|
||||
default=rabbitmq_host,
|
||||
help=f'RabbitMQ host (default: {rabbitmq_host})',
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
parser.add_argument(
|
||||
'--rabbitmq-port',
|
||||
type=int,
|
||||
default=DEFAULT_RABBITMQ_PORT,
|
||||
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
|
||||
)
|
||||
|
||||
if hasattr(self, "client"):
|
||||
if self.client:
|
||||
self.client.close()
|
||||
parser.add_argument(
|
||||
'--rabbitmq-username',
|
||||
default=DEFAULT_RABBITMQ_USERNAME,
|
||||
help='RabbitMQ username',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument(
|
||||
'--rabbitmq-password',
|
||||
default=DEFAULT_RABBITMQ_PASSWORD,
|
||||
help='RabbitMQ password',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=__class__.default_pulsar_host,
|
||||
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=__class__.default_pulsar_api_key,
|
||||
help=f'Pulsar API key',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
help=f'Pulsar listener (default: none)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rabbitmq-vhost',
|
||||
default=DEFAULT_RABBITMQ_VHOST,
|
||||
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,122 +9,14 @@ import pulsar
|
|||
import _pulsar
|
||||
import json
|
||||
import logging
|
||||
import base64
|
||||
import types
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, get_type_hints
|
||||
from typing import Any
|
||||
|
||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any) -> dict:
|
||||
"""
|
||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||
|
||||
None values are excluded from the dictionary (not serialized).
|
||||
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
|
||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
|
||||
# Handle dataclass - convert to dict then recursively process all values
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for key, value in asdict(obj).items():
|
||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||
return result
|
||||
|
||||
# Handle list - recursively process all items
|
||||
if isinstance(obj, list):
|
||||
return [dataclass_to_dict(item) for item in obj]
|
||||
|
||||
# Handle dict - recursively process all values
|
||||
if isinstance(obj, dict):
|
||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||
|
||||
# Return primitive types as-is
|
||||
return obj
|
||||
|
||||
|
||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||
"""
|
||||
Convert a dictionary back to a dataclass instance.
|
||||
|
||||
Handles nested dataclasses and missing fields.
|
||||
Uses get_type_hints() to resolve forward references (string annotations).
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if not is_dataclass(cls):
|
||||
return data
|
||||
|
||||
# Get field types from the dataclass, resolving forward references
|
||||
# get_type_hints() evaluates string annotations like "Triple | None"
|
||||
try:
|
||||
field_types = get_type_hints(cls)
|
||||
except Exception:
|
||||
# Fallback if get_type_hints fails (shouldn't happen normally)
|
||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in field_types:
|
||||
field_type = field_types[key]
|
||||
|
||||
# Handle modern union types (X | Y)
|
||||
if isinstance(field_type, types.UnionType):
|
||||
# Check if it's Optional (X | None)
|
||||
if type(None) in field_type.__args__:
|
||||
# Get the non-None type
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Check if this is a generic type (list, dict, etc.)
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
# Handle list[T]
|
||||
if field_type.__origin__ == list:
|
||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||
kwargs[key] = [
|
||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle old-style Optional[T] (which is Union[T, None])
|
||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||
# Get the non-None type from Union
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle direct dataclass fields
|
||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||
elif field_type == bytes and isinstance(value, str):
|
||||
kwargs[key] = value.encode('utf-8')
|
||||
else:
|
||||
kwargs[key] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class PulsarMessage:
|
||||
"""Wrapper for Pulsar messages to match Message protocol."""
|
||||
|
||||
|
|
@ -181,8 +73,11 @@ class PulsarBackendConsumer:
|
|||
self._schema_cls = schema_cls
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message."""
|
||||
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
|
||||
"""Receive a message. Raises TimeoutError if no message available."""
|
||||
try:
|
||||
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
|
||||
except _pulsar.Timeout:
|
||||
raise TimeoutError("No message received within timeout")
|
||||
return PulsarMessage(pulsar_msg, self._schema_cls)
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
|
|
@ -237,38 +132,44 @@ class PulsarBackend:
|
|||
self.client = pulsar.Client(**client_args)
|
||||
logger.info(f"Pulsar client connected to {host}")
|
||||
|
||||
def map_topic(self, generic_topic: str) -> str:
|
||||
def map_topic(self, queue_id: str) -> str:
|
||||
"""
|
||||
Map generic topic format to Pulsar URI.
|
||||
Map queue identifier to Pulsar URI.
|
||||
|
||||
Format: qos/tenant/namespace/queue
|
||||
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
|
||||
Format: class:topicspace:topic
|
||||
Example: flow:tg:text-completion-request -> persistent://tg/flow/text-completion-request
|
||||
|
||||
Args:
|
||||
generic_topic: Generic topic string or already-formatted Pulsar URI
|
||||
queue_id: Queue identifier string or already-formatted Pulsar URI
|
||||
|
||||
Returns:
|
||||
Pulsar topic URI
|
||||
"""
|
||||
# If already a Pulsar URI, return as-is
|
||||
if '://' in generic_topic:
|
||||
return generic_topic
|
||||
if '://' in queue_id:
|
||||
return queue_id
|
||||
|
||||
parts = generic_topic.split('/', 3)
|
||||
if len(parts) != 4:
|
||||
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
|
||||
parts = queue_id.split(':', 2)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"Invalid queue format: {queue_id}, "
|
||||
f"expected class:topicspace:topic"
|
||||
)
|
||||
|
||||
qos, tenant, namespace, queue = parts
|
||||
cls, topicspace, topic = parts
|
||||
|
||||
# Map QoS to persistence
|
||||
if qos == 'q0':
|
||||
persistence = 'non-persistent'
|
||||
elif qos in ['q1', 'q2']:
|
||||
# Map class to Pulsar persistence and namespace
|
||||
if cls in ('flow', 'state'):
|
||||
persistence = 'persistent'
|
||||
elif cls in ('request', 'response'):
|
||||
persistence = 'non-persistent'
|
||||
else:
|
||||
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
|
||||
raise ValueError(
|
||||
f"Invalid queue class: {cls}, "
|
||||
f"expected flow, request, response, or state"
|
||||
)
|
||||
|
||||
return f"{persistence}://{tenant}/{namespace}/{queue}"
|
||||
return f"{persistence}://{topicspace}/{cls}/{topic}"
|
||||
|
||||
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||
"""
|
||||
|
|
|
|||
391
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
391
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
"""
|
||||
RabbitMQ backend implementation for pub/sub abstraction.
|
||||
|
||||
Uses a single topic exchange per topicspace. The logical queue name
|
||||
becomes the routing key. Consumer behavior is determined by the
|
||||
subscription name:
|
||||
|
||||
- Same subscription + same topic = shared queue (competing consumers)
|
||||
- Different subscriptions = separate queues (broadcast / fan-out)
|
||||
|
||||
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
|
||||
|
||||
Architecture:
|
||||
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
|
||||
--routing key--> [named queue] --> Consumer
|
||||
--routing key--> [exclusive q] --> Subscriber
|
||||
|
||||
Uses basic_consume (push) instead of basic_get (polling) for
|
||||
efficient message delivery.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import pika
|
||||
from typing import Any
|
||||
|
||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RabbitMQMessage:
|
||||
"""Wrapper for RabbitMQ messages to match Message protocol."""
|
||||
|
||||
def __init__(self, method, properties, body, schema_cls):
|
||||
self._method = method
|
||||
self._properties = properties
|
||||
self._body = body
|
||||
self._schema_cls = schema_cls
|
||||
self._value = None
|
||||
|
||||
def value(self) -> Any:
|
||||
"""Deserialize and return the message value as a dataclass."""
|
||||
if self._value is None:
|
||||
data_dict = json.loads(self._body.decode('utf-8'))
|
||||
self._value = dict_to_dataclass(data_dict, self._schema_cls)
|
||||
return self._value
|
||||
|
||||
def properties(self) -> dict:
|
||||
"""Return message properties from AMQP headers."""
|
||||
headers = self._properties.headers or {}
|
||||
return dict(headers)
|
||||
|
||||
|
||||
class RabbitMQBackendProducer:
|
||||
"""Publishes messages to a topic exchange with a routing key.
|
||||
|
||||
Uses thread-local connections so each thread gets its own
|
||||
connection/channel. This avoids wire corruption from concurrent
|
||||
threads writing to the same socket (pika is not thread-safe).
|
||||
"""
|
||||
|
||||
def __init__(self, connection_params, exchange_name, routing_key,
|
||||
durable):
|
||||
self._connection_params = connection_params
|
||||
self._exchange_name = exchange_name
|
||||
self._routing_key = routing_key
|
||||
self._durable = durable
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_channel(self):
|
||||
"""Get or create a thread-local connection and channel."""
|
||||
conn = getattr(self._local, 'connection', None)
|
||||
chan = getattr(self._local, 'channel', None)
|
||||
|
||||
if conn is None or not conn.is_open or chan is None or not chan.is_open:
|
||||
# Close stale connection if any
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn = pika.BlockingConnection(self._connection_params)
|
||||
chan = conn.channel()
|
||||
chan.exchange_declare(
|
||||
exchange=self._exchange_name,
|
||||
exchange_type='topic',
|
||||
durable=True,
|
||||
)
|
||||
self._local.connection = conn
|
||||
self._local.channel = chan
|
||||
|
||||
return chan
|
||||
|
||||
def send(self, message: Any, properties: dict = {}) -> None:
|
||||
data_dict = dataclass_to_dict(message)
|
||||
json_data = json.dumps(data_dict)
|
||||
|
||||
amqp_properties = pika.BasicProperties(
|
||||
delivery_mode=2 if self._durable else 1,
|
||||
content_type='application/json',
|
||||
headers=properties if properties else None,
|
||||
)
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
channel = self._get_channel()
|
||||
channel.basic_publish(
|
||||
exchange=self._exchange_name,
|
||||
routing_key=self._routing_key,
|
||||
body=json_data.encode('utf-8'),
|
||||
properties=amqp_properties,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
|
||||
)
|
||||
# Force reconnect on next attempt
|
||||
self._local.connection = None
|
||||
self._local.channel = None
|
||||
if attempt == 1:
|
||||
raise
|
||||
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the thread-local connection if any."""
|
||||
conn = getattr(self._local, 'connection', None)
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._local.connection = None
|
||||
self._local.channel = None
|
||||
|
||||
|
||||
class RabbitMQBackendConsumer:
|
||||
"""Consumes from a queue bound to a topic exchange.
|
||||
|
||||
Uses basic_consume (push model) with messages delivered to an
|
||||
internal thread-safe queue. process_data_events() drives both
|
||||
message delivery and heartbeat processing.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_params, exchange_name, routing_key,
|
||||
queue_name, schema_cls, durable, exclusive=False,
|
||||
auto_delete=False):
|
||||
self._connection_params = connection_params
|
||||
self._exchange_name = exchange_name
|
||||
self._routing_key = routing_key
|
||||
self._queue_name = queue_name
|
||||
self._schema_cls = schema_cls
|
||||
self._durable = durable
|
||||
self._exclusive = exclusive
|
||||
self._auto_delete = auto_delete
|
||||
self._connection = None
|
||||
self._channel = None
|
||||
self._consumer_tag = None
|
||||
self._incoming = queue.Queue()
|
||||
|
||||
def _connect(self):
|
||||
self._connection = pika.BlockingConnection(self._connection_params)
|
||||
self._channel = self._connection.channel()
|
||||
|
||||
# Declare the topic exchange
|
||||
self._channel.exchange_declare(
|
||||
exchange=self._exchange_name,
|
||||
exchange_type='topic',
|
||||
durable=True,
|
||||
)
|
||||
|
||||
# Declare the queue — anonymous if exclusive
|
||||
result = self._channel.queue_declare(
|
||||
queue=self._queue_name,
|
||||
durable=self._durable,
|
||||
exclusive=self._exclusive,
|
||||
auto_delete=self._auto_delete,
|
||||
)
|
||||
# Capture actual name (important for anonymous queues where name='')
|
||||
self._queue_name = result.method.queue
|
||||
|
||||
self._channel.queue_bind(
|
||||
queue=self._queue_name,
|
||||
exchange=self._exchange_name,
|
||||
routing_key=self._routing_key,
|
||||
)
|
||||
|
||||
self._channel.basic_qos(prefetch_count=1)
|
||||
|
||||
# Register push-based consumer
|
||||
self._consumer_tag = self._channel.basic_consume(
|
||||
queue=self._queue_name,
|
||||
on_message_callback=self._on_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
|
||||
def _on_message(self, channel, method, properties, body):
|
||||
"""Callback invoked by pika when a message arrives."""
|
||||
self._incoming.put((method, properties, body))
|
||||
|
||||
def _is_alive(self):
|
||||
return (
|
||||
self._connection is not None
|
||||
and self._connection.is_open
|
||||
and self._channel is not None
|
||||
and self._channel.is_open
|
||||
)
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message. Raises TimeoutError if none available."""
|
||||
if not self._is_alive():
|
||||
self._connect()
|
||||
|
||||
timeout_seconds = timeout_millis / 1000.0
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
# Check if a message was already delivered
|
||||
try:
|
||||
method, properties, body = self._incoming.get_nowait()
|
||||
return RabbitMQMessage(
|
||||
method, properties, body, self._schema_cls,
|
||||
)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# Drive pika's I/O — delivers messages and processes heartbeats
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining > 0:
|
||||
self._connection.process_data_events(
|
||||
time_limit=min(0.1, remaining),
|
||||
)
|
||||
|
||||
raise TimeoutError("No message received within timeout")
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
if isinstance(message, RabbitMQMessage) and message._method:
|
||||
self._channel.basic_ack(
|
||||
delivery_tag=message._method.delivery_tag,
|
||||
)
|
||||
|
||||
def negative_acknowledge(self, message: Message) -> None:
|
||||
if isinstance(message, RabbitMQMessage) and message._method:
|
||||
self._channel.basic_nack(
|
||||
delivery_tag=message._method.delivery_tag,
|
||||
requeue=True,
|
||||
)
|
||||
|
||||
def unsubscribe(self) -> None:
|
||||
if self._consumer_tag and self._channel and self._channel.is_open:
|
||||
try:
|
||||
self._channel.basic_cancel(self._consumer_tag)
|
||||
except Exception:
|
||||
pass
|
||||
self._consumer_tag = None
|
||||
|
||||
def close(self) -> None:
|
||||
self.unsubscribe()
|
||||
try:
|
||||
if self._channel and self._channel.is_open:
|
||||
self._channel.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._connection and self._connection.is_open:
|
||||
self._connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._channel = None
|
||||
self._connection = None
|
||||
|
||||
|
||||
class RabbitMQBackend:
|
||||
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
|
||||
|
||||
def __init__(self, host='localhost', port=5672, username='guest',
|
||||
password='guest', vhost='/'):
|
||||
self._connection_params = pika.ConnectionParameters(
|
||||
host=host,
|
||||
port=port,
|
||||
virtual_host=vhost,
|
||||
credentials=pika.PlainCredentials(username, password),
|
||||
heartbeat=0,
|
||||
)
|
||||
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
|
||||
|
||||
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
|
||||
"""
|
||||
Parse queue identifier into exchange, routing key, and durability.
|
||||
|
||||
Format: class:topicspace:topic
|
||||
Returns: (exchange_name, routing_key, class, durable)
|
||||
"""
|
||||
if ':' not in queue_id:
|
||||
return 'tg', queue_id, 'flow', False
|
||||
|
||||
parts = queue_id.split(':', 2)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"Invalid queue format: {queue_id}, "
|
||||
f"expected class:topicspace:topic"
|
||||
)
|
||||
|
||||
cls, topicspace, topic = parts
|
||||
|
||||
if cls in ('flow', 'state'):
|
||||
durable = True
|
||||
elif cls in ('request', 'response'):
|
||||
durable = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid queue class: {cls}, "
|
||||
f"expected flow, request, response, or state"
|
||||
)
|
||||
|
||||
# Exchange per topicspace, routing key includes class
|
||||
exchange_name = topicspace
|
||||
routing_key = f"{cls}.{topic}"
|
||||
|
||||
return exchange_name, routing_key, cls, durable
|
||||
|
||||
# Keep map_queue_name for backward compatibility with tests
|
||||
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
|
||||
return f"{exchange}.{routing_key}", durable
|
||||
|
||||
def create_producer(self, topic: str, schema: type,
|
||||
**options) -> BackendProducer:
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||
logger.debug(
|
||||
f"Creating producer: exchange={exchange}, "
|
||||
f"routing_key={routing_key}"
|
||||
)
|
||||
return RabbitMQBackendProducer(
|
||||
self._connection_params, exchange, routing_key, durable,
|
||||
)
|
||||
|
||||
def create_consumer(self, topic: str, subscription: str, schema: type,
|
||||
initial_position: str = 'latest',
|
||||
consumer_type: str = 'shared',
|
||||
**options) -> BackendConsumer:
|
||||
"""Create a consumer with a queue bound to the topic exchange.
|
||||
|
||||
consumer_type='shared': Named durable queue. Multiple consumers
|
||||
with the same subscription compete (round-robin).
|
||||
consumer_type='exclusive': Anonymous ephemeral queue. Each
|
||||
consumer gets its own copy of every message (broadcast).
|
||||
"""
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||
|
||||
if consumer_type == 'exclusive' and cls == 'state':
|
||||
# State broadcast: named durable queue per subscriber.
|
||||
# Retains messages so late-starting processors see current state.
|
||||
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||
queue_durable = True
|
||||
exclusive = False
|
||||
auto_delete = False
|
||||
elif consumer_type == 'exclusive':
|
||||
# Broadcast: anonymous queue, auto-deleted on disconnect
|
||||
queue_name = ''
|
||||
queue_durable = False
|
||||
exclusive = True
|
||||
auto_delete = True
|
||||
else:
|
||||
# Shared: named queue, competing consumers
|
||||
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||
queue_durable = durable
|
||||
exclusive = False
|
||||
auto_delete = False
|
||||
|
||||
logger.debug(
|
||||
f"Creating consumer: exchange={exchange}, "
|
||||
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
|
||||
f"type={consumer_type}"
|
||||
)
|
||||
|
||||
return RabbitMQBackendConsumer(
|
||||
self._connection_params, exchange, routing_key,
|
||||
queue_name, schema, queue_durable, exclusive, auto_delete,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
JSON serialization helpers for dataclass ↔ dict conversion.
|
||||
|
||||
Used by pub/sub backends that use JSON as their wire format.
|
||||
"""
|
||||
|
||||
import types
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any) -> dict:
|
||||
"""
|
||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||
|
||||
None values are excluded from the dictionary (not serialized).
|
||||
Bytes values are decoded as UTF-8 strings for JSON serialization.
|
||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
|
||||
# Handle dataclass - convert to dict then recursively process all values
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for key, value in asdict(obj).items():
|
||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||
return result
|
||||
|
||||
# Handle list - recursively process all items
|
||||
if isinstance(obj, list):
|
||||
return [dataclass_to_dict(item) for item in obj]
|
||||
|
||||
# Handle dict - recursively process all values
|
||||
if isinstance(obj, dict):
|
||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||
|
||||
# Return primitive types as-is
|
||||
return obj
|
||||
|
||||
|
||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||
"""
|
||||
Convert a dictionary back to a dataclass instance.
|
||||
|
||||
Handles nested dataclasses and missing fields.
|
||||
Uses get_type_hints() to resolve forward references (string annotations).
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if not is_dataclass(cls):
|
||||
return data
|
||||
|
||||
# Get field types from the dataclass, resolving forward references
|
||||
# get_type_hints() evaluates string annotations like "Triple | None"
|
||||
try:
|
||||
field_types = get_type_hints(cls)
|
||||
except Exception:
|
||||
# Fallback if get_type_hints fails (shouldn't happen normally)
|
||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in field_types:
|
||||
field_type = field_types[key]
|
||||
|
||||
# Handle modern union types (X | Y)
|
||||
if isinstance(field_type, types.UnionType):
|
||||
# Check if it's Optional (X | None)
|
||||
if type(None) in field_type.__args__:
|
||||
# Get the non-None type
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Check if this is a generic type (list, dict, etc.)
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
# Handle list[T]
|
||||
if field_type.__origin__ == list:
|
||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||
kwargs[key] = [
|
||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle old-style Optional[T] (which is Union[T, None])
|
||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||
# Get the non-None type from Union
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle direct dataclass fields
|
||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||
elif field_type == bytes and isinstance(value, str):
|
||||
kwargs[key] = value.encode('utf-8')
|
||||
else:
|
||||
kwargs[key] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
|
@ -7,6 +7,7 @@ import asyncio
|
|||
import time
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -38,6 +39,7 @@ class Subscriber:
|
|||
self.pending_acks = {} # Track messages awaiting delivery
|
||||
|
||||
self.consumer = None
|
||||
self.executor = None
|
||||
|
||||
def __del__(self):
|
||||
|
||||
|
|
@ -45,15 +47,6 @@ class Subscriber:
|
|||
|
||||
async def start(self):
|
||||
|
||||
# Create consumer via backend
|
||||
self.consumer = await asyncio.to_thread(
|
||||
self.backend.create_consumer,
|
||||
topic=self.topic,
|
||||
subscription=self.subscription,
|
||||
schema=self.schema,
|
||||
consumer_type='shared',
|
||||
)
|
||||
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
|
|
@ -80,6 +73,21 @@ class Subscriber:
|
|||
|
||||
try:
|
||||
|
||||
# Create consumer and dedicated thread if needed
|
||||
# (first run or after failure)
|
||||
if self.consumer is None:
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
loop = asyncio.get_event_loop()
|
||||
self.consumer = await loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.backend.create_consumer(
|
||||
topic=self.topic,
|
||||
subscription=self.subscription,
|
||||
schema=self.schema,
|
||||
consumer_type='exclusive',
|
||||
),
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("running")
|
||||
|
||||
|
|
@ -128,9 +136,12 @@ class Subscriber:
|
|||
# Process messages only if not draining
|
||||
if not self.draining:
|
||||
try:
|
||||
msg = await asyncio.to_thread(
|
||||
self.consumer.receive,
|
||||
timeout_millis=250
|
||||
loop = asyncio.get_event_loop()
|
||||
msg = await loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.consumer.receive(
|
||||
timeout_millis=250
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle timeout from any backend
|
||||
|
|
@ -172,15 +183,18 @@ class Subscriber:
|
|||
except Exception:
|
||||
pass # Already closed or error
|
||||
self.consumer = None
|
||||
|
||||
|
||||
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
self.executor = None
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
||||
if not self.running and not self.draining:
|
||||
return
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
|
||||
# Sleep before retry
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def subscribe(self, id):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import AgentRequest, AgentResponse
|
||||
from .. schema import agent_request_queue
|
||||
|
|
@ -7,15 +6,11 @@ from .. schema import agent_response_queue
|
|||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class AgentClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
self,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
|
|
@ -27,7 +22,6 @@ class AgentClient(BaseClient):
|
|||
if output_queue is None: output_queue = agent_response_queue
|
||||
|
||||
super(AgentClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue