diff --git a/dev-tools/tests/agent_dag/analyse_trace.py b/dev-tools/tests/agent_dag/analyse_trace.py new file mode 100644 index 00000000..b71cdebe --- /dev/null +++ b/dev-tools/tests/agent_dag/analyse_trace.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +""" +Analyse a captured agent trace JSON file and check DAG integrity. + +Usage: + python analyse_trace.py react.json + python analyse_trace.py -u http://localhost:8088/ react.json +""" + +import argparse +import asyncio +import json +import os +import sys +import websockets + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_USER = "trustgraph" +DEFAULT_COLLECTION = "default" +DEFAULT_FLOW = "default" +GRAPH = "urn:graph:retrieval" + +# Namespace prefixes +PROV = "http://www.w3.org/ns/prov#" +RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDFS = "http://www.w3.org/2000/01/rdf-schema#" +TG = "https://trustgraph.ai/ns/" + +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +RDF_TYPE = RDF + "type" + +TG_ANALYSIS = TG + "Analysis" +TG_TOOL_USE = TG + "ToolUse" +TG_OBSERVATION_TYPE = TG + "Observation" +TG_CONCLUSION = TG + "Conclusion" +TG_SYNTHESIS = TG + "Synthesis" +TG_QUESTION = TG + "Question" + + +def shorten(uri): + """Shorten a URI for display.""" + for prefix, short in [ + (PROV, "prov:"), (RDF, "rdf:"), (RDFS, "rdfs:"), (TG, "tg:"), + ]: + if isinstance(uri, str) and uri.startswith(prefix): + return short + uri[len(prefix):] + return str(uri) + + +async def fetch_triples(ws, flow, subject, user, collection, request_counter): + """Query triples for a given subject URI.""" + request_counter[0] += 1 + req_id = f"q-{request_counter[0]}" + + msg = { + "id": req_id, + "service": "triples", + "flow": flow, + "request": { + "s": {"t": "i", "i": subject}, + "g": GRAPH, + "user": user, + "collection": collection, + "limit": 100, + }, + } + + await ws.send(json.dumps(msg)) + + while True: + raw = await ws.recv() + resp = json.loads(raw) + if resp.get("id") == req_id: + inner = resp.get("response", {}) + if isinstance(inner, dict): + return inner.get("response", []) + return inner + + +def extract_term(term): + """Extract value from wire-format term.""" + if not term: + return "" + t = term.get("t", "") + if t == "i": + return term.get("i", "") + elif t == "l": + return term.get("v", "") + elif t == "t": + tr = term.get("tr", {}) + return { + "s": extract_term(tr.get("s", {})), + "p": extract_term(tr.get("p", {})), + "o": extract_term(tr.get("o", {})), + } + return str(term) + + +def parse_triples(wire_triples): + """Convert wire triples to (s, p, o) tuples.""" + result = [] + for t in wire_triples: + s = extract_term(t.get("s", {})) + p = extract_term(t.get("p", {})) + o = extract_term(t.get("o", {})) + result.append((s, p, o)) + return result + + +def get_types(tuples): + """Get rdf:type values from parsed triples.""" + return {o for s, p, o in tuples if p == RDF_TYPE} + + +def get_derived_from(tuples): + """Get prov:wasDerivedFrom targets from parsed triples.""" + return [o for s, p, o in tuples if p == PROV_WAS_DERIVED_FROM] + + +async def analyse(path, url, flow, user, collection): + with open(path) as f: + messages = json.load(f) + + print(f"Total messages: {len(messages)}") + print() + + # ---- Pass 1: collect explain IDs and check streaming chunks ---- + + explain_ids = [] + errors = [] + + for i, msg in enumerate(messages): + resp = msg.get("response", {}) + chunk_type = resp.get("chunk_type", "?") + + if chunk_type == "explain": + explain_id = resp.get("explain_id", "") + explain_ids.append(explain_id) + print(f" {i:3d} {chunk_type} {explain_id}") + else: + print(f" {i:3d} {chunk_type}") + + # Rule 7: message_id on content chunks + if chunk_type in ("thought", "observation", "answer"): + mid = resp.get("message_id", "") + if not mid: + errors.append( + f"[msg {i}] {chunk_type} chunk missing message_id" + ) + + print() + print(f"Explain IDs ({len(explain_ids)}):") + for eid in explain_ids: + print(f" {eid}") + + # ---- Pass 2: fetch triples for each explain ID ---- + + ws_url = url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + + request_counter = [0] + # entity_id -> parsed triples [(s, p, o), ...] + entities = {} + + print() + print("Fetching triples...") + print() + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as ws: + for eid in explain_ids: + wire = await fetch_triples( + ws, flow, eid, user, collection, request_counter, + ) + + tuples = parse_triples(wire) if isinstance(wire, list) else [] + entities[eid] = tuples + + print(f" {eid}") + for s, p, o in tuples: + o_short = str(o) + if len(o_short) > 80: + o_short = o_short[:77] + "..." + print(f" {shorten(p)} = {o_short}") + print() + + # ---- Pass 3: check rules ---- + + all_ids = set(entities.keys()) + + # Collect entity metadata + roots = [] # entities with no wasDerivedFrom + conclusions = [] # tg:Conclusion entities + analyses = [] # tg:Analysis entities + observations = [] # tg:Observation entities + + for eid, tuples in entities.items(): + types = get_types(tuples) + parents = get_derived_from(tuples) + + if not tuples: + errors.append(f"[{eid}] entity has no triples in store") + + if not parents: + roots.append(eid) + + if TG_CONCLUSION in types: + conclusions.append(eid) + if TG_ANALYSIS in types: + analyses.append(eid) + if TG_OBSERVATION_TYPE in types: + observations.append(eid) + + # Rule 4: every non-root entity has wasDerivedFrom + if parents: + for parent in parents: + # Rule 5: parent exists in known entities + if parent not in all_ids: + errors.append( + f"[{eid}] wasDerivedFrom target not in explain set: " + f"{parent}" + ) + + # Rule 6: Analysis entities must have ToolUse type + if TG_ANALYSIS in types and TG_TOOL_USE not in types: + errors.append( + f"[{eid}] Analysis entity missing tg:ToolUse type" + ) + + # Rule 1: exactly one root + if len(roots) == 0: + errors.append("No root entity found (all have wasDerivedFrom)") + elif len(roots) > 1: + errors.append( + f"Multiple roots ({len(roots)}) — expected exactly 1:" + ) + for r in roots: + types = get_types(entities[r]) + type_labels = ", ".join(shorten(t) for t in types) + errors.append(f" root: {r} [{type_labels}]") + + # Rule 2: exactly one terminal node (nothing derives from it) + # Build set of entities that are parents of something + has_children = set() + for eid, tuples in entities.items(): + for parent in get_derived_from(tuples): + has_children.add(parent) + + terminals = [eid for eid in all_ids if eid not in has_children] + if len(terminals) == 0: + errors.append("No terminal entity found (cycle?)") + elif len(terminals) > 1: + errors.append( + f"Multiple terminal entities ({len(terminals)}) — expected exactly 1:" + ) + for t in terminals: + types = get_types(entities[t]) + type_labels = ", ".join(shorten(ty) for ty in types) + errors.append(f" terminal: {t} [{type_labels}]") + + # Rule 8: Observation should not derive from Analysis if a sub-trace + # exists as a sibling. Check: if an Analysis has both a Question child + # and an Observation child, the Observation should derive from the + # sub-trace's Synthesis, not from the Analysis. + for obs_id in observations: + obs_parents = get_derived_from(entities[obs_id]) + for parent in obs_parents: + if parent in entities: + parent_types = get_types(entities[parent]) + if TG_ANALYSIS in parent_types: + # Check if this Analysis also has a Question child + # (i.e. a sub-trace exists) + has_subtrace = False + for other_id, other_tuples in entities.items(): + if other_id == obs_id: + continue + other_parents = get_derived_from(other_tuples) + other_types = get_types(other_tuples) + if (parent in other_parents + and TG_QUESTION in other_types): + has_subtrace = True + break + if has_subtrace: + errors.append( + f"[{obs_id}] Observation derives from Analysis " + f"{parent} which has a sub-trace — should derive " + f"from the sub-trace's Synthesis instead" + ) + + # ---- Report ---- + + print() + print("=" * 60) + if errors: + print(f"ERRORS ({len(errors)}):") + print() + for err in errors: + print(f" !! {err}") + else: + print("ALL CHECKS PASSED") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("input", help="JSON trace file") + parser.add_argument("-u", "--url", default=DEFAULT_URL) + parser.add_argument("-f", "--flow", default=DEFAULT_FLOW) + parser.add_argument("-U", "--user", default=DEFAULT_USER) + parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION) + args = parser.parse_args() + + asyncio.run(analyse( + args.input, args.url, args.flow, + args.user, args.collection, + )) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/tests/agent_dag/ws_capture.py b/dev-tools/tests/agent_dag/ws_capture.py new file mode 100644 index 00000000..3002d563 --- /dev/null +++ b/dev-tools/tests/agent_dag/ws_capture.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Connect to TrustGraph websocket, run an agent query, capture all +response messages to a JSON file. + +Usage: + python ws_capture.py -q "What is the document about?" -o trace.json + python ws_capture.py -q "..." -u http://localhost:8088/ -o out.json +""" + +import argparse +import asyncio +import json +import os +import websockets + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_USER = "trustgraph" +DEFAULT_COLLECTION = "default" +DEFAULT_FLOW = "default" + + +async def capture(url, flow, question, user, collection, output): + + # Convert to ws URL + ws_url = url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=120) as ws: + + request = { + "id": "capture", + "service": "agent", + "flow": flow, + "request": { + "question": question, + "user": user, + "collection": collection, + "streaming": True, + }, + } + + await ws.send(json.dumps(request)) + + messages = [] + + async for raw in ws: + msg = json.loads(raw) + + if msg.get("id") != "capture": + continue + + messages.append(msg) + + if msg.get("complete"): + break + + with open(output, "w") as f: + json.dump(messages, f, indent=2) + + print(f"Captured {len(messages)} messages to {output}") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-q", "--question", required=True) + parser.add_argument("-o", "--output", default="trace.json") + parser.add_argument("-u", "--url", default=DEFAULT_URL) + parser.add_argument("-f", "--flow", default=DEFAULT_FLOW) + parser.add_argument("-U", "--user", default=DEFAULT_USER) + parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION) + args = parser.parse_args() + + asyncio.run(capture( + args.url, args.flow, args.question, + args.user, args.collection, args.output, + )) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/tests/librarian/simple_text_download.py b/dev-tools/tests/librarian/simple_text_download.py new file mode 100644 index 00000000..6af2a60d --- /dev/null +++ b/dev-tools/tests/librarian/simple_text_download.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Minimal example: download a text document in tiny chunks via websocket API +""" + +import asyncio +import json +import base64 +import websockets + +async def main(): + url = "ws://localhost:8088/api/v1/socket" + + document_id = "test-chunked-doc-001" + chunk_size = 10 # Tiny chunks! + + request_id = 0 + + async def send_request(ws, request): + nonlocal request_id + request_id += 1 + msg = { + "id": f"req-{request_id}", + "service": "librarian", + "request": request + } + await ws.send(json.dumps(msg)) + response = json.loads(await ws.recv()) + if "error" in response: + raise Exception(response["error"]) + return response.get("response", {}) + + async with websockets.connect(url) as ws: + + print(f"Fetching document: {document_id}") + print(f"Chunk size: {chunk_size} bytes") + print() + + chunk_index = 0 + all_content = b"" + + while True: + resp = await send_request(ws, { + "operation": "stream-document", + "user": "trustgraph", + "document-id": document_id, + "chunk-index": chunk_index, + "chunk-size": chunk_size, + }) + + chunk_data = base64.b64decode(resp["content"]) + total_chunks = resp["total-chunks"] + total_bytes = resp["total-bytes"] + + print(f"Chunk {chunk_index}: {chunk_data}") + + all_content += chunk_data + chunk_index += 1 + + if chunk_index >= total_chunks: + break + + print() + print(f"Complete: {all_content}") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev-tools/tests/librarian/simple_text_upload.py b/dev-tools/tests/librarian/simple_text_upload.py new file mode 100644 index 00000000..e21bd185 --- /dev/null +++ b/dev-tools/tests/librarian/simple_text_upload.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Minimal example: upload a small text document via websocket API +""" + +import asyncio +import json +import base64 +import time +import websockets + +async def main(): + url = "ws://localhost:8088/api/v1/socket" + + # Small text content + content = b"AAAAAAAAAABBBBBBBBBBCCCCCCCCCC" + + request_id = 0 + + async def send_request(ws, request): + nonlocal request_id + request_id += 1 + msg = { + "id": f"req-{request_id}", + "service": "librarian", + "request": request + } + await ws.send(json.dumps(msg)) + response = json.loads(await ws.recv()) + if "error" in response: + raise Exception(response["error"]) + return response.get("response", {}) + + async with websockets.connect(url) as ws: + + print(f"Uploading {len(content)} bytes...") + + resp = await send_request(ws, { + "operation": "add-document", + "document-metadata": { + "id": "test-chunked-doc-001", + "time": int(time.time()), + "kind": "text/plain", + "title": "My Test Document", + "comments": "Small doc for chunk testing", + "user": "trustgraph", + "tags": ["test"], + "metadata": [], + }, + "content": base64.b64encode(content).decode("utf-8"), + }) + + print("Done!") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev-tools/tests/relay/test_rev_gateway.py b/dev-tools/tests/relay/test_rev_gateway.py new file mode 100644 index 00000000..fe200e46 --- /dev/null +++ b/dev-tools/tests/relay/test_rev_gateway.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +WebSocket Test Client + +A simple client to test the reverse gateway through the relay. +Connects to the relay's /in endpoint and allows sending test messages. + +Usage: + python test_client.py [--uri URI] [--interactive] +""" + +import asyncio +import json +import logging +import argparse +import uuid +from aiohttp import ClientSession, WSMsgType + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("test_client") + +class TestClient: + """Simple WebSocket test client""" + + def __init__(self, uri: str): + self.uri = uri + self.session = None + self.ws = None + self.running = False + self.message_counter = 0 + self.client_id = str(uuid.uuid4())[:8] + + async def connect(self): + """Connect to the WebSocket""" + self.session = ClientSession() + logger.info(f"Connecting to {self.uri}") + self.ws = await self.session.ws_connect(self.uri) + logger.info("Connected successfully") + + async def disconnect(self): + """Disconnect from WebSocket""" + if self.ws and not self.ws.closed: + await self.ws.close() + if self.session and not self.session.closed: + await self.session.close() + logger.info("Disconnected") + + async def send_message(self, service: str, request_data: dict, flow: str = "default"): + """Send a properly formatted TrustGraph message""" + self.message_counter += 1 + message = { + "id": f"{self.client_id}-{self.message_counter}", + "service": service, + "request": request_data, + "flow": flow + } + + message_json = json.dumps(message, indent=2) + logger.info(f"Sending message:\n{message_json}") + await self.ws.send_str(json.dumps(message)) + + async def listen_for_responses(self): + """Listen for incoming messages""" + logger.info("Listening for responses...") + + async for msg in self.ws: + if msg.type == WSMsgType.TEXT: + try: + response = json.loads(msg.data) + logger.info(f"Received response:\n{json.dumps(response, indent=2)}") + except json.JSONDecodeError: + logger.info(f"Received text: {msg.data}") + elif msg.type == WSMsgType.BINARY: + logger.info(f"Received binary data: {len(msg.data)} bytes") + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error: {self.ws.exception()}") + break + else: + logger.info(f"Connection closed: {msg.type}") + break + + async def interactive_mode(self): + """Interactive mode for manual testing""" + print("\n=== Interactive Test Client ===") + print("Available commands:") + print(" text-completion - Test text completion service") + print(" agent - Test agent service") + print(" embeddings - Test embeddings service") + print(" custom - Send custom message") + print(" quit - Exit") + print() + + # Start response listener + listen_task = asyncio.create_task(self.listen_for_responses()) + + try: + while True: + try: + command = input("Command> ").strip().lower() + + if command == "quit": + break + elif command == "text-completion": + await self.send_message("text-completion", { + "system": "You are a helpful assistant.", + "prompt": "What is 2+2?" + }) + elif command == "agent": + await self.send_message("agent", { + "question": "What is the capital of France?" + }) + elif command == "embeddings": + await self.send_message("embeddings", { + "text": "Hello world" + }) + elif command == "custom": + service = input("Service name> ").strip() + request_json = input("Request JSON> ").strip() + try: + request_data = json.loads(request_json) + await self.send_message(service, request_data) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + elif command == "": + continue + else: + print(f"Unknown command: {command}") + + except KeyboardInterrupt: + break + except EOFError: + break + except Exception as e: + logger.error(f"Error in interactive mode: {e}") + + finally: + listen_task.cancel() + try: + await listen_task + except asyncio.CancelledError: + pass + + async def run_predefined_tests(self): + """Run a series of predefined tests""" + print("\n=== Running Predefined Tests ===") + + # Start response listener + listen_task = asyncio.create_task(self.listen_for_responses()) + + try: + # Test 1: Text completion + print("\n1. Testing text-completion service...") + await self.send_message("text-completion", { + "system": "You are a helpful assistant.", + "prompt": "What is 2+2?" + }) + await asyncio.sleep(2) + + # Test 2: Agent + print("\n2. Testing agent service...") + await self.send_message("agent", { + "question": "What is the capital of France?" + }) + await asyncio.sleep(2) + + # Test 3: Embeddings + print("\n3. Testing embeddings service...") + await self.send_message("embeddings", { + "text": "Hello world" + }) + await asyncio.sleep(2) + + # Test 4: Invalid service + print("\n4. Testing invalid service...") + await self.send_message("nonexistent-service", { + "test": "data" + }) + await asyncio.sleep(2) + + print("\nTests completed. Waiting for any remaining responses...") + await asyncio.sleep(3) + + finally: + listen_task.cancel() + try: + await listen_task + except asyncio.CancelledError: + pass + +async def main(): + parser = argparse.ArgumentParser( + description="WebSocket Test Client for Reverse Gateway" + ) + parser.add_argument( + '--uri', + default='ws://localhost:8080/in', + help='WebSocket URI to connect to (default: ws://localhost:8080/in)' + ) + parser.add_argument( + '--interactive', '-i', + action='store_true', + help='Run in interactive mode' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + client = TestClient(args.uri) + + try: + await client.connect() + + if args.interactive: + await client.interactive_mode() + else: + await client.run_predefined_tests() + + except KeyboardInterrupt: + print("\nShutdown requested by user") + except Exception as e: + logger.error(f"Client error: {e}") + finally: + await client.disconnect() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py new file mode 100644 index 00000000..d537f7da --- /dev/null +++ b/dev-tools/tests/relay/websocket_relay.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +WebSocket Relay Test Harness + +This script creates a relay server with two WebSocket endpoints: +- /in - for test clients to connect to +- /out - for reverse gateway to connect to + +Messages are bidirectionally relayed between the two connections. + +Usage: + python websocket_relay.py [--port PORT] [--host HOST] +""" + +import asyncio +import logging +import argparse +from aiohttp import web, WSMsgType +import weakref +from typing import Optional, Set + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("websocket_relay") + +class WebSocketRelay: + """WebSocket relay that forwards messages between 'in' and 'out' connections""" + + def __init__(self): + self.in_connections: Set = weakref.WeakSet() + self.out_connections: Set = weakref.WeakSet() + + async def handle_in_connection(self, request): + """Handle incoming connections on /in endpoint""" + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.in_connections.add(ws) + logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + data = msg.data + logger.info(f"IN → OUT: {data}") + await self._forward_to_out(data) + elif msg.type == WSMsgType.BINARY: + data = msg.data + logger.info(f"IN → OUT: {len(data)} bytes (binary)") + await self._forward_to_out(data, binary=True) + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + break + else: + break + except Exception as e: + logger.error(f"Error in 'in' connection handler: {e}") + finally: + logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + return ws + + async def handle_out_connection(self, request): + """Handle outgoing connections on /out endpoint""" + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + data = msg.data + logger.info(f"OUT → IN: {data}") + await self._forward_to_in(data) + elif msg.type == WSMsgType.BINARY: + data = msg.data + logger.info(f"OUT → IN: {len(data)} bytes (binary)") + await self._forward_to_in(data, binary=True) + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection handler: {e}") + finally: + logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + return ws + + async def _forward_to_out(self, data, binary=False): + """Forward message from 'in' to all 'out' connections""" + if not self.out_connections: + logger.warning("No 'out' connections available to forward message") + return + + closed_connections = [] + for ws in list(self.out_connections): + try: + if ws.closed: + closed_connections.append(ws) + continue + + if binary: + await ws.send_bytes(data) + else: + await ws.send_str(data) + except Exception as e: + logger.error(f"Error forwarding to 'out' connection: {e}") + closed_connections.append(ws) + + # Clean up closed connections + for ws in closed_connections: + if ws in self.out_connections: + self.out_connections.discard(ws) + + async def _forward_to_in(self, data, binary=False): + """Forward message from 'out' to all 'in' connections""" + if not self.in_connections: + logger.warning("No 'in' connections available to forward message") + return + + closed_connections = [] + for ws in list(self.in_connections): + try: + if ws.closed: + closed_connections.append(ws) + continue + + if binary: + await ws.send_bytes(data) + else: + await ws.send_str(data) + except Exception as e: + logger.error(f"Error forwarding to 'in' connection: {e}") + closed_connections.append(ws) + + # Clean up closed connections + for ws in closed_connections: + if ws in self.in_connections: + self.in_connections.discard(ws) + +async def create_app(relay): + """Create the web application with routes""" + app = web.Application() + + # Add routes + app.router.add_get('/in', relay.handle_in_connection) + app.router.add_get('/out', relay.handle_out_connection) + + # Add a simple status endpoint + async def status(request): + status_info = { + 'in_connections': len(relay.in_connections), + 'out_connections': len(relay.out_connections), + 'status': 'running' + } + return web.json_response(status_info) + + app.router.add_get('/status', status) + app.router.add_get('/', status) # Root also shows status + + return app + +def main(): + parser = argparse.ArgumentParser( + description="WebSocket Relay Test Harness" + ) + parser.add_argument( + '--host', + default='localhost', + help='Host to bind to (default: localhost)' + ) + parser.add_argument( + '--port', + type=int, + default=8080, + help='Port to bind to (default: 8080)' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + relay = WebSocketRelay() + + print(f"Starting WebSocket Relay on {args.host}:{args.port}") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") + print(f" Status: http://{args.host}:{args.port}/status") + print() + print("Usage:") + print(f" Test client connects to: ws://{args.host}:{args.port}/in") + print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") + + web.run_app(create_app(relay), host=args.host, port=args.port) + +if __name__ == "__main__": + main() \ No newline at end of file