Fix streaming API niggles (#599)

* Fix end-of-stream anomally with some graph-rag and document-rag

* Fix gateway translators dropping responses
This commit is contained in:
cybermaggedon 2026-01-06 16:41:35 +00:00 committed by GitHub
parent 3c675b8cfc
commit f0c95a4c5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 29 additions and 42 deletions

View file

@ -49,20 +49,21 @@ class PromptClient(RequestResponse):
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}") logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text: if resp.text:
last_text = resp.text last_text = resp.text
# Call chunk callback if provided # Call chunk callback if provided with both chunk and end_of_stream flag
if chunk_callback: if chunk_callback:
logger.info(f"DEBUG prompt_client: Calling chunk_callback") logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
if asyncio.iscoroutinefunction(chunk_callback): if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text) await chunk_callback(resp.text, end_stream)
else: else:
chunk_callback(resp.text) chunk_callback(resp.text, end_stream)
elif resp.object: elif resp.object:
logger.info(f"DEBUG prompt_client: Got object response") logger.info(f"DEBUG prompt_client: Got object response")
last_object = resp.object last_object = resp.object
end_stream = getattr(resp, 'end_of_stream', False)
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}") logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
return end_stream return end_stream

View file

@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.text: # Include text field if present (even if empty string)
if obj.text is not None:
result["text"] = obj.text result["text"] = obj.text
if obj.object: # Include object field if present
if obj.object is not None:
result["object"] = obj.object result["object"] = obj.object
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
return result return result
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -34,8 +34,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
result = {} result = {}
# Include response content (chunk or complete) # Include response content (even if empty string)
if obj.response: if obj.response is not None:
result["response"] = obj.response result["response"] = obj.response
# Include end_of_stream flag # Include end_of_stream flag
@ -90,8 +90,8 @@ class GraphRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
result = {} result = {}
# Include response content (chunk or complete) # Include response content (even if empty string)
if obj.response: if obj.response is not None:
result["response"] = obj.response result["response"] = obj.response
# Include end_of_stream flag # Include end_of_stream flag

View file

@ -95,19 +95,20 @@ class Processor(FlowProcessor):
# Check if streaming is requested # Check if streaming is requested
if v.streaming: if v.streaming:
# Define async callback for streaming chunks # Define async callback for streaming chunks
async def send_chunk(chunk): # Receives chunk text and end_of_stream flag from prompt client
async def send_chunk(chunk, end_of_stream):
await flow("response").send( await flow("response").send(
DocumentRagResponse( DocumentRagResponse(
response=chunk, response=chunk,
end_of_stream=False, end_of_stream=end_of_stream,
error=None error=None
), ),
properties={"id": id} properties={"id": id}
) )
# Query with streaming enabled # Query with streaming enabled
# The query returns the last chunk (not accumulated text) # All chunks (including final one with end_of_stream=True) are sent via callback
final_response = await self.rag.query( await self.rag.query(
v.query, v.query,
user=v.user, user=v.user,
collection=v.collection, collection=v.collection,
@ -115,16 +116,6 @@ class Processor(FlowProcessor):
streaming=True, streaming=True,
chunk_callback=send_chunk, chunk_callback=send_chunk,
) )
# Send final message with last chunk
await flow("response").send(
DocumentRagResponse(
response=final_response if final_response else "",
end_of_stream=True,
error=None
),
properties={"id": id}
)
else: else:
# Non-streaming path (existing behavior) # Non-streaming path (existing behavior)
response = await self.rag.query( response = await self.rag.query(

View file

@ -138,19 +138,20 @@ class Processor(FlowProcessor):
# Check if streaming is requested # Check if streaming is requested
if v.streaming: if v.streaming:
# Define async callback for streaming chunks # Define async callback for streaming chunks
async def send_chunk(chunk): # Receives chunk text and end_of_stream flag from prompt client
async def send_chunk(chunk, end_of_stream):
await flow("response").send( await flow("response").send(
GraphRagResponse( GraphRagResponse(
response=chunk, response=chunk,
end_of_stream=False, end_of_stream=end_of_stream,
error=None error=None
), ),
properties={"id": id} properties={"id": id}
) )
# Query with streaming enabled # Query with streaming enabled
# The query will send chunks via callback AND return the complete text # All chunks (including final one with end_of_stream=True) are sent via callback
final_response = await rag.query( await rag.query(
query = v.query, user = v.user, collection = v.collection, query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
@ -158,17 +159,6 @@ class Processor(FlowProcessor):
streaming = True, streaming = True,
chunk_callback = send_chunk, chunk_callback = send_chunk,
) )
# Send final message - may have last chunk of content with end_of_stream=True
# (prompt service may send final chunk with text, so we pass through whatever we got)
await flow("response").send(
GraphRagResponse(
response=final_response if final_response else "",
end_of_stream=True,
error=None
),
properties={"id": id}
)
else: else:
# Non-streaming path (existing behavior) # Non-streaming path (existing behavior)
response = await rag.query( response = await rag.query(