mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix streaming API niggles (#599)
* Fix end-of-stream anomally with some graph-rag and document-rag * Fix gateway translators dropping responses
This commit is contained in:
parent
3c675b8cfc
commit
f0c95a4c5e
5 changed files with 29 additions and 42 deletions
|
|
@ -49,20 +49,21 @@ class PromptClient(RequestResponse):
|
|||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
if resp.text:
|
||||
last_text = resp.text
|
||||
# Call chunk callback if provided
|
||||
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||
if chunk_callback:
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
logger.info(f"DEBUG prompt_client: Got object response")
|
||||
last_object = resp.object
|
||||
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||
return end_stream
|
||||
|
||||
|
|
|
|||
|
|
@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.text:
|
||||
|
||||
# Include text field if present (even if empty string)
|
||||
if obj.text is not None:
|
||||
result["text"] = obj.text
|
||||
if obj.object:
|
||||
# Include object field if present
|
||||
if obj.object is not None:
|
||||
result["object"] = obj.object
|
||||
|
||||
|
||||
# Always include end_of_stream flag for streaming support
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Include response content (chunk or complete)
|
||||
if obj.response:
|
||||
# Include response content (even if empty string)
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Include end_of_stream flag
|
||||
|
|
@ -90,8 +90,8 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Include response content (chunk or complete)
|
||||
if obj.response:
|
||||
# Include response content (even if empty string)
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Include end_of_stream flag
|
||||
|
|
|
|||
|
|
@ -95,19 +95,20 @@ class Processor(FlowProcessor):
|
|||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
async def send_chunk(chunk):
|
||||
# Receives chunk text and end_of_stream flag from prompt client
|
||||
async def send_chunk(chunk, end_of_stream):
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=chunk,
|
||||
end_of_stream=False,
|
||||
end_of_stream=end_of_stream,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Query with streaming enabled
|
||||
# The query returns the last chunk (not accumulated text)
|
||||
final_response = await self.rag.query(
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
|
|
@ -115,16 +116,6 @@ class Processor(FlowProcessor):
|
|||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
)
|
||||
|
||||
# Send final message with last chunk
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=final_response if final_response else "",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await self.rag.query(
|
||||
|
|
|
|||
|
|
@ -138,19 +138,20 @@ class Processor(FlowProcessor):
|
|||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
async def send_chunk(chunk):
|
||||
# Receives chunk text and end_of_stream flag from prompt client
|
||||
async def send_chunk(chunk, end_of_stream):
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response=chunk,
|
||||
end_of_stream=False,
|
||||
end_of_stream=end_of_stream,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Query with streaming enabled
|
||||
# The query will send chunks via callback AND return the complete text
|
||||
final_response = await rag.query(
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
@ -158,17 +159,6 @@ class Processor(FlowProcessor):
|
|||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
)
|
||||
|
||||
# Send final message - may have last chunk of content with end_of_stream=True
|
||||
# (prompt service may send final chunk with text, so we pass through whatever we got)
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response=final_response if final_response else "",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await rag.query(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue