mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix streaming API niggles (#599)
* Fix end-of-stream anomally with some graph-rag and document-rag * Fix gateway translators dropping responses
This commit is contained in:
parent
3c675b8cfc
commit
f0c95a4c5e
5 changed files with 29 additions and 42 deletions
|
|
@ -49,20 +49,21 @@ class PromptClient(RequestResponse):
|
||||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
|
end_stream = getattr(resp, 'end_of_stream', False)
|
||||||
|
|
||||||
if resp.text:
|
if resp.text:
|
||||||
last_text = resp.text
|
last_text = resp.text
|
||||||
# Call chunk callback if provided
|
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
|
||||||
if asyncio.iscoroutinefunction(chunk_callback):
|
if asyncio.iscoroutinefunction(chunk_callback):
|
||||||
await chunk_callback(resp.text)
|
await chunk_callback(resp.text, end_stream)
|
||||||
else:
|
else:
|
||||||
chunk_callback(resp.text)
|
chunk_callback(resp.text, end_stream)
|
||||||
elif resp.object:
|
elif resp.object:
|
||||||
logger.info(f"DEBUG prompt_client: Got object response")
|
logger.info(f"DEBUG prompt_client: Got object response")
|
||||||
last_object = resp.object
|
last_object = resp.object
|
||||||
|
|
||||||
end_stream = getattr(resp, 'end_of_stream', False)
|
|
||||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||||
return end_stream
|
return end_stream
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.text:
|
# Include text field if present (even if empty string)
|
||||||
|
if obj.text is not None:
|
||||||
result["text"] = obj.text
|
result["text"] = obj.text
|
||||||
if obj.object:
|
# Include object field if present
|
||||||
|
if obj.object is not None:
|
||||||
result["object"] = obj.object
|
result["object"] = obj.object
|
||||||
|
|
||||||
|
# Always include end_of_stream flag for streaming support
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,8 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# Include response content (chunk or complete)
|
# Include response content (even if empty string)
|
||||||
if obj.response:
|
if obj.response is not None:
|
||||||
result["response"] = obj.response
|
result["response"] = obj.response
|
||||||
|
|
||||||
# Include end_of_stream flag
|
# Include end_of_stream flag
|
||||||
|
|
@ -90,8 +90,8 @@ class GraphRagResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# Include response content (chunk or complete)
|
# Include response content (even if empty string)
|
||||||
if obj.response:
|
if obj.response is not None:
|
||||||
result["response"] = obj.response
|
result["response"] = obj.response
|
||||||
|
|
||||||
# Include end_of_stream flag
|
# Include end_of_stream flag
|
||||||
|
|
|
||||||
|
|
@ -95,19 +95,20 @@ class Processor(FlowProcessor):
|
||||||
# Check if streaming is requested
|
# Check if streaming is requested
|
||||||
if v.streaming:
|
if v.streaming:
|
||||||
# Define async callback for streaming chunks
|
# Define async callback for streaming chunks
|
||||||
async def send_chunk(chunk):
|
# Receives chunk text and end_of_stream flag from prompt client
|
||||||
|
async def send_chunk(chunk, end_of_stream):
|
||||||
await flow("response").send(
|
await flow("response").send(
|
||||||
DocumentRagResponse(
|
DocumentRagResponse(
|
||||||
response=chunk,
|
response=chunk,
|
||||||
end_of_stream=False,
|
end_of_stream=end_of_stream,
|
||||||
error=None
|
error=None
|
||||||
),
|
),
|
||||||
properties={"id": id}
|
properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query with streaming enabled
|
# Query with streaming enabled
|
||||||
# The query returns the last chunk (not accumulated text)
|
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||||
final_response = await self.rag.query(
|
await self.rag.query(
|
||||||
v.query,
|
v.query,
|
||||||
user=v.user,
|
user=v.user,
|
||||||
collection=v.collection,
|
collection=v.collection,
|
||||||
|
|
@ -115,16 +116,6 @@ class Processor(FlowProcessor):
|
||||||
streaming=True,
|
streaming=True,
|
||||||
chunk_callback=send_chunk,
|
chunk_callback=send_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send final message with last chunk
|
|
||||||
await flow("response").send(
|
|
||||||
DocumentRagResponse(
|
|
||||||
response=final_response if final_response else "",
|
|
||||||
end_of_stream=True,
|
|
||||||
error=None
|
|
||||||
),
|
|
||||||
properties={"id": id}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Non-streaming path (existing behavior)
|
# Non-streaming path (existing behavior)
|
||||||
response = await self.rag.query(
|
response = await self.rag.query(
|
||||||
|
|
|
||||||
|
|
@ -138,19 +138,20 @@ class Processor(FlowProcessor):
|
||||||
# Check if streaming is requested
|
# Check if streaming is requested
|
||||||
if v.streaming:
|
if v.streaming:
|
||||||
# Define async callback for streaming chunks
|
# Define async callback for streaming chunks
|
||||||
async def send_chunk(chunk):
|
# Receives chunk text and end_of_stream flag from prompt client
|
||||||
|
async def send_chunk(chunk, end_of_stream):
|
||||||
await flow("response").send(
|
await flow("response").send(
|
||||||
GraphRagResponse(
|
GraphRagResponse(
|
||||||
response=chunk,
|
response=chunk,
|
||||||
end_of_stream=False,
|
end_of_stream=end_of_stream,
|
||||||
error=None
|
error=None
|
||||||
),
|
),
|
||||||
properties={"id": id}
|
properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query with streaming enabled
|
# Query with streaming enabled
|
||||||
# The query will send chunks via callback AND return the complete text
|
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||||
final_response = await rag.query(
|
await rag.query(
|
||||||
query = v.query, user = v.user, collection = v.collection,
|
query = v.query, user = v.user, collection = v.collection,
|
||||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||||
max_subgraph_size = max_subgraph_size,
|
max_subgraph_size = max_subgraph_size,
|
||||||
|
|
@ -158,17 +159,6 @@ class Processor(FlowProcessor):
|
||||||
streaming = True,
|
streaming = True,
|
||||||
chunk_callback = send_chunk,
|
chunk_callback = send_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send final message - may have last chunk of content with end_of_stream=True
|
|
||||||
# (prompt service may send final chunk with text, so we pass through whatever we got)
|
|
||||||
await flow("response").send(
|
|
||||||
GraphRagResponse(
|
|
||||||
response=final_response if final_response else "",
|
|
||||||
end_of_stream=True,
|
|
||||||
error=None
|
|
||||||
),
|
|
||||||
properties={"id": id}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Non-streaming path (existing behavior)
|
# Non-streaming path (existing behavior)
|
||||||
response = await rag.query(
|
response = await rag.query(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue