diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py index 652894a2..a19f4c36 100644 --- a/tests/integration/test_agent_manager_integration.py +++ b/tests/integration/test_agent_manager_integration.py @@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl from trustgraph.agent.react.types import Action, Final, Tool, Argument from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error +from trustgraph.base import PromptResult @pytest.mark.integration @@ -28,19 +29,25 @@ class TestAgentManagerIntegration: # Mock prompt client prompt_client = AsyncMock() - prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning + prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search for information about machine learning Action: knowledge_query Args: { "question": "What is machine learning?" }""" - + ) + # Mock graph RAG client graph_rag_client = AsyncMock() graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data." - + # Mock text completion client text_completion_client = AsyncMock() - text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience." + text_completion_client.question.return_value = PromptResult( + response_type="text", + text="Machine learning involves algorithms that improve through experience." + ) # Mock MCP tool client mcp_tool_client = AsyncMock() @@ -147,8 +154,11 @@ Args: { async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context): """Test agent manager returning final answer""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I have enough information to answer the question Final Answer: Machine learning is a field of AI that enables computers to learn from data.""" + ) question = "What is machine learning?" history = [] @@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context): """Test ReAct cycle ending with final answer""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I can provide a direct answer Final Answer: Machine learning is a branch of artificial intelligence.""" + ) question = "What is machine learning?" history = [] @@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence.""" for tool_name, expected_service in tool_scenarios: # Arrange - mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name} + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=f"""Thought: I need to use {tool_name} Action: {tool_name} Args: {{ "question": "test question" }}""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -284,11 +300,14 @@ Args: {{ async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context): """Test agent manager error handling for unknown tool""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to use an unknown tool Action: unknown_tool Args: { "param": "value" }""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -321,11 +340,14 @@ Args: { question = "Find information about AI and summarize it" # Mock multi-step reasoning - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search for AI information first Action: knowledge_query Args: { "question": "What is artificial intelligence?" }""" + ) # Act action = await agent_manager.reason(question, [], mock_flow_context) @@ -372,9 +394,12 @@ Args: { # Format arguments as JSON import json args_json = json.dumps(test_case['arguments'], indent=4) - mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']} + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=f"""Thought: Using {test_case['action']} Action: {test_case['action']} Args: {args_json}""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -507,7 +532,10 @@ Args: { ] for test_case in test_cases: - mock_flow_context("prompt-request").agent_react.return_value = test_case["response"] + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=test_case["response"] + ) if test_case["error_contains"]: # Should raise an error @@ -527,13 +555,16 @@ Args: { async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context): """Test edge cases in text parsing""" # Test response with markdown code blocks - mock_flow_context("prompt-request").agent_react.return_value = """``` + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""``` Thought: I need to search for information Action: knowledge_query Args: { "question": "What is AI?" } ```""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -541,15 +572,18 @@ Args: { assert action.name == "knowledge_query" # Test response with extra whitespace - mock_flow_context("prompt-request").agent_react.return_value = """ + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=""" -Thought: I need to think about this -Action: knowledge_query +Thought: I need to think about this +Action: knowledge_query Args: { "question": "test" } """ + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -560,7 +594,9 @@ Args: { async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context): """Test handling of multi-line thoughts and final answers""" # Multi-line thought - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors: + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to consider multiple factors: 1. The user's question is complex 2. I should search for comprehensive information 3. This requires using the knowledge query tool @@ -568,6 +604,7 @@ Action: knowledge_query Args: { "question": "complex query" }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -575,13 +612,16 @@ Args: { assert "knowledge query tool" in action.thought # Multi-line final answer - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I have gathered enough information Final Answer: Here is a comprehensive answer: 1. First point about the topic 2. Second point with details 3. Final conclusion This covers all aspects of the question.""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Final) @@ -593,13 +633,16 @@ This covers all aspects of the question.""" async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context): """Test JSON arguments with special characters and edge cases""" # Test with special characters in JSON (properly escaped) - mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: Processing special characters Action: knowledge_query Args: { "question": "What about \\"quotes\\" and 'apostrophes'?", "context": "Line 1\\nLine 2\\tTabbed", "special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?" }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -608,7 +651,9 @@ Args: { assert "@#$%^&*" in action.arguments["special"] # Test with nested JSON - mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: Complex arguments Action: web_search Args: { "query": "test", @@ -621,6 +666,7 @@ Args: { } } }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -632,7 +678,9 @@ Args: { async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context): """Test final answers that contain JSON-like content""" # Final answer with JSON content - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I can provide the data in JSON format Final Answer: { "result": "success", "data": { @@ -642,6 +690,7 @@ Final Answer: { }, "confidence": 0.95 }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Final) @@ -792,11 +841,14 @@ Final Answer: { agent = AgentManager(tools=custom_tools, additional_context="") # Mock response for custom collection query - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search in the research papers Action: knowledge_query_custom Args: { "question": "Latest AI research?" }""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py index d6004c21..5c82eb8b 100644 --- a/tests/integration/test_agent_streaming_integration.py +++ b/tests/integration/test_agent_streaming_integration.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.tools import KnowledgeQueryImpl from trustgraph.agent.react.types import Tool, Argument +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_agent_streaming_chunks, assert_streaming_chunks_valid, @@ -51,10 +52,10 @@ Args: { is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: # Non-streaming response - same text - return full_text + return PromptResult(response_type="text", text=full_text) client.agent_react.side_effect = agent_react_streaming return client @@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines.""" for i, chunk in enumerate(chunks): is_final = (i == len(chunks) - 1) await chunk_callback(chunk + " ", is_final) - return response - return response + return PromptResult(response_type="text", text=response) + return PromptResult(response_type="text", text=response) mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 0fedd2b5..2442bf10 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -16,6 +16,7 @@ from trustgraph.schema import ( Error ) from trustgraph.agent.react.service import Processor +from trustgraph.base import PromptResult @pytest.mark.integration @@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration: # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to find customers from New York using structured query Action: structured-query Args: { "question": "Find all customers from New York" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -173,11 +177,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to query for a table that might not exist Action: structured-query Args: { "question": "Find data from a table that doesn't exist" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -250,11 +257,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to find customers from California first Action: structured-query Args: { "question": "Find all customers from California" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -339,11 +349,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to query the sales data Action: structured-query Args: { "question": "Query the sales data for recent transactions" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -447,11 +460,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to get customer information Action: structured-query Args: { "question": "Get customer information and format it nicely" }""" + ) # Set up flow context routing def flow_context(service_name): diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index e9df05cf..8c165385 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -10,6 +10,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import ChunkMatch +from trustgraph.base import PromptResult # Sample chunk content for testing - maps chunk_id to content @@ -61,11 +62,16 @@ class TestDocumentRagIntegration: def mock_prompt_client(self): """Mock prompt client that generates realistic responses""" client = AsyncMock() - client.document_prompt.return_value = ( - "Machine learning is a field of artificial intelligence that enables computers to learn " - "and improve from experience without being explicitly programmed. It uses algorithms " - "to find patterns in data and make predictions or decisions." + client.document_prompt.return_value = PromptResult( + response_type="text", + text=( + "Machine learning is a field of artificial intelligence that enables computers to learn " + "and improve from experience without being explicitly programmed. It uses algorithms " + "to find patterns in data and make predictions or decisions." + ) ) + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -119,6 +125,7 @@ class TestDocumentRagIntegration: ) # Verify final response + result, usage = result assert result is not None assert isinstance(result, str) assert "machine learning" in result.lower() @@ -131,7 +138,11 @@ class TestDocumentRagIntegration: """Test DocumentRAG behavior when no documents are retrieved""" # Arrange mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found - mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." + mock_prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="I couldn't find any relevant documents for your query." + ) + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") document_rag = DocumentRag( embeddings_client=mock_embeddings_client, @@ -152,7 +163,8 @@ class TestDocumentRagIntegration: documents=[] ) - assert result == "I couldn't find any relevant documents for your query." + result_text, usage = result + assert result_text == "I couldn't find any relevant documents for your query." @pytest.mark.asyncio async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index dad30a8f..e2c032ad 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import ChunkMatch +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -74,12 +75,14 @@ class TestDocumentRagStreaming: is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: # Non-streaming response - same text - return full_text + return PromptResult(response_type="text", text=full_text) client.document_prompt.side_effect = document_prompt_side_effect + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -119,11 +122,12 @@ class TestDocumentRagStreaming: collector.verify_streaming_protocol() # Verify full response matches concatenated chunks + result_text, usage = result full_from_chunks = collector.get_full_text() - assert result == full_from_chunks + assert result_text == full_from_chunks # Verify content is reasonable - assert len(result) > 0 + assert len(result_text) > 0 @pytest.mark.asyncio async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming): @@ -159,9 +163,11 @@ class TestDocumentRagStreaming: ) # Assert - Results should be equivalent - assert streaming_result == non_streaming_result + non_streaming_text, _ = non_streaming_result + streaming_text, _ = streaming_result + assert streaming_text == non_streaming_text assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_result + assert "".join(streaming_chunks) == streaming_text @pytest.mark.asyncio async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming): @@ -180,8 +186,9 @@ class TestDocumentRagStreaming: ) # Assert + result_text, usage = result assert callback.call_count > 0 - assert result is not None + assert result_text is not None # Verify all callback invocations had string arguments for call in callback.call_args_list: @@ -202,7 +209,8 @@ class TestDocumentRagStreaming: # Assert - Should complete without error assert result is not None - assert isinstance(result, str) + result_text, usage = result + assert isinstance(result_text, str) @pytest.mark.asyncio async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming, @@ -223,7 +231,8 @@ class TestDocumentRagStreaming: ) # Assert - Should still produce streamed response - assert result is not None + result_text, usage = result + assert result_text is not None assert callback.call_count > 0 @pytest.mark.asyncio @@ -271,7 +280,8 @@ class TestDocumentRagStreaming: ) # Assert - assert result is not None + result_text, usage = result + assert result_text is not None assert callback.call_count > 0 # Verify doc_limit was passed correctly diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 5e3279e3..9c3cdf45 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -12,6 +12,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.schema import EntityMatch, Term, IRI +from trustgraph.base import PromptResult @pytest.mark.integration @@ -93,18 +94,21 @@ class TestGraphRagIntegration: # 4. kg-synthesis returns the final answer async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-scoring": - return "" # No edges scored + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-reasoning": - return "" # No reasoning + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": - return ( - "Machine learning is a subset of artificial intelligence that enables computers " - "to learn from data without being explicitly programmed. It uses algorithms " - "and statistical models to find patterns in data." + return PromptResult( + response_type="text", + text=( + "Machine learning is a subset of artificial intelligence that enables computers " + "to learn from data without being explicitly programmed. It uses algorithms " + "and statistical models to find patterns in data." + ) ) - return "" + return PromptResult(response_type="text", text="") client.prompt.side_effect = mock_prompt return client @@ -169,6 +173,7 @@ class TestGraphRagIntegration: assert mock_prompt_client.prompt.call_count == 4 # Verify final response + response, usage = response assert response is not None assert isinstance(response, str) assert "machine learning" in response.lower() diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index b66c5289..95c494bb 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.schema import EntityMatch, Term, IRI +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -61,12 +62,12 @@ class TestGraphRagStreaming: async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): if prompt_id == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_id == "kg-edge-scoring": # Edge scoring returns JSONL with IDs and scores - return '{"id": "abc12345", "score": 0.9}\n' + return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n') elif prompt_id == "kg-edge-reasoning": - return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' + return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n') elif prompt_id == "kg-synthesis": if streaming and chunk_callback: # Simulate streaming chunks with end_of_stream flags @@ -79,10 +80,10 @@ class TestGraphRagStreaming: is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: - return full_text - return "" + return PromptResult(response_type="text", text=full_text) + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_side_effect return client @@ -123,6 +124,7 @@ class TestGraphRagStreaming: ) # Assert + response, usage = response assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) @@ -172,9 +174,11 @@ class TestGraphRagStreaming: ) # Assert - Results should be equivalent - assert streaming_response == non_streaming_response + non_streaming_text, _ = non_streaming_response + streaming_text, _ = streaming_response + assert streaming_text == non_streaming_text assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_response + assert "".join(streaming_chunks) == streaming_text @pytest.mark.asyncio async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming): @@ -213,7 +217,8 @@ class TestGraphRagStreaming: # Assert - Should complete without error assert response is not None - assert isinstance(response, str) + response_text, usage = response + assert isinstance(response_text, str) @pytest.mark.asyncio async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming, diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 4d8b60ad..84c0905d 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL +from trustgraph.base import PromptResult @pytest.mark.integration @@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration: # Mock prompt client for definitions extraction prompt_client = AsyncMock() - prompt_client.extract_definitions.return_value = [ - { - "entity": "Machine Learning", - "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." - }, - { - "entity": "Neural Networks", - "definition": "Computing systems inspired by biological neural networks that process information." - } - ] - + prompt_client.extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[ + { + "entity": "Machine Learning", + "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." + }, + { + "entity": "Neural Networks", + "definition": "Computing systems inspired by biological neural networks that process information." + } + ] + ) + # Mock prompt client for relationships extraction - prompt_client.extract_relationships.return_value = [ - { - "subject": "Machine Learning", - "predicate": "is_subset_of", - "object": "Artificial Intelligence", - "object-entity": True - }, - { - "subject": "Neural Networks", - "predicate": "is_used_in", - "object": "Machine Learning", - "object-entity": True - } - ] + prompt_client.extract_relationships.return_value = PromptResult( + response_type="jsonl", + objects=[ + { + "subject": "Machine Learning", + "predicate": "is_subset_of", + "object": "Artificial Intelligence", + "object-entity": True + }, + { + "subject": "Neural Networks", + "predicate": "is_used_in", + "object": "Machine Learning", + "object-entity": True + } + ] + ) # Mock producers for output streams triples_producer = AsyncMock() @@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration: async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk): """Test handling of empty extraction results""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = [] + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[] + ) mock_msg = MagicMock() mock_msg.value.return_value = sample_chunk @@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration: async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk): """Test handling of invalid extraction response format""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="text", + text="invalid format" + ) # Should be jsonl with objects list mock_msg = MagicMock() mock_msg.value.return_value = sample_chunk @@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration: async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context): """Test entity filtering and validation in extraction""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = [ - {"entity": "Valid Entity", "definition": "Valid definition"}, - {"entity": "", "definition": "Empty entity"}, # Should be filtered - {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered - {"entity": None, "definition": "None entity"}, # Should be filtered - {"entity": "Valid Entity 3", "definition": None}, # Should be filtered - ] + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[ + {"entity": "Valid Entity", "definition": "Valid definition"}, + {"entity": "", "definition": "Empty entity"}, # Should be filtered + {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered + {"entity": None, "definition": "None entity"}, # Should be filtered + {"entity": "Valid Entity 3", "definition": None}, # Should be filtered + ] + ) sample_chunk = Chunk( metadata=Metadata(id="test", user="user", collection="collection"), diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index faa63381..22ba9a3f 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -16,6 +16,7 @@ from trustgraph.schema import ( Chunk, ExtractedObject, Metadata, RowSchema, Field, PromptRequest, PromptResponse ) +from trustgraph.base import PromptResult @pytest.mark.integration @@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration: schema_name = schema.get("name") if isinstance(schema, dict) else schema.name if schema_name == "customer_records": if "john" in text.lower(): - return [ - { - "customer_id": "CUST001", - "name": "John Smith", - "email": "john.smith@email.com", - "phone": "555-0123" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "customer_id": "CUST001", + "name": "John Smith", + "email": "john.smith@email.com", + "phone": "555-0123" + } + ] + ) elif "jane" in text.lower(): - return [ - { - "customer_id": "CUST002", - "name": "Jane Doe", - "email": "jane.doe@email.com", - "phone": "" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "customer_id": "CUST002", + "name": "Jane Doe", + "email": "jane.doe@email.com", + "phone": "" + } + ] + ) else: - return [] - + return PromptResult(response_type="jsonl", objects=[]) + elif schema_name == "product_catalog": if "laptop" in text.lower(): - return [ - { - "product_id": "PROD001", - "name": "Gaming Laptop", - "price": "1299.99", - "category": "electronics" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "product_id": "PROD001", + "name": "Gaming Laptop", + "price": "1299.99", + "category": "electronics" + } + ] + ) elif "book" in text.lower(): - return [ - { - "product_id": "PROD002", - "name": "Python Programming Guide", - "price": "49.99", - "category": "books" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "product_id": "PROD002", + "name": "Python Programming Guide", + "price": "49.99", + "category": "books" + } + ] + ) else: - return [] - - return [] + return PromptResult(response_type="jsonl", objects=[]) + + return PromptResult(response_type="jsonl", objects=[]) prompt_client.extract_objects.side_effect = mock_extract_objects diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py index 9b1a06b6..a1414e2d 100644 --- a/tests/integration/test_prompt_streaming_integration.py +++ b/tests/integration/test_prompt_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.prompt.template.service import Processor from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse +from trustgraph.base.text_completion_client import TextCompletionResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -27,34 +28,52 @@ class TestPromptStreaming: # Mock text completion client with streaming text_completion_client = AsyncMock() - async def streaming_request(request, recipient=None, timeout=600): - """Simulate streaming text completion""" - if request.streaming and recipient: - # Simulate streaming chunks - chunks = [ - "Machine", " learning", " is", " a", " field", - " of", " artificial", " intelligence", "." - ] + # Streaming chunks to send + chunks = [ + "Machine", " learning", " is", " a", " field", + " of", " artificial", " intelligence", "." + ] - for i, chunk_text in enumerate(chunks): - is_final = (i == len(chunks) - 1) - response = TextCompletionResponse( - response=chunk_text, - error=None, - end_of_stream=is_final - ) - final = await recipient(response) - if final: - break - - # Final empty chunk - await recipient(TextCompletionResponse( - response="", + async def streaming_text_completion_stream(system, prompt, handler, timeout=600): + """Simulate streaming text completion via text_completion_stream""" + for i, chunk_text in enumerate(chunks): + response = TextCompletionResponse( + response=chunk_text, error=None, - end_of_stream=True - )) + end_of_stream=False + ) + await handler(response) - text_completion_client.request = streaming_request + # Send final empty chunk with end_of_stream + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=True + )) + + return TextCompletionResult( + text=None, + in_token=10, + out_token=9, + model="test-model", + ) + + async def non_streaming_text_completion(system, prompt, timeout=600): + """Simulate non-streaming text completion""" + full_text = "Machine learning is a field of artificial intelligence." + return TextCompletionResult( + text=full_text, + in_token=10, + out_token=9, + model="test-model", + ) + + text_completion_client.text_completion_stream = AsyncMock( + side_effect=streaming_text_completion_stream + ) + text_completion_client.text_completion = AsyncMock( + side_effect=non_streaming_text_completion + ) # Mock response producer response_producer = AsyncMock() @@ -156,14 +175,6 @@ class TestPromptStreaming: consumer = MagicMock() - # Mock non-streaming text completion - text_completion_client = mock_flow_context_streaming("text-completion-request") - - async def non_streaming_text_completion(system, prompt, streaming=False): - return "AI is the simulation of human intelligence in machines." - - text_completion_client.text_completion = non_streaming_text_completion - # Act await prompt_processor_streaming.on_request( message, consumer, mock_flow_context_streaming @@ -218,17 +229,12 @@ class TestPromptStreaming: # Mock text completion client that raises an error text_completion_client = AsyncMock() - async def failing_request(request, recipient=None, timeout=600): - if recipient: - # Send error response with proper Error schema - error_response = TextCompletionResponse( - response="", - error=Error(message="Text completion error", type="processing_error"), - end_of_stream=True - ) - await recipient(error_response) + async def failing_stream(system, prompt, handler, timeout=600): + raise RuntimeError("Text completion error") - text_completion_client.request = failing_request + text_completion_client.text_completion_stream = AsyncMock( + side_effect=failing_stream + ) # Mock response producer to capture error response response_producer = AsyncMock() @@ -255,22 +261,15 @@ class TestPromptStreaming: consumer = MagicMock() - # Act - The service catches errors and sends error responses, doesn't raise + # Act - The service catches errors and sends an error PromptResponse await prompt_processor_streaming.on_request(message, consumer, context) - # Assert - Verify error response was sent - assert response_producer.send.call_count > 0 - - # Check that at least one response contains an error - error_sent = False - for call in response_producer.send.call_args_list: - response = call.args[0] - if hasattr(response, 'error') and response.error: - error_sent = True - assert "Text completion error" in response.error.message - break - - assert error_sent, "Expected error response to be sent" + # Assert - error response was sent + calls = response_producer.send.call_args_list + assert len(calls) > 0 + error_response = calls[-1].args[0] + assert error_response.error is not None + assert "Text completion error" in error_response.error.message @pytest.mark.asyncio async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming, @@ -315,21 +314,22 @@ class TestPromptStreaming: # Mock text completion that sends empty chunks text_completion_client = AsyncMock() - async def empty_streaming_request(request, recipient=None, timeout=600): - if request.streaming and recipient: - # Send empty chunk followed by final marker - await recipient(TextCompletionResponse( - response="", - error=None, - end_of_stream=False - )) - await recipient(TextCompletionResponse( - response="", - error=None, - end_of_stream=True - )) + async def empty_streaming(system, prompt, handler, timeout=600): + # Send empty chunk followed by final marker + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=False + )) + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=True + )) - text_completion_client.request = empty_streaming_request + text_completion_client.text_completion_stream = AsyncMock( + side_effect=empty_streaming + ) response_producer = AsyncMock() def context_router(service_name): @@ -401,4 +401,4 @@ class TestPromptStreaming: # Verify chunks concatenate to expected result full_text = "".join(chunk_texts) - assert full_text == "Machine learning is a field of artificial intelligence" + assert full_text == "Machine learning is a field of artificial intelligence." diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index f5fe14b5..83a90412 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI +from trustgraph.base import PromptResult class TestGraphRagStreamingProtocol: @@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol: async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "kg-edge-selection": - # Edge selection returns empty (no edges selected) - return "" + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True @@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol: await chunk_callback(" answer", False) await chunk_callback(" is here.", False) await chunk_callback("", True) # Empty final chunk with end_of_stream=True - return "" # Return value not used since callback handles everything + return PromptResult(response_type="text", text="") else: - return "The answer is here." - return "" + return PromptResult(response_type="text", text="The answer is here.") + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_side_effect return client @@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol: await chunk_callback("Document", False) await chunk_callback(" summary", False) await chunk_callback(".", True) # Non-empty final chunk - return "" + return PromptResult(response_type="text", text="") else: - return "Document summary." + return PromptResult(response_type="text", text="Document summary.") client.document_prompt.side_effect = document_prompt_side_effect + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases: async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "kg-edge-selection": - return "" + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: await chunk_callback("text", False) await chunk_callback("", False) # Empty but not final await chunk_callback("more", False) await chunk_callback("", True) # Empty and final - return "" + return PromptResult(response_type="text", text="") else: - return "textmore" - return "" + return PromptResult(response_type="text", text="textmore") + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_with_empties diff --git a/tests/unit/test_agent/test_meta_router.py b/tests/unit/test_agent/test_meta_router.py index da0c634c..da8c6c79 100644 --- a/tests/unit/test_agent/test_meta_router.py +++ b/tests/unit/test_agent/test_meta_router.py @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.agent.orchestrator.meta_router import ( MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE, ) +from trustgraph.base import PromptResult def _make_config(patterns=None, task_types=None): @@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None): def _make_context(prompt_response): """Build a mock context that returns a mock prompt client.""" client = AsyncMock() - client.prompt = AsyncMock(return_value=prompt_response) + client.prompt = AsyncMock( + return_value=PromptResult(response_type="text", text=prompt_response) + ) def context(service_name): return client @@ -274,8 +277,8 @@ class TestRoute: nonlocal call_count call_count += 1 if call_count == 1: - return "research" # task type - return "plan-then-execute" # pattern + return PromptResult(response_type="text", text="research") + return PromptResult(response_type="text", text="plan-then-execute") client.prompt = mock_prompt context = lambda name: client diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py index 96d41259..05741cdc 100644 --- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field from trustgraph.schema import ( AgentRequest, AgentResponse, AgentStep, PlanStep, ) +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -183,7 +184,7 @@ class TestReactPatternProvenance: ) async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): # Simulate the on_action callback before returning Final if on_action: await on_action(Action( @@ -267,7 +268,7 @@ class TestReactPatternProvenance: MockAM.return_value = mock_am async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): if on_action: await on_action(action) return action @@ -309,7 +310,7 @@ class TestReactPatternProvenance: MockAM.return_value = mock_am async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): if on_action: await on_action(Action( thought="done", name="final", @@ -355,10 +356,13 @@ class TestPlanPatternProvenance: # Mock prompt client for plan creation mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = [ - {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, - {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, - ] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=[ + {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, + {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, + ], + ) def flow_factory(name): if name == "prompt-request": @@ -418,10 +422,13 @@ class TestPlanPatternProvenance: # Mock prompt for step execution mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = { - "tool": "knowledge-query", - "arguments": {"question": "quantum computing"}, - } + mock_prompt_client.prompt.return_value = PromptResult( + response_type="json", + object={ + "tool": "knowledge-query", + "arguments": {"question": "quantum computing"}, + }, + ) def flow_factory(name): if name == "prompt-request": @@ -475,7 +482,7 @@ class TestPlanPatternProvenance: # Mock prompt for synthesis mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = "The synthesised answer." + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.") def flow_factory(name): if name == "prompt-request": @@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance: # Mock prompt for decomposition mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = [ - "What is quantum computing?", - "What are qubits?", - ] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=[ + "What is quantum computing?", + "What are qubits?", + ], + ) def flow_factory(name): if name == "prompt-request": @@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance: # Mock prompt for synthesis mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = "The combined answer." + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.") def flow_factory(name): if name == "prompt-request": @@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance: flow = make_mock_flow() mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=["Goal A", "Goal B", "Goal C"], + ) def flow_factory(name): if name == "prompt-request": diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py index b651b59e..cbc9a05a 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.extract.kg.definitions.extract import ( Processor, default_triples_batch_size, default_entity_batch_size, ) +from trustgraph.base import PromptResult from trustgraph.schema import ( Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL, ) @@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): mock_triples_pub = AsyncMock() mock_ecs_pub = AsyncMock() mock_prompt_client = AsyncMock() + if isinstance(prompt_result, list): + wrapped = PromptResult(response_type="jsonl", objects=prompt_result) + else: + wrapped = PromptResult(response_type="text", text=prompt_result) mock_prompt_client.extract_definitions = AsyncMock( - return_value=prompt_result + return_value=wrapped ) def flow(name): diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py index cf3b1fb0..d9861cf3 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import ( from trustgraph.schema import ( Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL, ) +from trustgraph.base import PromptResult # --------------------------------------------------------------------------- @@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): mock_triples_pub = AsyncMock() mock_prompt_client = AsyncMock() mock_prompt_client.extract_relationships = AsyncMock( - return_value=prompt_result + return_value=PromptResult( + response_type="jsonl", + objects=prompt_result, + ) ) def flow(name): diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 27508ba4..1ff85f5a 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -6,6 +6,7 @@ import pytest from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query +from trustgraph.base import PromptResult # Sample chunk content mapping for tests @@ -132,7 +133,7 @@ class TestQuery: mock_rag.prompt_client = mock_prompt_client # Mock the prompt response with concept lines - mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns") query = Query( rag=mock_rag, @@ -157,7 +158,7 @@ class TestQuery: mock_rag.prompt_client = mock_prompt_client # Mock empty response - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") query = Query( rag=mock_rag, @@ -258,7 +259,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "test concept" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept") # Mock embeddings - one vector per concept test_vectors = [[0.1, 0.2, 0.3]] @@ -273,7 +274,7 @@ class TestQuery: expected_response = "This is the document RAG response" mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] - mock_prompt_client.document_prompt.return_value = expected_response + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response) document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -315,7 +316,8 @@ class TestQuery: assert "Relevant document content" in docs assert "Another document" in docs - assert result == expected_response + result_text, usage = result + assert result_text == expected_response @pytest.mark.asyncio async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): @@ -325,7 +327,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction fallback (empty → raw query) - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") # Mock responses mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] @@ -333,7 +335,7 @@ class TestQuery: mock_match.chunk_id = "doc/c5" mock_match.score = 0.9 mock_doc_embeddings_client.query.return_value = [mock_match] - mock_prompt_client.document_prompt.return_value = "Default response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -352,7 +354,8 @@ class TestQuery: collection="default" # Default collection ) - assert result == "Default response" + result_text, usage = result + assert result_text == "Default response" @pytest.mark.asyncio async def test_get_docs_with_verbose_output(self): @@ -401,7 +404,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "verbose query test" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test") # Mock responses mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] @@ -409,7 +412,7 @@ class TestQuery: mock_match.chunk_id = "doc/c7" mock_match.score = 0.92 mock_doc_embeddings_client.query.return_value = [mock_match] - mock_prompt_client.document_prompt.return_value = "Verbose RAG response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -428,7 +431,8 @@ class TestQuery: assert call_args.kwargs["query"] == "verbose query test" assert "Verbose doc content" in call_args.kwargs["documents"] - assert result == "Verbose RAG response" + result_text, usage = result + assert result_text == "Verbose RAG response" @pytest.mark.asyncio async def test_get_docs_with_empty_results(self): @@ -469,11 +473,11 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "query with no matching docs" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs") mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] mock_doc_embeddings_client.query.return_value = [] - mock_prompt_client.document_prompt.return_value = "No documents found response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -490,7 +494,8 @@ class TestQuery: documents=[] ) - assert result == "No documents found response" + result_text, usage = result + assert result_text == "No documents found response" @pytest.mark.asyncio async def test_get_vectors_with_verbose(self): @@ -525,7 +530,7 @@ class TestQuery: final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." # Mock concept extraction - mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence") # Mock embeddings - one vector per concept query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]] @@ -541,7 +546,7 @@ class TestQuery: MagicMock(chunk_id="doc/ml3", score=0.82), ] mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2] - mock_prompt_client.document_prompt.return_value = final_response + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response) document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -584,7 +589,8 @@ class TestQuery: assert "Common ML techniques include supervised and unsupervised learning..." in docs assert len(docs) == 3 # doc/ml2 deduplicated - assert result == final_response + result_text, usage = result + assert result_text == final_response @pytest.mark.asyncio async def test_get_docs_deduplicates_across_concepts(self): diff --git a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py index 74157285..8fa10642 100644 --- a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock from dataclasses import dataclass from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -89,8 +90,8 @@ def build_mock_clients(): # 1. Concept extraction async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": - return "return policy\nrefund" - return "" + return PromptResult(response_type="text", text="return policy\nrefund") + return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt @@ -113,8 +114,9 @@ def build_mock_clients(): fetch_chunk.side_effect = mock_fetch # 5. Synthesis - prompt_client.document_prompt.return_value = ( - "Items can be returned within 30 days for a full refund." + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="Items can be returned within 30 days for a full refund.", ) return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk @@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance: clients = build_mock_clients() rag = DocumentRag(*clients) - result = await rag.query( + result_text, usage = await rag.query( query="What is the return policy?", explain_callback=AsyncMock(), ) - assert result == "Items can be returned within 30 days for a full refund." + assert result_text == "Items can be returned within 30 days for a full refund." @pytest.mark.asyncio async def test_no_explain_callback_still_works(self): @@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance: clients = build_mock_clients() rag = DocumentRag(*clients) - result = await rag.query(query="What is the return policy?") - assert result == "Items can be returned within 30 days for a full refund." + result_text, usage = await rag.query(query="What is the return policy?") + assert result_text == "Items can be returned within 30 days for a full refund." @pytest.mark.asyncio async def test_all_triples_in_retrieval_graph(self): diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index 05e1bb60..a5d42f3a 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -34,7 +34,7 @@ class TestDocumentRagService: # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "test response" + mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None}) # Setup message with custom user/collection msg = MagicMock() @@ -97,7 +97,7 @@ class TestDocumentRagService: # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "A document about cats." + mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None}) # Setup message with non-streaming request msg = MagicMock() @@ -130,4 +130,5 @@ class TestDocumentRagService: assert isinstance(sent_response, DocumentRagResponse) assert sent_response.response == "A document about cats." assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" + assert sent_response.end_of_session is True assert sent_response.error is None \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 00d8b72a..00a9551f 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -7,6 +7,7 @@ import unittest.mock from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query +from trustgraph.base import PromptResult class TestGraphRag: @@ -172,7 +173,7 @@ class TestQuery: mock_prompt_client = AsyncMock() mock_rag.prompt_client = mock_prompt_client - mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n") query = Query( rag=mock_rag, @@ -196,7 +197,7 @@ class TestQuery: mock_prompt_client = AsyncMock() mock_rag.prompt_client = mock_prompt_client - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") query = Query( rag=mock_rag, @@ -220,7 +221,7 @@ class TestQuery: mock_rag.graph_embeddings_client = mock_graph_embeddings_client # extract_concepts returns empty -> falls back to [query] - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") # embed returns one vector set for the single concept test_vectors = [[0.1, 0.2, 0.3]] @@ -565,14 +566,14 @@ class TestQuery: # Mock prompt responses for the multi-step process async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-scoring": - return json.dumps({"id": test_edge_id, "score": 0.9}) + return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}]) elif prompt_name == "kg-edge-reasoning": - return json.dumps({"id": test_edge_id, "reasoning": "relevant"}) + return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}]) elif prompt_name == "kg-synthesis": - return expected_response - return "" + return PromptResult(response_type="text", text=expected_response) + return PromptResult(response_type="text", text="") mock_prompt_client.prompt = mock_prompt @@ -607,7 +608,8 @@ class TestQuery: explain_callback=collect_provenance ) - assert response == expected_response + response_text, usage = response + assert response_text == expected_response # 5 events: question, grounding, exploration, focus, synthesis assert len(provenance_events) == 5 diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py index 36536f7d..1eb0dd72 100644 --- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -136,24 +137,36 @@ def build_mock_clients(): async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": - return prompt_responses["extract-concepts"] + return PromptResult( + response_type="text", + text=prompt_responses["extract-concepts"], + ) elif template_id == "kg-edge-scoring": # Score all edges highly, using the IDs that GraphRag computed edges = variables.get("knowledge", []) - return [ - {"id": e["id"], "score": 10 - i} - for i, e in enumerate(edges) - ] + return PromptResult( + response_type="jsonl", + objects=[ + {"id": e["id"], "score": 10 - i} + for i, e in enumerate(edges) + ], + ) elif template_id == "kg-edge-reasoning": # Provide reasoning for each edge edges = variables.get("knowledge", []) - return [ - {"id": e["id"], "reasoning": f"Relevant edge {i}"} - for i, e in enumerate(edges) - ] + return PromptResult( + response_type="jsonl", + objects=[ + {"id": e["id"], "reasoning": f"Relevant edge {i}"} + for i, e in enumerate(edges) + ], + ) elif template_id == "kg-synthesis": - return synthesis_answer - return "" + return PromptResult( + response_type="text", + text=synthesis_answer, + ) + return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt @@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance: async def explain_callback(triples, explain_id): events.append({"triples": triples, "explain_id": explain_id}) - result = await rag.query( + result_text, usage = await rag.query( query="What is quantum computing?", explain_callback=explain_callback, edge_score_limit=0, ) - assert result == "Quantum computing applies physics principles to computation." + assert result_text == "Quantum computing applies physics principles to computation." @pytest.mark.asyncio async def test_parent_uri_links_question_to_parent(self): @@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance: clients = build_mock_clients() rag = GraphRag(*clients) - result = await rag.query( + result_text, usage = await rag.query( query="What is quantum computing?", edge_score_limit=0, ) - assert result == "Quantum computing applies physics principles to computation." + assert result_text == "Quantum computing applies physics principles to computation." @pytest.mark.asyncio async def test_all_triples_in_retrieval_graph(self): diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index 2cd62286..606aa7fe 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -44,7 +44,7 @@ class TestGraphRagService: await explain_callback([], "urn:trustgraph:prov:retrieval:test") await explain_callback([], "urn:trustgraph:prov:selection:test") await explain_callback([], "urn:trustgraph:prov:answer:test") - return "A small domesticated mammal." + return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None} mock_rag_instance.query.side_effect = mock_query @@ -79,8 +79,8 @@ class TestGraphRagService: # Execute await processor.on_request(msg, consumer, flow) - # Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session) - assert mock_response_producer.send.call_count == 6 + # Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session) + assert mock_response_producer.send.call_count == 5 # First 4 messages are explain (emitted in real-time during query) for i in range(4): @@ -88,17 +88,12 @@ class TestGraphRagService: assert prov_msg.message_type == "explain" assert prov_msg.explain_id is not None - # 5th message is chunk with response + # 5th message is chunk with response and end_of_session chunk_msg = mock_response_producer.send.call_args_list[4][0][0] assert chunk_msg.message_type == "chunk" assert chunk_msg.response == "A small domesticated mammal." assert chunk_msg.end_of_stream is True - - # 6th message is empty chunk with end_of_session=True - close_msg = mock_response_producer.send.call_args_list[5][0][0] - assert close_msg.message_type == "chunk" - assert close_msg.response == "" - assert close_msg.end_of_session is True + assert chunk_msg.end_of_session is True # Verify provenance triples were sent to provenance queue assert mock_provenance_producer.send.call_count == 4 @@ -187,7 +182,7 @@ class TestGraphRagService: async def mock_query(**kwargs): # Don't call explain_callback - return "Response text" + return "Response text", {"in_token": None, "out_token": None, "model": None} mock_rag_instance.query.side_effect = mock_query @@ -218,17 +213,12 @@ class TestGraphRagService: # Execute await processor.on_request(msg, consumer, flow) - # Verify: 2 messages (chunk + empty chunk to close) - assert mock_response_producer.send.call_count == 2 + # Verify: 1 combined message (chunk with end_of_session) + assert mock_response_producer.send.call_count == 1 - # First is the response chunk + # Single message has response and end_of_session chunk_msg = mock_response_producer.send.call_args_list[0][0][0] assert chunk_msg.message_type == "chunk" assert chunk_msg.response == "Response text" assert chunk_msg.end_of_stream is True - - # Second is empty chunk to close session - close_msg = mock_response_producer.send.call_args_list[1][0][0] - assert close_msg.message_type == "chunk" - assert close_msg.response == "" - assert close_msg.end_of_session is True + assert chunk_msg.end_of_session is True diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index 8b703dc7..2f44aad0 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -107,6 +107,7 @@ from .types import ( AgentObservation, AgentAnswer, RAGChunk, + TextCompletionResult, ProvenanceEvent, ) @@ -185,6 +186,7 @@ __all__ = [ "AgentObservation", "AgentAnswer", "RAGChunk", + "TextCompletionResult", "ProvenanceEvent", # Exceptions diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 2ff37307..68899341 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -14,6 +14,8 @@ import aiohttp import json from typing import Optional, Dict, Any, List +from . types import TextCompletionResult + from . exceptions import ProtocolException, ApplicationException @@ -434,12 +436,11 @@ class AsyncFlowInstance: return await self.request("agent", request_data) - async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str: + async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult: """ Generate text completion (non-streaming). Generates a text response from an LLM given a system prompt and user prompt. - Returns the complete response text. Note: This method does not support streaming. For streaming text generation, use AsyncSocketFlowInstance.text_completion() instead. @@ -450,19 +451,19 @@ class AsyncFlowInstance: **kwargs: Additional service-specific parameters Returns: - str: Complete generated text response + TextCompletionResult: Result with text, in_token, out_token, model Example: ```python async_flow = await api.async_flow() flow = async_flow.id("default") - # Generate text - response = await flow.text_completion( + result = await flow.text_completion( system="You are a helpful assistant.", prompt="Explain quantum computing in simple terms." ) - print(response) + print(result.text) + print(f"Tokens: {result.in_token} in, {result.out_token} out") ``` """ request_data = { @@ -473,7 +474,12 @@ class AsyncFlowInstance: request_data.update(kwargs) result = await self.request("text-completion", request_data) - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) async def graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 7a239b07..e1007556 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -4,7 +4,7 @@ import asyncio import websockets from typing import Optional, Dict, Any, AsyncIterator, Union -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult from . exceptions import ProtocolException, ApplicationException @@ -199,7 +199,10 @@ class AsyncSocketClient: return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), - end_of_dialog=resp.get("end_of_dialog", False) + end_of_dialog=resp.get("end_of_dialog", False), + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) elif chunk_type == "action": return AgentThought( @@ -211,7 +214,10 @@ class AsyncSocketClient: return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None + error=None, + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) async def aclose(self): @@ -269,7 +275,11 @@ class AsyncSocketFlowInstance: return await self.client._send_request("agent", self.flow_id, request) async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs): - """Text completion with optional streaming""" + """Text completion with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk). + """ request = { "system": system, "prompt": prompt, @@ -281,13 +291,18 @@ class AsyncSocketFlowInstance: return self._text_completion_streaming(request) else: result = await self.client._send_request("text-completion", self.flow_id, request) - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) async def _text_completion_streaming(self, request): - """Helper for streaming text completion""" + """Helper for streaming text completion. Yields RAGChunk objects.""" async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request): - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk async def graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 0aa55347..7ee32dad 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -11,7 +11,7 @@ import base64 from .. knowledge import hash, Uri, Literal, QuotedTriple from .. schema import IRI, LITERAL, TRIPLE -from . types import Triple +from . types import Triple, TextCompletionResult from . exceptions import ProtocolException @@ -360,16 +360,17 @@ class FlowInstance: prompt: User prompt/question Returns: - str: Generated response text + TextCompletionResult: Result with text, in_token, out_token, model Example: ```python flow = api.flow().id("default") - response = flow.text_completion( + result = flow.text_completion( system="You are a helpful assistant", prompt="What is quantum computing?" ) - print(response) + print(result.text) + print(f"Tokens: {result.in_token} in, {result.out_token} out") ``` """ @@ -379,10 +380,17 @@ class FlowInstance: "prompt": prompt } - return self.request( + result = self.request( "service/text-completion", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def agent(self, question, user="trustgraph", state=None, group=None, history=None): """ @@ -498,10 +506,17 @@ class FlowInstance: "edge-limit": edge_limit, } - return self.request( + result = self.request( "service/graph-rag", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def document_rag( self, query, user="trustgraph", collection="default", @@ -543,10 +558,17 @@ class FlowInstance: "doc-limit": doc_limit, } - return self.request( + result = self.request( "service/document-rag", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def embeddings(self, texts): """ diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index b6ceba00..fc238e36 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -14,7 +14,7 @@ import websockets from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult from . exceptions import ProtocolException, raise_from_error_dict @@ -393,6 +393,9 @@ class SocketClient: end_of_message=resp.get("end_of_message", False), end_of_dialog=resp.get("end_of_dialog", False), message_id=resp.get("message_id", ""), + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) elif chunk_type == "action": return AgentThought( @@ -404,7 +407,10 @@ class SocketClient: return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None + error=None, + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent: @@ -543,8 +549,12 @@ class SocketFlowInstance: streaming=True, include_provenance=True ) - def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: - """Execute text completion with optional streaming.""" + def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute text completion with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "system": system, "prompt": prompt, @@ -557,12 +567,17 @@ class SocketFlowInstance: if streaming: return self._text_completion_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) - def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]: for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk def graph_rag( self, @@ -577,8 +592,12 @@ class SocketFlowInstance: edge_limit: int = 25, streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute graph-based RAG query with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute graph-based RAG query with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "query": query, "user": user, @@ -598,7 +617,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def graph_rag_explain( self, @@ -642,8 +666,12 @@ class SocketFlowInstance: doc_limit: int = 10, streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute document-based RAG query with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute document-based RAG query with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "query": query, "user": user, @@ -658,7 +686,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def document_rag_explain( self, @@ -684,10 +717,10 @@ class SocketFlowInstance: streaming=True, include_provenance=True ) - def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]: for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk def prompt( self, @@ -695,8 +728,12 @@ class SocketFlowInstance: variables: Dict[str, str], streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute a prompt template with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute a prompt template with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "id": id, "variables": variables, @@ -709,7 +746,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("text", result.get("response", "")), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def graph_embeddings_query( self, diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 55635584..7b79c962 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -189,6 +189,9 @@ class AgentAnswer(StreamingChunk): chunk_type: str = "final-answer" end_of_dialog: bool = False message_id: str = "" + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None @dataclasses.dataclass class RAGChunk(StreamingChunk): @@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk): content: Generated text content end_of_stream: True if this is the final chunk of the stream error: Optional error information if an error occurred + in_token: Input token count (populated on the final chunk, 0 otherwise) + out_token: Output token count (populated on the final chunk, 0 otherwise) + model: Model identifier (populated on the final chunk, empty otherwise) chunk_type: Always "rag" """ chunk_type: str = "rag" end_of_stream: bool = False error: Optional[Dict[str, str]] = None + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + +@dataclasses.dataclass +class TextCompletionResult: + """ + Result from a text completion request. + + Returned by text_completion() in both streaming and non-streaming modes. + In streaming mode, text is None (chunks are delivered via the iterator). + In non-streaming mode, text contains the complete response. + + Attributes: + text: Complete response text (None in streaming mode) + in_token: Input token count (None if not available) + out_token: Output token count (None if not available) + model: Model identifier (None if not available) + """ + text: Optional[str] + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None @dataclasses.dataclass class ProvenanceEvent: diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 24b6c1f0..ce17a585 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec -from . text_completion_client import TextCompletionClientSpec -from . prompt_client import PromptClientSpec +from . text_completion_client import ( + TextCompletionClientSpec, TextCompletionClient, TextCompletionResult, +) +from . prompt_client import PromptClientSpec, PromptClient, PromptResult from . triples_store_service import TriplesStoreService from . graph_embeddings_store_service import GraphEmbeddingsStoreService from . document_embeddings_store_service import DocumentEmbeddingsStoreService diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 6859a9f0..853e7e66 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -1,10 +1,22 @@ import json import asyncio +from dataclasses import dataclass +from typing import Optional, Any from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import PromptRequest, PromptResponse +@dataclass +class PromptResult: + response_type: str # "text", "json", or "jsonl" + text: Optional[str] = None # populated for "text" + object: Any = None # populated for "json" + objects: Optional[list] = None # populated for "jsonl" + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + class PromptClient(RequestResponse): async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None): @@ -26,17 +38,40 @@ class PromptClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - if resp.text: return resp.text + if resp.text: + return PromptResult( + response_type="text", + text=resp.text, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) - return json.loads(resp.object) + parsed = json.loads(resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) else: - last_text = "" - last_object = None + last_resp = None async def forward_chunks(resp): - nonlocal last_text, last_object + nonlocal last_resp if resp.error: raise RuntimeError(resp.error.message) @@ -44,14 +79,13 @@ class PromptClient(RequestResponse): end_stream = getattr(resp, 'end_of_stream', False) if resp.text is not None: - last_text = resp.text if chunk_callback: if asyncio.iscoroutinefunction(chunk_callback): await chunk_callback(resp.text, end_stream) else: chunk_callback(resp.text, end_stream) - elif resp.object: - last_object = resp.object + + last_resp = resp return end_stream @@ -70,10 +104,36 @@ class PromptClient(RequestResponse): timeout=timeout ) - if last_text: - return last_text + if last_resp is None: + return PromptResult(response_type="text") - return json.loads(last_object) if last_object else None + if last_resp.object: + parsed = json.loads(last_resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="text", + text=last_resp.text, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) async def extract_definitions(self, text, timeout=600): return await self.prompt( @@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec): response_schema = PromptResponse, impl = PromptClient, ) - diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py index ae93e22e..876d71df 100644 --- a/trustgraph-base/trustgraph/base/text_completion_client.py +++ b/trustgraph-base/trustgraph/base/text_completion_client.py @@ -1,47 +1,71 @@ +from dataclasses import dataclass +from typing import Optional + from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import TextCompletionRequest, TextCompletionResponse +@dataclass +class TextCompletionResult: + text: Optional[str] + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + class TextCompletionClient(RequestResponse): - async def text_completion(self, system, prompt, streaming=False, timeout=600): - # If not streaming, use original behavior - if not streaming: - resp = await self.request( - TextCompletionRequest( - system = system, prompt = prompt, streaming = False - ), - timeout=timeout - ) - if resp.error: - raise RuntimeError(resp.error.message) + async def text_completion(self, system, prompt, timeout=600): - return resp.response - - # For streaming: collect all chunks and return complete response - full_response = "" - - async def collect_chunks(resp): - nonlocal full_response - - if resp.error: - raise RuntimeError(resp.error.message) - - if resp.response: - full_response += resp.response - - # Return True when end_of_stream is reached - return getattr(resp, 'end_of_stream', False) - - await self.request( + resp = await self.request( TextCompletionRequest( - system = system, prompt = prompt, streaming = True + system = system, prompt = prompt, streaming = False ), - recipient=collect_chunks, timeout=timeout ) - return full_response + if resp.error: + raise RuntimeError(resp.error.message) + + return TextCompletionResult( + text = resp.response, + in_token = resp.in_token, + out_token = resp.out_token, + model = resp.model, + ) + + async def text_completion_stream( + self, system, prompt, handler, timeout=600, + ): + """ + Streaming text completion. `handler` is an async callable invoked + once per chunk with the chunk's TextCompletionResponse. Returns a + TextCompletionResult with text=None and token counts / model taken + from the end_of_stream message. + """ + + async def on_chunk(resp): + + if resp.error: + raise RuntimeError(resp.error.message) + + await handler(resp) + + return getattr(resp, "end_of_stream", False) + + final = await self.request( + TextCompletionRequest( + system = system, prompt = prompt, streaming = True + ), + recipient=on_chunk, + timeout=timeout, + ) + + return TextCompletionResult( + text = None, + in_token = final.in_token, + out_token = final.out_token, + model = final.model, + ) class TextCompletionClientSpec(RequestResponseSpec): def __init__( @@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec): response_schema = TextCompletionResponse, impl = TextCompletionClient, ) - diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 8cf525f5..b255ea2c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "code": obj.error.code} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py index 4345e6fd..7f76bf4a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/prompt.py +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator): # Always include end_of_stream flag for streaming support result["end_of_stream"] = getattr(obj, "end_of_stream", False) + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 849bee94..e37b76e1 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: @@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py index 596ff744..62cc4afb 100644 --- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator): def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]: result = {"response": obj.response} - if obj.in_token: + if obj.in_token is not None: result["in_token"] = obj.in_token - if obj.out_token: + if obj.out_token is not None: result["out_token"] = obj.out_token - if obj.model: + if obj.model is not None: result["model"] = obj.model # Always include end_of_stream flag for streaming support diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index fbc0101c..3b3a6d01 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -66,5 +66,10 @@ class AgentResponse: error: Error | None = None + # Token usage (populated on end_of_dialog message) + in_token: int | None = None + out_token: int | None = None + model: str | None = None + ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 0fd6ab90..89c0cd54 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -17,9 +17,9 @@ class TextCompletionRequest: class TextCompletionResponse: error: Error | None = None response: str = "" - in_token: int = 0 - out_token: int = 0 - model: str = "" + in_token: int | None = None + out_token: int | None = None + model: str | None = None end_of_stream: bool = False # Indicates final message in stream ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index f7388102..1696790b 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -41,4 +41,9 @@ class PromptResponse: # Indicates final message in stream end_of_stream: bool = False + # Token usage from the underlying text completion + in_token: int | None = None + out_token: int | None = None + model: str | None = None + ############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 4b17733d..a1af9170 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -29,6 +29,9 @@ class GraphRagResponse: explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete + in_token: int | None = None + out_token: int | None = None + model: str | None = None ############################################################################ @@ -52,3 +55,6 @@ class DocumentRagResponse: explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete + in_token: int | None = None + out_token: int | None = None + model: str | None = None diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 026286d0..ddaef4ca 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -272,7 +272,8 @@ def question( url, question, flow_id, user, collection, plan=None, state=None, group=None, pattern=None, verbose=False, streaming=True, - token=None, explainable=False, debug=False + token=None, explainable=False, debug=False, + show_usage=False ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -323,6 +324,7 @@ def question( # Track last chunk type and current outputter for streaming last_chunk_type = None current_outputter = None + last_answer_chunk = None for chunk in response: chunk_type = chunk.chunk_type @@ -357,6 +359,7 @@ def question( current_outputter.word_buffer = "" elif chunk_type == "final-answer": print(content, end="", flush=True) + last_answer_chunk = chunk # Close any remaining outputter if current_outputter: @@ -366,6 +369,14 @@ def question( elif last_chunk_type == "final-answer": print() + if show_usage and last_answer_chunk: + print( + f"Input tokens: {last_answer_chunk.in_token} " + f"Output tokens: {last_answer_chunk.out_token} " + f"Model: {last_answer_chunk.model}", + file=sys.stderr, + ) + else: # Non-streaming response - but agents use multipart messaging # so we iterate through the chunks (which are complete messages, not text chunks) @@ -477,6 +488,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -496,6 +513,7 @@ def main(): token = args.token, explainable = args.explainable, debug = args.debug, + show_usage = args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 066b92f4..d566f51d 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -99,7 +99,8 @@ def question_explainable( def question( url, flow_id, question_text, user, collection, doc_limit, - streaming=True, token=None, explainable=False, debug=False + streaming=True, token=None, explainable=False, debug=False, + show_usage=False ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -133,22 +134,40 @@ def question( ) # Stream output + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) + last_chunk = chunk print() # Final newline + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=sys.stderr, + ) + finally: socket.close() else: # Use REST API for non-streaming flow = api.flow().id(flow_id) - resp = flow.document_rag( + result = flow.document_rag( query=question_text, user=user, collection=collection, doc_limit=doc_limit, ) - print(resp) + print(result.text) + + if show_usage: + print( + f"Input tokens: {result.in_token} " + f"Output tokens: {result.out_token} " + f"Model: {result.model}", + file=sys.stderr, + ) def main(): @@ -219,6 +238,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -234,6 +259,7 @@ def main(): token=args.token, explainable=args.explainable, debug=args.debug, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 230cc54b..c9efe54d 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -753,7 +753,7 @@ def question( url, flow_id, question, user, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, edge_limit=25, streaming=True, token=None, - explainable=False, debug=False + explainable=False, debug=False, show_usage=False ): # Explainable mode uses the API to capture and process provenance events @@ -798,16 +798,26 @@ def question( ) # Stream output + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) + last_chunk = chunk print() # Final newline + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=sys.stderr, + ) + finally: socket.close() else: # Use REST API for non-streaming flow = api.flow().id(flow_id) - resp = flow.graph_rag( + result = flow.graph_rag( query=question, user=user, collection=collection, @@ -818,7 +828,15 @@ def question( edge_score_limit=edge_score_limit, edge_limit=edge_limit, ) - print(resp) + print(result.text) + + if show_usage: + print( + f"Input tokens: {result.in_token} " + f"Output tokens: {result.out_token} " + f"Model: {result.model}", + file=sys.stderr, + ) def main(): @@ -923,6 +941,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -943,6 +967,7 @@ def main(): token=args.token, explainable=args.explainable, debug=args.debug, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index a1611625..3bf521f6 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -10,7 +10,8 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, system, prompt, streaming=True, token=None): +def query(url, flow_id, system, prompt, streaming=True, token=None, + show_usage=False): # Create API client api = Api(url=url, token=token) @@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None): ) if streaming: - # Stream output to stdout without newline + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) - # Add final newline after streaming + print(chunk.content, end="", flush=True) + last_chunk = chunk print() + + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=__import__('sys').stderr, + ) else: - # Non-streaming: print complete response - print(response) + print(response.text) + + if show_usage: + print( + f"Input tokens: {response.in_token} " + f"Output tokens: {response.out_token} " + f"Model: {response.model}", + file=__import__('sys').stderr, + ) finally: # Clean up socket connection @@ -82,6 +98,12 @@ def main(): help='Disable streaming (default: streaming enabled)' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -93,6 +115,7 @@ def main(): prompt=args.prompt[0], streaming=not args.no_streaming, token=args.token, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index 09cc9043..86f7a024 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -15,7 +15,8 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, template_id, variables, streaming=True, token=None): +def query(url, flow_id, template_id, variables, streaming=True, token=None, + show_usage=False): # Create API client api = Api(url=url, token=token) @@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None): ) if streaming: - # Stream output (prompt yields strings directly) + last_chunk = None for chunk in response: - if chunk: - print(chunk, end="", flush=True) - # Add final newline after streaming + if chunk.content: + print(chunk.content, end="", flush=True) + last_chunk = chunk print() + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=__import__('sys').stderr, + ) else: - # Non-streaming: print complete response - print(response) + print(response.text) + + if show_usage: + print( + f"Input tokens: {response.in_token} " + f"Output tokens: {response.out_token} " + f"Model: {response.model}", + file=__import__('sys').stderr, + ) finally: # Clean up socket connection @@ -92,6 +107,12 @@ specified multiple times''', help='Disable streaming (default: streaming enabled for text responses)' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() variables = {} @@ -113,6 +134,7 @@ specified multiple times''', variables=variables, streaming=not args.no_streaming, token=args.token, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py index c3b1afa6..97b87134 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py @@ -53,7 +53,7 @@ class MetaRouter: "general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""}, } - async def identify_task_type(self, question, context): + async def identify_task_type(self, question, context, usage=None): """ Use the LLM to classify the question into one of the known task types. @@ -71,7 +71,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="task-type-classify", variables={ "question": question, @@ -81,7 +81,9 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + if usage: + usage.track(result) + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in self.task_types: framing = self.task_types[selected].get("framing", DEFAULT_FRAMING) @@ -100,7 +102,7 @@ class MetaRouter: ) return DEFAULT_TASK_TYPE, framing - async def select_pattern(self, question, task_type, context): + async def select_pattern(self, question, task_type, context, usage=None): """ Use the LLM to select the best execution pattern for this task type. @@ -120,7 +122,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="pattern-select", variables={ "question": question, @@ -133,7 +135,9 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + if usage: + usage.track(result) + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in valid_patterns: logger.info(f"MetaRouter: selected pattern '{selected}'") @@ -148,19 +152,20 @@ class MetaRouter: logger.warning(f"MetaRouter: pattern selection failed: {e}") return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN - async def route(self, question, context): + async def route(self, question, context, usage=None): """ Full routing pipeline: identify task type, then select pattern. Args: question: The user's query. context: UserAwareContext (flow wrapper). + usage: Optional UsageTracker for token counting. Returns: (pattern, task_type, framing) tuple. """ - task_type, framing = await self.identify_task_type(question, context) - pattern = await self.select_pattern(question, task_type, context) + task_type, framing = await self.identify_task_type(question, context, usage=usage) + pattern = await self.select_pattern(question, task_type, context, usage=usage) logger.info( f"MetaRouter: route result — " f"pattern={pattern}, task_type={task_type}, framing={framing!r}" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index c18c5bac..689d57e6 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -65,6 +65,37 @@ class UserAwareContext: return client +class UsageTracker: + """Accumulates token usage across multiple prompt calls.""" + + def __init__(self): + self.total_in = 0 + self.total_out = 0 + self.last_model = None + + def track(self, result): + """Track usage from a PromptResult.""" + if result is not None: + if getattr(result, "in_token", None) is not None: + self.total_in += result.in_token + if getattr(result, "out_token", None) is not None: + self.total_out += result.out_token + if getattr(result, "model", None) is not None: + self.last_model = result.model + + @property + def in_token(self): + return self.total_in if self.total_in > 0 else None + + @property + def out_token(self): + return self.total_out if self.total_out > 0 else None + + @property + def model(self): + return self.last_model + + class PatternBase: """ Shared infrastructure for all agent patterns. @@ -571,7 +602,8 @@ class PatternBase: # ---- Response helpers --------------------------------------------------- async def prompt_as_answer(self, client, prompt_id, variables, - respond, streaming, message_id=""): + respond, streaming, message_id="", + usage=None): """Call a prompt template, forwarding chunks as answer AgentResponse messages when streaming is enabled. @@ -591,22 +623,28 @@ class PatternBase: message_id=message_id, )) - await client.prompt( + result = await client.prompt( id=prompt_id, variables=variables, streaming=True, chunk_callback=on_chunk, ) + if usage: + usage.track(result) return "".join(accumulated) else: - return await client.prompt( + result = await client.prompt( id=prompt_id, variables=variables, ) + if usage: + usage.track(result) + return result.text async def send_final_response(self, respond, streaming, answer_text, - already_streamed=False, message_id=""): + already_streamed=False, message_id="", + usage=None): """Send the answer content and end-of-dialog marker. Args: @@ -614,7 +652,16 @@ class PatternBase: via streaming callbacks (e.g. ReactPattern). Only the end-of-dialog marker is emitted. message_id: Provenance URI for the answer entity. + usage: UsageTracker with accumulated token counts. """ + usage_kwargs = {} + if usage: + usage_kwargs = { + "in_token": usage.in_token, + "out_token": usage.out_token, + "model": usage.model, + } + if streaming and not already_streamed: # Answer wasn't streamed yet — send it as a chunk first if answer_text: @@ -626,13 +673,14 @@ class PatternBase: message_id=message_id, )) if streaming: - # End-of-dialog marker + # End-of-dialog marker with usage await respond(AgentResponse( chunk_type="answer", content="", end_of_message=True, end_of_dialog=True, message_id=message_id, + **usage_kwargs, )) else: await respond(AgentResponse( @@ -641,6 +689,7 @@ class PatternBase: end_of_message=True, end_of_dialog=True, message_id=message_id, + **usage_kwargs, )) def build_next_request(self, request, history, session_id, collection, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 59d22929..8f5cdcdf 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -18,7 +18,7 @@ from trustgraph.provenance import ( agent_synthesis_uri, ) -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase): Subsequent calls execute the next pending plan step via ReACT. """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase): await self._planning_iteration( request, respond, next, flow, session_id, collection, streaming, session_uri, - iteration_num, + iteration_num, usage=usage, ) else: await self._execution_iteration( request, respond, next, flow, session_id, collection, streaming, session_uri, - iteration_num, plan, + iteration_num, plan, usage=usage, ) def _extract_plan(self, history): @@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase): async def _planning_iteration(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Ask the LLM to produce a structured plan.""" think = self.make_think_callback(respond, streaming) @@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Use the plan-create prompt template - plan_steps = await client.prompt( + result = await client.prompt( id="plan-create", variables={ "question": request.question, @@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase): ], }, ) + if usage: + usage.track(result) + plan_steps = result.objects # Validate we got a list if not isinstance(plan_steps, list) or not plan_steps: logger.warning("plan-create returned invalid result, falling back to single step") @@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase): async def _execution_iteration(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, plan): + session_uri, iteration_num, plan, + usage=None): """Execute the next pending plan step via single-shot tool call.""" pending_idx = self._find_next_pending_step(plan) @@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase): request, respond, next, flow, session_id, collection, streaming, session_uri, iteration_num, plan, + usage=usage, ) return @@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Single-shot: ask LLM which tool + arguments to use for this goal - tool_call = await client.prompt( + result = await client.prompt( id="plan-step-execute", variables={ "goal": goal, @@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase): ], }, ) + if usage: + usage.track(result) + tool_call = result.object tool_name = tool_call.get("tool", "") tool_arguments = tool_call.get("arguments", {}) @@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase): async def _synthesise(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, plan): + session_uri, iteration_num, plan, + usage=None): """Synthesise a final answer from all completed plan step results.""" think = self.make_think_callback(respond, streaming) @@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase): respond=respond, streaming=streaming, message_id=synthesis_msg_id, + usage=usage, ) # Emit synthesis provenance (links back to last step result) @@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase): await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, message_id=synthesis_msg_id, + usage=usage, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 67ded823..777f99c5 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager from ..react.types import Action, Final from ..tool_filter import get_next_state -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -37,7 +37,10 @@ class ReactPattern(PatternBase): result is appended to history and a next-request is emitted. """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -121,6 +124,7 @@ class ReactPattern(PatternBase): context=context, streaming=streaming, on_action=on_action, + usage=usage, ) logger.debug(f"Action: {act}") @@ -144,6 +148,7 @@ class ReactPattern(PatternBase): await self.send_final_response( respond, streaming, f, already_streamed=streaming, message_id=answer_msg_id, + usage=usage, ) return diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 5bf8e2fd..9a3584da 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -23,6 +23,7 @@ from ... base import Consumer, Producer from ... base import ConsumerMetrics, ProducerMetrics from ... schema import AgentRequest, AgentResponse, AgentStep, Error +from ..orchestrator.pattern_base import UsageTracker from ... schema import Triples, Metadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue @@ -493,6 +494,8 @@ class Processor(AgentService): async def agent_request(self, request, respond, next, flow): + usage = UsageTracker() + try: # Intercept subagent completion messages @@ -516,7 +519,7 @@ class Processor(AgentService): if self.meta_router: pattern, task_type, framing = await self.meta_router.route( - request.question, context, + request.question, context, usage=usage, ) else: pattern = "react" @@ -536,16 +539,16 @@ class Processor(AgentService): # Dispatch to the selected pattern if pattern == "plan-then-execute": await self.plan_pattern.iterate( - request, respond, next, flow, + request, respond, next, flow, usage=usage, ) elif pattern == "supervisor": await self.supervisor_pattern.iterate( - request, respond, next, flow, + request, respond, next, flow, usage=usage, ) else: # Default to react await self.react_pattern.iterate( - request, respond, next, flow, + request, respond, next, flow, usage=usage, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index d5537876..4b62e767 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -22,7 +22,7 @@ from trustgraph.provenance import ( agent_synthesis_uri, ) -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase): - "synthesise": triggered by aggregator with results in subagent_results """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase): request, respond, next, flow, session_id, collection, streaming, session_uri, iteration_num, + usage=usage, ) else: await self._decompose_and_fanout( request, respond, next, flow, session_id, collection, streaming, session_uri, iteration_num, + usage=usage, ) async def _decompose_and_fanout(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Decompose the question into sub-goals and fan out subagents.""" decompose_msg_id = agent_decomposition_uri(session_id) @@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase): client = context("prompt-request") # Use the supervisor-decompose prompt template - goals = await client.prompt( + result = await client.prompt( id="supervisor-decompose", variables={ "question": request.question, @@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase): ], }, ) + if usage: + usage.track(result) + goals = result.objects # Validate result if not isinstance(goals, list): goals = [] @@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase): async def _synthesise(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Synthesise final answer from subagent results.""" synthesis_msg_id = agent_synthesis_uri(session_id) @@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase): respond=respond, streaming=streaming, message_id=synthesis_msg_id, + usage=usage, ) # Emit synthesis provenance (links back to all findings) @@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase): await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, message_id=synthesis_msg_id, + usage=usage, ) diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index e86a2d6c..73686f21 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -170,7 +170,7 @@ class AgentManager: raise ValueError(f"Could not parse response: {text}") - async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None): + async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None): logger.debug(f"calling reason: {question}") @@ -255,11 +255,13 @@ class AgentManager: client = context("prompt-request") # Get streaming response - response_text = await client.agent_react( + prompt_result = await client.agent_react( variables=variables, streaming=True, chunk_callback=on_chunk ) + if usage: + usage.track(prompt_result) # Finalize parser parser.finalize() @@ -275,10 +277,13 @@ class AgentManager: # Non-streaming path - get complete text and parse client = context("prompt-request") - response_text = await client.agent_react( + prompt_result = await client.agent_react( variables=variables, streaming=False ) + if usage: + usage.track(prompt_result) + response_text = prompt_result.text logger.debug(f"Response text:\n{response_text}") @@ -292,7 +297,8 @@ class AgentManager: raise RuntimeError(f"Failed to parse agent response: {e}") async def react(self, question, history, think, observe, context, - streaming=False, answer=None, on_action=None): + streaming=False, answer=None, on_action=None, + usage=None): act = await self.reason( question = question, @@ -302,6 +308,7 @@ class AgentManager: think = think, observe = observe, answer = answer, + usage = usage, ) if isinstance(act, Final): diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 6fd96ade..c474f740 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -78,9 +78,10 @@ class TextCompletionImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug("Prompt question...") - return await client.question( + result = await client.question( arguments.get("question") ) + return result.text # This tool implementation knows how to do MCP tool invocation. This uses # the mcp-tool service. @@ -227,10 +228,11 @@ class PromptImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug(f"Prompt template invocation: {self.template_id}...") - return await client.prompt( + result = await client.prompt( id=self.template_id, variables=arguments ) + return result.text # This tool implementation invokes a dynamically configured tool service diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 2bb88c8a..9b5bbb79 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -117,10 +117,11 @@ class Processor(FlowProcessor): try: - defs = await flow("prompt-request").extract_definitions( + result = await flow("prompt-request").extract_definitions( text = chunk ) + defs = result.objects logger.debug(f"Definitions response: {defs}") if type(defs) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 29808cae..bdb0e6e8 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -376,10 +376,11 @@ class Processor(FlowProcessor): """ try: # Call prompt service with simplified format prompt - extraction_response = await flow("prompt-request").prompt( + result = await flow("prompt-request").prompt( id="extract-with-ontologies", variables=prompt_variables ) + extraction_response = result.object logger.debug(f"Simplified extraction response: {extraction_response}") # Parse response into structured format diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index b557ec32..8068a23d 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -100,10 +100,11 @@ class Processor(FlowProcessor): try: - rels = await flow("prompt-request").extract_relationships( + result = await flow("prompt-request").extract_relationships( text = chunk ) + rels = result.objects logger.debug(f"Prompt response: {rels}") if type(rels) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 8fd494b0..973bb3d7 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -148,11 +148,12 @@ class Processor(FlowProcessor): schema_dict = row_schema_translator.encode(schema) # Use prompt client to extract rows based on schema - objects = await flow("prompt-request").extract_objects( + result = await flow("prompt-request").extract_objects( schema=schema_dict, text=text ) - + + objects = result.objects if not isinstance(objects, list): return [] diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index 97298e13..c599ce77 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -11,7 +11,6 @@ import logging from ...schema import Definition, Relationship, Triple from ...schema import Topic from ...schema import PromptRequest, PromptResponse, Error -from ...schema import TextCompletionRequest, TextCompletionResponse from ...base import FlowProcessor from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec @@ -124,35 +123,26 @@ class Processor(FlowProcessor): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - # Use the text completion client with recipient handler - client = flow("text-completion-request") - async def forward_chunks(resp): - if resp.error: - raise RuntimeError(resp.error.message) - is_final = getattr(resp, 'end_of_stream', False) # Always send a message if there's content OR if it's the final message if resp.response or is_final: - # Forward each chunk immediately r = PromptResponse( text=resp.response if resp.response else "", object=None, error=None, end_of_stream=is_final, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, ) await flow("response").send(r, properties={"id": id}) - # Return True when end_of_stream - return is_final - - await client.request( - TextCompletionRequest( - system=system, prompt=prompt, streaming=True - ), - recipient=forward_chunks, - timeout=600 + await flow("text-completion-request").text_completion_stream( + system=system, prompt=prompt, + handler=forward_chunks, + timeout=600, ) # Return empty string since we already sent all chunks @@ -167,17 +157,21 @@ class Processor(FlowProcessor): return # Non-streaming path (original behavior) + usage = {} + async def llm(system, prompt): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - resp = await flow("text-completion-request").text_completion( - system = system, prompt = prompt, streaming = False, - ) - try: - return resp + result = await flow("text-completion-request").text_completion( + system = system, prompt = prompt, + ) + usage["in_token"] = result.in_token + usage["out_token"] = result.out_token + usage["model"] = result.model + return result.text except Exception as e: logger.error(f"LLM Exception: {e}", exc_info=True) return None @@ -199,6 +193,9 @@ class Processor(FlowProcessor): object=None, error=None, end_of_stream=True, + in_token=usage.get("in_token", 0), + out_token=usage.get("out_token", 0), + model=usage.get("model", ""), ) await flow("response").send(r, properties={"id": id}) @@ -215,6 +212,9 @@ class Processor(FlowProcessor): object=json.dumps(resp), error=None, end_of_stream=True, + in_token=usage.get("in_token", 0), + out_token=usage.get("out_token", 0), + model=usage.get("model", ""), ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 730a7226..a2480862 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -27,24 +27,27 @@ class Query: def __init__( self, rag, user, collection, verbose, - doc_limit=20 + doc_limit=20, track_usage=None, ): self.rag = rag self.user = user self.collection = collection self.verbose = verbose self.doc_limit = doc_limit + self.track_usage = track_usage async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) + if self.track_usage: + self.track_usage(result) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -167,8 +170,23 @@ class DocumentRag: save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian Returns: - str: The synthesized answer text + tuple: (answer_text, usage) where usage is a dict with + in_token, out_token, model """ + total_in = 0 + total_out = 0 + last_model = None + + def track_usage(result): + nonlocal total_in, total_out, last_model + if result is not None: + if result.in_token is not None: + total_in += result.in_token + if result.out_token is not None: + total_out += result.out_token + if result.model is not None: + last_model = result.model + if self.verbose: logger.debug("Constructing prompt...") @@ -191,7 +209,7 @@ class DocumentRag: q = Query( rag=self, user=user, collection=collection, verbose=self.verbose, - doc_limit=doc_limit + doc_limit=doc_limit, track_usage=track_usage, ) # Extract concepts from query (grounding step) @@ -228,19 +246,22 @@ class DocumentRag: accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) - resp = await self.prompt_client.document_prompt( + synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs, streaming=True, chunk_callback=accumulating_callback ) + track_usage(synthesis_result) # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.document_prompt( + synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs ) + track_usage(synthesis_result) + resp = synthesis_result.text if self.verbose: logger.debug("Query processing complete") @@ -273,5 +294,11 @@ class DocumentRag: if self.verbose: logger.debug(f"Emitted explain for session {session_id}") - return resp + usage = { + "in_token": total_in if total_in > 0 else None, + "out_token": total_out if total_out > 0 else None, + "model": last_model, + } + + return resp, usage diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 3b281fe3..dc7296ad 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -200,7 +200,7 @@ class Processor(FlowProcessor): # Query with streaming enabled # All chunks (including final one with end_of_stream=True) are sent via callback - await self.rag.query( + response, usage = await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -217,12 +217,15 @@ class Processor(FlowProcessor): response=None, end_of_session=True, message_type="end", + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} ) else: - # Non-streaming path (existing behavior) - response = await self.rag.query( + # Non-streaming path - single response with answer and token usage + response, usage = await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -233,11 +236,15 @@ class Processor(FlowProcessor): await flow("response").send( DocumentRagResponse( - response = response, - end_of_stream = True, - error = None + response=response, + end_of_stream=True, + end_of_session=True, + error=None, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), - properties = {"id": id} + properties={"id": id} ) logger.info("Request processing complete") diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 5cf7b991..07654c64 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -121,7 +121,7 @@ class Query: def __init__( self, rag, user, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, + max_path_length=2, track_usage=None, ): self.rag = rag self.user = user @@ -131,17 +131,20 @@ class Query: self.triple_limit = triple_limit self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length + self.track_usage = track_usage async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) + if self.track_usage: + self.track_usage(result) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -609,8 +612,24 @@ class GraphRag: save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian Returns: - str: The synthesized answer text + tuple: (answer_text, usage) where usage is a dict with + in_token, out_token, model """ + # Accumulate token usage across all prompt calls + total_in = 0 + total_out = 0 + last_model = None + + def track_usage(result): + nonlocal total_in, total_out, last_model + if result is not None: + if result.in_token is not None: + total_in += result.in_token + if result.out_token is not None: + total_out += result.out_token + if result.model is not None: + last_model = result.model + if self.verbose: logger.debug("Constructing prompt...") @@ -641,6 +660,7 @@ class GraphRag: triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + track_usage = track_usage, ) kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) @@ -751,21 +771,22 @@ class GraphRag: logger.debug(f"Built edge map with {len(edge_map)} edges") # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_response = await self.prompt_client.prompt( + scoring_result = await self.prompt_client.prompt( "kg-edge-scoring", variables={ "query": query, "knowledge": edges_with_ids } ) + track_usage(scoring_result) if self.verbose: - logger.debug(f"Edge scoring response: {scoring_response}") + logger.debug(f"Edge scoring result: {scoring_result}") - # Parse scoring response to get edge IDs with scores + # Parse scoring response (jsonl) to get edge IDs with scores scored_edges = [] - def parse_scored_edge(obj): + for obj in scoring_result.objects or []: if isinstance(obj, dict) and "id" in obj and "score" in obj: try: score = int(obj["score"]) @@ -773,21 +794,6 @@ class GraphRag: score = 0 scored_edges.append({"id": obj["id"], "score": score}) - if isinstance(scoring_response, list): - for obj in scoring_response: - parse_scored_edge(obj) - elif isinstance(scoring_response, str): - for line in scoring_response.strip().split('\n'): - line = line.strip() - if not line: - continue - try: - parse_scored_edge(json.loads(line)) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge scoring line: {line}" - ) - # Select top N edges by score scored_edges.sort(key=lambda x: x["score"], reverse=True) top_edges = scored_edges[:edge_limit] @@ -821,25 +827,30 @@ class GraphRag: ] # Run reasoning and document tracing concurrently - reasoning_task = self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) + async def _get_reasoning(): + result = await self.prompt_client.prompt( + "kg-edge-reasoning", + variables={ + "query": query, + "knowledge": selected_edges_with_ids + } + ) + track_usage(result) + return result + + reasoning_task = _get_reasoning() doc_trace_task = q.trace_source_documents(selected_edge_uris) - reasoning_response, source_documents = await asyncio.gather( + reasoning_result, source_documents = await asyncio.gather( reasoning_task, doc_trace_task, return_exceptions=True ) # Handle exceptions from gather - if isinstance(reasoning_response, Exception): + if isinstance(reasoning_result, Exception): logger.warning( - f"Edge reasoning failed: {reasoning_response}" + f"Edge reasoning failed: {reasoning_result}" ) - reasoning_response = "" + reasoning_result = None if isinstance(source_documents, Exception): logger.warning( f"Document tracing failed: {source_documents}" @@ -848,29 +859,15 @@ class GraphRag: if self.verbose: - logger.debug(f"Edge reasoning response: {reasoning_response}") + logger.debug(f"Edge reasoning result: {reasoning_result}") - # Parse reasoning response and build explainability data + # Parse reasoning response (jsonl) and build explainability data reasoning_map = {} - def parse_reasoning(obj): - if isinstance(obj, dict) and "id" in obj: - reasoning_map[obj["id"]] = obj.get("reasoning", "") - - if isinstance(reasoning_response, list): - for obj in reasoning_response: - parse_reasoning(obj) - elif isinstance(reasoning_response, str): - for line in reasoning_response.strip().split('\n'): - line = line.strip() - if not line: - continue - try: - parse_reasoning(json.loads(line)) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge reasoning line: {line}" - ) + if reasoning_result is not None: + for obj in reasoning_result.objects or []: + if isinstance(obj, dict) and "id" in obj: + reasoning_map[obj["id"]] = obj.get("reasoning", "") selected_edges_with_reasoning = [] for eid in selected_ids: @@ -919,19 +916,22 @@ class GraphRag: accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) - await self.prompt_client.prompt( + synthesis_result = await self.prompt_client.prompt( "kg-synthesis", variables=synthesis_variables, streaming=True, chunk_callback=accumulating_callback ) + track_usage(synthesis_result) # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.prompt( + synthesis_result = await self.prompt_client.prompt( "kg-synthesis", variables=synthesis_variables, ) + track_usage(synthesis_result) + resp = synthesis_result.text if self.verbose: logger.debug("Query processing complete") @@ -964,5 +964,11 @@ class GraphRag: if self.verbose: logger.debug(f"Emitted explain for session {session_id}") - return resp + usage = { + "in_token": total_in if total_in > 0 else None, + "out_token": total_out if total_out > 0 else None, + "model": last_model, + } + + return resp, usage diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index abf10e90..15c30ba1 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -332,7 +332,7 @@ class Processor(FlowProcessor): ) # Query with streaming and real-time explain - response = await rag.query( + response, usage = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -348,7 +348,7 @@ class Processor(FlowProcessor): else: # Non-streaming path with real-time explain - response = await rag.query( + response, usage = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -360,23 +360,30 @@ class Processor(FlowProcessor): parent_uri = v.parent_uri, ) - # Send chunk with response + # Send single response with answer and token usage await flow("response").send( GraphRagResponse( message_type="chunk", response=response, end_of_stream=True, - error=None, + end_of_session=True, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} ) + return - # Send final message to close session + # Streaming: send final message to close session with token usage await flow("response").send( GraphRagResponse( message_type="chunk", response="", end_of_session=True, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} )