Fix Python streaming SDK issues (#580)

* Fix verify CLI issues

* Fixing content mechanisms in API

* Fixing error handling

* Fixing invoke_prompt, invoke_llm, invoke_agent
This commit is contained in:
cybermaggedon 2025-12-04 20:42:25 +00:00 committed by GitHub
parent 52ca74bbbc
commit 664bce6182
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 234 additions and 58 deletions

View file

@ -34,7 +34,26 @@ from .types import (
) )
# Exceptions # Exceptions
from .exceptions import ProtocolException, ApplicationException from .exceptions import (
ProtocolException,
TrustGraphException,
AgentError,
ConfigError,
DocumentRagError,
FlowError,
GatewayError,
GraphRagError,
LLMError,
LoadError,
LookupError,
NLPQueryError,
ObjectsQueryError,
RequestError,
StructuredQueryError,
UnexpectedError,
# Legacy alias
ApplicationException,
)
__all__ = [ __all__ = [
# Core API # Core API
@ -75,6 +94,21 @@ __all__ = [
# Exceptions # Exceptions
"ProtocolException", "ProtocolException",
"ApplicationException", "TrustGraphException",
"AgentError",
"ConfigError",
"DocumentRagError",
"FlowError",
"GatewayError",
"GraphRagError",
"LLMError",
"LoadError",
"LookupError",
"NLPQueryError",
"ObjectsQueryError",
"RequestError",
"StructuredQueryError",
"UnexpectedError",
"ApplicationException", # Legacy alias
] ]

View file

@ -130,18 +130,26 @@ class AsyncSocketClient:
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)
) )
elif chunk_type == "final-answer": elif chunk_type == "answer" or chunk_type == "final-answer":
return AgentAnswer( return AgentAnswer(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False) end_of_dialog=resp.get("end_of_dialog", False)
) )
elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
else: else:
# RAG-style chunk (or generic chunk) # RAG-style chunk (or generic chunk)
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk( return RAGChunk(
content=resp.get("chunk", ""), content=content,
end_of_stream=resp.get("end_of_stream", False), end_of_stream=resp.get("end_of_stream", False),
error=resp.get("error") error=None # Errors are always thrown, never stored
) )
async def aclose(self): async def aclose(self):

View file

@ -1,6 +1,134 @@
"""
TrustGraph API Exceptions
Exception hierarchy for errors returned by TrustGraph services.
Each service error type maps to a specific exception class.
"""
# Protocol-level exceptions (communication errors)
class ProtocolException(Exception): class ProtocolException(Exception):
"""Raised when WebSocket protocol errors occur"""
pass pass
class ApplicationException(Exception):
# Base class for all TrustGraph application errors
class TrustGraphException(Exception):
"""Base class for all TrustGraph service errors"""
def __init__(self, message: str, error_type: str = None):
super().__init__(message)
self.message = message
self.error_type = error_type
# Service-specific exceptions
class AgentError(TrustGraphException):
"""Agent service error"""
pass pass
class ConfigError(TrustGraphException):
"""Configuration service error"""
pass
class DocumentRagError(TrustGraphException):
"""Document RAG retrieval error"""
pass
class FlowError(TrustGraphException):
"""Flow management error"""
pass
class GatewayError(TrustGraphException):
"""API Gateway error"""
pass
class GraphRagError(TrustGraphException):
"""Graph RAG retrieval error"""
pass
class LLMError(TrustGraphException):
"""LLM service error"""
pass
class LoadError(TrustGraphException):
"""Data loading error"""
pass
class LookupError(TrustGraphException):
"""Lookup/search error"""
pass
class NLPQueryError(TrustGraphException):
"""NLP query service error"""
pass
class ObjectsQueryError(TrustGraphException):
"""Objects query service error"""
pass
class RequestError(TrustGraphException):
"""Request processing error"""
pass
class StructuredQueryError(TrustGraphException):
"""Structured query service error"""
pass
class UnexpectedError(TrustGraphException):
"""Unexpected/unknown error"""
pass
# Mapping from error type string to exception class
ERROR_TYPE_MAPPING = {
"agent-error": AgentError,
"config-error": ConfigError,
"document-rag-error": DocumentRagError,
"flow-error": FlowError,
"gateway-error": GatewayError,
"graph-rag-error": GraphRagError,
"llm-error": LLMError,
"load-error": LoadError,
"lookup-error": LookupError,
"nlp-query-error": NLPQueryError,
"objects-query-error": ObjectsQueryError,
"request-error": RequestError,
"structured-query-error": StructuredQueryError,
"unexpected-error": UnexpectedError,
}
def raise_from_error_dict(error_dict: dict) -> None:
"""
Raise appropriate exception from TrustGraph error dictionary.
Args:
error_dict: Dictionary with 'type' and 'message' keys
Raises:
Appropriate TrustGraphException subclass based on error type
"""
error_type = error_dict.get("type", "unexpected-error")
message = error_dict.get("message", "Unknown error")
# Look up exception class, default to UnexpectedError
exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError)
# Raise the appropriate exception
raise exception_class(message, error_type)
# Legacy exception for backwards compatibility
ApplicationException = TrustGraphException

View file

@ -6,7 +6,7 @@ from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
from . exceptions import ProtocolException, ApplicationException from . exceptions import ProtocolException, raise_from_error_dict
class SocketClient: class SocketClient:
@ -126,7 +126,7 @@ class SocketClient:
raise ProtocolException(f"Response ID mismatch") raise ProtocolException(f"Response ID mismatch")
if "error" in response: if "error" in response:
raise ApplicationException(response["error"]) raise_from_error_dict(response["error"])
if "response" not in response: if "response" not in response:
raise ProtocolException(f"Missing response in message") raise ProtocolException(f"Missing response in message")
@ -171,11 +171,15 @@ class SocketClient:
continue # Ignore messages for other requests continue # Ignore messages for other requests
if "error" in response: if "error" in response:
raise ApplicationException(response["error"]) raise_from_error_dict(response["error"])
if "response" in response: if "response" in response:
resp = response["response"] resp = response["response"]
# Check for errors in response chunks
if "error" in resp:
raise_from_error_dict(resp["error"])
# Parse different chunk types # Parse different chunk types
chunk = self._parse_chunk(resp) chunk = self._parse_chunk(resp)
yield chunk yield chunk
@ -198,18 +202,26 @@ class SocketClient:
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)
) )
elif chunk_type == "final-answer": elif chunk_type == "answer" or chunk_type == "final-answer":
return AgentAnswer( return AgentAnswer(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False) end_of_dialog=resp.get("end_of_dialog", False)
) )
elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
else: else:
# RAG-style chunk (or generic chunk) # RAG-style chunk (or generic chunk)
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk( return RAGChunk(
content=resp.get("chunk", ""), content=content,
end_of_stream=resp.get("end_of_stream", False), end_of_stream=resp.get("end_of_stream", False),
error=resp.get("error") error=None # Errors are always thrown, never stored
) )
def close(self) -> None: def close(self) -> None:

View file

@ -79,5 +79,6 @@ class AgentAnswer(StreamingChunk):
@dataclasses.dataclass @dataclasses.dataclass
class RAGChunk(StreamingChunk): class RAGChunk(StreamingChunk):
"""RAG streaming chunk""" """RAG streaming chunk"""
chunk_type: str = "rag"
end_of_stream: bool = False end_of_stream: bool = False
error: Optional[Dict[str, str]] = None error: Optional[Dict[str, str]] = None

View file

@ -161,6 +161,11 @@ def question(
# Output the chunk # Output the chunk
if current_outputter: if current_outputter:
current_outputter.output(content) current_outputter.output(content)
# Flush word buffer after each chunk to avoid delay
if current_outputter.word_buffer:
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
elif chunk_type == "final-answer": elif chunk_type == "final-answer":
print(content, end="", flush=True) print(content, end="", flush=True)

View file

@ -28,7 +28,7 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
if streaming: if streaming:
# Stream output to stdout without newline # Stream output to stdout without newline
for chunk in response: for chunk in response:
print(chunk.content, end="", flush=True) print(chunk, end="", flush=True)
# Add final newline after streaming # Add final newline after streaming
print() print()
else: else:

View file

@ -31,36 +31,16 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
) )
if streaming: if streaming:
full_response = {"text": "", "object": ""} # Stream output (prompt yields strings directly)
# Stream output
for chunk in response: for chunk in response:
content = chunk.content if chunk:
if content: print(chunk, end="", flush=True)
print(content, end="", flush=True) # Add final newline after streaming
full_response["text"] += content
# Check if this is an object response (JSON)
if hasattr(chunk, 'object') and chunk.object:
full_response["object"] = chunk.object
# Handle final output
if full_response["text"]:
# Add final newline after streaming text
print() print()
elif full_response["object"]:
# Print JSON object (pretty-printed)
print(json.dumps(json.loads(full_response["object"]), indent=4))
else: else:
# Non-streaming: handle response # Non-streaming: print complete response
if isinstance(response, str):
print(response) print(response)
elif isinstance(response, dict):
if "text" in response:
print(response["text"])
elif "object" in response:
print(json.dumps(json.loads(response["object"]), indent=4))
finally: finally:
# Clean up socket connection # Clean up socket connection

View file

@ -171,14 +171,12 @@ def check_api_gateway(url: str, timeout: int, token: Optional[str] = None) -> Tu
def check_processors(url: str, min_processors: int, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: def check_processors(url: str, min_processors: int, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
"""Check if processors are running via show-processor-state.""" """Check if processors are running via metrics endpoint."""
try: try:
api = Api(url, token=token) # Construct metrics URL from API URL
if not url.endswith('/'):
# Use the metrics endpoint similar to show_processor_state url += '/'
# This is a simplified check - we'll use requests to check the metrics metrics_url = f"{url}api/metrics/query?query=processor_info"
metrics_url = url.replace('http://', '').replace('https://', '').split('/')[0]
metrics_url = f"http://{metrics_url}:8088/api/metrics/query?query=processor_info"
resp = requests.get(metrics_url, timeout=timeout) resp = requests.get(metrics_url, timeout=timeout)
if resp.status_code == 200: if resp.status_code == 200:
@ -199,7 +197,7 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona
def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
"""Check if flow classes are loaded.""" """Check if flow classes are loaded."""
try: try:
api = Api(url, token=token) api = Api(url, token=token, timeout=timeout)
flow_api = api.flow() flow_api = api.flow()
classes = flow_api.list_classes() classes = flow_api.list_classes()
@ -216,7 +214,7 @@ def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> T
def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
"""Check if flow manager is responding.""" """Check if flow manager is responding."""
try: try:
api = Api(url, token=token) api = Api(url, token=token, timeout=timeout)
flow_api = api.flow() flow_api = api.flow()
flows = flow_api.list() flows = flow_api.list()
@ -231,12 +229,22 @@ def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bo
def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
"""Check if prompts are loaded.""" """Check if prompts are loaded."""
try: try:
api = Api(url, token=token) api = Api(url, token=token, timeout=timeout)
config = api.config()
prompts = api.prompts().list() # Import ConfigKey here to avoid top-level import issues
from trustgraph.api.types import ConfigKey
import json
if prompts and len(prompts) > 0: # Get the template-index which lists all prompts
return True, f"Found {len(prompts)} prompt(s)" values = config.get([
ConfigKey(type="prompt", key="template-index")
])
ix = json.loads(values[0].value)
if ix and len(ix) > 0:
return True, f"Found {len(ix)} prompt(s)"
else: else:
return False, "No prompts found" return False, "No prompts found"
@ -247,7 +255,7 @@ def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[
def check_library(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: def check_library(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]:
"""Check if library service is responding.""" """Check if library service is responding."""
try: try:
api = Api(url, token=token) api = Api(url, token=token, timeout=timeout)
library_api = api.library() library_api = api.library()
# Try to get documents (with default user) # Try to get documents (with default user)
@ -365,10 +373,10 @@ def main():
print("=" * 60) print("=" * 60)
print("TrustGraph System Status Verification") print("TrustGraph System Status Verification")
print("=" * 60) print("=" * 60)
print(f"Global timeout: {args.global_timeout}s") # print(f"Global timeout: {args.global_timeout}s")
print(f"Check timeout: {args.check_timeout}s") # print(f"Check timeout: {args.check_timeout}s")
print(f"Retry delay: {args.retry_delay}s") # print(f"Retry delay: {args.retry_delay}s")
print("=" * 60) # print("=" * 60)
print() print()
# Phase 1: Infrastructure # Phase 1: Infrastructure