mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
parent
2bcf375103
commit
3ba6a3238f
6 changed files with 970 additions and 0 deletions
319
dev-tools/tests/agent_dag/analyse_trace.py
Normal file
319
dev-tools/tests/agent_dag/analyse_trace.py
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyse a captured agent trace JSON file and check DAG integrity.
|
||||
|
||||
Usage:
|
||||
python analyse_trace.py react.json
|
||||
python analyse_trace.py -u http://localhost:8088/ react.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import websockets
|
||||
|
||||
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
DEFAULT_USER = "trustgraph"
|
||||
DEFAULT_COLLECTION = "default"
|
||||
DEFAULT_FLOW = "default"
|
||||
GRAPH = "urn:graph:retrieval"
|
||||
|
||||
# Namespace prefixes
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
RDF_TYPE = RDF + "type"
|
||||
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_TOOL_USE = TG + "ToolUse"
|
||||
TG_OBSERVATION_TYPE = TG + "Observation"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_QUESTION = TG + "Question"
|
||||
|
||||
|
||||
def shorten(uri):
|
||||
"""Shorten a URI for display."""
|
||||
for prefix, short in [
|
||||
(PROV, "prov:"), (RDF, "rdf:"), (RDFS, "rdfs:"), (TG, "tg:"),
|
||||
]:
|
||||
if isinstance(uri, str) and uri.startswith(prefix):
|
||||
return short + uri[len(prefix):]
|
||||
return str(uri)
|
||||
|
||||
|
||||
async def fetch_triples(ws, flow, subject, user, collection, request_counter):
|
||||
"""Query triples for a given subject URI."""
|
||||
request_counter[0] += 1
|
||||
req_id = f"q-{request_counter[0]}"
|
||||
|
||||
msg = {
|
||||
"id": req_id,
|
||||
"service": "triples",
|
||||
"flow": flow,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": subject},
|
||||
"g": GRAPH,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": 100,
|
||||
},
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(msg))
|
||||
|
||||
while True:
|
||||
raw = await ws.recv()
|
||||
resp = json.loads(raw)
|
||||
if resp.get("id") == req_id:
|
||||
inner = resp.get("response", {})
|
||||
if isinstance(inner, dict):
|
||||
return inner.get("response", [])
|
||||
return inner
|
||||
|
||||
|
||||
def extract_term(term):
|
||||
"""Extract value from wire-format term."""
|
||||
if not term:
|
||||
return ""
|
||||
t = term.get("t", "")
|
||||
if t == "i":
|
||||
return term.get("i", "")
|
||||
elif t == "l":
|
||||
return term.get("v", "")
|
||||
elif t == "t":
|
||||
tr = term.get("tr", {})
|
||||
return {
|
||||
"s": extract_term(tr.get("s", {})),
|
||||
"p": extract_term(tr.get("p", {})),
|
||||
"o": extract_term(tr.get("o", {})),
|
||||
}
|
||||
return str(term)
|
||||
|
||||
|
||||
def parse_triples(wire_triples):
|
||||
"""Convert wire triples to (s, p, o) tuples."""
|
||||
result = []
|
||||
for t in wire_triples:
|
||||
s = extract_term(t.get("s", {}))
|
||||
p = extract_term(t.get("p", {}))
|
||||
o = extract_term(t.get("o", {}))
|
||||
result.append((s, p, o))
|
||||
return result
|
||||
|
||||
|
||||
def get_types(tuples):
|
||||
"""Get rdf:type values from parsed triples."""
|
||||
return {o for s, p, o in tuples if p == RDF_TYPE}
|
||||
|
||||
|
||||
def get_derived_from(tuples):
|
||||
"""Get prov:wasDerivedFrom targets from parsed triples."""
|
||||
return [o for s, p, o in tuples if p == PROV_WAS_DERIVED_FROM]
|
||||
|
||||
|
||||
async def analyse(path, url, flow, user, collection):
|
||||
with open(path) as f:
|
||||
messages = json.load(f)
|
||||
|
||||
print(f"Total messages: {len(messages)}")
|
||||
print()
|
||||
|
||||
# ---- Pass 1: collect explain IDs and check streaming chunks ----
|
||||
|
||||
explain_ids = []
|
||||
errors = []
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
resp = msg.get("response", {})
|
||||
chunk_type = resp.get("chunk_type", "?")
|
||||
|
||||
if chunk_type == "explain":
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_ids.append(explain_id)
|
||||
print(f" {i:3d} {chunk_type} {explain_id}")
|
||||
else:
|
||||
print(f" {i:3d} {chunk_type}")
|
||||
|
||||
# Rule 7: message_id on content chunks
|
||||
if chunk_type in ("thought", "observation", "answer"):
|
||||
mid = resp.get("message_id", "")
|
||||
if not mid:
|
||||
errors.append(
|
||||
f"[msg {i}] {chunk_type} chunk missing message_id"
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Explain IDs ({len(explain_ids)}):")
|
||||
for eid in explain_ids:
|
||||
print(f" {eid}")
|
||||
|
||||
# ---- Pass 2: fetch triples for each explain ID ----
|
||||
|
||||
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
|
||||
request_counter = [0]
|
||||
# entity_id -> parsed triples [(s, p, o), ...]
|
||||
entities = {}
|
||||
|
||||
print()
|
||||
print("Fetching triples...")
|
||||
print()
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as ws:
|
||||
for eid in explain_ids:
|
||||
wire = await fetch_triples(
|
||||
ws, flow, eid, user, collection, request_counter,
|
||||
)
|
||||
|
||||
tuples = parse_triples(wire) if isinstance(wire, list) else []
|
||||
entities[eid] = tuples
|
||||
|
||||
print(f" {eid}")
|
||||
for s, p, o in tuples:
|
||||
o_short = str(o)
|
||||
if len(o_short) > 80:
|
||||
o_short = o_short[:77] + "..."
|
||||
print(f" {shorten(p)} = {o_short}")
|
||||
print()
|
||||
|
||||
# ---- Pass 3: check rules ----
|
||||
|
||||
all_ids = set(entities.keys())
|
||||
|
||||
# Collect entity metadata
|
||||
roots = [] # entities with no wasDerivedFrom
|
||||
conclusions = [] # tg:Conclusion entities
|
||||
analyses = [] # tg:Analysis entities
|
||||
observations = [] # tg:Observation entities
|
||||
|
||||
for eid, tuples in entities.items():
|
||||
types = get_types(tuples)
|
||||
parents = get_derived_from(tuples)
|
||||
|
||||
if not tuples:
|
||||
errors.append(f"[{eid}] entity has no triples in store")
|
||||
|
||||
if not parents:
|
||||
roots.append(eid)
|
||||
|
||||
if TG_CONCLUSION in types:
|
||||
conclusions.append(eid)
|
||||
if TG_ANALYSIS in types:
|
||||
analyses.append(eid)
|
||||
if TG_OBSERVATION_TYPE in types:
|
||||
observations.append(eid)
|
||||
|
||||
# Rule 4: every non-root entity has wasDerivedFrom
|
||||
if parents:
|
||||
for parent in parents:
|
||||
# Rule 5: parent exists in known entities
|
||||
if parent not in all_ids:
|
||||
errors.append(
|
||||
f"[{eid}] wasDerivedFrom target not in explain set: "
|
||||
f"{parent}"
|
||||
)
|
||||
|
||||
# Rule 6: Analysis entities must have ToolUse type
|
||||
if TG_ANALYSIS in types and TG_TOOL_USE not in types:
|
||||
errors.append(
|
||||
f"[{eid}] Analysis entity missing tg:ToolUse type"
|
||||
)
|
||||
|
||||
# Rule 1: exactly one root
|
||||
if len(roots) == 0:
|
||||
errors.append("No root entity found (all have wasDerivedFrom)")
|
||||
elif len(roots) > 1:
|
||||
errors.append(
|
||||
f"Multiple roots ({len(roots)}) — expected exactly 1:"
|
||||
)
|
||||
for r in roots:
|
||||
types = get_types(entities[r])
|
||||
type_labels = ", ".join(shorten(t) for t in types)
|
||||
errors.append(f" root: {r} [{type_labels}]")
|
||||
|
||||
# Rule 2: exactly one terminal node (nothing derives from it)
|
||||
# Build set of entities that are parents of something
|
||||
has_children = set()
|
||||
for eid, tuples in entities.items():
|
||||
for parent in get_derived_from(tuples):
|
||||
has_children.add(parent)
|
||||
|
||||
terminals = [eid for eid in all_ids if eid not in has_children]
|
||||
if len(terminals) == 0:
|
||||
errors.append("No terminal entity found (cycle?)")
|
||||
elif len(terminals) > 1:
|
||||
errors.append(
|
||||
f"Multiple terminal entities ({len(terminals)}) — expected exactly 1:"
|
||||
)
|
||||
for t in terminals:
|
||||
types = get_types(entities[t])
|
||||
type_labels = ", ".join(shorten(ty) for ty in types)
|
||||
errors.append(f" terminal: {t} [{type_labels}]")
|
||||
|
||||
# Rule 8: Observation should not derive from Analysis if a sub-trace
|
||||
# exists as a sibling. Check: if an Analysis has both a Question child
|
||||
# and an Observation child, the Observation should derive from the
|
||||
# sub-trace's Synthesis, not from the Analysis.
|
||||
for obs_id in observations:
|
||||
obs_parents = get_derived_from(entities[obs_id])
|
||||
for parent in obs_parents:
|
||||
if parent in entities:
|
||||
parent_types = get_types(entities[parent])
|
||||
if TG_ANALYSIS in parent_types:
|
||||
# Check if this Analysis also has a Question child
|
||||
# (i.e. a sub-trace exists)
|
||||
has_subtrace = False
|
||||
for other_id, other_tuples in entities.items():
|
||||
if other_id == obs_id:
|
||||
continue
|
||||
other_parents = get_derived_from(other_tuples)
|
||||
other_types = get_types(other_tuples)
|
||||
if (parent in other_parents
|
||||
and TG_QUESTION in other_types):
|
||||
has_subtrace = True
|
||||
break
|
||||
if has_subtrace:
|
||||
errors.append(
|
||||
f"[{obs_id}] Observation derives from Analysis "
|
||||
f"{parent} which has a sub-trace — should derive "
|
||||
f"from the sub-trace's Synthesis instead"
|
||||
)
|
||||
|
||||
# ---- Report ----
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
if errors:
|
||||
print(f"ERRORS ({len(errors)}):")
|
||||
print()
|
||||
for err in errors:
|
||||
print(f" !! {err}")
|
||||
else:
|
||||
print("ALL CHECKS PASSED")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input", help="JSON trace file")
|
||||
parser.add_argument("-u", "--url", default=DEFAULT_URL)
|
||||
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
|
||||
parser.add_argument("-U", "--user", default=DEFAULT_USER)
|
||||
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(analyse(
|
||||
args.input, args.url, args.flow,
|
||||
args.user, args.collection,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
dev-tools/tests/agent_dag/ws_capture.py
Normal file
81
dev-tools/tests/agent_dag/ws_capture.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Connect to TrustGraph websocket, run an agent query, capture all
|
||||
response messages to a JSON file.
|
||||
|
||||
Usage:
|
||||
python ws_capture.py -q "What is the document about?" -o trace.json
|
||||
python ws_capture.py -q "..." -u http://localhost:8088/ -o out.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import websockets
|
||||
|
||||
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
|
||||
DEFAULT_USER = "trustgraph"
|
||||
DEFAULT_COLLECTION = "default"
|
||||
DEFAULT_FLOW = "default"
|
||||
|
||||
|
||||
async def capture(url, flow, question, user, collection, output):
|
||||
|
||||
# Convert to ws URL
|
||||
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=120) as ws:
|
||||
|
||||
request = {
|
||||
"id": "capture",
|
||||
"service": "agent",
|
||||
"flow": flow,
|
||||
"request": {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"streaming": True,
|
||||
},
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(request))
|
||||
|
||||
messages = []
|
||||
|
||||
async for raw in ws:
|
||||
msg = json.loads(raw)
|
||||
|
||||
if msg.get("id") != "capture":
|
||||
continue
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
if msg.get("complete"):
|
||||
break
|
||||
|
||||
with open(output, "w") as f:
|
||||
json.dump(messages, f, indent=2)
|
||||
|
||||
print(f"Captured {len(messages)} messages to {output}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("-q", "--question", required=True)
|
||||
parser.add_argument("-o", "--output", default="trace.json")
|
||||
parser.add_argument("-u", "--url", default=DEFAULT_URL)
|
||||
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
|
||||
parser.add_argument("-U", "--user", default=DEFAULT_USER)
|
||||
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(capture(
|
||||
args.url, args.flow, args.question,
|
||||
args.user, args.collection, args.output,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
67
dev-tools/tests/librarian/simple_text_download.py
Normal file
67
dev-tools/tests/librarian/simple_text_download.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal example: download a text document in tiny chunks via websocket API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import websockets
|
||||
|
||||
async def main():
|
||||
url = "ws://localhost:8088/api/v1/socket"
|
||||
|
||||
document_id = "test-chunked-doc-001"
|
||||
chunk_size = 10 # Tiny chunks!
|
||||
|
||||
request_id = 0
|
||||
|
||||
async def send_request(ws, request):
|
||||
nonlocal request_id
|
||||
request_id += 1
|
||||
msg = {
|
||||
"id": f"req-{request_id}",
|
||||
"service": "librarian",
|
||||
"request": request
|
||||
}
|
||||
await ws.send(json.dumps(msg))
|
||||
response = json.loads(await ws.recv())
|
||||
if "error" in response:
|
||||
raise Exception(response["error"])
|
||||
return response.get("response", {})
|
||||
|
||||
async with websockets.connect(url) as ws:
|
||||
|
||||
print(f"Fetching document: {document_id}")
|
||||
print(f"Chunk size: {chunk_size} bytes")
|
||||
print()
|
||||
|
||||
chunk_index = 0
|
||||
all_content = b""
|
||||
|
||||
while True:
|
||||
resp = await send_request(ws, {
|
||||
"operation": "stream-document",
|
||||
"user": "trustgraph",
|
||||
"document-id": document_id,
|
||||
"chunk-index": chunk_index,
|
||||
"chunk-size": chunk_size,
|
||||
})
|
||||
|
||||
chunk_data = base64.b64decode(resp["content"])
|
||||
total_chunks = resp["total-chunks"]
|
||||
total_bytes = resp["total-bytes"]
|
||||
|
||||
print(f"Chunk {chunk_index}: {chunk_data}")
|
||||
|
||||
all_content += chunk_data
|
||||
chunk_index += 1
|
||||
|
||||
if chunk_index >= total_chunks:
|
||||
break
|
||||
|
||||
print()
|
||||
print(f"Complete: {all_content}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
dev-tools/tests/librarian/simple_text_upload.py
Normal file
56
dev-tools/tests/librarian/simple_text_upload.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal example: upload a small text document via websocket API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
import websockets
|
||||
|
||||
async def main():
|
||||
url = "ws://localhost:8088/api/v1/socket"
|
||||
|
||||
# Small text content
|
||||
content = b"AAAAAAAAAABBBBBBBBBBCCCCCCCCCC"
|
||||
|
||||
request_id = 0
|
||||
|
||||
async def send_request(ws, request):
|
||||
nonlocal request_id
|
||||
request_id += 1
|
||||
msg = {
|
||||
"id": f"req-{request_id}",
|
||||
"service": "librarian",
|
||||
"request": request
|
||||
}
|
||||
await ws.send(json.dumps(msg))
|
||||
response = json.loads(await ws.recv())
|
||||
if "error" in response:
|
||||
raise Exception(response["error"])
|
||||
return response.get("response", {})
|
||||
|
||||
async with websockets.connect(url) as ws:
|
||||
|
||||
print(f"Uploading {len(content)} bytes...")
|
||||
|
||||
resp = await send_request(ws, {
|
||||
"operation": "add-document",
|
||||
"document-metadata": {
|
||||
"id": "test-chunked-doc-001",
|
||||
"time": int(time.time()),
|
||||
"kind": "text/plain",
|
||||
"title": "My Test Document",
|
||||
"comments": "Small doc for chunk testing",
|
||||
"user": "trustgraph",
|
||||
"tags": ["test"],
|
||||
"metadata": [],
|
||||
},
|
||||
"content": base64.b64encode(content).decode("utf-8"),
|
||||
})
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
237
dev-tools/tests/relay/test_rev_gateway.py
Normal file
237
dev-tools/tests/relay/test_rev_gateway.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket Test Client
|
||||
|
||||
A simple client to test the reverse gateway through the relay.
|
||||
Connects to the relay's /in endpoint and allows sending test messages.
|
||||
|
||||
Usage:
|
||||
python test_client.py [--uri URI] [--interactive]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import uuid
|
||||
from aiohttp import ClientSession, WSMsgType
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("test_client")
|
||||
|
||||
class TestClient:
|
||||
"""Simple WebSocket test client"""
|
||||
|
||||
def __init__(self, uri: str):
|
||||
self.uri = uri
|
||||
self.session = None
|
||||
self.ws = None
|
||||
self.running = False
|
||||
self.message_counter = 0
|
||||
self.client_id = str(uuid.uuid4())[:8]
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the WebSocket"""
|
||||
self.session = ClientSession()
|
||||
logger.info(f"Connecting to {self.uri}")
|
||||
self.ws = await self.session.ws_connect(self.uri)
|
||||
logger.info("Connected successfully")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket"""
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
logger.info("Disconnected")
|
||||
|
||||
async def send_message(self, service: str, request_data: dict, flow: str = "default"):
|
||||
"""Send a properly formatted TrustGraph message"""
|
||||
self.message_counter += 1
|
||||
message = {
|
||||
"id": f"{self.client_id}-{self.message_counter}",
|
||||
"service": service,
|
||||
"request": request_data,
|
||||
"flow": flow
|
||||
}
|
||||
|
||||
message_json = json.dumps(message, indent=2)
|
||||
logger.info(f"Sending message:\n{message_json}")
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
|
||||
async def listen_for_responses(self):
|
||||
"""Listen for incoming messages"""
|
||||
logger.info("Listening for responses...")
|
||||
|
||||
async for msg in self.ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
try:
|
||||
response = json.loads(msg.data)
|
||||
logger.info(f"Received response:\n{json.dumps(response, indent=2)}")
|
||||
except json.JSONDecodeError:
|
||||
logger.info(f"Received text: {msg.data}")
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
logger.info(f"Received binary data: {len(msg.data)} bytes")
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error: {self.ws.exception()}")
|
||||
break
|
||||
else:
|
||||
logger.info(f"Connection closed: {msg.type}")
|
||||
break
|
||||
|
||||
async def interactive_mode(self):
|
||||
"""Interactive mode for manual testing"""
|
||||
print("\n=== Interactive Test Client ===")
|
||||
print("Available commands:")
|
||||
print(" text-completion - Test text completion service")
|
||||
print(" agent - Test agent service")
|
||||
print(" embeddings - Test embeddings service")
|
||||
print(" custom - Send custom message")
|
||||
print(" quit - Exit")
|
||||
print()
|
||||
|
||||
# Start response listener
|
||||
listen_task = asyncio.create_task(self.listen_for_responses())
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
command = input("Command> ").strip().lower()
|
||||
|
||||
if command == "quit":
|
||||
break
|
||||
elif command == "text-completion":
|
||||
await self.send_message("text-completion", {
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is 2+2?"
|
||||
})
|
||||
elif command == "agent":
|
||||
await self.send_message("agent", {
|
||||
"question": "What is the capital of France?"
|
||||
})
|
||||
elif command == "embeddings":
|
||||
await self.send_message("embeddings", {
|
||||
"text": "Hello world"
|
||||
})
|
||||
elif command == "custom":
|
||||
service = input("Service name> ").strip()
|
||||
request_json = input("Request JSON> ").strip()
|
||||
try:
|
||||
request_data = json.loads(request_json)
|
||||
await self.send_message(service, request_data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
elif command == "":
|
||||
continue
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in interactive mode: {e}")
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def run_predefined_tests(self):
|
||||
"""Run a series of predefined tests"""
|
||||
print("\n=== Running Predefined Tests ===")
|
||||
|
||||
# Start response listener
|
||||
listen_task = asyncio.create_task(self.listen_for_responses())
|
||||
|
||||
try:
|
||||
# Test 1: Text completion
|
||||
print("\n1. Testing text-completion service...")
|
||||
await self.send_message("text-completion", {
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is 2+2?"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 2: Agent
|
||||
print("\n2. Testing agent service...")
|
||||
await self.send_message("agent", {
|
||||
"question": "What is the capital of France?"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 3: Embeddings
|
||||
print("\n3. Testing embeddings service...")
|
||||
await self.send_message("embeddings", {
|
||||
"text": "Hello world"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Test 4: Invalid service
|
||||
print("\n4. Testing invalid service...")
|
||||
await self.send_message("nonexistent-service", {
|
||||
"test": "data"
|
||||
})
|
||||
await asyncio.sleep(2)
|
||||
|
||||
print("\nTests completed. Waiting for any remaining responses...")
|
||||
await asyncio.sleep(3)
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Test Client for Reverse Gateway"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--uri',
|
||||
default='ws://localhost:8080/in',
|
||||
help='WebSocket URI to connect to (default: ws://localhost:8080/in)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--interactive', '-i',
|
||||
action='store_true',
|
||||
help='Run in interactive mode'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
client = TestClient(args.uri)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
|
||||
if args.interactive:
|
||||
await client.interactive_mode()
|
||||
else:
|
||||
await client.run_predefined_tests()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Client error: {e}")
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
210
dev-tools/tests/relay/websocket_relay.py
Normal file
210
dev-tools/tests/relay/websocket_relay.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket Relay Test Harness
|
||||
|
||||
This script creates a relay server with two WebSocket endpoints:
|
||||
- /in - for test clients to connect to
|
||||
- /out - for reverse gateway to connect to
|
||||
|
||||
Messages are bidirectionally relayed between the two connections.
|
||||
|
||||
Usage:
|
||||
python websocket_relay.py [--port PORT] [--host HOST]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import argparse
|
||||
from aiohttp import web, WSMsgType
|
||||
import weakref
|
||||
from typing import Optional, Set
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("websocket_relay")
|
||||
|
||||
class WebSocketRelay:
|
||||
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
|
||||
|
||||
def __init__(self):
|
||||
self.in_connections: Set = weakref.WeakSet()
|
||||
self.out_connections: Set = weakref.WeakSet()
|
||||
|
||||
async def handle_in_connection(self, request):
|
||||
"""Handle incoming connections on /in endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.in_connections.add(ws)
|
||||
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {data}")
|
||||
await self._forward_to_out(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
|
||||
await self._forward_to_out(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'in' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def handle_out_connection(self, request):
|
||||
"""Handle outgoing connections on /out endpoint"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.out_connections.add(ws)
|
||||
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {data}")
|
||||
await self._forward_to_in(data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
data = msg.data
|
||||
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
|
||||
await self._forward_to_in(data, binary=True)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
|
||||
break
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in 'out' connection handler: {e}")
|
||||
finally:
|
||||
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def _forward_to_out(self, data, binary=False):
|
||||
"""Forward message from 'in' to all 'out' connections"""
|
||||
if not self.out_connections:
|
||||
logger.warning("No 'out' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.out_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'out' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.out_connections:
|
||||
self.out_connections.discard(ws)
|
||||
|
||||
async def _forward_to_in(self, data, binary=False):
|
||||
"""Forward message from 'out' to all 'in' connections"""
|
||||
if not self.in_connections:
|
||||
logger.warning("No 'in' connections available to forward message")
|
||||
return
|
||||
|
||||
closed_connections = []
|
||||
for ws in list(self.in_connections):
|
||||
try:
|
||||
if ws.closed:
|
||||
closed_connections.append(ws)
|
||||
continue
|
||||
|
||||
if binary:
|
||||
await ws.send_bytes(data)
|
||||
else:
|
||||
await ws.send_str(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to 'in' connection: {e}")
|
||||
closed_connections.append(ws)
|
||||
|
||||
# Clean up closed connections
|
||||
for ws in closed_connections:
|
||||
if ws in self.in_connections:
|
||||
self.in_connections.discard(ws)
|
||||
|
||||
async def create_app(relay):
|
||||
"""Create the web application with routes"""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes
|
||||
app.router.add_get('/in', relay.handle_in_connection)
|
||||
app.router.add_get('/out', relay.handle_out_connection)
|
||||
|
||||
# Add a simple status endpoint
|
||||
async def status(request):
|
||||
status_info = {
|
||||
'in_connections': len(relay.in_connections),
|
||||
'out_connections': len(relay.out_connections),
|
||||
'status': 'running'
|
||||
}
|
||||
return web.json_response(status_info)
|
||||
|
||||
app.router.add_get('/status', status)
|
||||
app.router.add_get('/', status) # Root also shows status
|
||||
|
||||
return app
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebSocket Relay Test Harness"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
default='localhost',
|
||||
help='Host to bind to (default: localhost)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Port to bind to (default: 8080)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
relay = WebSocketRelay()
|
||||
|
||||
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
|
||||
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
|
||||
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
|
||||
print(f" Status: http://{args.host}:{args.port}/status")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
|
||||
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
|
||||
|
||||
web.run_app(create_app(relay), host=args.host, port=args.port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue