mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Streaming LLM part 2 (#567)
* Updates for agent API with streaming support * Added tg-dump-queues tool to dump Pulsar queues to a log * Updated tg-invoke-agent, incremental output * Queue dumper CLI - might be useful for debug * Updating for tests
This commit is contained in:
parent
310a2deb06
commit
b1cc724f7d
8 changed files with 609 additions and 51 deletions
|
|
@ -1,15 +1,20 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
logger.info(f"DEBUG prompt_client: prompt called, id={id}, streaming={streaming}, chunk_callback={chunk_callback is not None}")
|
||||
|
||||
if not streaming:
|
||||
logger.info("DEBUG prompt_client: Non-streaming path")
|
||||
# Non-streaming path
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
|
|
@ -31,44 +36,59 @@ class PromptClient(RequestResponse):
|
|||
return json.loads(resp.object)
|
||||
|
||||
else:
|
||||
logger.info("DEBUG prompt_client: Streaming path")
|
||||
# Streaming path - collect all chunks
|
||||
full_text = ""
|
||||
full_object = None
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_text, full_object
|
||||
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||
|
||||
if resp.error:
|
||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text:
|
||||
full_text += resp.text
|
||||
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars")
|
||||
# Call chunk callback if provided
|
||||
if chunk_callback:
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
elif resp.object:
|
||||
logger.info(f"DEBUG prompt_client: Got object response")
|
||||
full_object = resp.object
|
||||
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||
return end_stream
|
||||
|
||||
logger.info("DEBUG prompt_client: Creating PromptRequest")
|
||||
req = PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
||||
await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
),
|
||||
req,
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars")
|
||||
|
||||
if full_text: return full_text
|
||||
if full_text:
|
||||
logger.info("DEBUG prompt_client: Returning full_text")
|
||||
return full_text
|
||||
|
||||
logger.info("DEBUG prompt_client: Returning parsed full_object")
|
||||
return json.loads(full_object)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
|
|
|
|||
|
|
@ -43,12 +43,18 @@ class Subscriber:
|
|||
|
||||
async def start(self):
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
topic = self.topic,
|
||||
subscription_name = self.subscription,
|
||||
consumer_name = self.consumer_name,
|
||||
schema = JsonSchema(self.schema),
|
||||
)
|
||||
# Build subscribe arguments
|
||||
subscribe_args = {
|
||||
'topic': self.topic,
|
||||
'subscription_name': self.subscription,
|
||||
'consumer_name': self.consumer_name,
|
||||
}
|
||||
|
||||
# Only add schema if provided (omit if None)
|
||||
if self.schema is not None:
|
||||
subscribe_args['schema'] = JsonSchema(self.schema)
|
||||
|
||||
self.consumer = self.client.subscribe(**subscribe_args)
|
||||
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
|
|
@ -87,10 +93,14 @@ class Subscriber:
|
|||
if self.draining and drain_end_time is None:
|
||||
drain_end_time = time.time() + self.drain_timeout
|
||||
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||
|
||||
|
||||
# Stop accepting new messages from Pulsar during drain
|
||||
if self.consumer:
|
||||
self.consumer.pause_message_listener()
|
||||
try:
|
||||
self.consumer.pause_message_listener()
|
||||
except _pulsar.InvalidConfiguration:
|
||||
# Not all consumers have message listeners (e.g., blocking receive mode)
|
||||
pass
|
||||
|
||||
# Check drain timeout
|
||||
if self.draining and drain_end_time and time.time() > drain_end_time:
|
||||
|
|
@ -145,12 +155,21 @@ class Subscriber:
|
|||
finally:
|
||||
# Negative acknowledge any pending messages
|
||||
for msg in self.pending_acks.values():
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
try:
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Consumer already closed
|
||||
self.pending_acks.clear()
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
try:
|
||||
self.consumer.unsubscribe()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
try:
|
||||
self.consumer.close()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
self.consumer = None
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue