diff --git a/containers/Containerfile.unstructured b/containers/Containerfile.unstructured index 22ee05b2..7284901e 100644 --- a/containers/Containerfile.unstructured +++ b/containers/Containerfile.unstructured @@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.13 && \ +RUN dnf install -y python3.13 libxcb mesa-libGL && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ @@ -38,6 +38,11 @@ RUN ls /root/wheels FROM base +# Pre-install CPU-only PyTorch so that unstructured[pdf]'s torch +# dependency is satisfied without pulling in CUDA (~190MB vs ~2GB+) +RUN pip3 install --no-cache-dir torch==2.11.0+cpu \ + --index-url https://download.pytorch.org/whl/cpu + COPY --from=build /root/wheels /root/wheels RUN \ diff --git a/docs/contributor-licence-agreement.md b/docs/contributor-licence-agreement.md new file mode 100644 index 00000000..48314516 --- /dev/null +++ b/docs/contributor-licence-agreement.md @@ -0,0 +1,19 @@ +# Contributor Licence Agreement (CLA) + +We ask every contributor to sign a Contributor Licence Agreement before +we can merge a pull request. The CLA does **not** transfer copyright — +you keep full ownership of your work. It simply grants the TrustGraph +project a perpetual, royalty-free licence to distribute your +contribution under the project's +[Apache 2.0 licence](https://www.apache.org/licenses/LICENSE-2.0), and +confirms that you have the right to make the contribution. This protects +both the project and its users by ensuring every contribution has a +clear legal footing. + +When you open a pull request, the CLA bot will post a comment asking you +to review and sign the appropriate agreement — it only takes a moment +and you only need to do it once across all TrustGraph repositories. + +- Contributing as an **individual**? Sign the [Individual CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Fiduciary-Contributor-License-Agreement.md) +- Contributing on behalf of a **company or organisation**? Sign the [Entity CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Entity-Fiduciary-Contributor-License-Agreement.md) + diff --git a/tests/integration/test_text_completion_integration.py b/tests/integration/test_text_completion_integration.py index 08e2a995..26fa0ab3 100644 --- a/tests/integration/test_text_completion_integration.py +++ b/tests/integration/test_text_completion_integration.py @@ -93,7 +93,7 @@ class TestTextCompletionIntegration: assert call_args.kwargs['model'] == "gpt-3.5-turbo" assert call_args.kwargs['temperature'] == 0.7 - assert call_args.kwargs['max_tokens'] == 1024 + assert call_args.kwargs['max_completion_tokens'] == 1024 assert len(call_args.kwargs['messages']) == 1 assert call_args.kwargs['messages'][0]['role'] == "user" assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text'] @@ -134,7 +134,7 @@ class TestTextCompletionIntegration: call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == config['model'] assert call_args.kwargs['temperature'] == config['temperature'] - assert call_args.kwargs['max_tokens'] == config['max_output'] + assert call_args.kwargs['max_completion_tokens'] == config['max_output'] # Reset mock for next iteration mock_openai_client.reset_mock() @@ -286,7 +286,7 @@ class TestTextCompletionIntegration: # were removed in #561 as unnecessary parameters assert 'model' in call_args.kwargs assert 'temperature' in call_args.kwargs - assert 'max_tokens' in call_args.kwargs + assert 'max_completion_tokens' in call_args.kwargs # Verify result structure assert hasattr(result, 'text') @@ -362,7 +362,7 @@ class TestTextCompletionIntegration: call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['temperature'] == 0.8 - assert call_args.kwargs['max_tokens'] == 2048 + assert call_args.kwargs['max_completion_tokens'] == 2048 # Note: top_p, frequency_penalty, and presence_penalty # were removed in #561 as unnecessary parameters diff --git a/tests/integration/test_text_completion_streaming_integration.py b/tests/integration/test_text_completion_streaming_integration.py index a70afb4c..6968affa 100644 --- a/tests/integration/test_text_completion_streaming_integration.py +++ b/tests/integration/test_text_completion_streaming_integration.py @@ -201,7 +201,7 @@ class TestTextCompletionStreaming: call_args = mock_streaming_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['temperature'] == 0.5 - assert call_args.kwargs['max_tokens'] == 2048 + assert call_args.kwargs['max_completion_tokens'] == 2048 assert call_args.kwargs['stream'] is True # Verify chunks have correct model diff --git a/tests/unit/test_gateway/test_dispatch_mux.py b/tests/unit/test_gateway/test_dispatch_mux.py index b623a1b6..a0bc9460 100644 --- a/tests/unit/test_gateway/test_dispatch_mux.py +++ b/tests/unit/test_gateway/test_dispatch_mux.py @@ -121,7 +121,11 @@ class TestMux: # Based on the code, it seems to catch exceptions await mux.receive(mock_msg) - mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Bad message", "type": "error"}, + "complete": True, + "id": "test-id-123", + }) @pytest.mark.asyncio async def test_mux_receive_message_without_id(self): @@ -129,23 +133,26 @@ class TestMux: mock_dispatcher_manager = MagicMock() mock_ws = AsyncMock() mock_running = MagicMock() - + mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, running=mock_running ) - + # Mock message without id field mock_msg = MagicMock() mock_msg.json.return_value = { "request": {"type": "test"} } - + # receive method should handle the RuntimeError internally await mux.receive(mock_msg) - - mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) + + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Bad message", "type": "error"}, + "complete": True, + }) @pytest.mark.asyncio async def test_mux_receive_invalid_json(self): @@ -153,19 +160,22 @@ class TestMux: mock_dispatcher_manager = MagicMock() mock_ws = AsyncMock() mock_running = MagicMock() - + mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, running=mock_running ) - + # Mock message with invalid JSON mock_msg = MagicMock() mock_msg.json.side_effect = ValueError("Invalid JSON") - + # receive method should handle the ValueError internally await mux.receive(mock_msg) - + mock_msg.json.assert_called_once() - mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"}) \ No newline at end of file + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Invalid JSON", "type": "error"}, + "complete": True, + }) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_azure_openai_processor.py b/tests/unit/test_text_completion/test_azure_openai_processor.py index 02c85d54..b00531ad 100644 --- a/tests/unit/test_text_completion/test_azure_openai_processor.py +++ b/tests/unit/test_text_completion/test_azure_openai_processor.py @@ -108,7 +108,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase): }] }], temperature=0.0, - max_tokens=4192, + max_completion_tokens=4192, top_p=1 ) @@ -399,7 +399,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase): # Verify other parameters assert call_args[1]['model'] == 'gpt-4' assert call_args[1]['temperature'] == 0.5 - assert call_args[1]['max_tokens'] == 1024 + assert call_args[1]['max_completion_tokens'] == 1024 assert call_args[1]['top_p'] == 1 @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') diff --git a/tests/unit/test_text_completion/test_openai_processor.py b/tests/unit/test_text_completion/test_openai_processor.py index 352af062..514d35da 100644 --- a/tests/unit/test_text_completion/test_openai_processor.py +++ b/tests/unit/test_text_completion/test_openai_processor.py @@ -102,7 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): }] }], temperature=0.0, - max_tokens=4096 + max_completion_tokens=4096 ) @patch('trustgraph.model.text_completion.openai.llm.OpenAI') @@ -380,7 +380,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): # Verify other parameters assert call_args[1]['model'] == 'gpt-3.5-turbo' assert call_args[1]['temperature'] == 0.5 - assert call_args[1]['max_tokens'] == 1024 + assert call_args[1]['max_completion_tokens'] == 1024 @patch('trustgraph.model.text_completion.openai.llm.OpenAI') diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 3279609b..7a239b07 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -1,5 +1,6 @@ import json +import asyncio import websockets from typing import Optional, Dict, Any, AsyncIterator, Union @@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException class AsyncSocketClient: - """Asynchronous WebSocket client""" + """Asynchronous WebSocket client with persistent connection. + + Maintains a single websocket connection and multiplexes requests + by ID, routing responses via a background reader task. + + Use as an async context manager for proper lifecycle management: + + async with AsyncSocketClient(url, timeout, token) as client: + result = await client._send_request(...) + + Or call connect()/aclose() manually. + """ def __init__(self, url: str, timeout: int, token: Optional[str]): self.url = self._convert_to_ws_url(url) self.timeout = timeout self.token = token self._request_counter = 0 + self._socket = None + self._connect_cm = None + self._reader_task = None + self._pending = {} # request_id -> asyncio.Queue + self._connected = False def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,82 +42,123 @@ class AsyncSocketClient: elif url.startswith("ws://") or url.startswith("wss://"): return url else: - # Assume ws:// return f"ws://{url}" + def _build_ws_url(self): + ws_url = f"{self.url.rstrip('/')}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + return ws_url + + async def connect(self): + """Establish the persistent websocket connection.""" + if self._connected: + return + + ws_url = self._build_ws_url() + self._connect_cm = websockets.connect( + ws_url, ping_interval=20, ping_timeout=self.timeout + ) + self._socket = await self._connect_cm.__aenter__() + self._connected = True + self._reader_task = asyncio.create_task(self._reader()) + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + async def _ensure_connected(self): + """Lazily connect if not already connected.""" + if not self._connected: + await self.connect() + + async def _reader(self): + """Background task to read responses and route by request ID.""" + try: + async for raw_message in self._socket: + response = json.loads(raw_message) + request_id = response.get("id") + + if request_id and request_id in self._pending: + await self._pending[request_id].put(response) + # Ignore messages for unknown request IDs + + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + # Signal error to all pending requests + for queue in self._pending.values(): + try: + await queue.put({"error": str(e)}) + except: + pass + finally: + self._connected = False + + def _next_request_id(self): + self._request_counter += 1 + return f"req-{self._request_counter}" + def flow(self, flow_id: str): """Get async flow instance for WebSocket operations""" return AsyncSocketFlowInstance(self, flow_id) async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]): - """Async WebSocket request implementation (non-streaming)""" - # Generate unique request ID - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Send a request and wait for a single response.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Wait for single response - raw_message = await websocket.recv() - response = json.loads(raw_message) - - if response.get("id") != request_id: - raise ProtocolException(f"Response ID mismatch") + response = await queue.get() if "error" in response: raise ApplicationException(response["error"]) if "response" not in response: - raise ProtocolException(f"Missing response in message") + raise ProtocolException("Missing response in message") return response["response"] + finally: + self._pending.pop(request_id, None) + async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]): - """Async WebSocket request implementation (streaming)""" - # Generate unique request ID - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Send a request and yield streaming response chunks.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Yield chunks as they arrive - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue # Ignore messages for other requests + while True: + response = await queue.get() if "error" in response: raise ApplicationException(response["error"]) @@ -108,18 +166,16 @@ class AsyncSocketClient: if "response" in response: resp = response["response"] - # Parse different chunk types chunk = self._parse_chunk(resp) - if chunk is not None: # Skip provenance messages in streaming + if chunk is not None: yield chunk - # Check if this is the final message - # end_of_session indicates entire session is complete (including provenance) - # end_of_dialog is for agent dialogs - # complete is from the gateway envelope if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break + finally: + self._pending.pop(request_id, None) + def _parse_chunk(self, resp: Dict[str, Any]): """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") @@ -127,7 +183,6 @@ class AsyncSocketClient: # Handle new GraphRAG message format with message_type if message_type == "provenance": - # Provenance messages are not yielded to user - they're metadata return None if chunk_type == "thought": @@ -147,25 +202,41 @@ class AsyncSocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) elif chunk_type == "action": - # Agent action chunks - treat as thoughts for display purposes return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) else: - # RAG-style chunk (or generic chunk with message_type="chunk") - # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None # Errors are always thrown, never stored + error=None ) async def aclose(self): - """Close WebSocket connection""" - # Cleanup handled by context manager - pass + """Close the persistent WebSocket connection cleanly.""" + # Wait for reader to finish (socket close will cause it to exit) + if self._reader_task: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + # Exit the websockets context manager — this cleanly shuts down + # the connection and its keepalive task + if self._connect_cm: + try: + await self._connect_cm.__aexit__(None, None, None) + except Exception: + pass + self._connect_cm = None + + self._socket = None + self._connected = False + self._pending.clear() class AsyncSocketFlowInstance: @@ -292,7 +363,6 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" - # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -362,7 +432,6 @@ class AsyncSocketFlowInstance: limit: int = 10, **kwargs ): """Query row embeddings for semantic search on structured data""" - # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 91db8b69..40769fa0 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -3,6 +3,9 @@ TrustGraph Synchronous WebSocket Client This module provides synchronous WebSocket-based access to TrustGraph services with streaming support for real-time responses from agents, RAG queries, and text completions. + +Uses a persistent WebSocket connection with a background reader task that +multiplexes requests by ID. """ import json @@ -74,43 +77,27 @@ def build_term(value: Any, term_type: Optional[str] = None, class SocketClient: """ - Synchronous WebSocket client for streaming operations. + Synchronous WebSocket client with persistent connection. - Provides a synchronous interface to WebSocket-based TrustGraph services, - wrapping async websockets library with synchronous generators for ease of use. - Supports streaming responses from agents, RAG queries, and text completions. - - Note: This is a synchronous wrapper around async WebSocket operations. For - true async support, use AsyncSocketClient instead. + Maintains a single websocket connection and multiplexes requests + by ID via a background reader task. Provides synchronous generators + for streaming responses. """ def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: - """ - Initialize synchronous WebSocket client. - - Args: - url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) - timeout: WebSocket timeout in seconds - token: Optional bearer token for authentication - """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token - self._connection: Optional[Any] = None self._request_counter: int = 0 self._lock: Lock = Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None + self._socket = None + self._connect_cm = None + self._reader_task = None + self._pending: Dict[str, asyncio.Queue] = {} + self._connected: bool = False def _convert_to_ws_url(self, url: str) -> str: - """ - Convert HTTP URL to WebSocket URL. - - Args: - url: HTTP/HTTPS or WS/WSS URL - - Returns: - str: WebSocket URL (ws:// or wss://) - """ if url.startswith("http://"): return url.replace("http://", "ws://", 1) elif url.startswith("https://"): @@ -118,29 +105,68 @@ class SocketClient: elif url.startswith("ws://") or url.startswith("wss://"): return url else: - # Assume ws:// return f"ws://{url}" + def _build_ws_url(self): + ws_url = f"{self.url.rstrip('/')}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + return ws_url + + def _get_loop(self): + """Get or create the event loop, reusing across calls.""" + if self._loop is None or self._loop.is_closed(): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + return self._loop + + async def _ensure_connected(self): + """Lazily establish the persistent websocket connection.""" + if self._connected: + return + + ws_url = self._build_ws_url() + self._connect_cm = websockets.connect( + ws_url, ping_interval=20, ping_timeout=self.timeout + ) + self._socket = await self._connect_cm.__aenter__() + self._connected = True + self._reader_task = asyncio.create_task(self._reader()) + + async def _reader(self): + """Background task to read responses and route by request ID.""" + try: + async for raw_message in self._socket: + response = json.loads(raw_message) + request_id = response.get("id") + + if request_id and request_id in self._pending: + await self._pending[request_id].put(response) + + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + for queue in self._pending.values(): + try: + await queue.put({"error": str(e)}) + except: + pass + finally: + self._connected = False + + def _next_request_id(self): + with self._lock: + self._request_counter += 1 + return f"req-{self._request_counter}" + def flow(self, flow_id: str) -> "SocketFlowInstance": - """ - Get a flow instance for WebSocket streaming operations. - - Args: - flow_id: Flow identifier - - Returns: - SocketFlowInstance: Flow instance with streaming methods - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Stream agent responses - for chunk in flow.agent(question="Hello", user="trustgraph", streaming=True): - print(chunk.content, end='', flush=True) - ``` - """ return SocketFlowInstance(self, flow_id) def _send_request_sync( @@ -152,34 +178,14 @@ class SocketClient: streaming_raw: bool = False, include_provenance: bool = False ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: - """Synchronous wrapper around async WebSocket communication. - - Args: - service: Service name - flow: Flow ID (optional) - request: Request payload - streaming: Use parsed streaming (for agent/RAG chunk types) - streaming_raw: Use raw streaming (for data batches like triples) - """ - # Create event loop if needed - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # If loop is running (e.g., in Jupyter), create new loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + """Synchronous wrapper around async WebSocket communication.""" + loop = self._get_loop() if streaming_raw: - # Raw streaming for data batches (triples, rows, etc.) return self._streaming_generator_raw(service, flow, request, loop) elif streaming: - # Parsed streaming for agent/RAG chunk types return self._streaming_generator(service, flow, request, loop, include_provenance) else: - # Non-streaming single response return loop.run_until_complete(self._send_request_async(service, flow, request)) def _streaming_generator( @@ -190,7 +196,7 @@ class SocketClient: loop: asyncio.AbstractEventLoop, include_provenance: bool = False ) -> Iterator[StreamingChunk]: - """Generator that yields streaming chunks (for agent/RAG responses)""" + """Generator that yields streaming chunks.""" async_gen = self._send_request_async_streaming(service, flow, request, include_provenance) try: @@ -201,7 +207,6 @@ class SocketClient: except StopAsyncIteration: break finally: - # Clean up async generator try: loop.run_until_complete(async_gen.aclose()) except: @@ -214,7 +219,7 @@ class SocketClient: request: Dict[str, Any], loop: asyncio.AbstractEventLoop ) -> Iterator[Dict[str, Any]]: - """Generator that yields raw response dicts (for data streaming like triples)""" + """Generator that yields raw response dicts.""" async_gen = self._send_request_async_streaming_raw(service, flow, request) try: @@ -236,35 +241,26 @@ class SocketClient: flow: Optional[str], request: Dict[str, Any] ) -> Iterator[Dict[str, Any]]: - """Async streaming that yields raw response dicts without parsing. + """Async streaming that yields raw response dicts.""" + await self._ensure_connected() - Used for data streaming (triples, rows, etc.) where responses are - just batches of data, not agent/RAG chunk types. - """ - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + await self._socket.send(json.dumps(message)) - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue + while True: + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) @@ -275,51 +271,46 @@ class SocketClient: if response.get("complete"): break + finally: + self._pending.pop(request_id, None) + async def _send_request_async( self, service: str, flow: Optional[str], request: Dict[str, Any] ) -> Dict[str, Any]: - """Async implementation of WebSocket request (non-streaming)""" - # Generate unique request ID - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Async non-streaming request over persistent connection.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Wait for single response - raw_message = await websocket.recv() - response = json.loads(raw_message) - - if response.get("id") != request_id: - raise ProtocolException(f"Response ID mismatch") + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) if "response" not in response: - raise ProtocolException(f"Missing response in message") + raise ProtocolException("Missing response in message") return response["response"] + finally: + self._pending.pop(request_id, None) + async def _send_request_async_streaming( self, service: str, @@ -327,36 +318,26 @@ class SocketClient: request: Dict[str, Any], include_provenance: bool = False ) -> Iterator[StreamingChunk]: - """Async implementation of WebSocket request (streaming)""" - # Generate unique request ID - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Async streaming request over persistent connection.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Yield chunks as they arrive - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue # Ignore messages for other requests + while True: + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) @@ -364,22 +345,19 @@ class SocketClient: if "response" in response: resp = response["response"] - # Check for errors in response chunks if "error" in resp: raise_from_error_dict(resp["error"]) - # Parse different chunk types chunk = self._parse_chunk(resp, include_provenance=include_provenance) - if chunk is not None: # Skip provenance messages unless include_provenance + if chunk is not None: yield chunk - # Check if this is the final message - # end_of_session indicates entire session is complete (including provenance) - # end_of_dialog is for agent dialogs - # complete is from the gateway envelope if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break + finally: + self._pending.pop(request_id, None) + def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") @@ -388,12 +366,10 @@ class SocketClient: # Handle GraphRAG/DocRAG message format with message_type if message_type == "explain": if include_provenance: - # Return provenance event for explainability return ProvenanceEvent( explain_id=resp.get("explain_id", ""), explain_graph=resp.get("explain_graph", "") ) - # Provenance messages are not yielded to user - they're metadata return None # Handle Agent message format with chunk_type="explain" @@ -422,7 +398,6 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) elif chunk_type == "action": - # Agent action chunks - treat as thoughts for display purposes return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) @@ -445,23 +420,43 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) else: - # RAG-style chunk (or generic chunk with message_type="chunk") - # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None # Errors are always thrown, never stored + error=None ) def close(self) -> None: - """ - Close WebSocket connections. + """Close the persistent WebSocket connection.""" + if self._loop and not self._loop.is_closed(): + try: + self._loop.run_until_complete(self._close_async()) + except: + pass - Note: Cleanup is handled automatically by context managers in async code. - """ - # Cleanup handled by context manager in async code - pass + async def _close_async(self): + # Cancel reader task + if self._reader_task: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + # Exit the websockets context manager — this cleanly shuts down + # the connection and its keepalive task + if self._connect_cm: + try: + await self._connect_cm.__aexit__(None, None, None) + except Exception: + pass + self._connect_cm = None + + self._socket = None + self._connected = False + self._pending.clear() class SocketFlowInstance: @@ -469,18 +464,10 @@ class SocketFlowInstance: Synchronous WebSocket flow instance for streaming operations. Provides the same interface as REST FlowInstance but with WebSocket-based - streaming support for real-time responses. All methods support an optional - `streaming` parameter to enable incremental result delivery. + streaming support for real-time responses. """ def __init__(self, client: SocketClient, flow_id: str) -> None: - """ - Initialize socket flow instance. - - Args: - client: Parent SocketClient - flow_id: Flow identifier - """ self.client: SocketClient = client self.flow_id: str = flow_id @@ -494,44 +481,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: - """ - Execute an agent operation with streaming support. - - Agents can perform multi-step reasoning with tool use. This method always - returns streaming chunks (thoughts, observations, answers) even when - streaming=False, to show the agent's reasoning process. - - Args: - question: User question or instruction - user: User identifier - state: Optional state dictionary for stateful conversations - group: Optional group identifier for multi-user contexts - history: Optional conversation history as list of message dicts - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the agent service - - Returns: - Iterator[StreamingChunk]: Stream of agent thoughts, observations, and answers - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Stream agent reasoning - for chunk in flow.agent( - question="What is quantum computing?", - user="trustgraph", - streaming=True - ): - if isinstance(chunk, AgentThought): - print(f"[Thinking] {chunk.content}") - elif isinstance(chunk, AgentObservation): - print(f"[Observation] {chunk.content}") - elif isinstance(chunk, AgentAnswer): - print(f"[Answer] {chunk.content}") - ``` - """ + """Execute an agent operation with streaming support.""" request = { "question": question, "user": user, @@ -545,8 +495,6 @@ class SocketFlowInstance: request["history"] = history request.update(kwargs) - # Agents always use multipart messaging (multiple complete messages) - # regardless of streaming flag, so always use the streaming code path return self.client._send_request_sync("agent", self.flow_id, request, streaming=True) def agent_explain( @@ -559,70 +507,12 @@ class SocketFlowInstance: history: Optional[List[Dict[str, Any]]] = None, **kwargs: Any ) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]: - """ - Execute an agent operation with explainability support. - - Streams both content chunks (AgentThought, AgentObservation, AgentAnswer) - and provenance events (ProvenanceEvent). Provenance events contain URIs - that can be fetched using ExplainabilityClient to get detailed information - about the agent's reasoning process. - - Agent trace consists of: - - Session: The initial question and session metadata - - Iterations: Each thought/action/observation cycle - - Conclusion: The final answer - - Args: - question: User question or instruction - user: User identifier - collection: Collection identifier for provenance storage - state: Optional state dictionary for stateful conversations - group: Optional group identifier for multi-user contexts - history: Optional conversation history as list of message dicts - **kwargs: Additional parameters passed to the agent service - - Yields: - Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent - from trustgraph.api import AgentThought, AgentObservation, AgentAnswer - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - provenance_ids = [] - for item in flow.agent_explain( - question="What is the capital of France?", - user="trustgraph", - collection="default" - ): - if isinstance(item, AgentThought): - print(f"[Thought] {item.content}") - elif isinstance(item, AgentObservation): - print(f"[Observation] {item.content}") - elif isinstance(item, AgentAnswer): - print(f"[Answer] {item.content}") - elif isinstance(item, ProvenanceEvent): - provenance_ids.append(item.explain_id) - - # Fetch session trace after completion - if provenance_ids: - trace = explain_client.fetch_agent_trace( - provenance_ids[0], # Session URI is first - graph="urn:graph:retrieval", - user="trustgraph", - collection="default" - ) - ``` - """ + """Execute an agent operation with explainability support.""" request = { "question": question, "user": user, "collection": collection, - "streaming": True # Always streaming for explain + "streaming": True } if state is not None: request["state"] = state @@ -632,47 +522,13 @@ class SocketFlowInstance: request["history"] = history request.update(kwargs) - # Use streaming with provenance enabled return self.client._send_request_sync( "agent", self.flow_id, request, streaming=True, include_provenance=True ) def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: - """ - Execute text completion with optional streaming. - - Args: - system: System prompt defining the assistant's behavior - prompt: User prompt/question - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Non-streaming - response = flow.text_completion( - system="You are helpful", - prompt="Explain quantum computing", - streaming=False - ) - print(response) - - # Streaming - for chunk in flow.text_completion( - system="You are helpful", - prompt="Explain quantum computing", - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute text completion with optional streaming.""" request = { "system": system, "prompt": prompt, @@ -683,13 +539,11 @@ class SocketFlowInstance: result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming) if streaming: - # For text completion, return generator that yields content return self._text_completion_generator(result) else: return result.get("response", "") def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: - """Generator for text completion streaming""" for chunk in result: if hasattr(chunk, 'content'): yield chunk.content @@ -708,43 +562,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute graph-based RAG query with optional streaming. - - Uses knowledge graph structure to find relevant context, then generates - a response using an LLM. Streaming mode delivers results incrementally. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - entity_limit: Maximum entities to retrieve (default: 50) - triple_limit: Maximum triples per entity (default: 30) - max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_path_length: Maximum traversal depth (default: 2) - edge_score_limit: Max edges for semantic pre-filter (default: 50) - edge_limit: Max edges after LLM scoring (default: 25) - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming graph RAG - for chunk in flow.graph_rag( - query="Tell me about Marie Curie", - user="trustgraph", - collection="scientists", - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute graph-based RAG query with optional streaming.""" request = { "query": query, "user": user, @@ -779,61 +597,7 @@ class SocketFlowInstance: edge_limit: int = 25, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: - """ - Execute graph-based RAG query with explainability support. - - Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). - Provenance events contain URIs that can be fetched using ExplainabilityClient - to get detailed information about how the response was generated. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - entity_limit: Maximum entities to retrieve (default: 50) - triple_limit: Maximum triples per entity (default: 30) - max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_path_length: Maximum traversal depth (default: 2) - edge_score_limit: Max edges for semantic pre-filter (default: 50) - edge_limit: Max edges after LLM scoring (default: 25) - **kwargs: Additional parameters passed to the service - - Yields: - Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - provenance_ids = [] - response_text = "" - - for item in flow.graph_rag_explain( - query="Tell me about Marie Curie", - user="trustgraph", - collection="scientists" - ): - if isinstance(item, RAGChunk): - response_text += item.content - print(item.content, end='', flush=True) - elif isinstance(item, ProvenanceEvent): - provenance_ids.append(item.provenance_id) - - # Fetch explainability details - for prov_id in provenance_ids: - entity = explain_client.fetch_entity( - prov_id, - graph="urn:graph:retrieval", - user="trustgraph", - collection="scientists" - ) - print(f"Entity: {entity}") - ``` - """ + """Execute graph-based RAG query with explainability support.""" request = { "query": query, "user": user, @@ -849,7 +613,6 @@ class SocketFlowInstance: } request.update(kwargs) - # Use streaming with provenance events included return self.client._send_request_sync( "graph-rag", self.flow_id, request, streaming=True, include_provenance=True @@ -864,39 +627,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute document-based RAG query with optional streaming. - - Uses vector embeddings to find relevant document chunks, then generates - a response using an LLM. Streaming mode delivers results incrementally. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - doc_limit: Maximum document chunks to retrieve (default: 10) - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming document RAG - for chunk in flow.document_rag( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", - doc_limit=5, - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute document-based RAG query with optional streaming.""" request = { "query": query, "user": user, @@ -921,55 +652,7 @@ class SocketFlowInstance: doc_limit: int = 10, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: - """ - Execute document-based RAG query with explainability support. - - Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). - Provenance events contain URIs that can be fetched using ExplainabilityClient - to get detailed information about how the response was generated. - - Document RAG trace consists of: - - Question: The user's query - - Exploration: Chunks retrieved from document store (chunk_count) - - Synthesis: The generated answer - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - doc_limit: Maximum document chunks to retrieve (default: 10) - **kwargs: Additional parameters passed to the service - - Yields: - Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - for item in flow.document_rag_explain( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", - doc_limit=5 - ): - if isinstance(item, RAGChunk): - print(item.content, end='', flush=True) - elif isinstance(item, ProvenanceEvent): - # Fetch entity details - entity = explain_client.fetch_entity( - item.explain_id, - graph=item.explain_graph, - user="trustgraph", - collection="research-papers" - ) - print(f"Event: {entity}", file=sys.stderr) - ``` - """ + """Execute document-based RAG query with explainability support.""" request = { "query": query, "user": user, @@ -980,14 +663,12 @@ class SocketFlowInstance: } request.update(kwargs) - # Use streaming with provenance events included return self.client._send_request_sync( "document-rag", self.flow_id, request, streaming=True, include_provenance=True ) def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: - """Generator for RAG streaming (graph-rag and document-rag)""" for chunk in result: if hasattr(chunk, 'content'): yield chunk.content @@ -999,32 +680,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute a prompt template with optional streaming. - - Args: - id: Prompt template identifier - variables: Dictionary of variable name to value mappings - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming prompt execution - for chunk in flow.prompt( - id="summarize-template", - variables={"topic": "quantum computing", "length": "brief"}, - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute a prompt template with optional streaming.""" request = { "id": id, "variables": variables, @@ -1047,33 +703,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query knowledge graph entities using semantic similarity. - - Args: - text: Query text for semantic search - user: User/keyspace identifier - collection: Collection identifier - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with similar entities - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - results = flow.graph_embeddings_query( - text="physicist who discovered radioactivity", - user="trustgraph", - collection="scientists", - limit=5 - ) - ``` - """ - # First convert text to embedding vector + """Query knowledge graph entities using semantic similarity.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -1095,34 +725,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query document chunks using semantic similarity. - - Args: - text: Query text for semantic search - user: User/keyspace identifier - collection: Collection identifier - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with chunk_ids of matching document chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - results = flow.document_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="research-papers", - limit=5 - ) - # results contains {"chunks": [{"chunk_id": "...", "score": 0.95}, ...]} - ``` - """ - # First convert text to embedding vector + """Query document chunks using semantic similarity.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -1137,25 +740,7 @@ class SocketFlowInstance: return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]: - """ - Generate vector embeddings for one or more texts. - - Args: - texts: List of input texts to embed - **kwargs: Additional parameters passed to the service - - Returns: - dict: Response containing vectors (one set per input text) - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - result = flow.embeddings(["quantum computing"]) - vectors = result.get("vectors", []) - ``` - """ + """Generate vector embeddings for one or more texts.""" request = {"texts": texts} request.update(kwargs) @@ -1172,46 +757,9 @@ class SocketFlowInstance: limit: int = 100, **kwargs: Any ) -> List[Dict[str, Any]]: - """ - Query knowledge graph triples using pattern matching. - - Args: - s: Subject filter - URI string, Term dict, or None for wildcard - p: Predicate filter - URI string, Term dict, or None for wildcard - o: Object filter - URI/literal string, Term dict, or None for wildcard - g: Named graph filter - URI string or None for all graphs - user: User/keyspace identifier (optional) - collection: Collection identifier (optional) - limit: Maximum results to return (default: 100) - **kwargs: Additional parameters passed to the service - - Returns: - List[Dict]: List of matching triples in wire format - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Find all triples about a specific subject - triples = flow.triples_query( - s="http://example.org/person/marie-curie", - user="trustgraph", - collection="scientists" - ) - - # Query with named graph filter - triples = flow.triples_query( - s="urn:trustgraph:session:abc123", - g="urn:graph:retrieval", - user="trustgraph", - collection="default" - ) - ``` - """ + """Query knowledge graph triples using pattern matching.""" request = {"limit": limit} - # Build Term dicts for s/p/o (auto-converts strings) s_term = build_term(s) p_term = build_term(p) o_term = build_term(o) @@ -1231,7 +779,6 @@ class SocketFlowInstance: request.update(kwargs) result = self.client._send_request_sync("triples", self.flow_id, request, False) - # Return the triples list from the response if isinstance(result, dict) and "response" in result: return result["response"] return result @@ -1248,46 +795,13 @@ class SocketFlowInstance: batch_size: int = 20, **kwargs: Any ) -> Iterator[List[Dict[str, Any]]]: - """ - Query knowledge graph triples with streaming batches. - - Yields batches of triples as they arrive, reducing time-to-first-result - and memory overhead for large result sets. - - Args: - s: Subject filter - URI string, Term dict, or None for wildcard - p: Predicate filter - URI string, Term dict, or None for wildcard - o: Object filter - URI/literal string, Term dict, or None for wildcard - g: Named graph filter - URI string or None for all graphs - user: User/keyspace identifier (optional) - collection: Collection identifier (optional) - limit: Maximum results to return (default: 100) - batch_size: Triples per batch (default: 20) - **kwargs: Additional parameters passed to the service - - Yields: - List[Dict]: Batches of triples in wire format - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - for batch in flow.triples_query_stream( - user="trustgraph", - collection="default" - ): - for triple in batch: - print(triple["s"], triple["p"], triple["o"]) - ``` - """ + """Query knowledge graph triples with streaming batches.""" request = { "limit": limit, "streaming": True, "batch-size": batch_size, } - # Build Term dicts for s/p/o (auto-converts strings) s_term = build_term(s) p_term = build_term(p) o_term = build_term(o) @@ -1306,9 +820,7 @@ class SocketFlowInstance: request["collection"] = collection request.update(kwargs) - # Use raw streaming - yields response dicts directly without parsing for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True): - # Response is {"response": [...triples...]} from translator if isinstance(response, dict) and "response" in response: yield response["response"] else: @@ -1323,41 +835,7 @@ class SocketFlowInstance: operation_name: Optional[str] = None, **kwargs: Any ) -> Dict[str, Any]: - """ - Execute a GraphQL query against structured rows. - - Args: - query: GraphQL query string - user: User/keyspace identifier - collection: Collection identifier - variables: Optional query variables dictionary - operation_name: Optional operation name for multi-operation documents - **kwargs: Additional parameters passed to the service - - Returns: - dict: GraphQL response with data, errors, and/or extensions - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - query = ''' - { - scientists(limit: 10) { - name - field - discoveries - } - } - ''' - result = flow.rows_query( - query=query, - user="trustgraph", - collection="scientists" - ) - ``` - """ + """Execute a GraphQL query against structured rows.""" request = { "query": query, "user": user, @@ -1377,28 +855,7 @@ class SocketFlowInstance: parameters: Dict[str, Any], **kwargs: Any ) -> Dict[str, Any]: - """ - Execute a Model Context Protocol (MCP) tool. - - Args: - name: Tool name/identifier - parameters: Tool parameters dictionary - **kwargs: Additional parameters passed to the service - - Returns: - dict: Tool execution result - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - result = flow.mcp_tool( - name="search-web", - parameters={"query": "latest AI news", "limit": 5} - ) - ``` - """ + """Execute a Model Context Protocol (MCP) tool.""" request = { "name": name, "parameters": parameters @@ -1417,50 +874,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query row data using semantic similarity on indexed fields. - - Finds rows whose indexed field values are semantically similar to the - input text, using vector embeddings. This enables fuzzy/semantic matching - on structured data. - - Args: - text: Query text for semantic search - schema_name: Schema name to search within - user: User/keyspace identifier (default: "trustgraph") - collection: Collection identifier (default: "default") - index_name: Optional index name to filter search to specific index - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with matches containing index_name, index_value, - text, and score - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Search for customers by name similarity - results = flow.row_embeddings_query( - text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", - limit=5 - ) - - # Filter to specific index - results = flow.row_embeddings_query( - text="machine learning engineer", - schema_name="employees", - index_name="job_title", - limit=10 - ) - ``` - """ - # First convert text to embedding vector + """Query row data using semantic similarity on indexed fields.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index f545c53f..e6d1e075 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -58,6 +58,14 @@ def print_json(sessions): print(json.dumps(sessions, indent=2)) +# Map type names for display +TYPE_DISPLAY = { + "graphrag": "GraphRAG", + "docrag": "DocRAG", + "agent": "Agent", +} + + def main(): parser = argparse.ArgumentParser( prog='tg-list-explain-traces', @@ -118,7 +126,7 @@ def main(): explain_client = ExplainabilityClient(flow) try: - # List all sessions using the API + # List all sessions — uses persistent websocket via SocketClient questions = explain_client.list_sessions( graph=RETRIEVAL_GRAPH, user=args.user, @@ -126,7 +134,8 @@ def main(): limit=args.limit, ) - # Convert to output format + # detect_session_type is mostly a fast URI pattern check, + # only falls back to network calls for unrecognised URIs sessions = [] for q in questions: session_type = explain_client.detect_session_type( @@ -136,16 +145,9 @@ def main(): collection=args.collection ) - # Map type names - type_display = { - "graphrag": "GraphRAG", - "docrag": "DocRAG", - "agent": "Agent", - }.get(session_type, session_type.title()) - sessions.append({ "id": q.uri, - "type": type_display, + "type": TYPE_DISPLAY.get(session_type, session_type.title()), "question": q.query, "time": q.timestamp, }) diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py index ca1f5c83..8d16d098 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py @@ -3,31 +3,27 @@ Shows all defined flow blueprints. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def format_parameters(params_metadata, config_api): +def format_parameters(params_metadata, param_type_defs): """ - Format parameter metadata for display + Format parameter metadata for display. - Args: - params_metadata: Parameter definitions from flow blueprint - config_api: API client to get parameter type information - - Returns: - Formatted string describing parameters + param_type_defs is a dict of type_name -> parsed type definition, + pre-fetched concurrently. """ if not params_metadata: return "None" param_list = [] - # Sort parameters by order if available sorted_params = sorted( params_metadata.items(), key=lambda x: x[1].get("order", 999) @@ -37,41 +33,89 @@ def format_parameters(params_metadata, config_api): description = param_meta.get("description", param_name) param_type = param_meta.get("type", "unknown") - # Get type information if available type_info = param_type - if config_api: - try: - key = ConfigKey("parameter-type", param_type) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - # Add default value if available - default = param_type_def.get("default") - if default is not None: - type_info = f"{param_type} (default: {default})" - - except: - # If we can't get type definition, just show the type name - pass + if param_type in param_type_defs: + param_type_def = param_type_defs[param_type] + default = param_type_def.get("default") + if default is not None: + type_info = f"{param_type} (default: {default})" param_list.append(f" {param_name}: {description} [{type_info}]") return "\n".join(param_list) +async def fetch_data(client): + """Fetch all data needed for show_flow_blueprints concurrently.""" + + # Round 1: list blueprints + resp = await client._send_request("flow", None, { + "operation": "list-blueprints", + }) + blueprint_names = resp.get("blueprint-names", []) + + if not blueprint_names: + return [], {}, {} + + # Round 2: get all blueprints in parallel + blueprint_tasks = [ + client._send_request("flow", None, { + "operation": "get-blueprint", + "blueprint-name": name, + }) + for name in blueprint_names + ] + blueprint_results = await asyncio.gather(*blueprint_tasks) + + blueprints = {} + for name, resp in zip(blueprint_names, blueprint_results): + bp_data = resp.get("blueprint-definition", "{}") + blueprints[name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data + + # Round 3: get all parameter type definitions in parallel + param_types_needed = set() + for bp in blueprints.values(): + for param_meta in bp.get("parameters", {}).values(): + pt = param_meta.get("type", "") + if pt: + param_types_needed.add(pt) + + param_type_defs = {} + if param_types_needed: + param_type_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": pt}], + }) + for pt in param_types_needed + ] + param_type_results = await asyncio.gather(*param_type_tasks) + + for pt, resp in zip(param_types_needed, param_type_results): + values = resp.get("values", []) + if values: + try: + param_type_defs[pt] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return blueprint_names, blueprints, param_type_defs + +async def _show_flow_blueprints_async(url, token=None): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_data(client) + def show_flow_blueprints(url, token=None): - api = Api(url, token=token) - flow_api = api.flow() - config_api = api.config() + blueprint_names, blueprints, param_type_defs = asyncio.run( + _show_flow_blueprints_async(url, token=token) + ) - blueprint_names = flow_api.list_blueprints() - - if len(blueprint_names) == 0: + if not blueprint_names: print("No flow blueprints.") return for blueprint_name in blueprint_names: - cls = flow_api.get_blueprint(blueprint_name) + cls = blueprints[blueprint_name] table = [] table.append(("name", blueprint_name)) @@ -81,10 +125,9 @@ def show_flow_blueprints(url, token=None): if tags: table.append(("tags", ", ".join(tags))) - # Show parameters if they exist parameters = cls.get("parameters", {}) if parameters: - param_str = format_parameters(parameters, config_api) + param_str = format_parameters(parameters, param_type_defs) table.append(("parameters", param_str)) print(tabulate.tabulate( diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index 828c18f1..d1abf984 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -3,22 +3,15 @@ Shows configured flows. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import Api, AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def get_interface(config_api, i): - - key = ConfigKey("interface-description", i) - - value = config_api.get([key])[0].value - - return json.loads(value) - def describe_interfaces(intdefs, flow): intfs = flow.get("interfaces", {}) @@ -34,7 +27,7 @@ def describe_interfaces(intdefs, flow): if kind == "request-response": req = intfs[k]["request"] - resp = intfs[k]["request"] + resp = intfs[k]["response"] lst.append(f"{k} request: {req}") lst.append(f"{k} response: {resp}") @@ -49,17 +42,9 @@ def describe_interfaces(intdefs, flow): def get_enum_description(param_value, param_type_def): """ Get the human-readable description for an enum value - - Args: - param_value: The actual parameter value (e.g., "gpt-4") - param_type_def: The parameter type definition containing enum objects - - Returns: - Human-readable description or the original value if not found """ enum_list = param_type_def.get("enum", []) - # Handle both old format (strings) and new format (objects with id/description) for enum_item in enum_list: if isinstance(enum_item, dict): if enum_item.get("id") == param_value: @@ -67,27 +52,20 @@ def get_enum_description(param_value, param_type_def): elif enum_item == param_value: return param_value - # If not found in enum, return original value return param_value -def format_parameters(flow_params, blueprint_params_metadata, config_api): +def format_parameters(flow_params, blueprint_params_metadata, param_type_defs): """ - Format flow parameters with their human-readable descriptions + Format flow parameters with their human-readable descriptions. - Args: - flow_params: The actual parameter values used in the flow - blueprint_params_metadata: The parameter metadata from the flow blueprint definition - config_api: API client to retrieve parameter type definitions - - Returns: - Formatted string of parameters with descriptions + param_type_defs is a dict of type_name -> parsed type definition, + pre-fetched concurrently. """ if not flow_params: return "None" param_list = [] - # Sort parameters by order if available sorted_params = sorted( blueprint_params_metadata.items(), key=lambda x: x[1].get("order", 999) @@ -100,80 +78,165 @@ def format_parameters(flow_params, blueprint_params_metadata, config_api): param_type = param_meta.get("type", "") controlled_by = param_meta.get("controlled-by", None) - # Try to get enum description if this parameter has a type definition display_value = value - if param_type and config_api: - try: - from trustgraph.api import ConfigKey - key = ConfigKey("parameter-type", param_type) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - display_value = get_enum_description(value, param_type_def) - except: - # If we can't get the type definition, just use the original value - display_value = value + if param_type and param_type in param_type_defs: + display_value = get_enum_description( + value, param_type_defs[param_type] + ) - # Format the parameter line line = f"• {description}: {display_value}" - # Add controlled-by indicator if present if controlled_by: line += f" (controlled by {controlled_by})" param_list.append(line) - # Add any parameters that aren't in the blueprint metadata (shouldn't happen normally) for param_name, value in flow_params.items(): if param_name not in blueprint_params_metadata: param_list.append(f"• {param_name}: {value} (undefined)") return "\n".join(param_list) if param_list else "None" +async def fetch_show_flows(client): + """Fetch all data needed for show_flows concurrently.""" + + # Round 1: list interfaces and list flows in parallel + interface_names_resp, flow_ids_resp = await asyncio.gather( + client._send_request("config", None, { + "operation": "list", + "type": "interface-description", + }), + client._send_request("flow", None, { + "operation": "list-flows", + }), + ) + + interface_names = interface_names_resp.get("directory", []) + flow_ids = flow_ids_resp.get("flow-ids", []) + + if not flow_ids: + return {}, [], {}, {} + + # Round 2: get all interfaces + all flows in parallel + interface_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "interface-description", "key": name}], + }) + for name in interface_names + ] + + flow_tasks = [ + client._send_request("flow", None, { + "operation": "get-flow", + "flow-id": fid, + }) + for fid in flow_ids + ] + + results = await asyncio.gather(*interface_tasks, *flow_tasks) + + # Split results + interface_results = results[:len(interface_names)] + flow_results = results[len(interface_names):] + + # Parse interfaces + interface_defs = {} + for name, resp in zip(interface_names, interface_results): + values = resp.get("values", []) + if values: + interface_defs[name] = json.loads(values[0].get("value", "{}")) + + # Parse flows + flows = {} + for fid, resp in zip(flow_ids, flow_results): + flow_data = resp.get("flow", "{}") + flows[fid] = json.loads(flow_data) if isinstance(flow_data, str) else flow_data + + # Round 3: get all blueprints in parallel + blueprint_names = set() + for flow in flows.values(): + bp = flow.get("blueprint-name", "") + if bp: + blueprint_names.add(bp) + + blueprint_tasks = [ + client._send_request("flow", None, { + "operation": "get-blueprint", + "blueprint-name": bp_name, + }) + for bp_name in blueprint_names + ] + + blueprint_results = await asyncio.gather(*blueprint_tasks) + + blueprints = {} + for bp_name, resp in zip(blueprint_names, blueprint_results): + bp_data = resp.get("blueprint-definition", "{}") + blueprints[bp_name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data + + # Round 4: get all parameter type definitions in parallel + param_types_needed = set() + for bp in blueprints.values(): + for param_meta in bp.get("parameters", {}).values(): + pt = param_meta.get("type", "") + if pt: + param_types_needed.add(pt) + + param_type_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": pt}], + }) + for pt in param_types_needed + ] + + param_type_results = await asyncio.gather(*param_type_tasks) + + param_type_defs = {} + for pt, resp in zip(param_types_needed, param_type_results): + values = resp.get("values", []) + if values: + try: + param_type_defs[pt] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return interface_defs, flow_ids, flows, blueprints, param_type_defs + +async def _show_flows_async(url, token=None): + + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_show_flows(client) + def show_flows(url, token=None): - api = Api(url, token=token) - config_api = api.config() - flow_api = api.flow() + result = asyncio.run(_show_flows_async(url, token=token)) - interface_names = config_api.list("interface-description") + interface_defs, flow_ids, flows, blueprints, param_type_defs = result - interface_defs = { - i: get_interface(config_api, i) - for i in interface_names - } - - flow_ids = flow_api.list() - - if len(flow_ids) == 0: + if not flow_ids: print("No flows.") return - flows = [] + for fid in flow_ids: - for id in flow_ids: - - flow = flow_api.get(id) + flow = flows[fid] table = [] - table.append(("id", id)) + table.append(("id", fid)) table.append(("blueprint", flow.get("blueprint-name", ""))) table.append(("desc", flow.get("description", ""))) - # Display parameters with human-readable descriptions parameters = flow.get("parameters", {}) if parameters: - # Try to get the flow blueprint definition for parameter metadata blueprint_name = flow.get("blueprint-name", "") - if blueprint_name: - try: - flow_blueprint = flow_api.get_blueprint(blueprint_name) - blueprint_params_metadata = flow_blueprint.get("parameters", {}) - param_str = format_parameters(parameters, blueprint_params_metadata, config_api) - except Exception as e: - # Fallback to JSON if we can't get the blueprint definition - param_str = json.dumps(parameters, indent=2) + if blueprint_name and blueprint_name in blueprints: + blueprint_params_metadata = blueprints[blueprint_name].get("parameters", {}) + param_str = format_parameters( + parameters, blueprint_params_metadata, param_type_defs + ) else: - # No blueprint name, fallback to JSON param_str = json.dumps(parameters, indent=2) table.append(("parameters", param_str)) @@ -220,4 +283,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_parameter_types.py b/trustgraph-cli/trustgraph/cli/show_parameter_types.py index 2e0f1be3..67d6e823 100644 --- a/trustgraph-cli/trustgraph/cli/show_parameter_types.py +++ b/trustgraph-cli/trustgraph/cli/show_parameter_types.py @@ -7,9 +7,10 @@ valid enums, and validation rules. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') @@ -17,13 +18,7 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def format_enum_values(enum_list): """ - Format enum values for display, handling both old and new formats - - Args: - enum_list: List of enum values (strings or objects with id/description) - - Returns: - Formatted string describing enum options + Format enum values for display, handling both old and new formats. """ if not enum_list: return "Any value" @@ -31,7 +26,6 @@ def format_enum_values(enum_list): enum_items = [] for item in enum_list: if isinstance(item, dict): - # New format: objects with id and description enum_id = item.get("id", "") description = item.get("description", "") if description: @@ -39,99 +33,146 @@ def format_enum_values(enum_list): else: enum_items.append(enum_id) else: - # Old format: simple strings enum_items.append(str(item)) return "\n".join(f"• {item}" for item in enum_items) def format_constraints(param_type_def): """ - Format validation constraints for display - - Args: - param_type_def: Parameter type definition - - Returns: - Formatted string describing constraints + Format validation constraints for display. """ constraints = [] - # Handle numeric constraints if "minimum" in param_type_def: constraints.append(f"min: {param_type_def['minimum']}") if "maximum" in param_type_def: constraints.append(f"max: {param_type_def['maximum']}") - - # Handle string constraints if "minLength" in param_type_def: constraints.append(f"min length: {param_type_def['minLength']}") if "maxLength" in param_type_def: constraints.append(f"max length: {param_type_def['maxLength']}") if "pattern" in param_type_def: constraints.append(f"pattern: {param_type_def['pattern']}") - - # Handle required field if param_type_def.get("required", False): constraints.append("required") return ", ".join(constraints) if constraints else "None" +def format_param_type(param_type_name, param_type_def): + """Format a single parameter type for display.""" + table = [] + table.append(("name", param_type_name)) + table.append(("description", param_type_def.get("description", ""))) + table.append(("type", param_type_def.get("type", "unknown"))) + + default = param_type_def.get("default") + if default is not None: + table.append(("default", str(default))) + + enum_list = param_type_def.get("enum") + if enum_list: + enum_str = format_enum_values(enum_list) + table.append(("valid values", enum_str)) + + constraints = format_constraints(param_type_def) + if constraints != "None": + table.append(("constraints", constraints)) + + return table + +async def fetch_all_param_types(client): + """Fetch all parameter types concurrently.""" + + # Round 1: list parameter types + resp = await client._send_request("config", None, { + "operation": "list", + "type": "parameter-type", + }) + param_type_names = resp.get("directory", []) + + if not param_type_names: + return [], {} + + # Round 2: get all parameter types in parallel + tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": name}], + }) + for name in param_type_names + ] + results = await asyncio.gather(*tasks) + + param_type_defs = {} + for name, resp in zip(param_type_names, results): + values = resp.get("values", []) + if values: + try: + param_type_defs[name] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return param_type_names, param_type_defs + +async def fetch_single_param_type(client, param_type_name): + """Fetch a single parameter type.""" + resp = await client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": param_type_name}], + }) + values = resp.get("values", []) + if values: + return json.loads(values[0].get("value", "{}")) + return None + def show_parameter_types(url, token=None): - """ - Show all parameter type definitions - """ - api = Api(url, token=token) - config_api = api.config() + """Show all parameter type definitions.""" - # Get list of all parameter types - try: - param_type_names = config_api.list("parameter-type") - except Exception as e: - print(f"Error retrieving parameter types: {e}") - return + async def _fetch(): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_all_param_types(client) - if len(param_type_names) == 0: + param_type_names, param_type_defs = asyncio.run(_fetch()) + + if not param_type_names: print("No parameter types defined.") return - for param_type_name in param_type_names: - try: - # Get the parameter type definition - key = ConfigKey("parameter-type", param_type_name) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - table = [] - table.append(("name", param_type_name)) - table.append(("description", param_type_def.get("description", ""))) - table.append(("type", param_type_def.get("type", "unknown"))) - - # Show default value if present - default = param_type_def.get("default") - if default is not None: - table.append(("default", str(default))) - - # Show enum values if present - enum_list = param_type_def.get("enum") - if enum_list: - enum_str = format_enum_values(enum_list) - table.append(("valid values", enum_str)) - - # Show constraints - constraints = format_constraints(param_type_def) - if constraints != "None": - table.append(("constraints", constraints)) - - print(tabulate.tabulate( - table, - tablefmt="pretty", - stralign="left", - )) + for name in param_type_names: + if name not in param_type_defs: + print(f"Error retrieving parameter type '{name}'") print() + continue - except Exception as e: - print(f"Error retrieving parameter type '{param_type_name}': {e}") - print() + table = format_param_type(name, param_type_defs[name]) + + print(tabulate.tabulate( + table, + tablefmt="pretty", + stralign="left", + )) + print() + +def show_specific_parameter_type(url, param_type_name, token=None): + """Show a specific parameter type definition.""" + + async def _fetch(): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_single_param_type(client, param_type_name) + + param_type_def = asyncio.run(_fetch()) + + if param_type_def is None: + print(f"Error retrieving parameter type '{param_type_name}'") + return + + table = format_param_type(param_type_name, param_type_def) + + print(tabulate.tabulate( + table, + tablefmt="pretty", + stralign="left", + )) def main(): parser = argparse.ArgumentParser( @@ -161,57 +202,12 @@ def main(): try: if args.type: - # Show specific parameter type show_specific_parameter_type(args.api_url, args.type, args.token) else: - # Show all parameter types show_parameter_types(args.api_url, args.token) except Exception as e: print("Exception:", e, flush=True) -def show_specific_parameter_type(url, param_type_name, token=None): - """ - Show a specific parameter type definition - """ - api = Api(url, token=token) - config_api = api.config() - - try: - # Get the parameter type definition - key = ConfigKey("parameter-type", param_type_name) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - table = [] - table.append(("name", param_type_name)) - table.append(("description", param_type_def.get("description", ""))) - table.append(("type", param_type_def.get("type", "unknown"))) - - # Show default value if present - default = param_type_def.get("default") - if default is not None: - table.append(("default", str(default))) - - # Show enum values if present - enum_list = param_type_def.get("enum") - if enum_list: - enum_str = format_enum_values(enum_list) - table.append(("valid values", enum_str)) - - # Show constraints - constraints = format_constraints(param_type_def) - if constraints != "None": - table.append(("constraints", constraints)) - - print(tabulate.tabulate( - table, - tablefmt="pretty", - stralign="left", - )) - - except Exception as e: - print(f"Error retrieving parameter type '{param_type_name}': {e}") - if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 8cebc83f..5fea1bb0 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -178,7 +178,11 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona url += '/' metrics_url = f"{url}api/metrics/query?query=processor_info" - resp = requests.get(metrics_url, timeout=timeout) + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(metrics_url, timeout=timeout, headers=headers) if resp.status_code == 200: data = resp.json() processor_count = len(data.get("data", {}).get("result", [])) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index ddaa8ddf..fabd5c44 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -33,9 +33,12 @@ class Mux: async def receive(self, msg): + request_id = None + try: data = msg.json() + request_id = data.get("id") if "request" not in data: raise RuntimeError("Bad message") @@ -51,7 +54,13 @@ class Mux: except Exception as e: logger.error(f"Receive exception: {str(e)}", exc_info=True) - await self.ws.send_json({"error": str(e)}) + error_resp = { + "error": {"message": str(e), "type": "error"}, + "complete": True, + } + if request_id: + error_resp["id"] = request_id + await self.ws.send_json(error_resp) async def maybe_tidy_workers(self, workers): @@ -97,12 +106,12 @@ class Mux: }) worker = asyncio.create_task( - self.request_task(request, responder, flow, svc) + self.request_task(id, request, responder, flow, svc) ) workers.append(worker) - async def request_task(self, request, responder, flow, svc): + async def request_task(self, id, request, responder, flow, svc): try: @@ -119,7 +128,11 @@ class Mux: ) except Exception as e: - await self.ws.send_json({"error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) async def run(self): @@ -143,7 +156,11 @@ class Mux: except Exception as e: # This is an internal working error, may not be recoverable logger.error(f"Run prepare exception: {e}", exc_info=True) - await self.ws.send_json({"id": id, "error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) self.running.stop() if self.ws: @@ -160,7 +177,11 @@ class Mux: except Exception as e: logger.error(f"Exception in mux: {e}", exc_info=True) - await self.ws.send_json({"error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) self.running.stop() diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index 4ab0b302..def44bd4 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -20,7 +20,7 @@ default_ident = "text-completion" default_temperature = 0.0 default_max_output = 4192 -default_api = "2024-12-01-preview" +default_api = os.getenv("AZURE_API_VERSION", "2024-12-01-preview") default_endpoint = os.getenv("AZURE_ENDPOINT", None) default_token = os.getenv("AZURE_TOKEN", None) default_model = os.getenv("AZURE_MODEL", None) @@ -90,7 +90,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, top_p=1, ) @@ -159,7 +159,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, top_p=1, stream=True, stream_options={"include_usage": True} diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index d65e27bf..cdc8602a 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -86,7 +86,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, ) inputtokens = resp.usage.prompt_tokens @@ -152,7 +152,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, stream=True, stream_options={"include_usage": True} ) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index e551ed5d..eadd841b 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager class AppContext: sockets: dict[str, WebSocketManager] websocket_url: str + gateway_token: str @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: """ Manage application lifecycle with type-safe context @@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8 sockets = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url) + yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) finally: # Cleanup on shutdown @@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user): lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url + gateway_token = lifespan_context.gateway_token if user in sockets: logging.info("Return existing socket manager") return sockets[user] logging.info(f"Opening socket to {websocket_url}...") - + # Create manager with empty pending requests - manager = WebSocketManager(websocket_url) + manager = WebSocketManager(websocket_url, token=gateway_token) # Start reader task with the proper manager await manager.start() @@ -193,13 +195,14 @@ class GetSystemPromptResponse: prompt: str class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): + def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): self.host = host self.port = port self.websocket_url = websocket_url - + self.gateway_token = gateway_token + # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) + lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) self.mcp = FastMCP( "TrustGraph", dependencies=["trustgraph-base"], @@ -2060,8 +2063,11 @@ def main(): # Setup logging before creating server setup_logging(vars(args)) + # Read gateway auth token from environment + gateway_token = os.environ.get("GATEWAY_SECRET", "") + # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) + server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) server.run() def run(): diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index 44f1bf2e..d255ae14 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from websockets.asyncio.client import connect +from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json @@ -9,12 +10,22 @@ import time class WebSocketManager: - def __init__(self, url): + def __init__(self, url, token=None): self.url = url + self.token = token self.socket = None + def _build_url(self): + if not self.token: + return self.url + parsed = urlparse(self.url) + params = parse_qs(parsed.query) + params["token"] = [self.token] + new_query = urlencode(params, doseq=True) + return urlunparse(parsed._replace(query=new_query)) + async def start(self): - self.socket = await connect(self.url) + self.socket = await connect(self._build_url()) self.pending_requests = {} self.running = True self.reader_task = asyncio.create_task(self.reader()) diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index 35597398..33265edb 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pulsar-client", "prometheus-client", "python-magic", - "unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]", + "unstructured[csv,docx,epub,md,odt,pdf,pptx,rst,rtf,tsv,xlsx]", ] classifiers = [ "Programming Language :: Python :: 3",