mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-12 16:52:37 +02:00
Streaming LLM part 2 (#567)
* Updates for agent API with streaming support * Added tg-dump-queues tool to dump Pulsar queues to a log * Updated tg-invoke-agent, incremental output * Queue dumper CLI - might be useful for debug * Updating for tests
This commit is contained in:
parent
310a2deb06
commit
b1cc724f7d
8 changed files with 609 additions and 51 deletions
|
|
@ -135,10 +135,10 @@ Args: {
|
|||
# Verify prompt client was called correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.agent_react.assert_called_once()
|
||||
|
||||
|
||||
# Verify the prompt variables passed to agent_react
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
assert variables["question"] == question
|
||||
assert len(variables["tools"]) == 3 # knowledge_query, text_completion, web_search
|
||||
assert variables["context"] == "You are a helpful AI assistant with access to knowledge and tools."
|
||||
|
|
@ -237,7 +237,7 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
# Verify history was included in prompt variables
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
assert len(variables["history"]) == 1
|
||||
assert variables["history"][0]["thought"] == "I need to search for information about machine learning"
|
||||
assert variables["history"][0]["action"] == "knowledge_query"
|
||||
|
|
@ -337,7 +337,7 @@ Args: {
|
|||
# Verify tool information was passed to prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
|
||||
# Should have all 3 tools available
|
||||
tool_names = [tool["name"] for tool in variables["tools"]]
|
||||
|
|
@ -408,7 +408,7 @@ Args: {args_json}"""
|
|||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
|
||||
assert variables["context"] == "You are an expert in machine learning research."
|
||||
assert variables["question"] == question
|
||||
|
|
@ -427,7 +427,7 @@ Args: {args_json}"""
|
|||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
|
||||
assert len(variables["tools"]) == 0
|
||||
assert variables["tool_names"] == ""
|
||||
|
|
@ -682,7 +682,7 @@ Final Answer: {
|
|||
# Verify history was processed correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
assert len(variables["history"]) == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -709,7 +709,7 @@ Final Answer: {
|
|||
# Verify JSON was properly serialized in prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
variables = call_args.kwargs['variables']
|
||||
|
||||
# Should not raise JSON serialization errors
|
||||
json_str = json.dumps(variables, indent=4)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
logger.info(f"DEBUG prompt_client: prompt called, id={id}, streaming={streaming}, chunk_callback={chunk_callback is not None}")
|
||||
|
||||
if not streaming:
|
||||
logger.info("DEBUG prompt_client: Non-streaming path")
|
||||
# Non-streaming path
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
|
|
@ -31,44 +36,59 @@ class PromptClient(RequestResponse):
|
|||
return json.loads(resp.object)
|
||||
|
||||
else:
|
||||
logger.info("DEBUG prompt_client: Streaming path")
|
||||
# Streaming path - collect all chunks
|
||||
full_text = ""
|
||||
full_object = None
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_text, full_object
|
||||
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||
|
||||
if resp.error:
|
||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text:
|
||||
full_text += resp.text
|
||||
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars")
|
||||
# Call chunk callback if provided
|
||||
if chunk_callback:
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
elif resp.object:
|
||||
logger.info(f"DEBUG prompt_client: Got object response")
|
||||
full_object = resp.object
|
||||
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||
return end_stream
|
||||
|
||||
logger.info("DEBUG prompt_client: Creating PromptRequest")
|
||||
req = PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
||||
await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
),
|
||||
req,
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars")
|
||||
|
||||
if full_text: return full_text
|
||||
if full_text:
|
||||
logger.info("DEBUG prompt_client: Returning full_text")
|
||||
return full_text
|
||||
|
||||
logger.info("DEBUG prompt_client: Returning parsed full_object")
|
||||
return json.loads(full_object)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
|
|
|
|||
|
|
@ -43,12 +43,18 @@ class Subscriber:
|
|||
|
||||
async def start(self):
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
topic = self.topic,
|
||||
subscription_name = self.subscription,
|
||||
consumer_name = self.consumer_name,
|
||||
schema = JsonSchema(self.schema),
|
||||
)
|
||||
# Build subscribe arguments
|
||||
subscribe_args = {
|
||||
'topic': self.topic,
|
||||
'subscription_name': self.subscription,
|
||||
'consumer_name': self.consumer_name,
|
||||
}
|
||||
|
||||
# Only add schema if provided (omit if None)
|
||||
if self.schema is not None:
|
||||
subscribe_args['schema'] = JsonSchema(self.schema)
|
||||
|
||||
self.consumer = self.client.subscribe(**subscribe_args)
|
||||
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
|
|
@ -87,10 +93,14 @@ class Subscriber:
|
|||
if self.draining and drain_end_time is None:
|
||||
drain_end_time = time.time() + self.drain_timeout
|
||||
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||
|
||||
|
||||
# Stop accepting new messages from Pulsar during drain
|
||||
if self.consumer:
|
||||
self.consumer.pause_message_listener()
|
||||
try:
|
||||
self.consumer.pause_message_listener()
|
||||
except _pulsar.InvalidConfiguration:
|
||||
# Not all consumers have message listeners (e.g., blocking receive mode)
|
||||
pass
|
||||
|
||||
# Check drain timeout
|
||||
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||
|
|
@ -145,12 +155,21 @@ class Subscriber:
|
|||
finally:
|
||||
# Negative acknowledge any pending messages
|
||||
for msg in self.pending_acks.values():
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
try:
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Consumer already closed
|
||||
self.pending_acks.clear()
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
try:
|
||||
self.consumer.unsubscribe()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
try:
|
||||
self.consumer.close()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
self.consumer = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ tg-delete-mcp-tool = "trustgraph.cli.delete_mcp_tool:main"
|
|||
tg-delete-kg-core = "trustgraph.cli.delete_kg_core:main"
|
||||
tg-delete-tool = "trustgraph.cli.delete_tool:main"
|
||||
tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
|
||||
tg-dump-queues = "trustgraph.cli.dump_queues:main"
|
||||
tg-get-flow-class = "trustgraph.cli.get_flow_class:main"
|
||||
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
|
||||
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
|
||||
|
|
|
|||
362
trustgraph-cli/trustgraph/cli/dump_queues.py
Normal file
362
trustgraph-cli/trustgraph/cli/dump_queues.py
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
"""
|
||||
Multi-queue Pulsar message dumper for debugging TrustGraph message flows.
|
||||
|
||||
This utility monitors multiple Pulsar queues simultaneously and logs all messages
|
||||
to a file with timestamps and pretty-printed formatting. Useful for debugging
|
||||
message flows, diagnosing stuck services, and understanding system behavior.
|
||||
|
||||
Uses TrustGraph's Subscriber abstraction for future-proof pub/sub compatibility.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import BytesSchema
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
from trustgraph.base.subscriber import Subscriber
|
||||
|
||||
def format_message(queue_name, msg):
|
||||
"""Format a message with timestamp and queue name."""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# Try to parse as JSON and pretty-print
|
||||
try:
|
||||
# Handle both Message objects and raw bytes
|
||||
if hasattr(msg, 'value'):
|
||||
# Message object with .value() method
|
||||
value = msg.value()
|
||||
else:
|
||||
# Raw bytes from schema-less subscription
|
||||
value = msg
|
||||
|
||||
# If it's bytes, decode it
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
body = json.dumps(parsed, indent=2)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
body = value
|
||||
else:
|
||||
# Try to convert to dict for pretty printing
|
||||
try:
|
||||
# Pulsar schema objects have __dict__ or similar
|
||||
if hasattr(value, '__dict__'):
|
||||
parsed = {k: v for k, v in value.__dict__.items()
|
||||
if not k.startswith('_')}
|
||||
else:
|
||||
parsed = str(value)
|
||||
body = json.dumps(parsed, indent=2, default=str)
|
||||
except (TypeError, AttributeError):
|
||||
body = str(value)
|
||||
|
||||
except Exception as e:
|
||||
body = f"<Error formatting message: {e}>\n{str(msg)}"
|
||||
|
||||
# Format the output
|
||||
header = f"\n{'='*80}\n[{timestamp}] Queue: {queue_name}\n{'='*80}\n"
|
||||
return header + body + "\n"
|
||||
|
||||
|
||||
async def monitor_queue(subscriber, queue_name, central_queue, monitor_id, shutdown_event):
|
||||
"""
|
||||
Monitor a single queue via Subscriber and forward messages to central queue.
|
||||
|
||||
Args:
|
||||
subscriber: Subscriber instance for this queue
|
||||
queue_name: Name of the queue (for logging)
|
||||
central_queue: asyncio.Queue to forward messages to
|
||||
monitor_id: Unique ID for this monitor's subscription
|
||||
shutdown_event: asyncio.Event to signal shutdown
|
||||
"""
|
||||
msg_queue = None
|
||||
try:
|
||||
# Subscribe to all messages from this Subscriber
|
||||
msg_queue = await subscriber.subscribe_all(monitor_id)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
# Read from Subscriber's internal queue with timeout
|
||||
msg = await asyncio.wait_for(msg_queue.get(), timeout=0.5)
|
||||
timestamp = datetime.now()
|
||||
formatted = format_message(queue_name, msg)
|
||||
|
||||
# Forward to central queue for writing
|
||||
await central_queue.put((timestamp, queue_name, formatted))
|
||||
except asyncio.TimeoutError:
|
||||
# No message, check shutdown flag again
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
if not shutdown_event.is_set():
|
||||
error_msg = f"\n{'='*80}\n[{datetime.now().isoformat()}] ERROR in monitor for {queue_name}\n{'='*80}\n{e}\n"
|
||||
await central_queue.put((datetime.now(), queue_name, error_msg))
|
||||
finally:
|
||||
# Clean unsubscribe
|
||||
if msg_queue is not None:
|
||||
try:
|
||||
await subscriber.unsubscribe_all(monitor_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def log_writer(central_queue, file_handle, shutdown_event, console_output=True):
|
||||
"""
|
||||
Write messages from central queue to file.
|
||||
|
||||
Args:
|
||||
central_queue: asyncio.Queue containing (timestamp, queue_name, formatted_msg) tuples
|
||||
file_handle: Open file handle to write to
|
||||
shutdown_event: asyncio.Event to signal shutdown
|
||||
console_output: Whether to print abbreviated messages to console
|
||||
"""
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
# Wait for messages with timeout to check shutdown flag
|
||||
timestamp, queue_name, formatted_msg = await asyncio.wait_for(
|
||||
central_queue.get(), timeout=0.5
|
||||
)
|
||||
|
||||
# Write to file
|
||||
file_handle.write(formatted_msg)
|
||||
file_handle.flush()
|
||||
|
||||
# Print abbreviated message to console
|
||||
if console_output:
|
||||
time_str = timestamp.strftime('%H:%M:%S')
|
||||
print(f"[{time_str}] {queue_name}: Message received")
|
||||
except asyncio.TimeoutError:
|
||||
# No message, check shutdown flag again
|
||||
continue
|
||||
|
||||
finally:
|
||||
# Flush remaining messages after shutdown
|
||||
while not central_queue.empty():
|
||||
try:
|
||||
timestamp, queue_name, formatted_msg = central_queue.get_nowait()
|
||||
file_handle.write(formatted_msg)
|
||||
file_handle.flush()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
|
||||
async def async_main(queues, output_file, pulsar_host, listener_name, subscriber_name, append_mode):
|
||||
"""
|
||||
Main async function to monitor multiple queues concurrently.
|
||||
|
||||
Args:
|
||||
queues: List of queue names to monitor
|
||||
output_file: Path to output file
|
||||
pulsar_host: Pulsar connection URL
|
||||
listener_name: Pulsar listener name
|
||||
subscriber_name: Base name for subscribers
|
||||
append_mode: Whether to append to existing file
|
||||
"""
|
||||
print(f"TrustGraph Queue Dumper")
|
||||
print(f"Monitoring {len(queues)} queue(s):")
|
||||
for q in queues:
|
||||
print(f" - {q}")
|
||||
print(f"Output file: {output_file}")
|
||||
print(f"Mode: {'append' if append_mode else 'overwrite'}")
|
||||
print(f"Press Ctrl+C to stop\n")
|
||||
|
||||
# Connect to Pulsar
|
||||
try:
|
||||
client = pulsar.Client(pulsar_host, listener_name=listener_name)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to Pulsar at {pulsar_host}: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Create Subscribers and central queue
|
||||
central_queue = asyncio.Queue()
|
||||
subscribers = []
|
||||
|
||||
for queue_name in queues:
|
||||
try:
|
||||
sub = Subscriber(
|
||||
client=client,
|
||||
topic=queue_name,
|
||||
subscription=subscriber_name,
|
||||
consumer_name=f"{subscriber_name}-{queue_name}",
|
||||
schema=None, # No schema - accept any message type
|
||||
)
|
||||
await sub.start()
|
||||
subscribers.append((queue_name, sub))
|
||||
print(f"✓ Subscribed to: {queue_name}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error subscribing to {queue_name}: {e}", file=sys.stderr)
|
||||
|
||||
if not subscribers:
|
||||
print("\nNo subscribers created. Exiting.", file=sys.stderr)
|
||||
client.close()
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\nListening for messages...\n")
|
||||
|
||||
# Open output file
|
||||
mode = 'a' if append_mode else 'w'
|
||||
try:
|
||||
with open(output_file, mode) as f:
|
||||
f.write(f"\n{'#'*80}\n")
|
||||
f.write(f"# Session started: {datetime.now().isoformat()}\n")
|
||||
f.write(f"# Monitoring queues: {', '.join(queues)}\n")
|
||||
f.write(f"{'#'*80}\n")
|
||||
f.flush()
|
||||
|
||||
# Create shutdown event for clean coordination
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
# Start monitoring tasks
|
||||
tasks = []
|
||||
try:
|
||||
# Create one monitor task per subscriber
|
||||
for queue_name, sub in subscribers:
|
||||
task = asyncio.create_task(
|
||||
monitor_queue(sub, queue_name, central_queue, "logger", shutdown_event)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Create single writer task
|
||||
writer_task = asyncio.create_task(
|
||||
log_writer(central_queue, f, shutdown_event)
|
||||
)
|
||||
tasks.append(writer_task)
|
||||
|
||||
# Wait for all tasks (they check shutdown_event)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
# Signal shutdown to all tasks
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for tasks to finish cleanly (with timeout)
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
print("Warning: Shutdown timeout", file=sys.stderr)
|
||||
|
||||
# Write session end marker
|
||||
f.write(f"\n{'#'*80}\n")
|
||||
f.write(f"# Session ended: {datetime.now().isoformat()}\n")
|
||||
f.write(f"{'#'*80}\n")
|
||||
|
||||
except IOError as e:
|
||||
print(f"Error writing to {output_file}: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Clean shutdown of Subscribers
|
||||
for _, sub in subscribers:
|
||||
await sub.stop()
|
||||
client.close()
|
||||
|
||||
print(f"\nMessages logged to: {output_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-dump-queues',
|
||||
description='Monitor and dump messages from multiple Pulsar queues',
|
||||
epilog="""
|
||||
Examples:
|
||||
# Monitor agent and prompt queues
|
||||
tg-dump-queues non-persistent://tg/request/agent:default \\
|
||||
non-persistent://tg/request/prompt:default
|
||||
|
||||
# Monitor with custom output file
|
||||
tg-dump-queues non-persistent://tg/request/agent:default \\
|
||||
--output debug.log
|
||||
|
||||
# Append to existing log file
|
||||
tg-dump-queues non-persistent://tg/request/agent:default \\
|
||||
--output queue.log --append
|
||||
|
||||
Common queue patterns:
|
||||
- Agent requests: non-persistent://tg/request/agent:default
|
||||
- Agent responses: non-persistent://tg/response/agent:default
|
||||
- Prompt requests: non-persistent://tg/request/prompt:default
|
||||
- Prompt responses: non-persistent://tg/response/prompt:default
|
||||
- LLM requests: non-persistent://tg/request/text-completion:default
|
||||
- LLM responses: non-persistent://tg/response/text-completion:default
|
||||
|
||||
IMPORTANT:
|
||||
This tool subscribes to queues without a schema (schema-less mode). To avoid
|
||||
schema conflicts, ensure that TrustGraph services and flows are already started
|
||||
before running this tool. If this tool subscribes first, the real services may
|
||||
encounter schema mismatch errors when they try to connect.
|
||||
|
||||
Best practice: Start services → Set up flows → Run tg-dump-queues
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'queues',
|
||||
nargs='+',
|
||||
help='Pulsar queue names to monitor'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='queue.log',
|
||||
help='Output file (default: queue.log)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--append', '-a',
|
||||
action='store_true',
|
||||
help='Append to output file instead of overwriting'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-host',
|
||||
default='pulsar://localhost:6650',
|
||||
help='Pulsar host URL (default: pulsar://localhost:6650)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--listener-name',
|
||||
default='localhost',
|
||||
help='Pulsar listener name (default: localhost)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--subscriber',
|
||||
default='debug',
|
||||
help='Subscriber name for queue subscription (default: debug)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Filter out any accidentally included flags
|
||||
queues = [q for q in args.queues if not q.startswith('--')]
|
||||
|
||||
if not queues:
|
||||
parser.error("No queues specified")
|
||||
|
||||
# Run async main
|
||||
try:
|
||||
asyncio.run(async_main(
|
||||
queues=queues,
|
||||
output_file=args.output,
|
||||
pulsar_host=args.pulsar_host,
|
||||
listener_name=args.listener_name,
|
||||
subscriber_name=args.subscriber,
|
||||
append_mode=args.append
|
||||
))
|
||||
except KeyboardInterrupt:
|
||||
# Already handled in async_main
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Fatal error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -14,6 +14,78 @@ default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
|||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
class Outputter:
|
||||
def __init__(self, width=75, prefix="> "):
|
||||
self.width = width
|
||||
self.prefix = prefix
|
||||
self.column = 0
|
||||
self.word_buffer = ""
|
||||
self.just_wrapped = False
|
||||
|
||||
def __enter__(self):
|
||||
# Print prefix at start of first line
|
||||
print(self.prefix, end="", flush=True)
|
||||
self.column = len(self.prefix)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Flush remaining word buffer
|
||||
if self.word_buffer:
|
||||
print(self.word_buffer, end="", flush=True)
|
||||
self.column += len(self.word_buffer)
|
||||
self.word_buffer = ""
|
||||
|
||||
# Add final newline if not at line start
|
||||
if self.column > 0:
|
||||
print(flush=True)
|
||||
self.column = 0
|
||||
|
||||
def output(self, text):
|
||||
for char in text:
|
||||
# Handle whitespace (space/tab)
|
||||
if char in (' ', '\t'):
|
||||
# Flush word buffer if present
|
||||
if self.word_buffer:
|
||||
# Check if word + space would exceed width
|
||||
if self.column + len(self.word_buffer) + 1 > self.width:
|
||||
# Wrap: newline + prefix
|
||||
print(flush=True)
|
||||
print(self.prefix, end="", flush=True)
|
||||
self.column = len(self.prefix)
|
||||
self.just_wrapped = True
|
||||
|
||||
# Output word buffer
|
||||
print(self.word_buffer, end="", flush=True)
|
||||
self.column += len(self.word_buffer)
|
||||
self.word_buffer = ""
|
||||
|
||||
# Output the space
|
||||
print(char, end="", flush=True)
|
||||
self.column += 1
|
||||
self.just_wrapped = False
|
||||
|
||||
# Handle newline
|
||||
elif char == '\n':
|
||||
if self.just_wrapped:
|
||||
# Skip this newline (already wrapped)
|
||||
self.just_wrapped = False
|
||||
else:
|
||||
# Flush word buffer if any
|
||||
if self.word_buffer:
|
||||
print(self.word_buffer, end="", flush=True)
|
||||
self.word_buffer = ""
|
||||
|
||||
# Output newline + prefix
|
||||
print(flush=True)
|
||||
print(self.prefix, end="", flush=True)
|
||||
self.column = len(self.prefix)
|
||||
self.just_wrapped = False
|
||||
|
||||
# Regular character - add to word buffer
|
||||
else:
|
||||
self.word_buffer += char
|
||||
self.just_wrapped = False
|
||||
|
||||
def wrap(text, width=75):
|
||||
if text is None: text = "n/a"
|
||||
out = textwrap.wrap(
|
||||
|
|
@ -41,6 +113,10 @@ async def question(
|
|||
output(wrap(question), "\U00002753 ")
|
||||
print()
|
||||
|
||||
# Track last chunk type and current outputter for streaming
|
||||
last_chunk_type = None
|
||||
current_outputter = None
|
||||
|
||||
def think(x):
|
||||
if verbose:
|
||||
output(wrap(x), "\U0001f914 ")
|
||||
|
|
@ -97,14 +173,30 @@ async def question(
|
|||
chunk_type = response["chunk_type"]
|
||||
content = response.get("content", "")
|
||||
|
||||
if chunk_type == "thought":
|
||||
think(content)
|
||||
elif chunk_type == "observation":
|
||||
observe(content)
|
||||
# Check if we're switching to a new message type
|
||||
if last_chunk_type != chunk_type:
|
||||
# Close previous outputter if exists
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
print() # Blank line between message types
|
||||
|
||||
# Create new outputter for new message type
|
||||
if chunk_type == "thought" and verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
|
||||
current_outputter.__enter__()
|
||||
elif chunk_type == "observation" and verbose:
|
||||
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
|
||||
current_outputter.__enter__()
|
||||
# For answer, don't use Outputter - just print as-is
|
||||
|
||||
last_chunk_type = chunk_type
|
||||
|
||||
# Output the chunk
|
||||
if current_outputter:
|
||||
current_outputter.output(content)
|
||||
elif chunk_type == "answer":
|
||||
print(content)
|
||||
elif chunk_type == "error":
|
||||
raise RuntimeError(content)
|
||||
print(content, end="", flush=True)
|
||||
else:
|
||||
# Handle legacy format (backward compatibility)
|
||||
if "thought" in response:
|
||||
|
|
@ -119,7 +211,15 @@ async def question(
|
|||
if "error" in response:
|
||||
raise RuntimeError(response["error"])
|
||||
|
||||
if obj["complete"]: break
|
||||
if obj["complete"]:
|
||||
# Close any remaining outputter
|
||||
if current_outputter:
|
||||
current_outputter.__exit__(None, None, None)
|
||||
current_outputter = None
|
||||
# Add final newline if we were outputting answer
|
||||
elif last_chunk_type == "answer":
|
||||
print()
|
||||
break
|
||||
|
||||
await ws.close()
|
||||
|
||||
|
|
@ -212,4 +312,4 @@ def main():
|
|||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -220,32 +220,72 @@ class AgentManager:
|
|||
|
||||
logger.info(f"prompt: {variables}")
|
||||
|
||||
logger.info(f"DEBUG: streaming={streaming}, think={think is not None}")
|
||||
|
||||
# Streaming path - use StreamingReActParser
|
||||
if streaming and think:
|
||||
logger.info("DEBUG: Entering streaming path")
|
||||
from .streaming_parser import StreamingReActParser
|
||||
|
||||
# Create parser with streaming callbacks
|
||||
# Thought chunks go to think(), answer chunks go to answer()
|
||||
logger.info("DEBUG: Creating StreamingReActParser")
|
||||
|
||||
# Collect chunks to send via async callbacks
|
||||
thought_chunks = []
|
||||
answer_chunks = []
|
||||
|
||||
# Create parser with synchronous callbacks that just collect chunks
|
||||
parser = StreamingReActParser(
|
||||
on_thought_chunk=lambda chunk: asyncio.create_task(think(chunk)),
|
||||
on_answer_chunk=lambda chunk: asyncio.create_task(answer(chunk) if answer else think(chunk)),
|
||||
on_thought_chunk=lambda chunk: thought_chunks.append(chunk),
|
||||
on_answer_chunk=lambda chunk: answer_chunks.append(chunk),
|
||||
)
|
||||
logger.info("DEBUG: StreamingReActParser created")
|
||||
|
||||
# Create async chunk callback that feeds parser
|
||||
# Create async chunk callback that feeds parser and sends collected chunks
|
||||
async def on_chunk(text):
|
||||
parser.feed(text)
|
||||
logger.info(f"DEBUG: on_chunk called with {len(text)} chars")
|
||||
|
||||
# Track what we had before
|
||||
prev_thought_count = len(thought_chunks)
|
||||
prev_answer_count = len(answer_chunks)
|
||||
|
||||
# Feed the parser (synchronous)
|
||||
logger.info(f"DEBUG: About to call parser.feed")
|
||||
parser.feed(text)
|
||||
logger.info(f"DEBUG: parser.feed returned")
|
||||
|
||||
# Send any new thought chunks
|
||||
for i in range(prev_thought_count, len(thought_chunks)):
|
||||
logger.info(f"DEBUG: Sending thought chunk {i}")
|
||||
await think(thought_chunks[i])
|
||||
|
||||
# Send any new answer chunks
|
||||
for i in range(prev_answer_count, len(answer_chunks)):
|
||||
logger.info(f"DEBUG: Sending answer chunk {i}")
|
||||
if answer:
|
||||
await answer(answer_chunks[i])
|
||||
else:
|
||||
await think(answer_chunks[i])
|
||||
|
||||
logger.info("DEBUG: Getting prompt-request client from context")
|
||||
client = context("prompt-request")
|
||||
logger.info(f"DEBUG: Got client: {client}")
|
||||
|
||||
logger.info("DEBUG: About to call agent_react with streaming=True")
|
||||
# Get streaming response
|
||||
response_text = await context("prompt-request").agent_react(
|
||||
response_text = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk
|
||||
)
|
||||
logger.info(f"DEBUG: agent_react returned, got {len(response_text) if response_text else 0} chars")
|
||||
|
||||
# Finalize parser
|
||||
logger.info("DEBUG: Finalizing parser")
|
||||
parser.finalize()
|
||||
logger.info("DEBUG: Parser finalized")
|
||||
|
||||
# Get result
|
||||
logger.info("DEBUG: Getting result from parser")
|
||||
result = parser.get_result()
|
||||
if result is None:
|
||||
raise RuntimeError("Parser failed to produce a result")
|
||||
|
|
@ -254,11 +294,18 @@ class AgentManager:
|
|||
return result
|
||||
|
||||
else:
|
||||
logger.info("DEBUG: Entering NON-streaming path")
|
||||
# Non-streaming path - get complete text and parse
|
||||
response_text = await context("prompt-request").agent_react(
|
||||
logger.info("DEBUG: Getting prompt-request client from context")
|
||||
client = context("prompt-request")
|
||||
logger.info(f"DEBUG: Got client: {client}")
|
||||
|
||||
logger.info("DEBUG: About to call agent_react with streaming=False")
|
||||
response_text = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=False
|
||||
)
|
||||
logger.info(f"DEBUG: agent_react returned, got response")
|
||||
|
||||
logger.debug(f"Response text:\n{response_text}")
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,11 @@ class StreamingReActParser:
|
|||
self.line_buffer = re.sub(r'\n```$', '', self.line_buffer)
|
||||
|
||||
# Process based on current state
|
||||
# Track previous state to detect if we're making progress
|
||||
while self.line_buffer and self.state != ParserState.COMPLETE:
|
||||
prev_buffer_len = len(self.line_buffer)
|
||||
prev_state = self.state
|
||||
|
||||
if self.state == ParserState.INITIAL:
|
||||
self._process_initial()
|
||||
elif self.state == ParserState.THOUGHT:
|
||||
|
|
@ -130,14 +134,19 @@ class StreamingReActParser:
|
|||
elif self.state == ParserState.FINAL_ANSWER:
|
||||
self._process_final_answer()
|
||||
|
||||
# If no progress was made (buffer unchanged AND state unchanged), break
|
||||
# to avoid infinite loop. We'll process more when the next chunk arrives.
|
||||
if len(self.line_buffer) == prev_buffer_len and self.state == prev_state:
|
||||
break
|
||||
|
||||
def _process_initial(self) -> None:
|
||||
"""Process INITIAL state - looking for 'Thought:' delimiter"""
|
||||
idx = self.line_buffer.find(self.THOUGHT_DELIMITER)
|
||||
|
||||
if idx >= 0:
|
||||
# Found thought delimiter
|
||||
# Discard any content before it
|
||||
self.line_buffer = self.line_buffer[idx + len(self.THOUGHT_DELIMITER):]
|
||||
# Discard any content before it and strip leading whitespace after delimiter
|
||||
self.line_buffer = self.line_buffer[idx + len(self.THOUGHT_DELIMITER):].lstrip()
|
||||
self.state = ParserState.THOUGHT
|
||||
elif len(self.line_buffer) >= self.MAX_DELIMITER_BUFFER:
|
||||
# Buffer getting too large, probably junk before thought
|
||||
|
|
@ -171,7 +180,7 @@ class StreamingReActParser:
|
|||
if self.on_thought_chunk:
|
||||
self.on_thought_chunk(thought_chunk)
|
||||
|
||||
self.line_buffer = self.line_buffer[next_delimiter_idx + delimiter_len:]
|
||||
self.line_buffer = self.line_buffer[next_delimiter_idx + delimiter_len:].lstrip()
|
||||
self.state = next_state
|
||||
else:
|
||||
# No delimiter found yet
|
||||
|
|
@ -194,7 +203,7 @@ class StreamingReActParser:
|
|||
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
|
||||
# Args delimiter found first
|
||||
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
|
||||
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):]
|
||||
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
|
||||
self.state = ParserState.ARGS
|
||||
elif newline_idx >= 0:
|
||||
# Newline found, action name complete
|
||||
|
|
@ -204,7 +213,7 @@ class StreamingReActParser:
|
|||
# Actually, check if next line has Args:
|
||||
if self.line_buffer.lstrip().startswith(self.ARGS_DELIMITER):
|
||||
args_start = self.line_buffer.find(self.ARGS_DELIMITER)
|
||||
self.line_buffer = self.line_buffer[args_start + len(self.ARGS_DELIMITER):]
|
||||
self.line_buffer = self.line_buffer[args_start + len(self.ARGS_DELIMITER):].lstrip()
|
||||
self.state = ParserState.ARGS
|
||||
else:
|
||||
# Not enough content yet, keep buffering
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue