diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 55c54cfc..370cf78a 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -49,20 +49,21 @@ class PromptClient(RequestResponse): logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}") raise RuntimeError(resp.error.message) + end_stream = getattr(resp, 'end_of_stream', False) + if resp.text: last_text = resp.text - # Call chunk callback if provided + # Call chunk callback if provided with both chunk and end_of_stream flag if chunk_callback: - logger.info(f"DEBUG prompt_client: Calling chunk_callback") + logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}") if asyncio.iscoroutinefunction(chunk_callback): - await chunk_callback(resp.text) + await chunk_callback(resp.text, end_stream) else: - chunk_callback(resp.text) + chunk_callback(resp.text, end_stream) elif resp.object: logger.info(f"DEBUG prompt_client: Got object response") last_object = resp.object - end_stream = getattr(resp, 'end_of_stream', False) logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}") return end_stream diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py index 8916a77c..5ff99fdc 100644 --- a/trustgraph-base/trustgraph/messaging/translators/prompt.py +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator): def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: result = {} - - if obj.text: + + # Include text field if present (even if empty string) + if obj.text is not None: result["text"] = obj.text - if obj.object: + # Include object field if present + if obj.object is not None: result["object"] = obj.object - + + # Always include end_of_stream flag for streaming support + result["end_of_stream"] = getattr(obj, "end_of_stream", False) + return result def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index d2161cff..22166bd9 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -34,8 +34,8 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: result = {} - # Include response content (chunk or complete) - if obj.response: + # Include response content (even if empty string) + if obj.response is not None: result["response"] = obj.response # Include end_of_stream flag @@ -90,8 +90,8 @@ class GraphRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: result = {} - # Include response content (chunk or complete) - if obj.response: + # Include response content (even if empty string) + if obj.response is not None: result["response"] = obj.response # Include end_of_stream flag diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index ec67a072..14d71d97 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -95,19 +95,20 @@ class Processor(FlowProcessor): # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks - async def send_chunk(chunk): + # Receives chunk text and end_of_stream flag from prompt client + async def send_chunk(chunk, end_of_stream): await flow("response").send( DocumentRagResponse( response=chunk, - end_of_stream=False, + end_of_stream=end_of_stream, error=None ), properties={"id": id} ) # Query with streaming enabled - # The query returns the last chunk (not accumulated text) - final_response = await self.rag.query( + # All chunks (including final one with end_of_stream=True) are sent via callback + await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -115,16 +116,6 @@ class Processor(FlowProcessor): streaming=True, chunk_callback=send_chunk, ) - - # Send final message with last chunk - await flow("response").send( - DocumentRagResponse( - response=final_response if final_response else "", - end_of_stream=True, - error=None - ), - properties={"id": id} - ) else: # Non-streaming path (existing behavior) response = await self.rag.query( diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index de1f0e24..d159dbae 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -138,19 +138,20 @@ class Processor(FlowProcessor): # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks - async def send_chunk(chunk): + # Receives chunk text and end_of_stream flag from prompt client + async def send_chunk(chunk, end_of_stream): await flow("response").send( GraphRagResponse( response=chunk, - end_of_stream=False, + end_of_stream=end_of_stream, error=None ), properties={"id": id} ) # Query with streaming enabled - # The query will send chunks via callback AND return the complete text - final_response = await rag.query( + # All chunks (including final one with end_of_stream=True) are sent via callback + await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -158,17 +159,6 @@ class Processor(FlowProcessor): streaming = True, chunk_callback = send_chunk, ) - - # Send final message - may have last chunk of content with end_of_stream=True - # (prompt service may send final chunk with text, so we pass through whatever we got) - await flow("response").send( - GraphRagResponse( - response=final_response if final_response else "", - end_of_stream=True, - error=None - ), - properties={"id": id} - ) else: # Non-streaming path (existing behavior) response = await rag.query(