mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 10:26:21 +02:00
release/v2.2 -> master (#733)
This commit is contained in:
parent
3ed71a5620
commit
2449392896
20 changed files with 774 additions and 1111 deletions
|
|
@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base
|
||||||
|
|
||||||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
|
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 libxcb mesa-libGL && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
|
|
@ -38,6 +38,11 @@ RUN ls /root/wheels
|
||||||
|
|
||||||
FROM base
|
FROM base
|
||||||
|
|
||||||
|
# Pre-install CPU-only PyTorch so that unstructured[pdf]'s torch
|
||||||
|
# dependency is satisfied without pulling in CUDA (~190MB vs ~2GB+)
|
||||||
|
RUN pip3 install --no-cache-dir torch==2.11.0+cpu \
|
||||||
|
--index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
COPY --from=build /root/wheels /root/wheels
|
COPY --from=build /root/wheels /root/wheels
|
||||||
|
|
||||||
RUN \
|
RUN \
|
||||||
|
|
|
||||||
19
docs/contributor-licence-agreement.md
Normal file
19
docs/contributor-licence-agreement.md
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Contributor Licence Agreement (CLA)
|
||||||
|
|
||||||
|
We ask every contributor to sign a Contributor Licence Agreement before
|
||||||
|
we can merge a pull request. The CLA does **not** transfer copyright —
|
||||||
|
you keep full ownership of your work. It simply grants the TrustGraph
|
||||||
|
project a perpetual, royalty-free licence to distribute your
|
||||||
|
contribution under the project's
|
||||||
|
[Apache 2.0 licence](https://www.apache.org/licenses/LICENSE-2.0), and
|
||||||
|
confirms that you have the right to make the contribution. This protects
|
||||||
|
both the project and its users by ensuring every contribution has a
|
||||||
|
clear legal footing.
|
||||||
|
|
||||||
|
When you open a pull request, the CLA bot will post a comment asking you
|
||||||
|
to review and sign the appropriate agreement — it only takes a moment
|
||||||
|
and you only need to do it once across all TrustGraph repositories.
|
||||||
|
|
||||||
|
- Contributing as an **individual**? Sign the [Individual CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Fiduciary-Contributor-License-Agreement.md)
|
||||||
|
- Contributing on behalf of a **company or organisation**? Sign the [Entity CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Entity-Fiduciary-Contributor-License-Agreement.md)
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ class TestTextCompletionIntegration:
|
||||||
|
|
||||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo"
|
assert call_args.kwargs['model'] == "gpt-3.5-turbo"
|
||||||
assert call_args.kwargs['temperature'] == 0.7
|
assert call_args.kwargs['temperature'] == 0.7
|
||||||
assert call_args.kwargs['max_tokens'] == 1024
|
assert call_args.kwargs['max_completion_tokens'] == 1024
|
||||||
assert len(call_args.kwargs['messages']) == 1
|
assert len(call_args.kwargs['messages']) == 1
|
||||||
assert call_args.kwargs['messages'][0]['role'] == "user"
|
assert call_args.kwargs['messages'][0]['role'] == "user"
|
||||||
assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text']
|
assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text']
|
||||||
|
|
@ -134,7 +134,7 @@ class TestTextCompletionIntegration:
|
||||||
call_args = mock_openai_client.chat.completions.create.call_args
|
call_args = mock_openai_client.chat.completions.create.call_args
|
||||||
assert call_args.kwargs['model'] == config['model']
|
assert call_args.kwargs['model'] == config['model']
|
||||||
assert call_args.kwargs['temperature'] == config['temperature']
|
assert call_args.kwargs['temperature'] == config['temperature']
|
||||||
assert call_args.kwargs['max_tokens'] == config['max_output']
|
assert call_args.kwargs['max_completion_tokens'] == config['max_output']
|
||||||
|
|
||||||
# Reset mock for next iteration
|
# Reset mock for next iteration
|
||||||
mock_openai_client.reset_mock()
|
mock_openai_client.reset_mock()
|
||||||
|
|
@ -286,7 +286,7 @@ class TestTextCompletionIntegration:
|
||||||
# were removed in #561 as unnecessary parameters
|
# were removed in #561 as unnecessary parameters
|
||||||
assert 'model' in call_args.kwargs
|
assert 'model' in call_args.kwargs
|
||||||
assert 'temperature' in call_args.kwargs
|
assert 'temperature' in call_args.kwargs
|
||||||
assert 'max_tokens' in call_args.kwargs
|
assert 'max_completion_tokens' in call_args.kwargs
|
||||||
|
|
||||||
# Verify result structure
|
# Verify result structure
|
||||||
assert hasattr(result, 'text')
|
assert hasattr(result, 'text')
|
||||||
|
|
@ -362,7 +362,7 @@ class TestTextCompletionIntegration:
|
||||||
call_args = mock_openai_client.chat.completions.create.call_args
|
call_args = mock_openai_client.chat.completions.create.call_args
|
||||||
assert call_args.kwargs['model'] == "gpt-4"
|
assert call_args.kwargs['model'] == "gpt-4"
|
||||||
assert call_args.kwargs['temperature'] == 0.8
|
assert call_args.kwargs['temperature'] == 0.8
|
||||||
assert call_args.kwargs['max_tokens'] == 2048
|
assert call_args.kwargs['max_completion_tokens'] == 2048
|
||||||
# Note: top_p, frequency_penalty, and presence_penalty
|
# Note: top_p, frequency_penalty, and presence_penalty
|
||||||
# were removed in #561 as unnecessary parameters
|
# were removed in #561 as unnecessary parameters
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ class TestTextCompletionStreaming:
|
||||||
call_args = mock_streaming_openai_client.chat.completions.create.call_args
|
call_args = mock_streaming_openai_client.chat.completions.create.call_args
|
||||||
assert call_args.kwargs['model'] == "gpt-4"
|
assert call_args.kwargs['model'] == "gpt-4"
|
||||||
assert call_args.kwargs['temperature'] == 0.5
|
assert call_args.kwargs['temperature'] == 0.5
|
||||||
assert call_args.kwargs['max_tokens'] == 2048
|
assert call_args.kwargs['max_completion_tokens'] == 2048
|
||||||
assert call_args.kwargs['stream'] is True
|
assert call_args.kwargs['stream'] is True
|
||||||
|
|
||||||
# Verify chunks have correct model
|
# Verify chunks have correct model
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,11 @@ class TestMux:
|
||||||
# Based on the code, it seems to catch exceptions
|
# Based on the code, it seems to catch exceptions
|
||||||
await mux.receive(mock_msg)
|
await mux.receive(mock_msg)
|
||||||
|
|
||||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
mock_ws.send_json.assert_called_once_with({
|
||||||
|
"error": {"message": "Bad message", "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
"id": "test-id-123",
|
||||||
|
})
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mux_receive_message_without_id(self):
|
async def test_mux_receive_message_without_id(self):
|
||||||
|
|
@ -145,7 +149,10 @@ class TestMux:
|
||||||
# receive method should handle the RuntimeError internally
|
# receive method should handle the RuntimeError internally
|
||||||
await mux.receive(mock_msg)
|
await mux.receive(mock_msg)
|
||||||
|
|
||||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
mock_ws.send_json.assert_called_once_with({
|
||||||
|
"error": {"message": "Bad message", "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mux_receive_invalid_json(self):
|
async def test_mux_receive_invalid_json(self):
|
||||||
|
|
@ -168,4 +175,7 @@ class TestMux:
|
||||||
await mux.receive(mock_msg)
|
await mux.receive(mock_msg)
|
||||||
|
|
||||||
mock_msg.json.assert_called_once()
|
mock_msg.json.assert_called_once()
|
||||||
mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"})
|
mock_ws.send_json.assert_called_once_with({
|
||||||
|
"error": {"message": "Invalid JSON", "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
|
@ -108,7 +108,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
}]
|
}]
|
||||||
}],
|
}],
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=4192,
|
max_completion_tokens=4192,
|
||||||
top_p=1
|
top_p=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -399,7 +399,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Verify other parameters
|
# Verify other parameters
|
||||||
assert call_args[1]['model'] == 'gpt-4'
|
assert call_args[1]['model'] == 'gpt-4'
|
||||||
assert call_args[1]['temperature'] == 0.5
|
assert call_args[1]['temperature'] == 0.5
|
||||||
assert call_args[1]['max_tokens'] == 1024
|
assert call_args[1]['max_completion_tokens'] == 1024
|
||||||
assert call_args[1]['top_p'] == 1
|
assert call_args[1]['top_p'] == 1
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
}]
|
}]
|
||||||
}],
|
}],
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=4096
|
max_completion_tokens=4096
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||||
|
|
@ -380,7 +380,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
# Verify other parameters
|
# Verify other parameters
|
||||||
assert call_args[1]['model'] == 'gpt-3.5-turbo'
|
assert call_args[1]['model'] == 'gpt-3.5-turbo'
|
||||||
assert call_args[1]['temperature'] == 0.5
|
assert call_args[1]['temperature'] == 0.5
|
||||||
assert call_args[1]['max_tokens'] == 1024
|
assert call_args[1]['max_completion_tokens'] == 1024
|
||||||
|
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
import websockets
|
import websockets
|
||||||
from typing import Optional, Dict, Any, AsyncIterator, Union
|
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||||
|
|
||||||
|
|
@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException
|
||||||
|
|
||||||
|
|
||||||
class AsyncSocketClient:
|
class AsyncSocketClient:
|
||||||
"""Asynchronous WebSocket client"""
|
"""Asynchronous WebSocket client with persistent connection.
|
||||||
|
|
||||||
|
Maintains a single websocket connection and multiplexes requests
|
||||||
|
by ID, routing responses via a background reader task.
|
||||||
|
|
||||||
|
Use as an async context manager for proper lifecycle management:
|
||||||
|
|
||||||
|
async with AsyncSocketClient(url, timeout, token) as client:
|
||||||
|
result = await client._send_request(...)
|
||||||
|
|
||||||
|
Or call connect()/aclose() manually.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, url: str, timeout: int, token: Optional[str]):
|
def __init__(self, url: str, timeout: int, token: Optional[str]):
|
||||||
self.url = self._convert_to_ws_url(url)
|
self.url = self._convert_to_ws_url(url)
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.token = token
|
self.token = token
|
||||||
self._request_counter = 0
|
self._request_counter = 0
|
||||||
|
self._socket = None
|
||||||
|
self._connect_cm = None
|
||||||
|
self._reader_task = None
|
||||||
|
self._pending = {} # request_id -> asyncio.Queue
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
def _convert_to_ws_url(self, url: str) -> str:
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
"""Convert HTTP URL to WebSocket URL"""
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
|
@ -25,82 +42,123 @@ class AsyncSocketClient:
|
||||||
elif url.startswith("ws://") or url.startswith("wss://"):
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||||
return url
|
return url
|
||||||
else:
|
else:
|
||||||
# Assume ws://
|
|
||||||
return f"ws://{url}"
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def _build_ws_url(self):
|
||||||
|
ws_url = f"{self.url.rstrip('/')}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
return ws_url
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Establish the persistent websocket connection."""
|
||||||
|
if self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
ws_url = self._build_ws_url()
|
||||||
|
self._connect_cm = websockets.connect(
|
||||||
|
ws_url, ping_interval=20, ping_timeout=self.timeout
|
||||||
|
)
|
||||||
|
self._socket = await self._connect_cm.__aenter__()
|
||||||
|
self._connected = True
|
||||||
|
self._reader_task = asyncio.create_task(self._reader())
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.aclose()
|
||||||
|
|
||||||
|
async def _ensure_connected(self):
|
||||||
|
"""Lazily connect if not already connected."""
|
||||||
|
if not self._connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def _reader(self):
|
||||||
|
"""Background task to read responses and route by request ID."""
|
||||||
|
try:
|
||||||
|
async for raw_message in self._socket:
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
request_id = response.get("id")
|
||||||
|
|
||||||
|
if request_id and request_id in self._pending:
|
||||||
|
await self._pending[request_id].put(response)
|
||||||
|
# Ignore messages for unknown request IDs
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
# Signal error to all pending requests
|
||||||
|
for queue in self._pending.values():
|
||||||
|
try:
|
||||||
|
await queue.put({"error": str(e)})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
def _next_request_id(self):
|
||||||
|
self._request_counter += 1
|
||||||
|
return f"req-{self._request_counter}"
|
||||||
|
|
||||||
def flow(self, flow_id: str):
|
def flow(self, flow_id: str):
|
||||||
"""Get async flow instance for WebSocket operations"""
|
"""Get async flow instance for WebSocket operations"""
|
||||||
return AsyncSocketFlowInstance(self, flow_id)
|
return AsyncSocketFlowInstance(self, flow_id)
|
||||||
|
|
||||||
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||||
"""Async WebSocket request implementation (non-streaming)"""
|
"""Send a request and wait for a single response."""
|
||||||
# Generate unique request ID
|
await self._ensure_connected()
|
||||||
self._request_counter += 1
|
|
||||||
request_id = f"req-{self._request_counter}"
|
|
||||||
|
|
||||||
# Build WebSocket URL with optional token
|
request_id = self._next_request_id()
|
||||||
ws_url = f"{self.url}/api/v1/socket"
|
queue = asyncio.Queue()
|
||||||
if self.token:
|
self._pending[request_id] = queue
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
# Build request message
|
try:
|
||||||
message = {
|
message = {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"service": service,
|
"service": service,
|
||||||
"request": request
|
"request": request
|
||||||
}
|
}
|
||||||
if flow:
|
if flow:
|
||||||
message["flow"] = flow
|
message["flow"] = flow
|
||||||
|
|
||||||
# Connect and send request
|
await self._socket.send(json.dumps(message))
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
|
||||||
await websocket.send(json.dumps(message))
|
|
||||||
|
|
||||||
# Wait for single response
|
response = await queue.get()
|
||||||
raw_message = await websocket.recv()
|
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != request_id:
|
|
||||||
raise ProtocolException(f"Response ID mismatch")
|
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
raise ApplicationException(response["error"])
|
raise ApplicationException(response["error"])
|
||||||
|
|
||||||
if "response" not in response:
|
if "response" not in response:
|
||||||
raise ProtocolException(f"Missing response in message")
|
raise ProtocolException("Missing response in message")
|
||||||
|
|
||||||
return response["response"]
|
return response["response"]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._pending.pop(request_id, None)
|
||||||
|
|
||||||
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||||
"""Async WebSocket request implementation (streaming)"""
|
"""Send a request and yield streaming response chunks."""
|
||||||
# Generate unique request ID
|
await self._ensure_connected()
|
||||||
self._request_counter += 1
|
|
||||||
request_id = f"req-{self._request_counter}"
|
|
||||||
|
|
||||||
# Build WebSocket URL with optional token
|
request_id = self._next_request_id()
|
||||||
ws_url = f"{self.url}/api/v1/socket"
|
queue = asyncio.Queue()
|
||||||
if self.token:
|
self._pending[request_id] = queue
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
# Build request message
|
try:
|
||||||
message = {
|
message = {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"service": service,
|
"service": service,
|
||||||
"request": request
|
"request": request
|
||||||
}
|
}
|
||||||
if flow:
|
if flow:
|
||||||
message["flow"] = flow
|
message["flow"] = flow
|
||||||
|
|
||||||
# Connect and send request
|
await self._socket.send(json.dumps(message))
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
|
||||||
await websocket.send(json.dumps(message))
|
|
||||||
|
|
||||||
# Yield chunks as they arrive
|
while True:
|
||||||
async for raw_message in websocket:
|
response = await queue.get()
|
||||||
response = json.loads(raw_message)
|
|
||||||
|
|
||||||
if response.get("id") != request_id:
|
|
||||||
continue # Ignore messages for other requests
|
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
raise ApplicationException(response["error"])
|
raise ApplicationException(response["error"])
|
||||||
|
|
@ -108,18 +166,16 @@ class AsyncSocketClient:
|
||||||
if "response" in response:
|
if "response" in response:
|
||||||
resp = response["response"]
|
resp = response["response"]
|
||||||
|
|
||||||
# Parse different chunk types
|
|
||||||
chunk = self._parse_chunk(resp)
|
chunk = self._parse_chunk(resp)
|
||||||
if chunk is not None: # Skip provenance messages in streaming
|
if chunk is not None:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Check if this is the final message
|
|
||||||
# end_of_session indicates entire session is complete (including provenance)
|
|
||||||
# end_of_dialog is for agent dialogs
|
|
||||||
# complete is from the gateway envelope
|
|
||||||
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._pending.pop(request_id, None)
|
||||||
|
|
||||||
def _parse_chunk(self, resp: Dict[str, Any]):
|
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||||
chunk_type = resp.get("chunk_type")
|
chunk_type = resp.get("chunk_type")
|
||||||
|
|
@ -127,7 +183,6 @@ class AsyncSocketClient:
|
||||||
|
|
||||||
# Handle new GraphRAG message format with message_type
|
# Handle new GraphRAG message format with message_type
|
||||||
if message_type == "provenance":
|
if message_type == "provenance":
|
||||||
# Provenance messages are not yielded to user - they're metadata
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if chunk_type == "thought":
|
if chunk_type == "thought":
|
||||||
|
|
@ -147,25 +202,41 @@ class AsyncSocketClient:
|
||||||
end_of_dialog=resp.get("end_of_dialog", False)
|
end_of_dialog=resp.get("end_of_dialog", False)
|
||||||
)
|
)
|
||||||
elif chunk_type == "action":
|
elif chunk_type == "action":
|
||||||
# Agent action chunks - treat as thoughts for display purposes
|
|
||||||
return AgentThought(
|
return AgentThought(
|
||||||
content=resp.get("content", ""),
|
content=resp.get("content", ""),
|
||||||
end_of_message=resp.get("end_of_message", False)
|
end_of_message=resp.get("end_of_message", False)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# RAG-style chunk (or generic chunk with message_type="chunk")
|
|
||||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
|
||||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||||
return RAGChunk(
|
return RAGChunk(
|
||||||
content=content,
|
content=content,
|
||||||
end_of_stream=resp.get("end_of_stream", False),
|
end_of_stream=resp.get("end_of_stream", False),
|
||||||
error=None # Errors are always thrown, never stored
|
error=None
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
"""Close WebSocket connection"""
|
"""Close the persistent WebSocket connection cleanly."""
|
||||||
# Cleanup handled by context manager
|
# Wait for reader to finish (socket close will cause it to exit)
|
||||||
pass
|
if self._reader_task:
|
||||||
|
self._reader_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._reader_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._reader_task = None
|
||||||
|
|
||||||
|
# Exit the websockets context manager — this cleanly shuts down
|
||||||
|
# the connection and its keepalive task
|
||||||
|
if self._connect_cm:
|
||||||
|
try:
|
||||||
|
await self._connect_cm.__aexit__(None, None, None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._connect_cm = None
|
||||||
|
|
||||||
|
self._socket = None
|
||||||
|
self._connected = False
|
||||||
|
self._pending.clear()
|
||||||
|
|
||||||
|
|
||||||
class AsyncSocketFlowInstance:
|
class AsyncSocketFlowInstance:
|
||||||
|
|
@ -292,7 +363,6 @@ class AsyncSocketFlowInstance:
|
||||||
|
|
||||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||||
"""Query graph embeddings for semantic search"""
|
"""Query graph embeddings for semantic search"""
|
||||||
# First convert text to embedding vector
|
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vector = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
|
|
@ -362,7 +432,6 @@ class AsyncSocketFlowInstance:
|
||||||
limit: int = 10, **kwargs
|
limit: int = 10, **kwargs
|
||||||
):
|
):
|
||||||
"""Query row embeddings for semantic search on structured data"""
|
"""Query row embeddings for semantic search on structured data"""
|
||||||
# First convert text to embedding vector
|
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vector = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -58,6 +58,14 @@ def print_json(sessions):
|
||||||
print(json.dumps(sessions, indent=2))
|
print(json.dumps(sessions, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
# Map type names for display
|
||||||
|
TYPE_DISPLAY = {
|
||||||
|
"graphrag": "GraphRAG",
|
||||||
|
"docrag": "DocRAG",
|
||||||
|
"agent": "Agent",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='tg-list-explain-traces',
|
prog='tg-list-explain-traces',
|
||||||
|
|
@ -118,7 +126,7 @@ def main():
|
||||||
explain_client = ExplainabilityClient(flow)
|
explain_client = ExplainabilityClient(flow)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# List all sessions using the API
|
# List all sessions — uses persistent websocket via SocketClient
|
||||||
questions = explain_client.list_sessions(
|
questions = explain_client.list_sessions(
|
||||||
graph=RETRIEVAL_GRAPH,
|
graph=RETRIEVAL_GRAPH,
|
||||||
user=args.user,
|
user=args.user,
|
||||||
|
|
@ -126,7 +134,8 @@ def main():
|
||||||
limit=args.limit,
|
limit=args.limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to output format
|
# detect_session_type is mostly a fast URI pattern check,
|
||||||
|
# only falls back to network calls for unrecognised URIs
|
||||||
sessions = []
|
sessions = []
|
||||||
for q in questions:
|
for q in questions:
|
||||||
session_type = explain_client.detect_session_type(
|
session_type = explain_client.detect_session_type(
|
||||||
|
|
@ -136,16 +145,9 @@ def main():
|
||||||
collection=args.collection
|
collection=args.collection
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map type names
|
|
||||||
type_display = {
|
|
||||||
"graphrag": "GraphRAG",
|
|
||||||
"docrag": "DocRAG",
|
|
||||||
"agent": "Agent",
|
|
||||||
}.get(session_type, session_type.title())
|
|
||||||
|
|
||||||
sessions.append({
|
sessions.append({
|
||||||
"id": q.uri,
|
"id": q.uri,
|
||||||
"type": type_display,
|
"type": TYPE_DISPLAY.get(session_type, session_type.title()),
|
||||||
"question": q.query,
|
"question": q.query,
|
||||||
"time": q.timestamp,
|
"time": q.timestamp,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -3,31 +3,27 @@ Shows all defined flow blueprints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import tabulate
|
import tabulate
|
||||||
from trustgraph.api import Api, ConfigKey
|
from trustgraph.api import AsyncSocketClient
|
||||||
import json
|
import json
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
def format_parameters(params_metadata, config_api):
|
def format_parameters(params_metadata, param_type_defs):
|
||||||
"""
|
"""
|
||||||
Format parameter metadata for display
|
Format parameter metadata for display.
|
||||||
|
|
||||||
Args:
|
param_type_defs is a dict of type_name -> parsed type definition,
|
||||||
params_metadata: Parameter definitions from flow blueprint
|
pre-fetched concurrently.
|
||||||
config_api: API client to get parameter type information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string describing parameters
|
|
||||||
"""
|
"""
|
||||||
if not params_metadata:
|
if not params_metadata:
|
||||||
return "None"
|
return "None"
|
||||||
|
|
||||||
param_list = []
|
param_list = []
|
||||||
|
|
||||||
# Sort parameters by order if available
|
|
||||||
sorted_params = sorted(
|
sorted_params = sorted(
|
||||||
params_metadata.items(),
|
params_metadata.items(),
|
||||||
key=lambda x: x[1].get("order", 999)
|
key=lambda x: x[1].get("order", 999)
|
||||||
|
|
@ -37,41 +33,89 @@ def format_parameters(params_metadata, config_api):
|
||||||
description = param_meta.get("description", param_name)
|
description = param_meta.get("description", param_name)
|
||||||
param_type = param_meta.get("type", "unknown")
|
param_type = param_meta.get("type", "unknown")
|
||||||
|
|
||||||
# Get type information if available
|
|
||||||
type_info = param_type
|
type_info = param_type
|
||||||
if config_api:
|
if param_type in param_type_defs:
|
||||||
try:
|
param_type_def = param_type_defs[param_type]
|
||||||
key = ConfigKey("parameter-type", param_type)
|
default = param_type_def.get("default")
|
||||||
type_def_value = config_api.get([key])[0].value
|
if default is not None:
|
||||||
param_type_def = json.loads(type_def_value)
|
type_info = f"{param_type} (default: {default})"
|
||||||
|
|
||||||
# Add default value if available
|
|
||||||
default = param_type_def.get("default")
|
|
||||||
if default is not None:
|
|
||||||
type_info = f"{param_type} (default: {default})"
|
|
||||||
|
|
||||||
except:
|
|
||||||
# If we can't get type definition, just show the type name
|
|
||||||
pass
|
|
||||||
|
|
||||||
param_list.append(f" {param_name}: {description} [{type_info}]")
|
param_list.append(f" {param_name}: {description} [{type_info}]")
|
||||||
|
|
||||||
return "\n".join(param_list)
|
return "\n".join(param_list)
|
||||||
|
|
||||||
|
async def fetch_data(client):
|
||||||
|
"""Fetch all data needed for show_flow_blueprints concurrently."""
|
||||||
|
|
||||||
|
# Round 1: list blueprints
|
||||||
|
resp = await client._send_request("flow", None, {
|
||||||
|
"operation": "list-blueprints",
|
||||||
|
})
|
||||||
|
blueprint_names = resp.get("blueprint-names", [])
|
||||||
|
|
||||||
|
if not blueprint_names:
|
||||||
|
return [], {}, {}
|
||||||
|
|
||||||
|
# Round 2: get all blueprints in parallel
|
||||||
|
blueprint_tasks = [
|
||||||
|
client._send_request("flow", None, {
|
||||||
|
"operation": "get-blueprint",
|
||||||
|
"blueprint-name": name,
|
||||||
|
})
|
||||||
|
for name in blueprint_names
|
||||||
|
]
|
||||||
|
blueprint_results = await asyncio.gather(*blueprint_tasks)
|
||||||
|
|
||||||
|
blueprints = {}
|
||||||
|
for name, resp in zip(blueprint_names, blueprint_results):
|
||||||
|
bp_data = resp.get("blueprint-definition", "{}")
|
||||||
|
blueprints[name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data
|
||||||
|
|
||||||
|
# Round 3: get all parameter type definitions in parallel
|
||||||
|
param_types_needed = set()
|
||||||
|
for bp in blueprints.values():
|
||||||
|
for param_meta in bp.get("parameters", {}).values():
|
||||||
|
pt = param_meta.get("type", "")
|
||||||
|
if pt:
|
||||||
|
param_types_needed.add(pt)
|
||||||
|
|
||||||
|
param_type_defs = {}
|
||||||
|
if param_types_needed:
|
||||||
|
param_type_tasks = [
|
||||||
|
client._send_request("config", None, {
|
||||||
|
"operation": "get",
|
||||||
|
"keys": [{"type": "parameter-type", "key": pt}],
|
||||||
|
})
|
||||||
|
for pt in param_types_needed
|
||||||
|
]
|
||||||
|
param_type_results = await asyncio.gather(*param_type_tasks)
|
||||||
|
|
||||||
|
for pt, resp in zip(param_types_needed, param_type_results):
|
||||||
|
values = resp.get("values", [])
|
||||||
|
if values:
|
||||||
|
try:
|
||||||
|
param_type_defs[pt] = json.loads(values[0].get("value", "{}"))
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return blueprint_names, blueprints, param_type_defs
|
||||||
|
|
||||||
|
async def _show_flow_blueprints_async(url, token=None):
|
||||||
|
async with AsyncSocketClient(url, timeout=60, token=token) as client:
|
||||||
|
return await fetch_data(client)
|
||||||
|
|
||||||
def show_flow_blueprints(url, token=None):
|
def show_flow_blueprints(url, token=None):
|
||||||
|
|
||||||
api = Api(url, token=token)
|
blueprint_names, blueprints, param_type_defs = asyncio.run(
|
||||||
flow_api = api.flow()
|
_show_flow_blueprints_async(url, token=token)
|
||||||
config_api = api.config()
|
)
|
||||||
|
|
||||||
blueprint_names = flow_api.list_blueprints()
|
if not blueprint_names:
|
||||||
|
|
||||||
if len(blueprint_names) == 0:
|
|
||||||
print("No flow blueprints.")
|
print("No flow blueprints.")
|
||||||
return
|
return
|
||||||
|
|
||||||
for blueprint_name in blueprint_names:
|
for blueprint_name in blueprint_names:
|
||||||
cls = flow_api.get_blueprint(blueprint_name)
|
cls = blueprints[blueprint_name]
|
||||||
|
|
||||||
table = []
|
table = []
|
||||||
table.append(("name", blueprint_name))
|
table.append(("name", blueprint_name))
|
||||||
|
|
@ -81,10 +125,9 @@ def show_flow_blueprints(url, token=None):
|
||||||
if tags:
|
if tags:
|
||||||
table.append(("tags", ", ".join(tags)))
|
table.append(("tags", ", ".join(tags)))
|
||||||
|
|
||||||
# Show parameters if they exist
|
|
||||||
parameters = cls.get("parameters", {})
|
parameters = cls.get("parameters", {})
|
||||||
if parameters:
|
if parameters:
|
||||||
param_str = format_parameters(parameters, config_api)
|
param_str = format_parameters(parameters, param_type_defs)
|
||||||
table.append(("parameters", param_str))
|
table.append(("parameters", param_str))
|
||||||
|
|
||||||
print(tabulate.tabulate(
|
print(tabulate.tabulate(
|
||||||
|
|
|
||||||
|
|
@ -3,22 +3,15 @@ Shows configured flows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import tabulate
|
import tabulate
|
||||||
from trustgraph.api import Api, ConfigKey
|
from trustgraph.api import Api, AsyncSocketClient
|
||||||
import json
|
import json
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
def get_interface(config_api, i):
|
|
||||||
|
|
||||||
key = ConfigKey("interface-description", i)
|
|
||||||
|
|
||||||
value = config_api.get([key])[0].value
|
|
||||||
|
|
||||||
return json.loads(value)
|
|
||||||
|
|
||||||
def describe_interfaces(intdefs, flow):
|
def describe_interfaces(intdefs, flow):
|
||||||
|
|
||||||
intfs = flow.get("interfaces", {})
|
intfs = flow.get("interfaces", {})
|
||||||
|
|
@ -34,7 +27,7 @@ def describe_interfaces(intdefs, flow):
|
||||||
|
|
||||||
if kind == "request-response":
|
if kind == "request-response":
|
||||||
req = intfs[k]["request"]
|
req = intfs[k]["request"]
|
||||||
resp = intfs[k]["request"]
|
resp = intfs[k]["response"]
|
||||||
|
|
||||||
lst.append(f"{k} request: {req}")
|
lst.append(f"{k} request: {req}")
|
||||||
lst.append(f"{k} response: {resp}")
|
lst.append(f"{k} response: {resp}")
|
||||||
|
|
@ -49,17 +42,9 @@ def describe_interfaces(intdefs, flow):
|
||||||
def get_enum_description(param_value, param_type_def):
|
def get_enum_description(param_value, param_type_def):
|
||||||
"""
|
"""
|
||||||
Get the human-readable description for an enum value
|
Get the human-readable description for an enum value
|
||||||
|
|
||||||
Args:
|
|
||||||
param_value: The actual parameter value (e.g., "gpt-4")
|
|
||||||
param_type_def: The parameter type definition containing enum objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Human-readable description or the original value if not found
|
|
||||||
"""
|
"""
|
||||||
enum_list = param_type_def.get("enum", [])
|
enum_list = param_type_def.get("enum", [])
|
||||||
|
|
||||||
# Handle both old format (strings) and new format (objects with id/description)
|
|
||||||
for enum_item in enum_list:
|
for enum_item in enum_list:
|
||||||
if isinstance(enum_item, dict):
|
if isinstance(enum_item, dict):
|
||||||
if enum_item.get("id") == param_value:
|
if enum_item.get("id") == param_value:
|
||||||
|
|
@ -67,27 +52,20 @@ def get_enum_description(param_value, param_type_def):
|
||||||
elif enum_item == param_value:
|
elif enum_item == param_value:
|
||||||
return param_value
|
return param_value
|
||||||
|
|
||||||
# If not found in enum, return original value
|
|
||||||
return param_value
|
return param_value
|
||||||
|
|
||||||
def format_parameters(flow_params, blueprint_params_metadata, config_api):
|
def format_parameters(flow_params, blueprint_params_metadata, param_type_defs):
|
||||||
"""
|
"""
|
||||||
Format flow parameters with their human-readable descriptions
|
Format flow parameters with their human-readable descriptions.
|
||||||
|
|
||||||
Args:
|
param_type_defs is a dict of type_name -> parsed type definition,
|
||||||
flow_params: The actual parameter values used in the flow
|
pre-fetched concurrently.
|
||||||
blueprint_params_metadata: The parameter metadata from the flow blueprint definition
|
|
||||||
config_api: API client to retrieve parameter type definitions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string of parameters with descriptions
|
|
||||||
"""
|
"""
|
||||||
if not flow_params:
|
if not flow_params:
|
||||||
return "None"
|
return "None"
|
||||||
|
|
||||||
param_list = []
|
param_list = []
|
||||||
|
|
||||||
# Sort parameters by order if available
|
|
||||||
sorted_params = sorted(
|
sorted_params = sorted(
|
||||||
blueprint_params_metadata.items(),
|
blueprint_params_metadata.items(),
|
||||||
key=lambda x: x[1].get("order", 999)
|
key=lambda x: x[1].get("order", 999)
|
||||||
|
|
@ -100,80 +78,165 @@ def format_parameters(flow_params, blueprint_params_metadata, config_api):
|
||||||
param_type = param_meta.get("type", "")
|
param_type = param_meta.get("type", "")
|
||||||
controlled_by = param_meta.get("controlled-by", None)
|
controlled_by = param_meta.get("controlled-by", None)
|
||||||
|
|
||||||
# Try to get enum description if this parameter has a type definition
|
|
||||||
display_value = value
|
display_value = value
|
||||||
if param_type and config_api:
|
if param_type and param_type in param_type_defs:
|
||||||
try:
|
display_value = get_enum_description(
|
||||||
from trustgraph.api import ConfigKey
|
value, param_type_defs[param_type]
|
||||||
key = ConfigKey("parameter-type", param_type)
|
)
|
||||||
type_def_value = config_api.get([key])[0].value
|
|
||||||
param_type_def = json.loads(type_def_value)
|
|
||||||
display_value = get_enum_description(value, param_type_def)
|
|
||||||
except:
|
|
||||||
# If we can't get the type definition, just use the original value
|
|
||||||
display_value = value
|
|
||||||
|
|
||||||
# Format the parameter line
|
|
||||||
line = f"• {description}: {display_value}"
|
line = f"• {description}: {display_value}"
|
||||||
|
|
||||||
# Add controlled-by indicator if present
|
|
||||||
if controlled_by:
|
if controlled_by:
|
||||||
line += f" (controlled by {controlled_by})"
|
line += f" (controlled by {controlled_by})"
|
||||||
|
|
||||||
param_list.append(line)
|
param_list.append(line)
|
||||||
|
|
||||||
# Add any parameters that aren't in the blueprint metadata (shouldn't happen normally)
|
|
||||||
for param_name, value in flow_params.items():
|
for param_name, value in flow_params.items():
|
||||||
if param_name not in blueprint_params_metadata:
|
if param_name not in blueprint_params_metadata:
|
||||||
param_list.append(f"• {param_name}: {value} (undefined)")
|
param_list.append(f"• {param_name}: {value} (undefined)")
|
||||||
|
|
||||||
return "\n".join(param_list) if param_list else "None"
|
return "\n".join(param_list) if param_list else "None"
|
||||||
|
|
||||||
|
async def fetch_show_flows(client):
|
||||||
|
"""Fetch all data needed for show_flows concurrently."""
|
||||||
|
|
||||||
|
# Round 1: list interfaces and list flows in parallel
|
||||||
|
interface_names_resp, flow_ids_resp = await asyncio.gather(
|
||||||
|
client._send_request("config", None, {
|
||||||
|
"operation": "list",
|
||||||
|
"type": "interface-description",
|
||||||
|
}),
|
||||||
|
client._send_request("flow", None, {
|
||||||
|
"operation": "list-flows",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
interface_names = interface_names_resp.get("directory", [])
|
||||||
|
flow_ids = flow_ids_resp.get("flow-ids", [])
|
||||||
|
|
||||||
|
if not flow_ids:
|
||||||
|
return {}, [], {}, {}
|
||||||
|
|
||||||
|
# Round 2: get all interfaces + all flows in parallel
|
||||||
|
interface_tasks = [
|
||||||
|
client._send_request("config", None, {
|
||||||
|
"operation": "get",
|
||||||
|
"keys": [{"type": "interface-description", "key": name}],
|
||||||
|
})
|
||||||
|
for name in interface_names
|
||||||
|
]
|
||||||
|
|
||||||
|
flow_tasks = [
|
||||||
|
client._send_request("flow", None, {
|
||||||
|
"operation": "get-flow",
|
||||||
|
"flow-id": fid,
|
||||||
|
})
|
||||||
|
for fid in flow_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*interface_tasks, *flow_tasks)
|
||||||
|
|
||||||
|
# Split results
|
||||||
|
interface_results = results[:len(interface_names)]
|
||||||
|
flow_results = results[len(interface_names):]
|
||||||
|
|
||||||
|
# Parse interfaces
|
||||||
|
interface_defs = {}
|
||||||
|
for name, resp in zip(interface_names, interface_results):
|
||||||
|
values = resp.get("values", [])
|
||||||
|
if values:
|
||||||
|
interface_defs[name] = json.loads(values[0].get("value", "{}"))
|
||||||
|
|
||||||
|
# Parse flows
|
||||||
|
flows = {}
|
||||||
|
for fid, resp in zip(flow_ids, flow_results):
|
||||||
|
flow_data = resp.get("flow", "{}")
|
||||||
|
flows[fid] = json.loads(flow_data) if isinstance(flow_data, str) else flow_data
|
||||||
|
|
||||||
|
# Round 3: get all blueprints in parallel
|
||||||
|
blueprint_names = set()
|
||||||
|
for flow in flows.values():
|
||||||
|
bp = flow.get("blueprint-name", "")
|
||||||
|
if bp:
|
||||||
|
blueprint_names.add(bp)
|
||||||
|
|
||||||
|
blueprint_tasks = [
|
||||||
|
client._send_request("flow", None, {
|
||||||
|
"operation": "get-blueprint",
|
||||||
|
"blueprint-name": bp_name,
|
||||||
|
})
|
||||||
|
for bp_name in blueprint_names
|
||||||
|
]
|
||||||
|
|
||||||
|
blueprint_results = await asyncio.gather(*blueprint_tasks)
|
||||||
|
|
||||||
|
blueprints = {}
|
||||||
|
for bp_name, resp in zip(blueprint_names, blueprint_results):
|
||||||
|
bp_data = resp.get("blueprint-definition", "{}")
|
||||||
|
blueprints[bp_name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data
|
||||||
|
|
||||||
|
# Round 4: get all parameter type definitions in parallel
|
||||||
|
param_types_needed = set()
|
||||||
|
for bp in blueprints.values():
|
||||||
|
for param_meta in bp.get("parameters", {}).values():
|
||||||
|
pt = param_meta.get("type", "")
|
||||||
|
if pt:
|
||||||
|
param_types_needed.add(pt)
|
||||||
|
|
||||||
|
param_type_tasks = [
|
||||||
|
client._send_request("config", None, {
|
||||||
|
"operation": "get",
|
||||||
|
"keys": [{"type": "parameter-type", "key": pt}],
|
||||||
|
})
|
||||||
|
for pt in param_types_needed
|
||||||
|
]
|
||||||
|
|
||||||
|
param_type_results = await asyncio.gather(*param_type_tasks)
|
||||||
|
|
||||||
|
param_type_defs = {}
|
||||||
|
for pt, resp in zip(param_types_needed, param_type_results):
|
||||||
|
values = resp.get("values", [])
|
||||||
|
if values:
|
||||||
|
try:
|
||||||
|
param_type_defs[pt] = json.loads(values[0].get("value", "{}"))
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return interface_defs, flow_ids, flows, blueprints, param_type_defs
|
||||||
|
|
||||||
|
async def _show_flows_async(url, token=None):
|
||||||
|
|
||||||
|
async with AsyncSocketClient(url, timeout=60, token=token) as client:
|
||||||
|
return await fetch_show_flows(client)
|
||||||
|
|
||||||
def show_flows(url, token=None):
|
def show_flows(url, token=None):
|
||||||
|
|
||||||
api = Api(url, token=token)
|
result = asyncio.run(_show_flows_async(url, token=token))
|
||||||
config_api = api.config()
|
|
||||||
flow_api = api.flow()
|
|
||||||
|
|
||||||
interface_names = config_api.list("interface-description")
|
interface_defs, flow_ids, flows, blueprints, param_type_defs = result
|
||||||
|
|
||||||
interface_defs = {
|
if not flow_ids:
|
||||||
i: get_interface(config_api, i)
|
|
||||||
for i in interface_names
|
|
||||||
}
|
|
||||||
|
|
||||||
flow_ids = flow_api.list()
|
|
||||||
|
|
||||||
if len(flow_ids) == 0:
|
|
||||||
print("No flows.")
|
print("No flows.")
|
||||||
return
|
return
|
||||||
|
|
||||||
flows = []
|
for fid in flow_ids:
|
||||||
|
|
||||||
for id in flow_ids:
|
flow = flows[fid]
|
||||||
|
|
||||||
flow = flow_api.get(id)
|
|
||||||
|
|
||||||
table = []
|
table = []
|
||||||
table.append(("id", id))
|
table.append(("id", fid))
|
||||||
table.append(("blueprint", flow.get("blueprint-name", "")))
|
table.append(("blueprint", flow.get("blueprint-name", "")))
|
||||||
table.append(("desc", flow.get("description", "")))
|
table.append(("desc", flow.get("description", "")))
|
||||||
|
|
||||||
# Display parameters with human-readable descriptions
|
|
||||||
parameters = flow.get("parameters", {})
|
parameters = flow.get("parameters", {})
|
||||||
if parameters:
|
if parameters:
|
||||||
# Try to get the flow blueprint definition for parameter metadata
|
|
||||||
blueprint_name = flow.get("blueprint-name", "")
|
blueprint_name = flow.get("blueprint-name", "")
|
||||||
if blueprint_name:
|
if blueprint_name and blueprint_name in blueprints:
|
||||||
try:
|
blueprint_params_metadata = blueprints[blueprint_name].get("parameters", {})
|
||||||
flow_blueprint = flow_api.get_blueprint(blueprint_name)
|
param_str = format_parameters(
|
||||||
blueprint_params_metadata = flow_blueprint.get("parameters", {})
|
parameters, blueprint_params_metadata, param_type_defs
|
||||||
param_str = format_parameters(parameters, blueprint_params_metadata, config_api)
|
)
|
||||||
except Exception as e:
|
|
||||||
# Fallback to JSON if we can't get the blueprint definition
|
|
||||||
param_str = json.dumps(parameters, indent=2)
|
|
||||||
else:
|
else:
|
||||||
# No blueprint name, fallback to JSON
|
|
||||||
param_str = json.dumps(parameters, indent=2)
|
param_str = json.dumps(parameters, indent=2)
|
||||||
|
|
||||||
table.append(("parameters", param_str))
|
table.append(("parameters", param_str))
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,10 @@ valid enums, and validation rules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import tabulate
|
import tabulate
|
||||||
from trustgraph.api import Api, ConfigKey
|
from trustgraph.api import AsyncSocketClient
|
||||||
import json
|
import json
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
|
@ -17,13 +18,7 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
def format_enum_values(enum_list):
|
def format_enum_values(enum_list):
|
||||||
"""
|
"""
|
||||||
Format enum values for display, handling both old and new formats
|
Format enum values for display, handling both old and new formats.
|
||||||
|
|
||||||
Args:
|
|
||||||
enum_list: List of enum values (strings or objects with id/description)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string describing enum options
|
|
||||||
"""
|
"""
|
||||||
if not enum_list:
|
if not enum_list:
|
||||||
return "Any value"
|
return "Any value"
|
||||||
|
|
@ -31,7 +26,6 @@ def format_enum_values(enum_list):
|
||||||
enum_items = []
|
enum_items = []
|
||||||
for item in enum_list:
|
for item in enum_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# New format: objects with id and description
|
|
||||||
enum_id = item.get("id", "")
|
enum_id = item.get("id", "")
|
||||||
description = item.get("description", "")
|
description = item.get("description", "")
|
||||||
if description:
|
if description:
|
||||||
|
|
@ -39,99 +33,146 @@ def format_enum_values(enum_list):
|
||||||
else:
|
else:
|
||||||
enum_items.append(enum_id)
|
enum_items.append(enum_id)
|
||||||
else:
|
else:
|
||||||
# Old format: simple strings
|
|
||||||
enum_items.append(str(item))
|
enum_items.append(str(item))
|
||||||
|
|
||||||
return "\n".join(f"• {item}" for item in enum_items)
|
return "\n".join(f"• {item}" for item in enum_items)
|
||||||
|
|
||||||
def format_constraints(param_type_def):
|
def format_constraints(param_type_def):
|
||||||
"""
|
"""
|
||||||
Format validation constraints for display
|
Format validation constraints for display.
|
||||||
|
|
||||||
Args:
|
|
||||||
param_type_def: Parameter type definition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string describing constraints
|
|
||||||
"""
|
"""
|
||||||
constraints = []
|
constraints = []
|
||||||
|
|
||||||
# Handle numeric constraints
|
|
||||||
if "minimum" in param_type_def:
|
if "minimum" in param_type_def:
|
||||||
constraints.append(f"min: {param_type_def['minimum']}")
|
constraints.append(f"min: {param_type_def['minimum']}")
|
||||||
if "maximum" in param_type_def:
|
if "maximum" in param_type_def:
|
||||||
constraints.append(f"max: {param_type_def['maximum']}")
|
constraints.append(f"max: {param_type_def['maximum']}")
|
||||||
|
|
||||||
# Handle string constraints
|
|
||||||
if "minLength" in param_type_def:
|
if "minLength" in param_type_def:
|
||||||
constraints.append(f"min length: {param_type_def['minLength']}")
|
constraints.append(f"min length: {param_type_def['minLength']}")
|
||||||
if "maxLength" in param_type_def:
|
if "maxLength" in param_type_def:
|
||||||
constraints.append(f"max length: {param_type_def['maxLength']}")
|
constraints.append(f"max length: {param_type_def['maxLength']}")
|
||||||
if "pattern" in param_type_def:
|
if "pattern" in param_type_def:
|
||||||
constraints.append(f"pattern: {param_type_def['pattern']}")
|
constraints.append(f"pattern: {param_type_def['pattern']}")
|
||||||
|
|
||||||
# Handle required field
|
|
||||||
if param_type_def.get("required", False):
|
if param_type_def.get("required", False):
|
||||||
constraints.append("required")
|
constraints.append("required")
|
||||||
|
|
||||||
return ", ".join(constraints) if constraints else "None"
|
return ", ".join(constraints) if constraints else "None"
|
||||||
|
|
||||||
|
def format_param_type(param_type_name, param_type_def):
|
||||||
|
"""Format a single parameter type for display."""
|
||||||
|
table = []
|
||||||
|
table.append(("name", param_type_name))
|
||||||
|
table.append(("description", param_type_def.get("description", "")))
|
||||||
|
table.append(("type", param_type_def.get("type", "unknown")))
|
||||||
|
|
||||||
|
default = param_type_def.get("default")
|
||||||
|
if default is not None:
|
||||||
|
table.append(("default", str(default)))
|
||||||
|
|
||||||
|
enum_list = param_type_def.get("enum")
|
||||||
|
if enum_list:
|
||||||
|
enum_str = format_enum_values(enum_list)
|
||||||
|
table.append(("valid values", enum_str))
|
||||||
|
|
||||||
|
constraints = format_constraints(param_type_def)
|
||||||
|
if constraints != "None":
|
||||||
|
table.append(("constraints", constraints))
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
async def fetch_all_param_types(client):
|
||||||
|
"""Fetch all parameter types concurrently."""
|
||||||
|
|
||||||
|
# Round 1: list parameter types
|
||||||
|
resp = await client._send_request("config", None, {
|
||||||
|
"operation": "list",
|
||||||
|
"type": "parameter-type",
|
||||||
|
})
|
||||||
|
param_type_names = resp.get("directory", [])
|
||||||
|
|
||||||
|
if not param_type_names:
|
||||||
|
return [], {}
|
||||||
|
|
||||||
|
# Round 2: get all parameter types in parallel
|
||||||
|
tasks = [
|
||||||
|
client._send_request("config", None, {
|
||||||
|
"operation": "get",
|
||||||
|
"keys": [{"type": "parameter-type", "key": name}],
|
||||||
|
})
|
||||||
|
for name in param_type_names
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
param_type_defs = {}
|
||||||
|
for name, resp in zip(param_type_names, results):
|
||||||
|
values = resp.get("values", [])
|
||||||
|
if values:
|
||||||
|
try:
|
||||||
|
param_type_defs[name] = json.loads(values[0].get("value", "{}"))
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return param_type_names, param_type_defs
|
||||||
|
|
||||||
|
async def fetch_single_param_type(client, param_type_name):
|
||||||
|
"""Fetch a single parameter type."""
|
||||||
|
resp = await client._send_request("config", None, {
|
||||||
|
"operation": "get",
|
||||||
|
"keys": [{"type": "parameter-type", "key": param_type_name}],
|
||||||
|
})
|
||||||
|
values = resp.get("values", [])
|
||||||
|
if values:
|
||||||
|
return json.loads(values[0].get("value", "{}"))
|
||||||
|
return None
|
||||||
|
|
||||||
def show_parameter_types(url, token=None):
|
def show_parameter_types(url, token=None):
|
||||||
"""
|
"""Show all parameter type definitions."""
|
||||||
Show all parameter type definitions
|
|
||||||
"""
|
|
||||||
api = Api(url, token=token)
|
|
||||||
config_api = api.config()
|
|
||||||
|
|
||||||
# Get list of all parameter types
|
async def _fetch():
|
||||||
try:
|
async with AsyncSocketClient(url, timeout=60, token=token) as client:
|
||||||
param_type_names = config_api.list("parameter-type")
|
return await fetch_all_param_types(client)
|
||||||
except Exception as e:
|
|
||||||
print(f"Error retrieving parameter types: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(param_type_names) == 0:
|
param_type_names, param_type_defs = asyncio.run(_fetch())
|
||||||
|
|
||||||
|
if not param_type_names:
|
||||||
print("No parameter types defined.")
|
print("No parameter types defined.")
|
||||||
return
|
return
|
||||||
|
|
||||||
for param_type_name in param_type_names:
|
for name in param_type_names:
|
||||||
try:
|
if name not in param_type_defs:
|
||||||
# Get the parameter type definition
|
print(f"Error retrieving parameter type '{name}'")
|
||||||
key = ConfigKey("parameter-type", param_type_name)
|
|
||||||
type_def_value = config_api.get([key])[0].value
|
|
||||||
param_type_def = json.loads(type_def_value)
|
|
||||||
|
|
||||||
table = []
|
|
||||||
table.append(("name", param_type_name))
|
|
||||||
table.append(("description", param_type_def.get("description", "")))
|
|
||||||
table.append(("type", param_type_def.get("type", "unknown")))
|
|
||||||
|
|
||||||
# Show default value if present
|
|
||||||
default = param_type_def.get("default")
|
|
||||||
if default is not None:
|
|
||||||
table.append(("default", str(default)))
|
|
||||||
|
|
||||||
# Show enum values if present
|
|
||||||
enum_list = param_type_def.get("enum")
|
|
||||||
if enum_list:
|
|
||||||
enum_str = format_enum_values(enum_list)
|
|
||||||
table.append(("valid values", enum_str))
|
|
||||||
|
|
||||||
# Show constraints
|
|
||||||
constraints = format_constraints(param_type_def)
|
|
||||||
if constraints != "None":
|
|
||||||
table.append(("constraints", constraints))
|
|
||||||
|
|
||||||
print(tabulate.tabulate(
|
|
||||||
table,
|
|
||||||
tablefmt="pretty",
|
|
||||||
stralign="left",
|
|
||||||
))
|
|
||||||
print()
|
print()
|
||||||
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
table = format_param_type(name, param_type_defs[name])
|
||||||
print(f"Error retrieving parameter type '{param_type_name}': {e}")
|
|
||||||
print()
|
print(tabulate.tabulate(
|
||||||
|
table,
|
||||||
|
tablefmt="pretty",
|
||||||
|
stralign="left",
|
||||||
|
))
|
||||||
|
print()
|
||||||
|
|
||||||
|
def show_specific_parameter_type(url, param_type_name, token=None):
|
||||||
|
"""Show a specific parameter type definition."""
|
||||||
|
|
||||||
|
async def _fetch():
|
||||||
|
async with AsyncSocketClient(url, timeout=60, token=token) as client:
|
||||||
|
return await fetch_single_param_type(client, param_type_name)
|
||||||
|
|
||||||
|
param_type_def = asyncio.run(_fetch())
|
||||||
|
|
||||||
|
if param_type_def is None:
|
||||||
|
print(f"Error retrieving parameter type '{param_type_name}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = format_param_type(param_type_name, param_type_def)
|
||||||
|
|
||||||
|
print(tabulate.tabulate(
|
||||||
|
table,
|
||||||
|
tablefmt="pretty",
|
||||||
|
stralign="left",
|
||||||
|
))
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|
@ -161,57 +202,12 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.type:
|
if args.type:
|
||||||
# Show specific parameter type
|
|
||||||
show_specific_parameter_type(args.api_url, args.type, args.token)
|
show_specific_parameter_type(args.api_url, args.type, args.token)
|
||||||
else:
|
else:
|
||||||
# Show all parameter types
|
|
||||||
show_parameter_types(args.api_url, args.token)
|
show_parameter_types(args.api_url, args.token)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
def show_specific_parameter_type(url, param_type_name, token=None):
|
|
||||||
"""
|
|
||||||
Show a specific parameter type definition
|
|
||||||
"""
|
|
||||||
api = Api(url, token=token)
|
|
||||||
config_api = api.config()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get the parameter type definition
|
|
||||||
key = ConfigKey("parameter-type", param_type_name)
|
|
||||||
type_def_value = config_api.get([key])[0].value
|
|
||||||
param_type_def = json.loads(type_def_value)
|
|
||||||
|
|
||||||
table = []
|
|
||||||
table.append(("name", param_type_name))
|
|
||||||
table.append(("description", param_type_def.get("description", "")))
|
|
||||||
table.append(("type", param_type_def.get("type", "unknown")))
|
|
||||||
|
|
||||||
# Show default value if present
|
|
||||||
default = param_type_def.get("default")
|
|
||||||
if default is not None:
|
|
||||||
table.append(("default", str(default)))
|
|
||||||
|
|
||||||
# Show enum values if present
|
|
||||||
enum_list = param_type_def.get("enum")
|
|
||||||
if enum_list:
|
|
||||||
enum_str = format_enum_values(enum_list)
|
|
||||||
table.append(("valid values", enum_str))
|
|
||||||
|
|
||||||
# Show constraints
|
|
||||||
constraints = format_constraints(param_type_def)
|
|
||||||
if constraints != "None":
|
|
||||||
table.append(("constraints", constraints))
|
|
||||||
|
|
||||||
print(tabulate.tabulate(
|
|
||||||
table,
|
|
||||||
tablefmt="pretty",
|
|
||||||
stralign="left",
|
|
||||||
))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error retrieving parameter type '{param_type_name}': {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -178,7 +178,11 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona
|
||||||
url += '/'
|
url += '/'
|
||||||
metrics_url = f"{url}api/metrics/query?query=processor_info"
|
metrics_url = f"{url}api/metrics/query?query=processor_info"
|
||||||
|
|
||||||
resp = requests.get(metrics_url, timeout=timeout)
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
resp = requests.get(metrics_url, timeout=timeout, headers=headers)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
processor_count = len(data.get("data", {}).get("result", []))
|
processor_count = len(data.get("data", {}).get("result", []))
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,12 @@ class Mux:
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
|
request_id = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
request_id = data.get("id")
|
||||||
|
|
||||||
if "request" not in data:
|
if "request" not in data:
|
||||||
raise RuntimeError("Bad message")
|
raise RuntimeError("Bad message")
|
||||||
|
|
@ -51,7 +54,13 @@ class Mux:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Receive exception: {str(e)}", exc_info=True)
|
logger.error(f"Receive exception: {str(e)}", exc_info=True)
|
||||||
await self.ws.send_json({"error": str(e)})
|
error_resp = {
|
||||||
|
"error": {"message": str(e), "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
}
|
||||||
|
if request_id:
|
||||||
|
error_resp["id"] = request_id
|
||||||
|
await self.ws.send_json(error_resp)
|
||||||
|
|
||||||
async def maybe_tidy_workers(self, workers):
|
async def maybe_tidy_workers(self, workers):
|
||||||
|
|
||||||
|
|
@ -97,12 +106,12 @@ class Mux:
|
||||||
})
|
})
|
||||||
|
|
||||||
worker = asyncio.create_task(
|
worker = asyncio.create_task(
|
||||||
self.request_task(request, responder, flow, svc)
|
self.request_task(id, request, responder, flow, svc)
|
||||||
)
|
)
|
||||||
|
|
||||||
workers.append(worker)
|
workers.append(worker)
|
||||||
|
|
||||||
async def request_task(self, request, responder, flow, svc):
|
async def request_task(self, id, request, responder, flow, svc):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -119,7 +128,11 @@ class Mux:
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.ws.send_json({"error": str(e)})
|
await self.ws.send_json({
|
||||||
|
"id": id,
|
||||||
|
"error": {"message": str(e), "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
||||||
|
|
@ -143,7 +156,11 @@ class Mux:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# This is an internal working error, may not be recoverable
|
# This is an internal working error, may not be recoverable
|
||||||
logger.error(f"Run prepare exception: {e}", exc_info=True)
|
logger.error(f"Run prepare exception: {e}", exc_info=True)
|
||||||
await self.ws.send_json({"id": id, "error": str(e)})
|
await self.ws.send_json({
|
||||||
|
"id": id,
|
||||||
|
"error": {"message": str(e), "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
if self.ws:
|
if self.ws:
|
||||||
|
|
@ -160,7 +177,11 @@ class Mux:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception in mux: {e}", exc_info=True)
|
logger.error(f"Exception in mux: {e}", exc_info=True)
|
||||||
await self.ws.send_json({"error": str(e)})
|
await self.ws.send_json({
|
||||||
|
"id": id,
|
||||||
|
"error": {"message": str(e), "type": "error"},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
|
||||||
self.running.stop()
|
self.running.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ default_ident = "text-completion"
|
||||||
|
|
||||||
default_temperature = 0.0
|
default_temperature = 0.0
|
||||||
default_max_output = 4192
|
default_max_output = 4192
|
||||||
default_api = "2024-12-01-preview"
|
default_api = os.getenv("AZURE_API_VERSION", "2024-12-01-preview")
|
||||||
default_endpoint = os.getenv("AZURE_ENDPOINT", None)
|
default_endpoint = os.getenv("AZURE_ENDPOINT", None)
|
||||||
default_token = os.getenv("AZURE_TOKEN", None)
|
default_token = os.getenv("AZURE_TOKEN", None)
|
||||||
default_model = os.getenv("AZURE_MODEL", None)
|
default_model = os.getenv("AZURE_MODEL", None)
|
||||||
|
|
@ -90,7 +90,7 @@ class Processor(LlmService):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=effective_temperature,
|
temperature=effective_temperature,
|
||||||
max_tokens=self.max_output,
|
max_completion_tokens=self.max_output,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -159,7 +159,7 @@ class Processor(LlmService):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=effective_temperature,
|
temperature=effective_temperature,
|
||||||
max_tokens=self.max_output,
|
max_completion_tokens=self.max_output,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True}
|
stream_options={"include_usage": True}
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ class Processor(LlmService):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=effective_temperature,
|
temperature=effective_temperature,
|
||||||
max_tokens=self.max_output,
|
max_completion_tokens=self.max_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputtokens = resp.usage.prompt_tokens
|
inputtokens = resp.usage.prompt_tokens
|
||||||
|
|
@ -152,7 +152,7 @@ class Processor(LlmService):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=effective_temperature,
|
temperature=effective_temperature,
|
||||||
max_tokens=self.max_output,
|
max_completion_tokens=self.max_output,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True}
|
stream_options={"include_usage": True}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
|
||||||
class AppContext:
|
class AppContext:
|
||||||
sockets: dict[str, WebSocketManager]
|
sockets: dict[str, WebSocketManager]
|
||||||
websocket_url: str
|
websocket_url: str
|
||||||
|
gateway_token: str
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
|
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Manage application lifecycle with type-safe context
|
Manage application lifecycle with type-safe context
|
||||||
|
|
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
|
||||||
sockets = {}
|
sockets = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield AppContext(sockets=sockets, websocket_url=websocket_url)
|
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
# Cleanup on shutdown
|
# Cleanup on shutdown
|
||||||
|
|
@ -53,6 +54,7 @@ async def get_socket_manager(ctx, user):
|
||||||
lifespan_context = ctx.request_context.lifespan_context
|
lifespan_context = ctx.request_context.lifespan_context
|
||||||
sockets = lifespan_context.sockets
|
sockets = lifespan_context.sockets
|
||||||
websocket_url = lifespan_context.websocket_url
|
websocket_url = lifespan_context.websocket_url
|
||||||
|
gateway_token = lifespan_context.gateway_token
|
||||||
|
|
||||||
if user in sockets:
|
if user in sockets:
|
||||||
logging.info("Return existing socket manager")
|
logging.info("Return existing socket manager")
|
||||||
|
|
@ -61,7 +63,7 @@ async def get_socket_manager(ctx, user):
|
||||||
logging.info(f"Opening socket to {websocket_url}...")
|
logging.info(f"Opening socket to {websocket_url}...")
|
||||||
|
|
||||||
# Create manager with empty pending requests
|
# Create manager with empty pending requests
|
||||||
manager = WebSocketManager(websocket_url)
|
manager = WebSocketManager(websocket_url, token=gateway_token)
|
||||||
|
|
||||||
# Start reader task with the proper manager
|
# Start reader task with the proper manager
|
||||||
await manager.start()
|
await manager.start()
|
||||||
|
|
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
class McpServer:
|
class McpServer:
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
|
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.websocket_url = websocket_url
|
self.websocket_url = websocket_url
|
||||||
|
self.gateway_token = gateway_token
|
||||||
|
|
||||||
# Create a partial function to pass websocket_url to app_lifespan
|
# Create a partial function to pass websocket_url to app_lifespan
|
||||||
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
|
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||||
|
|
||||||
self.mcp = FastMCP(
|
self.mcp = FastMCP(
|
||||||
"TrustGraph", dependencies=["trustgraph-base"],
|
"TrustGraph", dependencies=["trustgraph-base"],
|
||||||
|
|
@ -2060,8 +2063,11 @@ def main():
|
||||||
# Setup logging before creating server
|
# Setup logging before creating server
|
||||||
setup_logging(vars(args))
|
setup_logging(vars(args))
|
||||||
|
|
||||||
|
# Read gateway auth token from environment
|
||||||
|
gateway_token = os.environ.get("GATEWAY_SECRET", "")
|
||||||
|
|
||||||
# Create and run the MCP server
|
# Create and run the MCP server
|
||||||
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
|
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from websockets.asyncio.client import connect
|
from websockets.asyncio.client import connect
|
||||||
|
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
@ -9,12 +10,22 @@ import time
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
|
|
||||||
def __init__(self, url):
|
def __init__(self, url, token=None):
|
||||||
self.url = url
|
self.url = url
|
||||||
|
self.token = token
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
|
def _build_url(self):
|
||||||
|
if not self.token:
|
||||||
|
return self.url
|
||||||
|
parsed = urlparse(self.url)
|
||||||
|
params = parse_qs(parsed.query)
|
||||||
|
params["token"] = [self.token]
|
||||||
|
new_query = urlencode(params, doseq=True)
|
||||||
|
return urlunparse(parsed._replace(query=new_query))
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.socket = await connect(self.url)
|
self.socket = await connect(self._build_url())
|
||||||
self.pending_requests = {}
|
self.pending_requests = {}
|
||||||
self.running = True
|
self.running = True
|
||||||
self.reader_task = asyncio.create_task(self.reader())
|
self.reader_task = asyncio.create_task(self.reader())
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ dependencies = [
|
||||||
"pulsar-client",
|
"pulsar-client",
|
||||||
"prometheus-client",
|
"prometheus-client",
|
||||||
"python-magic",
|
"python-magic",
|
||||||
"unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]",
|
"unstructured[csv,docx,epub,md,odt,pdf,pptx,rst,rtf,tsv,xlsx]",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue