release/v2.2 -> master (#733)

This commit is contained in:
cybermaggedon 2026-03-29 20:27:25 +01:00 committed by GitHub
parent 3ed71a5620
commit 2449392896
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 774 additions and 1111 deletions

View file

@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1 ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 libxcb mesa-libGL && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
@ -38,6 +38,11 @@ RUN ls /root/wheels
FROM base FROM base
# Pre-install CPU-only PyTorch so that unstructured[pdf]'s torch
# dependency is satisfied without pulling in CUDA (~190MB vs ~2GB+)
RUN pip3 install --no-cache-dir torch==2.11.0+cpu \
--index-url https://download.pytorch.org/whl/cpu
COPY --from=build /root/wheels /root/wheels COPY --from=build /root/wheels /root/wheels
RUN \ RUN \

View file

@ -0,0 +1,19 @@
# Contributor Licence Agreement (CLA)
We ask every contributor to sign a Contributor Licence Agreement before
we can merge a pull request. The CLA does **not** transfer copyright —
you keep full ownership of your work. It simply grants the TrustGraph
project a perpetual, royalty-free licence to distribute your
contribution under the project's
[Apache 2.0 licence](https://www.apache.org/licenses/LICENSE-2.0), and
confirms that you have the right to make the contribution. This protects
both the project and its users by ensuring every contribution has a
clear legal footing.
When you open a pull request, the CLA bot will post a comment asking you
to review and sign the appropriate agreement — it only takes a moment
and you only need to do it once across all TrustGraph repositories.
- Contributing as an **individual**? Sign the [Individual CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Fiduciary-Contributor-License-Agreement.md)
- Contributing on behalf of a **company or organisation**? Sign the [Entity CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Entity-Fiduciary-Contributor-License-Agreement.md)

View file

@ -93,7 +93,7 @@ class TestTextCompletionIntegration:
assert call_args.kwargs['model'] == "gpt-3.5-turbo" assert call_args.kwargs['model'] == "gpt-3.5-turbo"
assert call_args.kwargs['temperature'] == 0.7 assert call_args.kwargs['temperature'] == 0.7
assert call_args.kwargs['max_tokens'] == 1024 assert call_args.kwargs['max_completion_tokens'] == 1024
assert len(call_args.kwargs['messages']) == 1 assert len(call_args.kwargs['messages']) == 1
assert call_args.kwargs['messages'][0]['role'] == "user" assert call_args.kwargs['messages'][0]['role'] == "user"
assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text'] assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text']
@ -134,7 +134,7 @@ class TestTextCompletionIntegration:
call_args = mock_openai_client.chat.completions.create.call_args call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == config['model'] assert call_args.kwargs['model'] == config['model']
assert call_args.kwargs['temperature'] == config['temperature'] assert call_args.kwargs['temperature'] == config['temperature']
assert call_args.kwargs['max_tokens'] == config['max_output'] assert call_args.kwargs['max_completion_tokens'] == config['max_output']
# Reset mock for next iteration # Reset mock for next iteration
mock_openai_client.reset_mock() mock_openai_client.reset_mock()
@ -286,7 +286,7 @@ class TestTextCompletionIntegration:
# were removed in #561 as unnecessary parameters # were removed in #561 as unnecessary parameters
assert 'model' in call_args.kwargs assert 'model' in call_args.kwargs
assert 'temperature' in call_args.kwargs assert 'temperature' in call_args.kwargs
assert 'max_tokens' in call_args.kwargs assert 'max_completion_tokens' in call_args.kwargs
# Verify result structure # Verify result structure
assert hasattr(result, 'text') assert hasattr(result, 'text')
@ -362,7 +362,7 @@ class TestTextCompletionIntegration:
call_args = mock_openai_client.chat.completions.create.call_args call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['model'] == "gpt-4"
assert call_args.kwargs['temperature'] == 0.8 assert call_args.kwargs['temperature'] == 0.8
assert call_args.kwargs['max_tokens'] == 2048 assert call_args.kwargs['max_completion_tokens'] == 2048
# Note: top_p, frequency_penalty, and presence_penalty # Note: top_p, frequency_penalty, and presence_penalty
# were removed in #561 as unnecessary parameters # were removed in #561 as unnecessary parameters

View file

@ -201,7 +201,7 @@ class TestTextCompletionStreaming:
call_args = mock_streaming_openai_client.chat.completions.create.call_args call_args = mock_streaming_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['model'] == "gpt-4"
assert call_args.kwargs['temperature'] == 0.5 assert call_args.kwargs['temperature'] == 0.5
assert call_args.kwargs['max_tokens'] == 2048 assert call_args.kwargs['max_completion_tokens'] == 2048
assert call_args.kwargs['stream'] is True assert call_args.kwargs['stream'] is True
# Verify chunks have correct model # Verify chunks have correct model

View file

@ -121,7 +121,11 @@ class TestMux:
# Based on the code, it seems to catch exceptions # Based on the code, it seems to catch exceptions
await mux.receive(mock_msg) await mux.receive(mock_msg)
mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) mock_ws.send_json.assert_called_once_with({
"error": {"message": "Bad message", "type": "error"},
"complete": True,
"id": "test-id-123",
})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mux_receive_message_without_id(self): async def test_mux_receive_message_without_id(self):
@ -145,7 +149,10 @@ class TestMux:
# receive method should handle the RuntimeError internally # receive method should handle the RuntimeError internally
await mux.receive(mock_msg) await mux.receive(mock_msg)
mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) mock_ws.send_json.assert_called_once_with({
"error": {"message": "Bad message", "type": "error"},
"complete": True,
})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mux_receive_invalid_json(self): async def test_mux_receive_invalid_json(self):
@ -168,4 +175,7 @@ class TestMux:
await mux.receive(mock_msg) await mux.receive(mock_msg)
mock_msg.json.assert_called_once() mock_msg.json.assert_called_once()
mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"}) mock_ws.send_json.assert_called_once_with({
"error": {"message": "Invalid JSON", "type": "error"},
"complete": True,
})

View file

@ -108,7 +108,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
}] }]
}], }],
temperature=0.0, temperature=0.0,
max_tokens=4192, max_completion_tokens=4192,
top_p=1 top_p=1
) )
@ -399,7 +399,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
# Verify other parameters # Verify other parameters
assert call_args[1]['model'] == 'gpt-4' assert call_args[1]['model'] == 'gpt-4'
assert call_args[1]['temperature'] == 0.5 assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024 assert call_args[1]['max_completion_tokens'] == 1024
assert call_args[1]['top_p'] == 1 assert call_args[1]['top_p'] == 1
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')

View file

@ -102,7 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
}] }]
}], }],
temperature=0.0, temperature=0.0,
max_tokens=4096 max_completion_tokens=4096
) )
@patch('trustgraph.model.text_completion.openai.llm.OpenAI') @patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@ -380,7 +380,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
# Verify other parameters # Verify other parameters
assert call_args[1]['model'] == 'gpt-3.5-turbo' assert call_args[1]['model'] == 'gpt-3.5-turbo'
assert call_args[1]['temperature'] == 0.5 assert call_args[1]['temperature'] == 0.5
assert call_args[1]['max_tokens'] == 1024 assert call_args[1]['max_completion_tokens'] == 1024
@patch('trustgraph.model.text_completion.openai.llm.OpenAI') @patch('trustgraph.model.text_completion.openai.llm.OpenAI')

View file

@ -1,5 +1,6 @@
import json import json
import asyncio
import websockets import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union from typing import Optional, Dict, Any, AsyncIterator, Union
@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException
class AsyncSocketClient: class AsyncSocketClient:
"""Asynchronous WebSocket client""" """Asynchronous WebSocket client with persistent connection.
Maintains a single websocket connection and multiplexes requests
by ID, routing responses via a background reader task.
Use as an async context manager for proper lifecycle management:
async with AsyncSocketClient(url, timeout, token) as client:
result = await client._send_request(...)
Or call connect()/aclose() manually.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]): def __init__(self, url: str, timeout: int, token: Optional[str]):
self.url = self._convert_to_ws_url(url) self.url = self._convert_to_ws_url(url)
self.timeout = timeout self.timeout = timeout
self.token = token self.token = token
self._request_counter = 0 self._request_counter = 0
self._socket = None
self._connect_cm = None
self._reader_task = None
self._pending = {} # request_id -> asyncio.Queue
self._connected = False
def _convert_to_ws_url(self, url: str) -> str: def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL""" """Convert HTTP URL to WebSocket URL"""
@ -25,82 +42,123 @@ class AsyncSocketClient:
elif url.startswith("ws://") or url.startswith("wss://"): elif url.startswith("ws://") or url.startswith("wss://"):
return url return url
else: else:
# Assume ws://
return f"ws://{url}" return f"ws://{url}"
def _build_ws_url(self):
ws_url = f"{self.url.rstrip('/')}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
return ws_url
async def connect(self):
"""Establish the persistent websocket connection."""
if self._connected:
return
ws_url = self._build_ws_url()
self._connect_cm = websockets.connect(
ws_url, ping_interval=20, ping_timeout=self.timeout
)
self._socket = await self._connect_cm.__aenter__()
self._connected = True
self._reader_task = asyncio.create_task(self._reader())
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
async def _ensure_connected(self):
"""Lazily connect if not already connected."""
if not self._connected:
await self.connect()
async def _reader(self):
"""Background task to read responses and route by request ID."""
try:
async for raw_message in self._socket:
response = json.loads(raw_message)
request_id = response.get("id")
if request_id and request_id in self._pending:
await self._pending[request_id].put(response)
# Ignore messages for unknown request IDs
except websockets.exceptions.ConnectionClosed:
pass
except Exception as e:
# Signal error to all pending requests
for queue in self._pending.values():
try:
await queue.put({"error": str(e)})
except:
pass
finally:
self._connected = False
def _next_request_id(self):
self._request_counter += 1
return f"req-{self._request_counter}"
def flow(self, flow_id: str): def flow(self, flow_id: str):
"""Get async flow instance for WebSocket operations""" """Get async flow instance for WebSocket operations"""
return AsyncSocketFlowInstance(self, flow_id) return AsyncSocketFlowInstance(self, flow_id)
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]): async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (non-streaming)""" """Send a request and wait for a single response."""
# Generate unique request ID await self._ensure_connected()
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token request_id = self._next_request_id()
ws_url = f"{self.url}/api/v1/socket" queue = asyncio.Queue()
if self.token: self._pending[request_id] = queue
ws_url = f"{ws_url}?token={self.token}"
# Build request message try:
message = { message = {
"id": request_id, "id": request_id,
"service": service, "service": service,
"request": request "request": request
} }
if flow: if flow:
message["flow"] = flow message["flow"] = flow
# Connect and send request await self._socket.send(json.dumps(message))
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Wait for single response response = await queue.get()
raw_message = await websocket.recv()
response = json.loads(raw_message)
if response.get("id") != request_id:
raise ProtocolException(f"Response ID mismatch")
if "error" in response: if "error" in response:
raise ApplicationException(response["error"]) raise ApplicationException(response["error"])
if "response" not in response: if "response" not in response:
raise ProtocolException(f"Missing response in message") raise ProtocolException("Missing response in message")
return response["response"] return response["response"]
finally:
self._pending.pop(request_id, None)
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]): async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (streaming)""" """Send a request and yield streaming response chunks."""
# Generate unique request ID await self._ensure_connected()
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token request_id = self._next_request_id()
ws_url = f"{self.url}/api/v1/socket" queue = asyncio.Queue()
if self.token: self._pending[request_id] = queue
ws_url = f"{ws_url}?token={self.token}"
# Build request message try:
message = { message = {
"id": request_id, "id": request_id,
"service": service, "service": service,
"request": request "request": request
} }
if flow: if flow:
message["flow"] = flow message["flow"] = flow
# Connect and send request await self._socket.send(json.dumps(message))
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Yield chunks as they arrive while True:
async for raw_message in websocket: response = await queue.get()
response = json.loads(raw_message)
if response.get("id") != request_id:
continue # Ignore messages for other requests
if "error" in response: if "error" in response:
raise ApplicationException(response["error"]) raise ApplicationException(response["error"])
@ -108,18 +166,16 @@ class AsyncSocketClient:
if "response" in response: if "response" in response:
resp = response["response"] resp = response["response"]
# Parse different chunk types
chunk = self._parse_chunk(resp) chunk = self._parse_chunk(resp)
if chunk is not None: # Skip provenance messages in streaming if chunk is not None:
yield chunk yield chunk
# Check if this is the final message
# end_of_session indicates entire session is complete (including provenance)
# end_of_dialog is for agent dialogs
# complete is from the gateway envelope
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
break break
finally:
self._pending.pop(request_id, None)
def _parse_chunk(self, resp: Dict[str, Any]): def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type. Returns None for non-content messages.""" """Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type") chunk_type = resp.get("chunk_type")
@ -127,7 +183,6 @@ class AsyncSocketClient:
# Handle new GraphRAG message format with message_type # Handle new GraphRAG message format with message_type
if message_type == "provenance": if message_type == "provenance":
# Provenance messages are not yielded to user - they're metadata
return None return None
if chunk_type == "thought": if chunk_type == "thought":
@ -147,25 +202,41 @@ class AsyncSocketClient:
end_of_dialog=resp.get("end_of_dialog", False) end_of_dialog=resp.get("end_of_dialog", False)
) )
elif chunk_type == "action": elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)
) )
else: else:
# RAG-style chunk (or generic chunk with message_type="chunk")
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", ""))) content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk( return RAGChunk(
content=content, content=content,
end_of_stream=resp.get("end_of_stream", False), end_of_stream=resp.get("end_of_stream", False),
error=None # Errors are always thrown, never stored error=None
) )
async def aclose(self): async def aclose(self):
"""Close WebSocket connection""" """Close the persistent WebSocket connection cleanly."""
# Cleanup handled by context manager # Wait for reader to finish (socket close will cause it to exit)
pass if self._reader_task:
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
self._reader_task = None
# Exit the websockets context manager — this cleanly shuts down
# the connection and its keepalive task
if self._connect_cm:
try:
await self._connect_cm.__aexit__(None, None, None)
except Exception:
pass
self._connect_cm = None
self._socket = None
self._connected = False
self._pending.clear()
class AsyncSocketFlowInstance: class AsyncSocketFlowInstance:
@ -292,7 +363,6 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search""" """Query graph embeddings for semantic search"""
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vector = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
@ -362,7 +432,6 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs limit: int = 10, **kwargs
): ):
"""Query row embeddings for semantic search on structured data""" """Query row embeddings for semantic search on structured data"""
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vector = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]

File diff suppressed because it is too large Load diff

View file

@ -58,6 +58,14 @@ def print_json(sessions):
print(json.dumps(sessions, indent=2)) print(json.dumps(sessions, indent=2))
# Map type names for display
TYPE_DISPLAY = {
"graphrag": "GraphRAG",
"docrag": "DocRAG",
"agent": "Agent",
}
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='tg-list-explain-traces', prog='tg-list-explain-traces',
@ -118,7 +126,7 @@ def main():
explain_client = ExplainabilityClient(flow) explain_client = ExplainabilityClient(flow)
try: try:
# List all sessions using the API # List all sessions — uses persistent websocket via SocketClient
questions = explain_client.list_sessions( questions = explain_client.list_sessions(
graph=RETRIEVAL_GRAPH, graph=RETRIEVAL_GRAPH,
user=args.user, user=args.user,
@ -126,7 +134,8 @@ def main():
limit=args.limit, limit=args.limit,
) )
# Convert to output format # detect_session_type is mostly a fast URI pattern check,
# only falls back to network calls for unrecognised URIs
sessions = [] sessions = []
for q in questions: for q in questions:
session_type = explain_client.detect_session_type( session_type = explain_client.detect_session_type(
@ -136,16 +145,9 @@ def main():
collection=args.collection collection=args.collection
) )
# Map type names
type_display = {
"graphrag": "GraphRAG",
"docrag": "DocRAG",
"agent": "Agent",
}.get(session_type, session_type.title())
sessions.append({ sessions.append({
"id": q.uri, "id": q.uri,
"type": type_display, "type": TYPE_DISPLAY.get(session_type, session_type.title()),
"question": q.query, "question": q.query,
"time": q.timestamp, "time": q.timestamp,
}) })

View file

@ -3,31 +3,27 @@ Shows all defined flow blueprints.
""" """
import argparse import argparse
import asyncio
import os import os
import tabulate import tabulate
from trustgraph.api import Api, ConfigKey from trustgraph.api import AsyncSocketClient
import json import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def format_parameters(params_metadata, config_api): def format_parameters(params_metadata, param_type_defs):
""" """
Format parameter metadata for display Format parameter metadata for display.
Args: param_type_defs is a dict of type_name -> parsed type definition,
params_metadata: Parameter definitions from flow blueprint pre-fetched concurrently.
config_api: API client to get parameter type information
Returns:
Formatted string describing parameters
""" """
if not params_metadata: if not params_metadata:
return "None" return "None"
param_list = [] param_list = []
# Sort parameters by order if available
sorted_params = sorted( sorted_params = sorted(
params_metadata.items(), params_metadata.items(),
key=lambda x: x[1].get("order", 999) key=lambda x: x[1].get("order", 999)
@ -37,41 +33,89 @@ def format_parameters(params_metadata, config_api):
description = param_meta.get("description", param_name) description = param_meta.get("description", param_name)
param_type = param_meta.get("type", "unknown") param_type = param_meta.get("type", "unknown")
# Get type information if available
type_info = param_type type_info = param_type
if config_api: if param_type in param_type_defs:
try: param_type_def = param_type_defs[param_type]
key = ConfigKey("parameter-type", param_type) default = param_type_def.get("default")
type_def_value = config_api.get([key])[0].value if default is not None:
param_type_def = json.loads(type_def_value) type_info = f"{param_type} (default: {default})"
# Add default value if available
default = param_type_def.get("default")
if default is not None:
type_info = f"{param_type} (default: {default})"
except:
# If we can't get type definition, just show the type name
pass
param_list.append(f" {param_name}: {description} [{type_info}]") param_list.append(f" {param_name}: {description} [{type_info}]")
return "\n".join(param_list) return "\n".join(param_list)
async def fetch_data(client):
"""Fetch all data needed for show_flow_blueprints concurrently."""
# Round 1: list blueprints
resp = await client._send_request("flow", None, {
"operation": "list-blueprints",
})
blueprint_names = resp.get("blueprint-names", [])
if not blueprint_names:
return [], {}, {}
# Round 2: get all blueprints in parallel
blueprint_tasks = [
client._send_request("flow", None, {
"operation": "get-blueprint",
"blueprint-name": name,
})
for name in blueprint_names
]
blueprint_results = await asyncio.gather(*blueprint_tasks)
blueprints = {}
for name, resp in zip(blueprint_names, blueprint_results):
bp_data = resp.get("blueprint-definition", "{}")
blueprints[name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data
# Round 3: get all parameter type definitions in parallel
param_types_needed = set()
for bp in blueprints.values():
for param_meta in bp.get("parameters", {}).values():
pt = param_meta.get("type", "")
if pt:
param_types_needed.add(pt)
param_type_defs = {}
if param_types_needed:
param_type_tasks = [
client._send_request("config", None, {
"operation": "get",
"keys": [{"type": "parameter-type", "key": pt}],
})
for pt in param_types_needed
]
param_type_results = await asyncio.gather(*param_type_tasks)
for pt, resp in zip(param_types_needed, param_type_results):
values = resp.get("values", [])
if values:
try:
param_type_defs[pt] = json.loads(values[0].get("value", "{}"))
except (json.JSONDecodeError, AttributeError):
pass
return blueprint_names, blueprints, param_type_defs
async def _show_flow_blueprints_async(url, token=None):
async with AsyncSocketClient(url, timeout=60, token=token) as client:
return await fetch_data(client)
def show_flow_blueprints(url, token=None): def show_flow_blueprints(url, token=None):
api = Api(url, token=token) blueprint_names, blueprints, param_type_defs = asyncio.run(
flow_api = api.flow() _show_flow_blueprints_async(url, token=token)
config_api = api.config() )
blueprint_names = flow_api.list_blueprints() if not blueprint_names:
if len(blueprint_names) == 0:
print("No flow blueprints.") print("No flow blueprints.")
return return
for blueprint_name in blueprint_names: for blueprint_name in blueprint_names:
cls = flow_api.get_blueprint(blueprint_name) cls = blueprints[blueprint_name]
table = [] table = []
table.append(("name", blueprint_name)) table.append(("name", blueprint_name))
@ -81,10 +125,9 @@ def show_flow_blueprints(url, token=None):
if tags: if tags:
table.append(("tags", ", ".join(tags))) table.append(("tags", ", ".join(tags)))
# Show parameters if they exist
parameters = cls.get("parameters", {}) parameters = cls.get("parameters", {})
if parameters: if parameters:
param_str = format_parameters(parameters, config_api) param_str = format_parameters(parameters, param_type_defs)
table.append(("parameters", param_str)) table.append(("parameters", param_str))
print(tabulate.tabulate( print(tabulate.tabulate(

View file

@ -3,22 +3,15 @@ Shows configured flows.
""" """
import argparse import argparse
import asyncio
import os import os
import tabulate import tabulate
from trustgraph.api import Api, ConfigKey from trustgraph.api import Api, AsyncSocketClient
import json import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def get_interface(config_api, i):
key = ConfigKey("interface-description", i)
value = config_api.get([key])[0].value
return json.loads(value)
def describe_interfaces(intdefs, flow): def describe_interfaces(intdefs, flow):
intfs = flow.get("interfaces", {}) intfs = flow.get("interfaces", {})
@ -34,7 +27,7 @@ def describe_interfaces(intdefs, flow):
if kind == "request-response": if kind == "request-response":
req = intfs[k]["request"] req = intfs[k]["request"]
resp = intfs[k]["request"] resp = intfs[k]["response"]
lst.append(f"{k} request: {req}") lst.append(f"{k} request: {req}")
lst.append(f"{k} response: {resp}") lst.append(f"{k} response: {resp}")
@ -49,17 +42,9 @@ def describe_interfaces(intdefs, flow):
def get_enum_description(param_value, param_type_def): def get_enum_description(param_value, param_type_def):
""" """
Get the human-readable description for an enum value Get the human-readable description for an enum value
Args:
param_value: The actual parameter value (e.g., "gpt-4")
param_type_def: The parameter type definition containing enum objects
Returns:
Human-readable description or the original value if not found
""" """
enum_list = param_type_def.get("enum", []) enum_list = param_type_def.get("enum", [])
# Handle both old format (strings) and new format (objects with id/description)
for enum_item in enum_list: for enum_item in enum_list:
if isinstance(enum_item, dict): if isinstance(enum_item, dict):
if enum_item.get("id") == param_value: if enum_item.get("id") == param_value:
@ -67,27 +52,20 @@ def get_enum_description(param_value, param_type_def):
elif enum_item == param_value: elif enum_item == param_value:
return param_value return param_value
# If not found in enum, return original value
return param_value return param_value
def format_parameters(flow_params, blueprint_params_metadata, config_api): def format_parameters(flow_params, blueprint_params_metadata, param_type_defs):
""" """
Format flow parameters with their human-readable descriptions Format flow parameters with their human-readable descriptions.
Args: param_type_defs is a dict of type_name -> parsed type definition,
flow_params: The actual parameter values used in the flow pre-fetched concurrently.
blueprint_params_metadata: The parameter metadata from the flow blueprint definition
config_api: API client to retrieve parameter type definitions
Returns:
Formatted string of parameters with descriptions
""" """
if not flow_params: if not flow_params:
return "None" return "None"
param_list = [] param_list = []
# Sort parameters by order if available
sorted_params = sorted( sorted_params = sorted(
blueprint_params_metadata.items(), blueprint_params_metadata.items(),
key=lambda x: x[1].get("order", 999) key=lambda x: x[1].get("order", 999)
@ -100,80 +78,165 @@ def format_parameters(flow_params, blueprint_params_metadata, config_api):
param_type = param_meta.get("type", "") param_type = param_meta.get("type", "")
controlled_by = param_meta.get("controlled-by", None) controlled_by = param_meta.get("controlled-by", None)
# Try to get enum description if this parameter has a type definition
display_value = value display_value = value
if param_type and config_api: if param_type and param_type in param_type_defs:
try: display_value = get_enum_description(
from trustgraph.api import ConfigKey value, param_type_defs[param_type]
key = ConfigKey("parameter-type", param_type) )
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
display_value = get_enum_description(value, param_type_def)
except:
# If we can't get the type definition, just use the original value
display_value = value
# Format the parameter line
line = f"{description}: {display_value}" line = f"{description}: {display_value}"
# Add controlled-by indicator if present
if controlled_by: if controlled_by:
line += f" (controlled by {controlled_by})" line += f" (controlled by {controlled_by})"
param_list.append(line) param_list.append(line)
# Add any parameters that aren't in the blueprint metadata (shouldn't happen normally)
for param_name, value in flow_params.items(): for param_name, value in flow_params.items():
if param_name not in blueprint_params_metadata: if param_name not in blueprint_params_metadata:
param_list.append(f"{param_name}: {value} (undefined)") param_list.append(f"{param_name}: {value} (undefined)")
return "\n".join(param_list) if param_list else "None" return "\n".join(param_list) if param_list else "None"
async def fetch_show_flows(client):
"""Fetch all data needed for show_flows concurrently."""
# Round 1: list interfaces and list flows in parallel
interface_names_resp, flow_ids_resp = await asyncio.gather(
client._send_request("config", None, {
"operation": "list",
"type": "interface-description",
}),
client._send_request("flow", None, {
"operation": "list-flows",
}),
)
interface_names = interface_names_resp.get("directory", [])
flow_ids = flow_ids_resp.get("flow-ids", [])
if not flow_ids:
return {}, [], {}, {}
# Round 2: get all interfaces + all flows in parallel
interface_tasks = [
client._send_request("config", None, {
"operation": "get",
"keys": [{"type": "interface-description", "key": name}],
})
for name in interface_names
]
flow_tasks = [
client._send_request("flow", None, {
"operation": "get-flow",
"flow-id": fid,
})
for fid in flow_ids
]
results = await asyncio.gather(*interface_tasks, *flow_tasks)
# Split results
interface_results = results[:len(interface_names)]
flow_results = results[len(interface_names):]
# Parse interfaces
interface_defs = {}
for name, resp in zip(interface_names, interface_results):
values = resp.get("values", [])
if values:
interface_defs[name] = json.loads(values[0].get("value", "{}"))
# Parse flows
flows = {}
for fid, resp in zip(flow_ids, flow_results):
flow_data = resp.get("flow", "{}")
flows[fid] = json.loads(flow_data) if isinstance(flow_data, str) else flow_data
# Round 3: get all blueprints in parallel
blueprint_names = set()
for flow in flows.values():
bp = flow.get("blueprint-name", "")
if bp:
blueprint_names.add(bp)
blueprint_tasks = [
client._send_request("flow", None, {
"operation": "get-blueprint",
"blueprint-name": bp_name,
})
for bp_name in blueprint_names
]
blueprint_results = await asyncio.gather(*blueprint_tasks)
blueprints = {}
for bp_name, resp in zip(blueprint_names, blueprint_results):
bp_data = resp.get("blueprint-definition", "{}")
blueprints[bp_name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data
# Round 4: get all parameter type definitions in parallel
param_types_needed = set()
for bp in blueprints.values():
for param_meta in bp.get("parameters", {}).values():
pt = param_meta.get("type", "")
if pt:
param_types_needed.add(pt)
param_type_tasks = [
client._send_request("config", None, {
"operation": "get",
"keys": [{"type": "parameter-type", "key": pt}],
})
for pt in param_types_needed
]
param_type_results = await asyncio.gather(*param_type_tasks)
param_type_defs = {}
for pt, resp in zip(param_types_needed, param_type_results):
values = resp.get("values", [])
if values:
try:
param_type_defs[pt] = json.loads(values[0].get("value", "{}"))
except (json.JSONDecodeError, AttributeError):
pass
return interface_defs, flow_ids, flows, blueprints, param_type_defs
async def _show_flows_async(url, token=None):
async with AsyncSocketClient(url, timeout=60, token=token) as client:
return await fetch_show_flows(client)
def show_flows(url, token=None): def show_flows(url, token=None):
api = Api(url, token=token) result = asyncio.run(_show_flows_async(url, token=token))
config_api = api.config()
flow_api = api.flow()
interface_names = config_api.list("interface-description") interface_defs, flow_ids, flows, blueprints, param_type_defs = result
interface_defs = { if not flow_ids:
i: get_interface(config_api, i)
for i in interface_names
}
flow_ids = flow_api.list()
if len(flow_ids) == 0:
print("No flows.") print("No flows.")
return return
flows = [] for fid in flow_ids:
for id in flow_ids: flow = flows[fid]
flow = flow_api.get(id)
table = [] table = []
table.append(("id", id)) table.append(("id", fid))
table.append(("blueprint", flow.get("blueprint-name", ""))) table.append(("blueprint", flow.get("blueprint-name", "")))
table.append(("desc", flow.get("description", ""))) table.append(("desc", flow.get("description", "")))
# Display parameters with human-readable descriptions
parameters = flow.get("parameters", {}) parameters = flow.get("parameters", {})
if parameters: if parameters:
# Try to get the flow blueprint definition for parameter metadata
blueprint_name = flow.get("blueprint-name", "") blueprint_name = flow.get("blueprint-name", "")
if blueprint_name: if blueprint_name and blueprint_name in blueprints:
try: blueprint_params_metadata = blueprints[blueprint_name].get("parameters", {})
flow_blueprint = flow_api.get_blueprint(blueprint_name) param_str = format_parameters(
blueprint_params_metadata = flow_blueprint.get("parameters", {}) parameters, blueprint_params_metadata, param_type_defs
param_str = format_parameters(parameters, blueprint_params_metadata, config_api) )
except Exception as e:
# Fallback to JSON if we can't get the blueprint definition
param_str = json.dumps(parameters, indent=2)
else: else:
# No blueprint name, fallback to JSON
param_str = json.dumps(parameters, indent=2) param_str = json.dumps(parameters, indent=2)
table.append(("parameters", param_str)) table.append(("parameters", param_str))

View file

@ -7,9 +7,10 @@ valid enums, and validation rules.
""" """
import argparse import argparse
import asyncio
import os import os
import tabulate import tabulate
from trustgraph.api import Api, ConfigKey from trustgraph.api import AsyncSocketClient
import json import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
@ -17,13 +18,7 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def format_enum_values(enum_list): def format_enum_values(enum_list):
""" """
Format enum values for display, handling both old and new formats Format enum values for display, handling both old and new formats.
Args:
enum_list: List of enum values (strings or objects with id/description)
Returns:
Formatted string describing enum options
""" """
if not enum_list: if not enum_list:
return "Any value" return "Any value"
@ -31,7 +26,6 @@ def format_enum_values(enum_list):
enum_items = [] enum_items = []
for item in enum_list: for item in enum_list:
if isinstance(item, dict): if isinstance(item, dict):
# New format: objects with id and description
enum_id = item.get("id", "") enum_id = item.get("id", "")
description = item.get("description", "") description = item.get("description", "")
if description: if description:
@ -39,99 +33,146 @@ def format_enum_values(enum_list):
else: else:
enum_items.append(enum_id) enum_items.append(enum_id)
else: else:
# Old format: simple strings
enum_items.append(str(item)) enum_items.append(str(item))
return "\n".join(f"{item}" for item in enum_items) return "\n".join(f"{item}" for item in enum_items)
def format_constraints(param_type_def): def format_constraints(param_type_def):
""" """
Format validation constraints for display Format validation constraints for display.
Args:
param_type_def: Parameter type definition
Returns:
Formatted string describing constraints
""" """
constraints = [] constraints = []
# Handle numeric constraints
if "minimum" in param_type_def: if "minimum" in param_type_def:
constraints.append(f"min: {param_type_def['minimum']}") constraints.append(f"min: {param_type_def['minimum']}")
if "maximum" in param_type_def: if "maximum" in param_type_def:
constraints.append(f"max: {param_type_def['maximum']}") constraints.append(f"max: {param_type_def['maximum']}")
# Handle string constraints
if "minLength" in param_type_def: if "minLength" in param_type_def:
constraints.append(f"min length: {param_type_def['minLength']}") constraints.append(f"min length: {param_type_def['minLength']}")
if "maxLength" in param_type_def: if "maxLength" in param_type_def:
constraints.append(f"max length: {param_type_def['maxLength']}") constraints.append(f"max length: {param_type_def['maxLength']}")
if "pattern" in param_type_def: if "pattern" in param_type_def:
constraints.append(f"pattern: {param_type_def['pattern']}") constraints.append(f"pattern: {param_type_def['pattern']}")
# Handle required field
if param_type_def.get("required", False): if param_type_def.get("required", False):
constraints.append("required") constraints.append("required")
return ", ".join(constraints) if constraints else "None" return ", ".join(constraints) if constraints else "None"
def format_param_type(param_type_name, param_type_def):
"""Format a single parameter type for display."""
table = []
table.append(("name", param_type_name))
table.append(("description", param_type_def.get("description", "")))
table.append(("type", param_type_def.get("type", "unknown")))
default = param_type_def.get("default")
if default is not None:
table.append(("default", str(default)))
enum_list = param_type_def.get("enum")
if enum_list:
enum_str = format_enum_values(enum_list)
table.append(("valid values", enum_str))
constraints = format_constraints(param_type_def)
if constraints != "None":
table.append(("constraints", constraints))
return table
async def fetch_all_param_types(client):
"""Fetch all parameter types concurrently."""
# Round 1: list parameter types
resp = await client._send_request("config", None, {
"operation": "list",
"type": "parameter-type",
})
param_type_names = resp.get("directory", [])
if not param_type_names:
return [], {}
# Round 2: get all parameter types in parallel
tasks = [
client._send_request("config", None, {
"operation": "get",
"keys": [{"type": "parameter-type", "key": name}],
})
for name in param_type_names
]
results = await asyncio.gather(*tasks)
param_type_defs = {}
for name, resp in zip(param_type_names, results):
values = resp.get("values", [])
if values:
try:
param_type_defs[name] = json.loads(values[0].get("value", "{}"))
except (json.JSONDecodeError, AttributeError):
pass
return param_type_names, param_type_defs
async def fetch_single_param_type(client, param_type_name):
"""Fetch a single parameter type."""
resp = await client._send_request("config", None, {
"operation": "get",
"keys": [{"type": "parameter-type", "key": param_type_name}],
})
values = resp.get("values", [])
if values:
return json.loads(values[0].get("value", "{}"))
return None
def show_parameter_types(url, token=None): def show_parameter_types(url, token=None):
""" """Show all parameter type definitions."""
Show all parameter type definitions
"""
api = Api(url, token=token)
config_api = api.config()
# Get list of all parameter types async def _fetch():
try: async with AsyncSocketClient(url, timeout=60, token=token) as client:
param_type_names = config_api.list("parameter-type") return await fetch_all_param_types(client)
except Exception as e:
print(f"Error retrieving parameter types: {e}")
return
if len(param_type_names) == 0: param_type_names, param_type_defs = asyncio.run(_fetch())
if not param_type_names:
print("No parameter types defined.") print("No parameter types defined.")
return return
for param_type_name in param_type_names: for name in param_type_names:
try: if name not in param_type_defs:
# Get the parameter type definition print(f"Error retrieving parameter type '{name}'")
key = ConfigKey("parameter-type", param_type_name)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
table = []
table.append(("name", param_type_name))
table.append(("description", param_type_def.get("description", "")))
table.append(("type", param_type_def.get("type", "unknown")))
# Show default value if present
default = param_type_def.get("default")
if default is not None:
table.append(("default", str(default)))
# Show enum values if present
enum_list = param_type_def.get("enum")
if enum_list:
enum_str = format_enum_values(enum_list)
table.append(("valid values", enum_str))
# Show constraints
constraints = format_constraints(param_type_def)
if constraints != "None":
table.append(("constraints", constraints))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
print() print()
continue
except Exception as e: table = format_param_type(name, param_type_defs[name])
print(f"Error retrieving parameter type '{param_type_name}': {e}")
print() print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
print()
def show_specific_parameter_type(url, param_type_name, token=None):
"""Show a specific parameter type definition."""
async def _fetch():
async with AsyncSocketClient(url, timeout=60, token=token) as client:
return await fetch_single_param_type(client, param_type_name)
param_type_def = asyncio.run(_fetch())
if param_type_def is None:
print(f"Error retrieving parameter type '{param_type_name}'")
return
table = format_param_type(param_type_name, param_type_def)
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -161,57 +202,12 @@ def main():
try: try:
if args.type: if args.type:
# Show specific parameter type
show_specific_parameter_type(args.api_url, args.type, args.token) show_specific_parameter_type(args.api_url, args.type, args.token)
else: else:
# Show all parameter types
show_parameter_types(args.api_url, args.token) show_parameter_types(args.api_url, args.token)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
def show_specific_parameter_type(url, param_type_name, token=None):
"""
Show a specific parameter type definition
"""
api = Api(url, token=token)
config_api = api.config()
try:
# Get the parameter type definition
key = ConfigKey("parameter-type", param_type_name)
type_def_value = config_api.get([key])[0].value
param_type_def = json.loads(type_def_value)
table = []
table.append(("name", param_type_name))
table.append(("description", param_type_def.get("description", "")))
table.append(("type", param_type_def.get("type", "unknown")))
# Show default value if present
default = param_type_def.get("default")
if default is not None:
table.append(("default", str(default)))
# Show enum values if present
enum_list = param_type_def.get("enum")
if enum_list:
enum_str = format_enum_values(enum_list)
table.append(("valid values", enum_str))
# Show constraints
constraints = format_constraints(param_type_def)
if constraints != "None":
table.append(("constraints", constraints))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
))
except Exception as e:
print(f"Error retrieving parameter type '{param_type_name}': {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -178,7 +178,11 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona
url += '/' url += '/'
metrics_url = f"{url}api/metrics/query?query=processor_info" metrics_url = f"{url}api/metrics/query?query=processor_info"
resp = requests.get(metrics_url, timeout=timeout) headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.get(metrics_url, timeout=timeout, headers=headers)
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
processor_count = len(data.get("data", {}).get("result", [])) processor_count = len(data.get("data", {}).get("result", []))

View file

@ -33,9 +33,12 @@ class Mux:
async def receive(self, msg): async def receive(self, msg):
request_id = None
try: try:
data = msg.json() data = msg.json()
request_id = data.get("id")
if "request" not in data: if "request" not in data:
raise RuntimeError("Bad message") raise RuntimeError("Bad message")
@ -51,7 +54,13 @@ class Mux:
except Exception as e: except Exception as e:
logger.error(f"Receive exception: {str(e)}", exc_info=True) logger.error(f"Receive exception: {str(e)}", exc_info=True)
await self.ws.send_json({"error": str(e)}) error_resp = {
"error": {"message": str(e), "type": "error"},
"complete": True,
}
if request_id:
error_resp["id"] = request_id
await self.ws.send_json(error_resp)
async def maybe_tidy_workers(self, workers): async def maybe_tidy_workers(self, workers):
@ -97,12 +106,12 @@ class Mux:
}) })
worker = asyncio.create_task( worker = asyncio.create_task(
self.request_task(request, responder, flow, svc) self.request_task(id, request, responder, flow, svc)
) )
workers.append(worker) workers.append(worker)
async def request_task(self, request, responder, flow, svc): async def request_task(self, id, request, responder, flow, svc):
try: try:
@ -119,7 +128,11 @@ class Mux:
) )
except Exception as e: except Exception as e:
await self.ws.send_json({"error": str(e)}) await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
async def run(self): async def run(self):
@ -143,7 +156,11 @@ class Mux:
except Exception as e: except Exception as e:
# This is an internal working error, may not be recoverable # This is an internal working error, may not be recoverable
logger.error(f"Run prepare exception: {e}", exc_info=True) logger.error(f"Run prepare exception: {e}", exc_info=True)
await self.ws.send_json({"id": id, "error": str(e)}) await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
self.running.stop() self.running.stop()
if self.ws: if self.ws:
@ -160,7 +177,11 @@ class Mux:
except Exception as e: except Exception as e:
logger.error(f"Exception in mux: {e}", exc_info=True) logger.error(f"Exception in mux: {e}", exc_info=True)
await self.ws.send_json({"error": str(e)}) await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
self.running.stop() self.running.stop()

View file

@ -20,7 +20,7 @@ default_ident = "text-completion"
default_temperature = 0.0 default_temperature = 0.0
default_max_output = 4192 default_max_output = 4192
default_api = "2024-12-01-preview" default_api = os.getenv("AZURE_API_VERSION", "2024-12-01-preview")
default_endpoint = os.getenv("AZURE_ENDPOINT", None) default_endpoint = os.getenv("AZURE_ENDPOINT", None)
default_token = os.getenv("AZURE_TOKEN", None) default_token = os.getenv("AZURE_TOKEN", None)
default_model = os.getenv("AZURE_MODEL", None) default_model = os.getenv("AZURE_MODEL", None)
@ -90,7 +90,7 @@ class Processor(LlmService):
} }
], ],
temperature=effective_temperature, temperature=effective_temperature,
max_tokens=self.max_output, max_completion_tokens=self.max_output,
top_p=1, top_p=1,
) )
@ -159,7 +159,7 @@ class Processor(LlmService):
} }
], ],
temperature=effective_temperature, temperature=effective_temperature,
max_tokens=self.max_output, max_completion_tokens=self.max_output,
top_p=1, top_p=1,
stream=True, stream=True,
stream_options={"include_usage": True} stream_options={"include_usage": True}

View file

@ -86,7 +86,7 @@ class Processor(LlmService):
} }
], ],
temperature=effective_temperature, temperature=effective_temperature,
max_tokens=self.max_output, max_completion_tokens=self.max_output,
) )
inputtokens = resp.usage.prompt_tokens inputtokens = resp.usage.prompt_tokens
@ -152,7 +152,7 @@ class Processor(LlmService):
} }
], ],
temperature=effective_temperature, temperature=effective_temperature,
max_tokens=self.max_output, max_completion_tokens=self.max_output,
stream=True, stream=True,
stream_options={"include_usage": True} stream_options={"include_usage": True}
) )

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext: class AppContext:
sockets: dict[str, WebSocketManager] sockets: dict[str, WebSocketManager]
websocket_url: str websocket_url: str
gateway_token: str
@asynccontextmanager @asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
""" """
Manage application lifecycle with type-safe context Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {} sockets = {}
try: try:
yield AppContext(sockets=sockets, websocket_url=websocket_url) yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally: finally:
# Cleanup on shutdown # Cleanup on shutdown
@ -53,6 +54,7 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets: if user in sockets:
logging.info("Return existing socket manager") logging.info("Return existing socket manager")
@ -61,7 +63,7 @@ async def get_socket_manager(ctx, user):
logging.info(f"Opening socket to {websocket_url}...") logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests # Create manager with empty pending requests
manager = WebSocketManager(websocket_url) manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager # Start reader task with the proper manager
await manager.start() await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str prompt: str
class McpServer: class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host self.host = host
self.port = port self.port = port
self.websocket_url = websocket_url self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan # Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP( self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"], "TrustGraph", dependencies=["trustgraph-base"],
@ -2060,8 +2063,11 @@ def main():
# Setup logging before creating server # Setup logging before creating server
setup_logging(vars(args)) setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server # Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run() server.run()
def run(): def run():

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio import asyncio
import logging import logging
import json import json
@ -9,12 +10,22 @@ import time
class WebSocketManager: class WebSocketManager:
def __init__(self, url): def __init__(self, url, token=None):
self.url = url self.url = url
self.token = token
self.socket = None self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self): async def start(self):
self.socket = await connect(self.url) self.socket = await connect(self._build_url())
self.pending_requests = {} self.pending_requests = {}
self.running = True self.running = True
self.reader_task = asyncio.create_task(self.reader()) self.reader_task = asyncio.create_task(self.reader())

View file

@ -14,7 +14,7 @@ dependencies = [
"pulsar-client", "pulsar-client",
"prometheus-client", "prometheus-client",
"python-magic", "python-magic",
"unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]", "unstructured[csv,docx,epub,md,odt,pdf,pptx,rst,rtf,tsv,xlsx]",
] ]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",