mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 18:06:21 +02:00
Fix Python streaming SDK issues (#580)
* Fix verify CLI issues * Fixing content mechanisms in API * Fixing error handling * Fixing invoke_prompt, invoke_llm, invoke_agent
This commit is contained in:
parent
52ca74bbbc
commit
664bce6182
9 changed files with 234 additions and 58 deletions
|
|
@ -34,7 +34,26 @@ from .types import (
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exceptions
|
# Exceptions
|
||||||
from .exceptions import ProtocolException, ApplicationException
|
from .exceptions import (
|
||||||
|
ProtocolException,
|
||||||
|
TrustGraphException,
|
||||||
|
AgentError,
|
||||||
|
ConfigError,
|
||||||
|
DocumentRagError,
|
||||||
|
FlowError,
|
||||||
|
GatewayError,
|
||||||
|
GraphRagError,
|
||||||
|
LLMError,
|
||||||
|
LoadError,
|
||||||
|
LookupError,
|
||||||
|
NLPQueryError,
|
||||||
|
ObjectsQueryError,
|
||||||
|
RequestError,
|
||||||
|
StructuredQueryError,
|
||||||
|
UnexpectedError,
|
||||||
|
# Legacy alias
|
||||||
|
ApplicationException,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core API
|
# Core API
|
||||||
|
|
@ -75,6 +94,21 @@ __all__ = [
|
||||||
|
|
||||||
# Exceptions
|
# Exceptions
|
||||||
"ProtocolException",
|
"ProtocolException",
|
||||||
"ApplicationException",
|
"TrustGraphException",
|
||||||
|
"AgentError",
|
||||||
|
"ConfigError",
|
||||||
|
"DocumentRagError",
|
||||||
|
"FlowError",
|
||||||
|
"GatewayError",
|
||||||
|
"GraphRagError",
|
||||||
|
"LLMError",
|
||||||
|
"LoadError",
|
||||||
|
"LookupError",
|
||||||
|
"NLPQueryError",
|
||||||
|
"ObjectsQueryError",
|
||||||
|
"RequestError",
|
||||||
|
"StructuredQueryError",
|
||||||
|
"UnexpectedError",
|
||||||
|
"ApplicationException", # Legacy alias
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,18 +130,26 @@ class AsyncSocketClient:
|
||||||
content=resp.get("content", ""),
|
content=resp.get("content", ""),
|
||||||
end_of_message=resp.get("end_of_message", False)
|
end_of_message=resp.get("end_of_message", False)
|
||||||
)
|
)
|
||||||
elif chunk_type == "final-answer":
|
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||||
return AgentAnswer(
|
return AgentAnswer(
|
||||||
content=resp.get("content", ""),
|
content=resp.get("content", ""),
|
||||||
end_of_message=resp.get("end_of_message", False),
|
end_of_message=resp.get("end_of_message", False),
|
||||||
end_of_dialog=resp.get("end_of_dialog", False)
|
end_of_dialog=resp.get("end_of_dialog", False)
|
||||||
)
|
)
|
||||||
|
elif chunk_type == "action":
|
||||||
|
# Agent action chunks - treat as thoughts for display purposes
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# RAG-style chunk (or generic chunk)
|
# RAG-style chunk (or generic chunk)
|
||||||
|
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||||
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||||
return RAGChunk(
|
return RAGChunk(
|
||||||
content=resp.get("chunk", ""),
|
content=content,
|
||||||
end_of_stream=resp.get("end_of_stream", False),
|
end_of_stream=resp.get("end_of_stream", False),
|
||||||
error=resp.get("error")
|
error=None # Errors are always thrown, never stored
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aclose(self):
|
async def aclose(self):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,134 @@
|
||||||
|
"""
|
||||||
|
TrustGraph API Exceptions
|
||||||
|
|
||||||
|
Exception hierarchy for errors returned by TrustGraph services.
|
||||||
|
Each service error type maps to a specific exception class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Protocol-level exceptions (communication errors)
|
||||||
class ProtocolException(Exception):
|
class ProtocolException(Exception):
|
||||||
|
"""Raised when WebSocket protocol errors occur"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ApplicationException(Exception):
|
|
||||||
|
# Base class for all TrustGraph application errors
|
||||||
|
class TrustGraphException(Exception):
|
||||||
|
"""Base class for all TrustGraph service errors"""
|
||||||
|
def __init__(self, message: str, error_type: str = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.error_type = error_type
|
||||||
|
|
||||||
|
|
||||||
|
# Service-specific exceptions
|
||||||
|
class AgentError(TrustGraphException):
|
||||||
|
"""Agent service error"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(TrustGraphException):
|
||||||
|
"""Configuration service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRagError(TrustGraphException):
|
||||||
|
"""Document RAG retrieval error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FlowError(TrustGraphException):
|
||||||
|
"""Flow management error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayError(TrustGraphException):
|
||||||
|
"""API Gateway error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRagError(TrustGraphException):
|
||||||
|
"""Graph RAG retrieval error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(TrustGraphException):
|
||||||
|
"""LLM service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LoadError(TrustGraphException):
|
||||||
|
"""Data loading error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LookupError(TrustGraphException):
|
||||||
|
"""Lookup/search error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NLPQueryError(TrustGraphException):
|
||||||
|
"""NLP query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectsQueryError(TrustGraphException):
|
||||||
|
"""Objects query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RequestError(TrustGraphException):
|
||||||
|
"""Request processing error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredQueryError(TrustGraphException):
|
||||||
|
"""Structured query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedError(TrustGraphException):
|
||||||
|
"""Unexpected/unknown error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from error type string to exception class
|
||||||
|
ERROR_TYPE_MAPPING = {
|
||||||
|
"agent-error": AgentError,
|
||||||
|
"config-error": ConfigError,
|
||||||
|
"document-rag-error": DocumentRagError,
|
||||||
|
"flow-error": FlowError,
|
||||||
|
"gateway-error": GatewayError,
|
||||||
|
"graph-rag-error": GraphRagError,
|
||||||
|
"llm-error": LLMError,
|
||||||
|
"load-error": LoadError,
|
||||||
|
"lookup-error": LookupError,
|
||||||
|
"nlp-query-error": NLPQueryError,
|
||||||
|
"objects-query-error": ObjectsQueryError,
|
||||||
|
"request-error": RequestError,
|
||||||
|
"structured-query-error": StructuredQueryError,
|
||||||
|
"unexpected-error": UnexpectedError,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def raise_from_error_dict(error_dict: dict) -> None:
|
||||||
|
"""
|
||||||
|
Raise appropriate exception from TrustGraph error dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_dict: Dictionary with 'type' and 'message' keys
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Appropriate TrustGraphException subclass based on error type
|
||||||
|
"""
|
||||||
|
error_type = error_dict.get("type", "unexpected-error")
|
||||||
|
message = error_dict.get("message", "Unknown error")
|
||||||
|
|
||||||
|
# Look up exception class, default to UnexpectedError
|
||||||
|
exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError)
|
||||||
|
|
||||||
|
# Raise the appropriate exception
|
||||||
|
raise exception_class(message, error_type)
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy exception for backwards compatibility
|
||||||
|
ApplicationException = TrustGraphException
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Optional, Dict, Any, Iterator, Union, List
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
|
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
|
||||||
from . exceptions import ProtocolException, ApplicationException
|
from . exceptions import ProtocolException, raise_from_error_dict
|
||||||
|
|
||||||
|
|
||||||
class SocketClient:
|
class SocketClient:
|
||||||
|
|
@ -126,7 +126,7 @@ class SocketClient:
|
||||||
raise ProtocolException(f"Response ID mismatch")
|
raise ProtocolException(f"Response ID mismatch")
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
raise ApplicationException(response["error"])
|
raise_from_error_dict(response["error"])
|
||||||
|
|
||||||
if "response" not in response:
|
if "response" not in response:
|
||||||
raise ProtocolException(f"Missing response in message")
|
raise ProtocolException(f"Missing response in message")
|
||||||
|
|
@ -171,11 +171,15 @@ class SocketClient:
|
||||||
continue # Ignore messages for other requests
|
continue # Ignore messages for other requests
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
raise ApplicationException(response["error"])
|
raise_from_error_dict(response["error"])
|
||||||
|
|
||||||
if "response" in response:
|
if "response" in response:
|
||||||
resp = response["response"]
|
resp = response["response"]
|
||||||
|
|
||||||
|
# Check for errors in response chunks
|
||||||
|
if "error" in resp:
|
||||||
|
raise_from_error_dict(resp["error"])
|
||||||
|
|
||||||
# Parse different chunk types
|
# Parse different chunk types
|
||||||
chunk = self._parse_chunk(resp)
|
chunk = self._parse_chunk(resp)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -198,18 +202,26 @@ class SocketClient:
|
||||||
content=resp.get("content", ""),
|
content=resp.get("content", ""),
|
||||||
end_of_message=resp.get("end_of_message", False)
|
end_of_message=resp.get("end_of_message", False)
|
||||||
)
|
)
|
||||||
elif chunk_type == "final-answer":
|
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||||
return AgentAnswer(
|
return AgentAnswer(
|
||||||
content=resp.get("content", ""),
|
content=resp.get("content", ""),
|
||||||
end_of_message=resp.get("end_of_message", False),
|
end_of_message=resp.get("end_of_message", False),
|
||||||
end_of_dialog=resp.get("end_of_dialog", False)
|
end_of_dialog=resp.get("end_of_dialog", False)
|
||||||
)
|
)
|
||||||
|
elif chunk_type == "action":
|
||||||
|
# Agent action chunks - treat as thoughts for display purposes
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# RAG-style chunk (or generic chunk)
|
# RAG-style chunk (or generic chunk)
|
||||||
|
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||||
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||||
return RAGChunk(
|
return RAGChunk(
|
||||||
content=resp.get("chunk", ""),
|
content=content,
|
||||||
end_of_stream=resp.get("end_of_stream", False),
|
end_of_stream=resp.get("end_of_stream", False),
|
||||||
error=resp.get("error")
|
error=None # Errors are always thrown, never stored
|
||||||
)
|
)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -79,5 +79,6 @@ class AgentAnswer(StreamingChunk):
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class RAGChunk(StreamingChunk):
|
class RAGChunk(StreamingChunk):
|
||||||
"""RAG streaming chunk"""
|
"""RAG streaming chunk"""
|
||||||
|
chunk_type: str = "rag"
|
||||||
end_of_stream: bool = False
|
end_of_stream: bool = False
|
||||||
error: Optional[Dict[str, str]] = None
|
error: Optional[Dict[str, str]] = None
|
||||||
|
|
|
||||||
|
|
@ -161,6 +161,11 @@ def question(
|
||||||
# Output the chunk
|
# Output the chunk
|
||||||
if current_outputter:
|
if current_outputter:
|
||||||
current_outputter.output(content)
|
current_outputter.output(content)
|
||||||
|
# Flush word buffer after each chunk to avoid delay
|
||||||
|
if current_outputter.word_buffer:
|
||||||
|
print(current_outputter.word_buffer, end="", flush=True)
|
||||||
|
current_outputter.column += len(current_outputter.word_buffer)
|
||||||
|
current_outputter.word_buffer = ""
|
||||||
elif chunk_type == "final-answer":
|
elif chunk_type == "final-answer":
|
||||||
print(content, end="", flush=True)
|
print(content, end="", flush=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
|
||||||
if streaming:
|
if streaming:
|
||||||
# Stream output to stdout without newline
|
# Stream output to stdout without newline
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk.content, end="", flush=True)
|
print(chunk, end="", flush=True)
|
||||||
# Add final newline after streaming
|
# Add final newline after streaming
|
||||||
print()
|
print()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -31,36 +31,16 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
full_response = {"text": "", "object": ""}
|
# Stream output (prompt yields strings directly)
|
||||||
|
|
||||||
# Stream output
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
content = chunk.content
|
if chunk:
|
||||||
if content:
|
print(chunk, end="", flush=True)
|
||||||
print(content, end="", flush=True)
|
# Add final newline after streaming
|
||||||
full_response["text"] += content
|
|
||||||
|
|
||||||
# Check if this is an object response (JSON)
|
|
||||||
if hasattr(chunk, 'object') and chunk.object:
|
|
||||||
full_response["object"] = chunk.object
|
|
||||||
|
|
||||||
# Handle final output
|
|
||||||
if full_response["text"]:
|
|
||||||
# Add final newline after streaming text
|
|
||||||
print()
|
print()
|
||||||
elif full_response["object"]:
|
|
||||||
# Print JSON object (pretty-printed)
|
|
||||||
print(json.dumps(json.loads(full_response["object"]), indent=4))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Non-streaming: handle response
|
# Non-streaming: print complete response
|
||||||
if isinstance(response, str):
|
|
||||||
print(response)
|
print(response)
|
||||||
elif isinstance(response, dict):
|
|
||||||
if "text" in response:
|
|
||||||
print(response["text"])
|
|
||||||
elif "object" in response:
|
|
||||||
print(json.dumps(json.loads(response["object"]), indent=4))
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up socket connection
|
# Clean up socket connection
|
||||||
|
|
|
||||||
|
|
@ -171,14 +171,12 @@ def check_api_gateway(url: str, timeout: int, token: Optional[str] = None) -> Tu
|
||||||
|
|
||||||
|
|
||||||
def check_processors(url: str, min_processors: int, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
def check_processors(url: str, min_processors: int, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""Check if processors are running via show-processor-state."""
|
"""Check if processors are running via metrics endpoint."""
|
||||||
try:
|
try:
|
||||||
api = Api(url, token=token)
|
# Construct metrics URL from API URL
|
||||||
|
if not url.endswith('/'):
|
||||||
# Use the metrics endpoint similar to show_processor_state
|
url += '/'
|
||||||
# This is a simplified check - we'll use requests to check the metrics
|
metrics_url = f"{url}api/metrics/query?query=processor_info"
|
||||||
metrics_url = url.replace('http://', '').replace('https://', '').split('/')[0]
|
|
||||||
metrics_url = f"http://{metrics_url}:8088/api/metrics/query?query=processor_info"
|
|
||||||
|
|
||||||
resp = requests.get(metrics_url, timeout=timeout)
|
resp = requests.get(metrics_url, timeout=timeout)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
|
|
@ -199,7 +197,7 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona
|
||||||
def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""Check if flow classes are loaded."""
|
"""Check if flow classes are loaded."""
|
||||||
try:
|
try:
|
||||||
api = Api(url, token=token)
|
api = Api(url, token=token, timeout=timeout)
|
||||||
flow_api = api.flow()
|
flow_api = api.flow()
|
||||||
|
|
||||||
classes = flow_api.list_classes()
|
classes = flow_api.list_classes()
|
||||||
|
|
@ -216,7 +214,7 @@ def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> T
|
||||||
def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""Check if flow manager is responding."""
|
"""Check if flow manager is responding."""
|
||||||
try:
|
try:
|
||||||
api = Api(url, token=token)
|
api = Api(url, token=token, timeout=timeout)
|
||||||
flow_api = api.flow()
|
flow_api = api.flow()
|
||||||
|
|
||||||
flows = flow_api.list()
|
flows = flow_api.list()
|
||||||
|
|
@ -231,12 +229,22 @@ def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bo
|
||||||
def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""Check if prompts are loaded."""
|
"""Check if prompts are loaded."""
|
||||||
try:
|
try:
|
||||||
api = Api(url, token=token)
|
api = Api(url, token=token, timeout=timeout)
|
||||||
|
config = api.config()
|
||||||
|
|
||||||
prompts = api.prompts().list()
|
# Import ConfigKey here to avoid top-level import issues
|
||||||
|
from trustgraph.api.types import ConfigKey
|
||||||
|
import json
|
||||||
|
|
||||||
if prompts and len(prompts) > 0:
|
# Get the template-index which lists all prompts
|
||||||
return True, f"Found {len(prompts)} prompt(s)"
|
values = config.get([
|
||||||
|
ConfigKey(type="prompt", key="template-index")
|
||||||
|
])
|
||||||
|
|
||||||
|
ix = json.loads(values[0].value)
|
||||||
|
|
||||||
|
if ix and len(ix) > 0:
|
||||||
|
return True, f"Found {len(ix)} prompt(s)"
|
||||||
else:
|
else:
|
||||||
return False, "No prompts found"
|
return False, "No prompts found"
|
||||||
|
|
||||||
|
|
@ -247,7 +255,7 @@ def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[
|
||||||
def check_library(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
def check_library(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""Check if library service is responding."""
|
"""Check if library service is responding."""
|
||||||
try:
|
try:
|
||||||
api = Api(url, token=token)
|
api = Api(url, token=token, timeout=timeout)
|
||||||
library_api = api.library()
|
library_api = api.library()
|
||||||
|
|
||||||
# Try to get documents (with default user)
|
# Try to get documents (with default user)
|
||||||
|
|
@ -365,10 +373,10 @@ def main():
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("TrustGraph System Status Verification")
|
print("TrustGraph System Status Verification")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Global timeout: {args.global_timeout}s")
|
# print(f"Global timeout: {args.global_timeout}s")
|
||||||
print(f"Check timeout: {args.check_timeout}s")
|
# print(f"Check timeout: {args.check_timeout}s")
|
||||||
print(f"Retry delay: {args.retry_delay}s")
|
# print(f"Retry delay: {args.retry_delay}s")
|
||||||
print("=" * 60)
|
# print("=" * 60)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Phase 1: Infrastructure
|
# Phase 1: Infrastructure
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue