mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
parent
67cfa80836
commit
14e49d83c7
60 changed files with 1252 additions and 577 deletions
|
|
@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
|
|||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -28,11 +29,14 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
|
|
@ -40,7 +44,10 @@ Args: {
|
|||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
text_completion_client.question.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Machine learning involves algorithms that improve through experience."
|
||||
)
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
|
|
@ -147,8 +154,11 @@ Args: {
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -284,11 +300,14 @@ Args: {{
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -321,11 +340,14 @@ Args: {
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -372,9 +394,12 @@ Args: {
|
|||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -507,7 +532,10 @@ Args: {
|
|||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=test_case["response"]
|
||||
)
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
|
|
@ -527,13 +555,16 @@ Args: {
|
|||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -541,7 +572,9 @@ Args: {
|
|||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
|
|
@ -550,6 +583,7 @@ Args: {
|
|||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -560,7 +594,9 @@ Args: {
|
|||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
|
|
@ -568,6 +604,7 @@ Action: knowledge_query
|
|||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -575,13 +612,16 @@ Args: {
|
|||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -593,13 +633,16 @@ This covers all aspects of the question."""
|
|||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -608,7 +651,9 @@ Args: {
|
|||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
|
|
@ -621,6 +666,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -632,7 +678,9 @@ Args: {
|
|||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
|
|
@ -642,6 +690,7 @@ Final Answer: {
|
|||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -792,11 +841,14 @@ Final Answer: {
|
|||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||
from trustgraph.agent.react.types import Tool, Argument
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_agent_streaming_chunks,
|
||||
assert_streaming_chunks_valid,
|
||||
|
|
@ -51,10 +52,10 @@ Args: {
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.agent_react.side_effect = agent_react_streaming
|
||||
return client
|
||||
|
|
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
return PromptResult(response_type="text", text=response)
|
||||
return PromptResult(response_type="text", text=response)
|
||||
|
||||
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -173,11 +177,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -250,11 +257,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -339,11 +349,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -447,11 +460,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
|
|||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
)
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Verify final response
|
||||
result, usage = result
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
|
|
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
|
|||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="I couldn't find any relevant documents for your query."
|
||||
)
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
result_text, usage = result
|
||||
assert result_text == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
|
|||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
result_text, usage = result
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert result_text == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert len(result) > 0
|
||||
assert len(result_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||
|
|
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
non_streaming_text, _ = non_streaming_result
|
||||
streaming_text, _ = streaming_result
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||
|
|
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
result_text, usage = result
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert result_text is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
result_text, usage = result
|
||||
assert isinstance(result_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||
|
|
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
# Verify doc_limit was passed correctly
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
|
|||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
|
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
|
|||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
|
|||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
response, usage = response
|
||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
|
|
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_response == non_streaming_response
|
||||
non_streaming_text, _ = non_streaming_response
|
||||
streaming_text, _ = streaming_response
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
response_text, usage = response
|
||||
assert isinstance(response_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
|
|||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
prompt_client.extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
prompt_client.extract_relationships.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
|
|
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="invalid format"
|
||||
) # Should be jsonl with objects list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
|
|||
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
|
||||
if schema_name == "customer_records":
|
||||
if "john" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "jane" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
elif schema_name == "product_catalog":
|
||||
if "laptop" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "book" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
prompt_client.extract_objects.side_effect = mock_extract_objects
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.prompt.template.service import Processor
|
||||
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||
from trustgraph.base.text_completion_client import TextCompletionResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -27,34 +28,52 @@ class TestPromptStreaming:
|
|||
# Mock text completion client with streaming
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def streaming_request(request, recipient=None, timeout=600):
|
||||
"""Simulate streaming text completion"""
|
||||
if request.streaming and recipient:
|
||||
# Simulate streaming chunks
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
# Streaming chunks to send
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=is_final
|
||||
)
|
||||
final = await recipient(response)
|
||||
if final:
|
||||
break
|
||||
|
||||
# Final empty chunk
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
|
||||
"""Simulate streaming text completion via text_completion_stream"""
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
end_of_stream=False
|
||||
)
|
||||
await handler(response)
|
||||
|
||||
text_completion_client.request = streaming_request
|
||||
# Send final empty chunk with end_of_stream
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
return TextCompletionResult(
|
||||
text=None,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, timeout=600):
|
||||
"""Simulate non-streaming text completion"""
|
||||
full_text = "Machine learning is a field of artificial intelligence."
|
||||
return TextCompletionResult(
|
||||
text=full_text,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=streaming_text_completion_stream
|
||||
)
|
||||
text_completion_client.text_completion = AsyncMock(
|
||||
side_effect=non_streaming_text_completion
|
||||
)
|
||||
|
||||
# Mock response producer
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -156,14 +175,6 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock non-streaming text completion
|
||||
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||
return "AI is the simulation of human intelligence in machines."
|
||||
|
||||
text_completion_client.text_completion = non_streaming_text_completion
|
||||
|
||||
# Act
|
||||
await prompt_processor_streaming.on_request(
|
||||
message, consumer, mock_flow_context_streaming
|
||||
|
|
@ -218,17 +229,12 @@ class TestPromptStreaming:
|
|||
# Mock text completion client that raises an error
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def failing_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Send error response with proper Error schema
|
||||
error_response = TextCompletionResponse(
|
||||
response="",
|
||||
error=Error(message="Text completion error", type="processing_error"),
|
||||
end_of_stream=True
|
||||
)
|
||||
await recipient(error_response)
|
||||
async def failing_stream(system, prompt, handler, timeout=600):
|
||||
raise RuntimeError("Text completion error")
|
||||
|
||||
text_completion_client.request = failing_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=failing_stream
|
||||
)
|
||||
|
||||
# Mock response producer to capture error response
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -255,22 +261,15 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Act - The service catches errors and sends error responses, doesn't raise
|
||||
# Act - The service catches errors and sends an error PromptResponse
|
||||
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||
|
||||
# Assert - Verify error response was sent
|
||||
assert response_producer.send.call_count > 0
|
||||
|
||||
# Check that at least one response contains an error
|
||||
error_sent = False
|
||||
for call in response_producer.send.call_args_list:
|
||||
response = call.args[0]
|
||||
if hasattr(response, 'error') and response.error:
|
||||
error_sent = True
|
||||
assert "Text completion error" in response.error.message
|
||||
break
|
||||
|
||||
assert error_sent, "Expected error response to be sent"
|
||||
# Assert - error response was sent
|
||||
calls = response_producer.send.call_args_list
|
||||
assert len(calls) > 0
|
||||
error_response = calls[-1].args[0]
|
||||
assert error_response.error is not None
|
||||
assert "Text completion error" in error_response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||
|
|
@ -315,21 +314,22 @@ class TestPromptStreaming:
|
|||
# Mock text completion that sends empty chunks
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||
if request.streaming and recipient:
|
||||
# Send empty chunk followed by final marker
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
async def empty_streaming(system, prompt, handler, timeout=600):
|
||||
# Send empty chunk followed by final marker
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
text_completion_client.request = empty_streaming_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=empty_streaming
|
||||
)
|
||||
response_producer = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
|
|
@ -401,4 +401,4 @@ class TestPromptStreaming:
|
|||
|
||||
# Verify chunks concatenate to expected result
|
||||
full_text = "".join(chunk_texts)
|
||||
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||
assert full_text == "Machine learning is a field of artificial intelligence."
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
|||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
|
|||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
|
|
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
|
|||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="The answer is here.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "Document summary."
|
||||
return PromptResult(response_type="text", text="Document summary.")
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="textmore")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
|
|
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
|
|||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
client.prompt = AsyncMock(
|
||||
return_value=PromptResult(response_type="text", text=prompt_response)
|
||||
)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
|
@ -274,8 +277,8 @@ class TestRoute:
|
|||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
return PromptResult(response_type="text", text="research")
|
||||
return PromptResult(response_type="text", text="plan-then-execute")
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
|||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
|
|||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
|
|
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
|
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
|
|
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="json",
|
||||
object={
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
},
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
|
|||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B", "Goal C"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
|
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
if isinstance(prompt_result, list):
|
||||
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
|
||||
else:
|
||||
wrapped = PromptResult(response_type="text", text=prompt_result)
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=wrapped
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
|
|||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=prompt_result,
|
||||
)
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content mapping for tests
|
||||
|
|
@ -132,7 +133,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -157,7 +158,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -258,7 +259,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -273,7 +274,7 @@ class TestQuery:
|
|||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -315,7 +316,8 @@ class TestQuery:
|
|||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
assert result == expected_response
|
||||
result_text, usage = result
|
||||
assert result_text == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
|
|
@ -325,7 +327,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
|
|
@ -333,7 +335,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c5"
|
||||
mock_match.score = 0.9
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -352,7 +354,8 @@ class TestQuery:
|
|||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
|
|
@ -401,7 +404,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
|
|
@ -409,7 +412,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c7"
|
||||
mock_match.score = 0.92
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -428,7 +431,8 @@ class TestQuery:
|
|||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
|
|
@ -469,11 +473,11 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -490,7 +494,8 @@ class TestQuery:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
result_text, usage = result
|
||||
assert result_text == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
|
|
@ -525,7 +530,7 @@ class TestQuery:
|
|||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
|
|
@ -541,7 +546,7 @@ class TestQuery:
|
|||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -584,7 +589,8 @@ class TestQuery:
|
|||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
assert result == final_response
|
||||
result_text, usage = result
|
||||
assert result_text == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
|
|||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -89,8 +90,8 @@ def build_mock_clients():
|
|||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -113,8 +114,9 @@ def build_mock_clients():
|
|||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
|
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
|
|
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
result_text, usage = await rag.query(query="What is the return policy?")
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "test response"
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
msg = MagicMock()
|
||||
|
|
@ -97,7 +97,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -130,4 +130,5 @@ class TestDocumentRagService:
|
|||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.end_of_session is True
|
||||
assert sent_response.error is None
|
||||
|
|
@ -7,6 +7,7 @@ import unittest.mock
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
|
|
@ -172,7 +173,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -196,7 +197,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -220,7 +221,7 @@ class TestQuery:
|
|||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -565,14 +566,14 @@ class TestQuery:
|
|||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
|
|
@ -607,7 +608,8 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
assert response == expected_response
|
||||
response_text, usage = response
|
||||
assert response_text == expected_response
|
||||
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -136,24 +137,36 @@ def build_mock_clients():
|
|||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=synthesis_answer,
|
||||
)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
|
|||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
|
|
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestGraphRagService:
|
|||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -79,8 +79,8 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 5
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
|
|
@ -88,17 +88,12 @@ class TestGraphRagService:
|
|||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
# 5th message is chunk with response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
|
@ -187,7 +182,7 @@ class TestGraphRagService:
|
|||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
return "Response text", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -218,17 +213,12 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
# Verify: 1 combined message (chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 1
|
||||
|
||||
# First is the response chunk
|
||||
# Single message has response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ from .types import (
|
|||
AgentObservation,
|
||||
AgentAnswer,
|
||||
RAGChunk,
|
||||
TextCompletionResult,
|
||||
ProvenanceEvent,
|
||||
)
|
||||
|
||||
|
|
@ -185,6 +186,7 @@ __all__ = [
|
|||
"AgentObservation",
|
||||
"AgentAnswer",
|
||||
"RAGChunk",
|
||||
"TextCompletionResult",
|
||||
"ProvenanceEvent",
|
||||
|
||||
# Exceptions
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ import aiohttp
|
|||
import json
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from . types import TextCompletionResult
|
||||
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
|
|
@ -434,12 +436,11 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("agent", request_data)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
|
||||
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
|
||||
"""
|
||||
Generate text completion (non-streaming).
|
||||
|
||||
Generates a text response from an LLM given a system prompt and user prompt.
|
||||
Returns the complete response text.
|
||||
|
||||
Note: This method does not support streaming. For streaming text generation,
|
||||
use AsyncSocketFlowInstance.text_completion() instead.
|
||||
|
|
@ -450,19 +451,19 @@ class AsyncFlowInstance:
|
|||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
str: Complete generated text response
|
||||
TextCompletionResult: Result with text, in_token, out_token, model
|
||||
|
||||
Example:
|
||||
```python
|
||||
async_flow = await api.async_flow()
|
||||
flow = async_flow.id("default")
|
||||
|
||||
# Generate text
|
||||
response = await flow.text_completion(
|
||||
result = await flow.text_completion(
|
||||
system="You are a helpful assistant.",
|
||||
prompt="Explain quantum computing in simple terms."
|
||||
)
|
||||
print(response)
|
||||
print(result.text)
|
||||
print(f"Tokens: {result.in_token} in, {result.out_token} out")
|
||||
```
|
||||
"""
|
||||
request_data = {
|
||||
|
|
@ -473,7 +474,12 @@ class AsyncFlowInstance:
|
|||
request_data.update(kwargs)
|
||||
|
||||
result = await self.request("text-completion", request_data)
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import websockets
|
||||
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
|
|
@ -199,7 +199,10 @@ class AsyncSocketClient:
|
|||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
end_of_dialog=resp.get("end_of_dialog", False),
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
return AgentThought(
|
||||
|
|
@ -211,7 +214,10 @@ class AsyncSocketClient:
|
|||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None
|
||||
error=None,
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
|
|
@ -269,7 +275,11 @@ class AsyncSocketFlowInstance:
|
|||
return await self.client._send_request("agent", self.flow_id, request)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
|
||||
"""Text completion with optional streaming"""
|
||||
"""Text completion with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
|
|
@ -281,13 +291,18 @@ class AsyncSocketFlowInstance:
|
|||
return self._text_completion_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("text-completion", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
async def _text_completion_streaming(self, request):
|
||||
"""Helper for streaming text completion"""
|
||||
"""Helper for streaming text completion. Yields RAGChunk objects."""
|
||||
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import base64
|
|||
|
||||
from .. knowledge import hash, Uri, Literal, QuotedTriple
|
||||
from .. schema import IRI, LITERAL, TRIPLE
|
||||
from . types import Triple
|
||||
from . types import Triple, TextCompletionResult
|
||||
from . exceptions import ProtocolException
|
||||
|
||||
|
||||
|
|
@ -360,16 +360,17 @@ class FlowInstance:
|
|||
prompt: User prompt/question
|
||||
|
||||
Returns:
|
||||
str: Generated response text
|
||||
TextCompletionResult: Result with text, in_token, out_token, model
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
response = flow.text_completion(
|
||||
result = flow.text_completion(
|
||||
system="You are a helpful assistant",
|
||||
prompt="What is quantum computing?"
|
||||
)
|
||||
print(response)
|
||||
print(result.text)
|
||||
print(f"Tokens: {result.in_token} in, {result.out_token} out")
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -379,10 +380,17 @@ class FlowInstance:
|
|||
"prompt": prompt
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/text-completion",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
|
||||
"""
|
||||
|
|
@ -498,10 +506,17 @@ class FlowInstance:
|
|||
"edge-limit": edge_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/graph-rag",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def document_rag(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
|
|
@ -543,10 +558,17 @@ class FlowInstance:
|
|||
"doc-limit": doc_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/document-rag",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def embeddings(self, texts):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import websockets
|
|||
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||
from threading import Lock
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
|
||||
from . exceptions import ProtocolException, raise_from_error_dict
|
||||
|
||||
|
||||
|
|
@ -393,6 +393,9 @@ class SocketClient:
|
|||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
return AgentThought(
|
||||
|
|
@ -404,7 +407,10 @@ class SocketClient:
|
|||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None
|
||||
error=None,
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
|
||||
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
|
||||
|
|
@ -543,8 +549,12 @@ class SocketFlowInstance:
|
|||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
|
||||
"""Execute text completion with optional streaming."""
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute text completion with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
|
|
@ -557,12 +567,17 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._text_completion_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
def graph_rag(
|
||||
self,
|
||||
|
|
@ -577,8 +592,12 @@ class SocketFlowInstance:
|
|||
edge_limit: int = 25,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute graph-based RAG query with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute graph-based RAG query with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
|
|
@ -598,7 +617,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def graph_rag_explain(
|
||||
self,
|
||||
|
|
@ -642,8 +666,12 @@ class SocketFlowInstance:
|
|||
doc_limit: int = 10,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute document-based RAG query with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute document-based RAG query with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
|
|
@ -658,7 +686,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def document_rag_explain(
|
||||
self,
|
||||
|
|
@ -684,10 +717,10 @@ class SocketFlowInstance:
|
|||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
|
|
@ -695,8 +728,12 @@ class SocketFlowInstance:
|
|||
variables: Dict[str, str],
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute a prompt template with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute a prompt template with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"id": id,
|
||||
"variables": variables,
|
||||
|
|
@ -709,7 +746,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("text", result.get("response", "")),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def graph_embeddings_query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -189,6 +189,9 @@ class AgentAnswer(StreamingChunk):
|
|||
chunk_type: str = "final-answer"
|
||||
end_of_dialog: bool = False
|
||||
message_id: str = ""
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RAGChunk(StreamingChunk):
|
||||
|
|
@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
|
|||
content: Generated text content
|
||||
end_of_stream: True if this is the final chunk of the stream
|
||||
error: Optional error information if an error occurred
|
||||
in_token: Input token count (populated on the final chunk, 0 otherwise)
|
||||
out_token: Output token count (populated on the final chunk, 0 otherwise)
|
||||
model: Model identifier (populated on the final chunk, empty otherwise)
|
||||
chunk_type: Always "rag"
|
||||
"""
|
||||
chunk_type: str = "rag"
|
||||
end_of_stream: bool = False
|
||||
error: Optional[Dict[str, str]] = None
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TextCompletionResult:
|
||||
"""
|
||||
Result from a text completion request.
|
||||
|
||||
Returned by text_completion() in both streaming and non-streaming modes.
|
||||
In streaming mode, text is None (chunks are delivered via the iterator).
|
||||
In non-streaming mode, text contains the complete response.
|
||||
|
||||
Attributes:
|
||||
text: Complete response text (None in streaming mode)
|
||||
in_token: Input token count (None if not available)
|
||||
out_token: Output token count (None if not available)
|
||||
model: Model identifier (None if not available)
|
||||
"""
|
||||
text: Optional[str]
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProvenanceEvent:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
|
|||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
from . text_completion_client import TextCompletionClientSpec
|
||||
from . prompt_client import PromptClientSpec
|
||||
from . text_completion_client import (
|
||||
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
|
||||
)
|
||||
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
|
||||
from . triples_store_service import TriplesStoreService
|
||||
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
|
||||
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
|
||||
|
|
|
|||
|
|
@ -1,10 +1,22 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
response_type: str # "text", "json", or "jsonl"
|
||||
text: Optional[str] = None # populated for "text"
|
||||
object: Any = None # populated for "json"
|
||||
objects: Optional[list] = None # populated for "jsonl"
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
|
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text: return resp.text
|
||||
if resp.text:
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=resp.text,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return json.loads(resp.object)
|
||||
parsed = json.loads(resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
last_text = ""
|
||||
last_object = None
|
||||
last_resp = None
|
||||
|
||||
async def forward_chunks(resp):
|
||||
nonlocal last_text, last_object
|
||||
nonlocal last_resp
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
|
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
|
|||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
if resp.text is not None:
|
||||
last_text = resp.text
|
||||
if chunk_callback:
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
last_object = resp.object
|
||||
|
||||
last_resp = resp
|
||||
|
||||
return end_stream
|
||||
|
||||
|
|
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
if last_text:
|
||||
return last_text
|
||||
if last_resp is None:
|
||||
return PromptResult(response_type="text")
|
||||
|
||||
return json.loads(last_object) if last_object else None
|
||||
if last_resp.object:
|
||||
parsed = json.loads(last_resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=last_resp.text,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
|
|||
response_schema = PromptResponse,
|
||||
impl = PromptClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,71 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
@dataclass
|
||||
class TextCompletionResult:
|
||||
text: Optional[str]
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
class TextCompletionClient(RequestResponse):
|
||||
async def text_completion(self, system, prompt, streaming=False, timeout=600):
|
||||
# If not streaming, use original behavior
|
||||
if not streaming:
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
async def text_completion(self, system, prompt, timeout=600):
|
||||
|
||||
return resp.response
|
||||
|
||||
# For streaming: collect all chunks and return complete response
|
||||
full_response = ""
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_response
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.response:
|
||||
full_response += resp.response
|
||||
|
||||
# Return True when end_of_stream is reached
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return full_response
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = resp.response,
|
||||
in_token = resp.in_token,
|
||||
out_token = resp.out_token,
|
||||
model = resp.model,
|
||||
)
|
||||
|
||||
async def text_completion_stream(
|
||||
self, system, prompt, handler, timeout=600,
|
||||
):
|
||||
"""
|
||||
Streaming text completion. `handler` is an async callable invoked
|
||||
once per chunk with the chunk's TextCompletionResponse. Returns a
|
||||
TextCompletionResult with text=None and token counts / model taken
|
||||
from the end_of_stream message.
|
||||
"""
|
||||
|
||||
async def on_chunk(resp):
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
await handler(resp)
|
||||
|
||||
return getattr(resp, "end_of_stream", False)
|
||||
|
||||
final = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
),
|
||||
recipient=on_chunk,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = None,
|
||||
in_token = final.in_token,
|
||||
out_token = final.out_token,
|
||||
model = final.model,
|
||||
)
|
||||
|
||||
class TextCompletionClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
|
|||
response_schema = TextCompletionResponse,
|
||||
impl = TextCompletionClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "code": obj.error.code}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
|
|||
# Always include end_of_stream flag for streaming support
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
|||
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||
result = {"response": obj.response}
|
||||
|
||||
if obj.in_token:
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token:
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model:
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
# Always include end_of_stream flag for streaming support
|
||||
|
|
|
|||
|
|
@ -66,5 +66,10 @@ class AgentResponse:
|
|||
|
||||
error: Error | None = None
|
||||
|
||||
# Token usage (populated on end_of_dialog message)
|
||||
in_token: int | None = None
|
||||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ class TextCompletionRequest:
|
|||
class TextCompletionResponse:
|
||||
error: Error | None = None
|
||||
response: str = ""
|
||||
in_token: int = 0
|
||||
out_token: int = 0
|
||||
model: str = ""
|
||||
in_token: int | None = None
|
||||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
end_of_stream: bool = False # Indicates final message in stream
|
||||
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -41,4 +41,9 @@ class PromptResponse:
|
|||
# Indicates final message in stream
|
||||
end_of_stream: bool = False
|
||||
|
||||
# Token usage from the underlying text completion
|
||||
in_token: int | None = None
|
||||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
|
|
@ -29,6 +29,9 @@ class GraphRagResponse:
|
|||
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
in_token: int | None = None
|
||||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -52,3 +55,6 @@ class DocumentRagResponse:
|
|||
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False # Entire session complete
|
||||
in_token: int | None = None
|
||||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
|
|
|||
|
|
@ -272,7 +272,8 @@ def question(
|
|||
url, question, flow_id, user, collection,
|
||||
plan=None, state=None, group=None, pattern=None,
|
||||
verbose=False, streaming=True,
|
||||
token=None, explainable=False, debug=False
|
||||
token=None, explainable=False, debug=False,
|
||||
show_usage=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
|
|
@ -323,6 +324,7 @@ def question(
|
|||
# Track last chunk type and current outputter for streaming
|
||||
last_chunk_type = None
|
||||
current_outputter = None
|
||||
last_answer_chunk = None
|
||||
|
||||
for chunk in response:
|
||||
chunk_type = chunk.chunk_type
|
||||
|
|
@ -357,6 +359,7 @@ def question(
|
|||
current_outputter.word_buffer = ""
|
||||
elif chunk_type == "final-answer":
|
||||
print(content, end="", flush=True)
|
||||
last_answer_chunk = chunk
|
||||
|
||||
# Close any remaining outputter
|
||||
if current_outputter:
|
||||
|
|
@ -366,6 +369,14 @@ def question(
|
|||
elif last_chunk_type == "final-answer":
|
||||
print()
|
||||
|
||||
if show_usage and last_answer_chunk:
|
||||
print(
|
||||
f"Input tokens: {last_answer_chunk.in_token} "
|
||||
f"Output tokens: {last_answer_chunk.out_token} "
|
||||
f"Model: {last_answer_chunk.model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
else:
|
||||
# Non-streaming response - but agents use multipart messaging
|
||||
# so we iterate through the chunks (which are complete messages, not text chunks)
|
||||
|
|
@ -477,6 +488,12 @@ def main():
|
|||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-usage',
|
||||
action='store_true',
|
||||
help='Show token usage and model on stderr'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -496,6 +513,7 @@ def main():
|
|||
token = args.token,
|
||||
explainable = args.explainable,
|
||||
debug = args.debug,
|
||||
show_usage = args.show_usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -99,7 +99,8 @@ def question_explainable(
|
|||
|
||||
def question(
|
||||
url, flow_id, question_text, user, collection, doc_limit,
|
||||
streaming=True, token=None, explainable=False, debug=False
|
||||
streaming=True, token=None, explainable=False, debug=False,
|
||||
show_usage=False
|
||||
):
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
if explainable:
|
||||
|
|
@ -133,22 +134,40 @@ def question(
|
|||
)
|
||||
|
||||
# Stream output
|
||||
last_chunk = None
|
||||
for chunk in response:
|
||||
print(chunk, end="", flush=True)
|
||||
print(chunk.content, end="", flush=True)
|
||||
last_chunk = chunk
|
||||
print() # Final newline
|
||||
|
||||
if show_usage and last_chunk:
|
||||
print(
|
||||
f"Input tokens: {last_chunk.in_token} "
|
||||
f"Output tokens: {last_chunk.out_token} "
|
||||
f"Model: {last_chunk.model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
else:
|
||||
# Use REST API for non-streaming
|
||||
flow = api.flow().id(flow_id)
|
||||
resp = flow.document_rag(
|
||||
result = flow.document_rag(
|
||||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
)
|
||||
print(resp)
|
||||
print(result.text)
|
||||
|
||||
if show_usage:
|
||||
print(
|
||||
f"Input tokens: {result.in_token} "
|
||||
f"Output tokens: {result.out_token} "
|
||||
f"Model: {result.model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -219,6 +238,12 @@ def main():
|
|||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-usage',
|
||||
action='store_true',
|
||||
help='Show token usage and model on stderr'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -234,6 +259,7 @@ def main():
|
|||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
debug=args.debug,
|
||||
show_usage=args.show_usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -753,7 +753,7 @@ def question(
|
|||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=50,
|
||||
edge_limit=25, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
explainable=False, debug=False, show_usage=False
|
||||
):
|
||||
|
||||
# Explainable mode uses the API to capture and process provenance events
|
||||
|
|
@ -798,16 +798,26 @@ def question(
|
|||
)
|
||||
|
||||
# Stream output
|
||||
last_chunk = None
|
||||
for chunk in response:
|
||||
print(chunk, end="", flush=True)
|
||||
print(chunk.content, end="", flush=True)
|
||||
last_chunk = chunk
|
||||
print() # Final newline
|
||||
|
||||
if show_usage and last_chunk:
|
||||
print(
|
||||
f"Input tokens: {last_chunk.in_token} "
|
||||
f"Output tokens: {last_chunk.out_token} "
|
||||
f"Model: {last_chunk.model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
else:
|
||||
# Use REST API for non-streaming
|
||||
flow = api.flow().id(flow_id)
|
||||
resp = flow.graph_rag(
|
||||
result = flow.graph_rag(
|
||||
query=question,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -818,7 +828,15 @@ def question(
|
|||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
)
|
||||
print(resp)
|
||||
print(result.text)
|
||||
|
||||
if show_usage:
|
||||
print(
|
||||
f"Input tokens: {result.in_token} "
|
||||
f"Output tokens: {result.out_token} "
|
||||
f"Model: {result.model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -923,6 +941,12 @@ def main():
|
|||
help='Show debug output for troubleshooting'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-usage',
|
||||
action='store_true',
|
||||
help='Show token usage and model on stderr'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -943,6 +967,7 @@ def main():
|
|||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
debug=args.debug,
|
||||
show_usage=args.show_usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from trustgraph.api import Api
|
|||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def query(url, flow_id, system, prompt, streaming=True, token=None):
|
||||
def query(url, flow_id, system, prompt, streaming=True, token=None,
|
||||
show_usage=False):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
|
|||
)
|
||||
|
||||
if streaming:
|
||||
# Stream output to stdout without newline
|
||||
last_chunk = None
|
||||
for chunk in response:
|
||||
print(chunk, end="", flush=True)
|
||||
# Add final newline after streaming
|
||||
print(chunk.content, end="", flush=True)
|
||||
last_chunk = chunk
|
||||
print()
|
||||
|
||||
if show_usage and last_chunk:
|
||||
print(
|
||||
f"Input tokens: {last_chunk.in_token} "
|
||||
f"Output tokens: {last_chunk.out_token} "
|
||||
f"Model: {last_chunk.model}",
|
||||
file=__import__('sys').stderr,
|
||||
)
|
||||
else:
|
||||
# Non-streaming: print complete response
|
||||
print(response)
|
||||
print(response.text)
|
||||
|
||||
if show_usage:
|
||||
print(
|
||||
f"Input tokens: {response.in_token} "
|
||||
f"Output tokens: {response.out_token} "
|
||||
f"Model: {response.model}",
|
||||
file=__import__('sys').stderr,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
|
|
@ -82,6 +98,12 @@ def main():
|
|||
help='Disable streaming (default: streaming enabled)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-usage',
|
||||
action='store_true',
|
||||
help='Show token usage and model on stderr'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -93,6 +115,7 @@ def main():
|
|||
prompt=args.prompt[0],
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
show_usage=args.show_usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from trustgraph.api import Api
|
|||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def query(url, flow_id, template_id, variables, streaming=True, token=None):
|
||||
def query(url, flow_id, template_id, variables, streaming=True, token=None,
|
||||
show_usage=False):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
|
|||
)
|
||||
|
||||
if streaming:
|
||||
# Stream output (prompt yields strings directly)
|
||||
last_chunk = None
|
||||
for chunk in response:
|
||||
if chunk:
|
||||
print(chunk, end="", flush=True)
|
||||
# Add final newline after streaming
|
||||
if chunk.content:
|
||||
print(chunk.content, end="", flush=True)
|
||||
last_chunk = chunk
|
||||
print()
|
||||
|
||||
if show_usage and last_chunk:
|
||||
print(
|
||||
f"Input tokens: {last_chunk.in_token} "
|
||||
f"Output tokens: {last_chunk.out_token} "
|
||||
f"Model: {last_chunk.model}",
|
||||
file=__import__('sys').stderr,
|
||||
)
|
||||
else:
|
||||
# Non-streaming: print complete response
|
||||
print(response)
|
||||
print(response.text)
|
||||
|
||||
if show_usage:
|
||||
print(
|
||||
f"Input tokens: {response.in_token} "
|
||||
f"Output tokens: {response.out_token} "
|
||||
f"Model: {response.model}",
|
||||
file=__import__('sys').stderr,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
|
|
@ -92,6 +107,12 @@ specified multiple times''',
|
|||
help='Disable streaming (default: streaming enabled for text responses)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--show-usage',
|
||||
action='store_true',
|
||||
help='Show token usage and model on stderr'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
variables = {}
|
||||
|
|
@ -113,6 +134,7 @@ specified multiple times''',
|
|||
variables=variables,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
show_usage=args.show_usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class MetaRouter:
|
|||
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
|
||||
}
|
||||
|
||||
async def identify_task_type(self, question, context):
|
||||
async def identify_task_type(self, question, context, usage=None):
|
||||
"""
|
||||
Use the LLM to classify the question into one of the known task types.
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="task-type-classify",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -81,7 +81,9 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
if usage:
|
||||
usage.track(result)
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in self.task_types:
|
||||
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
|
||||
|
|
@ -100,7 +102,7 @@ class MetaRouter:
|
|||
)
|
||||
return DEFAULT_TASK_TYPE, framing
|
||||
|
||||
async def select_pattern(self, question, task_type, context):
|
||||
async def select_pattern(self, question, task_type, context, usage=None):
|
||||
"""
|
||||
Use the LLM to select the best execution pattern for this task type.
|
||||
|
||||
|
|
@ -120,7 +122,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="pattern-select",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -133,7 +135,9 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
if usage:
|
||||
usage.track(result)
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in valid_patterns:
|
||||
logger.info(f"MetaRouter: selected pattern '{selected}'")
|
||||
|
|
@ -148,19 +152,20 @@ class MetaRouter:
|
|||
logger.warning(f"MetaRouter: pattern selection failed: {e}")
|
||||
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
|
||||
|
||||
async def route(self, question, context):
|
||||
async def route(self, question, context, usage=None):
|
||||
"""
|
||||
Full routing pipeline: identify task type, then select pattern.
|
||||
|
||||
Args:
|
||||
question: The user's query.
|
||||
context: UserAwareContext (flow wrapper).
|
||||
usage: Optional UsageTracker for token counting.
|
||||
|
||||
Returns:
|
||||
(pattern, task_type, framing) tuple.
|
||||
"""
|
||||
task_type, framing = await self.identify_task_type(question, context)
|
||||
pattern = await self.select_pattern(question, task_type, context)
|
||||
task_type, framing = await self.identify_task_type(question, context, usage=usage)
|
||||
pattern = await self.select_pattern(question, task_type, context, usage=usage)
|
||||
logger.info(
|
||||
f"MetaRouter: route result — "
|
||||
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"
|
||||
|
|
|
|||
|
|
@ -65,6 +65,37 @@ class UserAwareContext:
|
|||
return client
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
"""Accumulates token usage across multiple prompt calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_in = 0
|
||||
self.total_out = 0
|
||||
self.last_model = None
|
||||
|
||||
def track(self, result):
|
||||
"""Track usage from a PromptResult."""
|
||||
if result is not None:
|
||||
if getattr(result, "in_token", None) is not None:
|
||||
self.total_in += result.in_token
|
||||
if getattr(result, "out_token", None) is not None:
|
||||
self.total_out += result.out_token
|
||||
if getattr(result, "model", None) is not None:
|
||||
self.last_model = result.model
|
||||
|
||||
@property
|
||||
def in_token(self):
|
||||
return self.total_in if self.total_in > 0 else None
|
||||
|
||||
@property
|
||||
def out_token(self):
|
||||
return self.total_out if self.total_out > 0 else None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.last_model
|
||||
|
||||
|
||||
class PatternBase:
|
||||
"""
|
||||
Shared infrastructure for all agent patterns.
|
||||
|
|
@ -571,7 +602,8 @@ class PatternBase:
|
|||
# ---- Response helpers ---------------------------------------------------
|
||||
|
||||
async def prompt_as_answer(self, client, prompt_id, variables,
|
||||
respond, streaming, message_id=""):
|
||||
respond, streaming, message_id="",
|
||||
usage=None):
|
||||
"""Call a prompt template, forwarding chunks as answer
|
||||
AgentResponse messages when streaming is enabled.
|
||||
|
||||
|
|
@ -591,22 +623,28 @@ class PatternBase:
|
|||
message_id=message_id,
|
||||
))
|
||||
|
||||
await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk,
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
return "".join(accumulated)
|
||||
else:
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
return result.text
|
||||
|
||||
async def send_final_response(self, respond, streaming, answer_text,
|
||||
already_streamed=False, message_id=""):
|
||||
already_streamed=False, message_id="",
|
||||
usage=None):
|
||||
"""Send the answer content and end-of-dialog marker.
|
||||
|
||||
Args:
|
||||
|
|
@ -614,7 +652,16 @@ class PatternBase:
|
|||
via streaming callbacks (e.g. ReactPattern). Only the
|
||||
end-of-dialog marker is emitted.
|
||||
message_id: Provenance URI for the answer entity.
|
||||
usage: UsageTracker with accumulated token counts.
|
||||
"""
|
||||
usage_kwargs = {}
|
||||
if usage:
|
||||
usage_kwargs = {
|
||||
"in_token": usage.in_token,
|
||||
"out_token": usage.out_token,
|
||||
"model": usage.model,
|
||||
}
|
||||
|
||||
if streaming and not already_streamed:
|
||||
# Answer wasn't streamed yet — send it as a chunk first
|
||||
if answer_text:
|
||||
|
|
@ -626,13 +673,14 @@ class PatternBase:
|
|||
message_id=message_id,
|
||||
))
|
||||
if streaming:
|
||||
# End-of-dialog marker
|
||||
# End-of-dialog marker with usage
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
message_id=message_id,
|
||||
**usage_kwargs,
|
||||
))
|
||||
else:
|
||||
await respond(AgentResponse(
|
||||
|
|
@ -641,6 +689,7 @@ class PatternBase:
|
|||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
message_id=message_id,
|
||||
**usage_kwargs,
|
||||
))
|
||||
|
||||
def build_next_request(self, request, history, session_id, collection,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from trustgraph.provenance import (
|
|||
agent_synthesis_uri,
|
||||
)
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
Subsequent calls execute the next pending plan step via ReACT.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase):
|
|||
await self._planning_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num,
|
||||
iteration_num, usage=usage,
|
||||
)
|
||||
else:
|
||||
await self._execution_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num, plan,
|
||||
iteration_num, plan, usage=usage,
|
||||
)
|
||||
|
||||
def _extract_plan(self, history):
|
||||
|
|
@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _planning_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Ask the LLM to produce a structured plan."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
|
|
@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the plan-create prompt template
|
||||
plan_steps = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-create",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
plan_steps = result.objects
|
||||
# Validate we got a list
|
||||
if not isinstance(plan_steps, list) or not plan_steps:
|
||||
logger.warning("plan-create returned invalid result, falling back to single step")
|
||||
|
|
@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _execution_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
session_uri, iteration_num, plan,
|
||||
usage=None):
|
||||
"""Execute the next pending plan step via single-shot tool call."""
|
||||
|
||||
pending_idx = self._find_next_pending_step(plan)
|
||||
|
|
@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan,
|
||||
usage=usage,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Single-shot: ask LLM which tool + arguments to use for this goal
|
||||
tool_call = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-step-execute",
|
||||
variables={
|
||||
"goal": goal,
|
||||
|
|
@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
tool_call = result.object
|
||||
tool_name = tool_call.get("tool", "")
|
||||
tool_arguments = tool_call.get("arguments", {})
|
||||
|
||||
|
|
@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
session_uri, iteration_num, plan,
|
||||
usage=None):
|
||||
"""Synthesise a final answer from all completed plan step results."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
|
|
@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
respond=respond,
|
||||
streaming=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Emit synthesis provenance (links back to last step result)
|
||||
|
|
@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
|
|||
from ..react.types import Action, Final
|
||||
from ..tool_filter import get_next_state
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,7 +37,10 @@ class ReactPattern(PatternBase):
|
|||
result is appended to history and a next-request is emitted.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -121,6 +124,7 @@ class ReactPattern(PatternBase):
|
|||
context=context,
|
||||
streaming=streaming,
|
||||
on_action=on_action,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
logger.debug(f"Action: {act}")
|
||||
|
|
@ -144,6 +148,7 @@ class ReactPattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, f, already_streamed=streaming,
|
||||
message_id=answer_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from ... base import Consumer, Producer
|
|||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ..orchestrator.pattern_base import UsageTracker
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
|
|
@ -493,6 +494,8 @@ class Processor(AgentService):
|
|||
|
||||
async def agent_request(self, request, respond, next, flow):
|
||||
|
||||
usage = UsageTracker()
|
||||
|
||||
try:
|
||||
|
||||
# Intercept subagent completion messages
|
||||
|
|
@ -516,7 +519,7 @@ class Processor(AgentService):
|
|||
|
||||
if self.meta_router:
|
||||
pattern, task_type, framing = await self.meta_router.route(
|
||||
request.question, context,
|
||||
request.question, context, usage=usage,
|
||||
)
|
||||
else:
|
||||
pattern = "react"
|
||||
|
|
@ -536,16 +539,16 @@ class Processor(AgentService):
|
|||
# Dispatch to the selected pattern
|
||||
if pattern == "plan-then-execute":
|
||||
await self.plan_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
elif pattern == "supervisor":
|
||||
await self.supervisor_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
else:
|
||||
# Default to react
|
||||
await self.react_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from trustgraph.provenance import (
|
|||
agent_synthesis_uri,
|
||||
)
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase):
|
|||
- "synthesise": triggered by aggregator with results in subagent_results
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase):
|
|||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
usage=usage,
|
||||
)
|
||||
else:
|
||||
await self._decompose_and_fanout(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _decompose_and_fanout(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Decompose the question into sub-goals and fan out subagents."""
|
||||
|
||||
decompose_msg_id = agent_decomposition_uri(session_id)
|
||||
|
|
@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the supervisor-decompose prompt template
|
||||
goals = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="supervisor-decompose",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
goals = result.objects
|
||||
# Validate result
|
||||
if not isinstance(goals, list):
|
||||
goals = []
|
||||
|
|
@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase):
|
|||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Synthesise final answer from subagent results."""
|
||||
|
||||
synthesis_msg_id = agent_synthesis_uri(session_id)
|
||||
|
|
@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase):
|
|||
respond=respond,
|
||||
streaming=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Emit synthesis provenance (links back to all findings)
|
||||
|
|
@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class AgentManager:
|
|||
|
||||
raise ValueError(f"Could not parse response: {text}")
|
||||
|
||||
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None):
|
||||
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
|
||||
|
||||
logger.debug(f"calling reason: {question}")
|
||||
|
||||
|
|
@ -255,11 +255,13 @@ class AgentManager:
|
|||
client = context("prompt-request")
|
||||
|
||||
# Get streaming response
|
||||
response_text = await client.agent_react(
|
||||
prompt_result = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk
|
||||
)
|
||||
if usage:
|
||||
usage.track(prompt_result)
|
||||
|
||||
# Finalize parser
|
||||
parser.finalize()
|
||||
|
|
@ -275,10 +277,13 @@ class AgentManager:
|
|||
# Non-streaming path - get complete text and parse
|
||||
client = context("prompt-request")
|
||||
|
||||
response_text = await client.agent_react(
|
||||
prompt_result = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=False
|
||||
)
|
||||
if usage:
|
||||
usage.track(prompt_result)
|
||||
response_text = prompt_result.text
|
||||
|
||||
logger.debug(f"Response text:\n{response_text}")
|
||||
|
||||
|
|
@ -292,7 +297,8 @@ class AgentManager:
|
|||
raise RuntimeError(f"Failed to parse agent response: {e}")
|
||||
|
||||
async def react(self, question, history, think, observe, context,
|
||||
streaming=False, answer=None, on_action=None):
|
||||
streaming=False, answer=None, on_action=None,
|
||||
usage=None):
|
||||
|
||||
act = await self.reason(
|
||||
question = question,
|
||||
|
|
@ -302,6 +308,7 @@ class AgentManager:
|
|||
think = think,
|
||||
observe = observe,
|
||||
answer = answer,
|
||||
usage = usage,
|
||||
)
|
||||
|
||||
if isinstance(act, Final):
|
||||
|
|
|
|||
|
|
@ -78,9 +78,10 @@ class TextCompletionImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug("Prompt question...")
|
||||
return await client.question(
|
||||
result = await client.question(
|
||||
arguments.get("question")
|
||||
)
|
||||
return result.text
|
||||
|
||||
# This tool implementation knows how to do MCP tool invocation. This uses
|
||||
# the mcp-tool service.
|
||||
|
|
@ -227,10 +228,11 @@ class PromptImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug(f"Prompt template invocation: {self.template_id}...")
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=self.template_id,
|
||||
variables=arguments
|
||||
)
|
||||
return result.text
|
||||
|
||||
|
||||
# This tool implementation invokes a dynamically configured tool service
|
||||
|
|
|
|||
|
|
@ -117,10 +117,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
defs = await flow("prompt-request").extract_definitions(
|
||||
result = await flow("prompt-request").extract_definitions(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
defs = result.objects
|
||||
logger.debug(f"Definitions response: {defs}")
|
||||
|
||||
if type(defs) != list:
|
||||
|
|
|
|||
|
|
@ -376,10 +376,11 @@ class Processor(FlowProcessor):
|
|||
"""
|
||||
try:
|
||||
# Call prompt service with simplified format prompt
|
||||
extraction_response = await flow("prompt-request").prompt(
|
||||
result = await flow("prompt-request").prompt(
|
||||
id="extract-with-ontologies",
|
||||
variables=prompt_variables
|
||||
)
|
||||
extraction_response = result.object
|
||||
logger.debug(f"Simplified extraction response: {extraction_response}")
|
||||
|
||||
# Parse response into structured format
|
||||
|
|
|
|||
|
|
@ -100,10 +100,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
rels = await flow("prompt-request").extract_relationships(
|
||||
result = await flow("prompt-request").extract_relationships(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
rels = result.objects
|
||||
logger.debug(f"Prompt response: {rels}")
|
||||
|
||||
if type(rels) != list:
|
||||
|
|
|
|||
|
|
@ -148,11 +148,12 @@ class Processor(FlowProcessor):
|
|||
schema_dict = row_schema_translator.encode(schema)
|
||||
|
||||
# Use prompt client to extract rows based on schema
|
||||
objects = await flow("prompt-request").extract_objects(
|
||||
result = await flow("prompt-request").extract_objects(
|
||||
schema=schema_dict,
|
||||
text=text
|
||||
)
|
||||
|
||||
objects = result.objects
|
||||
if not isinstance(objects, list):
|
||||
return []
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import logging
|
|||
from ...schema import Definition, Relationship, Triple
|
||||
from ...schema import Topic
|
||||
from ...schema import PromptRequest, PromptResponse, Error
|
||||
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from ...base import FlowProcessor
|
||||
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
|
|
@ -124,35 +123,26 @@ class Processor(FlowProcessor):
|
|||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
# Use the text completion client with recipient handler
|
||||
client = flow("text-completion-request")
|
||||
|
||||
async def forward_chunks(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
is_final = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
# Always send a message if there's content OR if it's the final message
|
||||
if resp.response or is_final:
|
||||
# Forward each chunk immediately
|
||||
r = PromptResponse(
|
||||
text=resp.response if resp.response else "",
|
||||
object=None,
|
||||
error=None,
|
||||
end_of_stream=is_final,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
# Return True when end_of_stream
|
||||
return is_final
|
||||
|
||||
await client.request(
|
||||
TextCompletionRequest(
|
||||
system=system, prompt=prompt, streaming=True
|
||||
),
|
||||
recipient=forward_chunks,
|
||||
timeout=600
|
||||
await flow("text-completion-request").text_completion_stream(
|
||||
system=system, prompt=prompt,
|
||||
handler=forward_chunks,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
# Return empty string since we already sent all chunks
|
||||
|
|
@ -167,17 +157,21 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
|
||||
# Non-streaming path (original behavior)
|
||||
usage = {}
|
||||
|
||||
async def llm(system, prompt):
|
||||
|
||||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
resp = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt, streaming = False,
|
||||
)
|
||||
|
||||
try:
|
||||
return resp
|
||||
result = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt,
|
||||
)
|
||||
usage["in_token"] = result.in_token
|
||||
usage["out_token"] = result.out_token
|
||||
usage["model"] = result.model
|
||||
return result.text
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Exception: {e}", exc_info=True)
|
||||
return None
|
||||
|
|
@ -199,6 +193,9 @@ class Processor(FlowProcessor):
|
|||
object=None,
|
||||
error=None,
|
||||
end_of_stream=True,
|
||||
in_token=usage.get("in_token", 0),
|
||||
out_token=usage.get("out_token", 0),
|
||||
model=usage.get("model", ""),
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
@ -215,6 +212,9 @@ class Processor(FlowProcessor):
|
|||
object=json.dumps(resp),
|
||||
error=None,
|
||||
end_of_stream=True,
|
||||
in_token=usage.get("in_token", 0),
|
||||
out_token=usage.get("out_token", 0),
|
||||
model=usage.get("model", ""),
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -27,24 +27,27 @@ class Query:
|
|||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
doc_limit=20
|
||||
doc_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
if self.track_usage:
|
||||
self.track_usage(result)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -167,8 +170,23 @@ class DocumentRag:
|
|||
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
last_model = None
|
||||
|
||||
def track_usage(result):
|
||||
nonlocal total_in, total_out, last_model
|
||||
if result is not None:
|
||||
if result.in_token is not None:
|
||||
total_in += result.in_token
|
||||
if result.out_token is not None:
|
||||
total_out += result.out_token
|
||||
if result.model is not None:
|
||||
last_model = result.model
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
|
|
@ -191,7 +209,7 @@ class DocumentRag:
|
|||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -228,19 +246,22 @@ class DocumentRag:
|
|||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
synthesis_result = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
synthesis_result = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
resp = synthesis_result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
@ -273,5 +294,11 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
usage = {
|
||||
"in_token": total_in if total_in > 0 else None,
|
||||
"out_token": total_out if total_out > 0 else None,
|
||||
"model": last_model,
|
||||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Query with streaming enabled
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await self.rag.query(
|
||||
response, usage = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
|
|
@ -217,12 +217,15 @@ class Processor(FlowProcessor):
|
|||
response=None,
|
||||
end_of_session=True,
|
||||
message_type="end",
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await self.rag.query(
|
||||
# Non-streaming path - single response with answer and token usage
|
||||
response, usage = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
|
|
@ -233,11 +236,15 @@ class Processor(FlowProcessor):
|
|||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties = {"id": id}
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
max_path_length=2, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
|
|
@ -131,17 +131,20 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
if self.track_usage:
|
||||
self.track_usage(result)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -609,8 +612,24 @@ class GraphRag:
|
|||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
last_model = None
|
||||
|
||||
def track_usage(result):
|
||||
nonlocal total_in, total_out, last_model
|
||||
if result is not None:
|
||||
if result.in_token is not None:
|
||||
total_in += result.in_token
|
||||
if result.out_token is not None:
|
||||
total_out += result.out_token
|
||||
if result.model is not None:
|
||||
last_model = result.model
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
|
|
@ -641,6 +660,7 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
|
|
@ -751,21 +771,22 @@ class GraphRag:
|
|||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response to get edge IDs with scores
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
def parse_scored_edge(obj):
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
|
|
@ -773,21 +794,6 @@ class GraphRag:
|
|||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
if isinstance(scoring_response, list):
|
||||
for obj in scoring_response:
|
||||
parse_scored_edge(obj)
|
||||
elif isinstance(scoring_response, str):
|
||||
for line in scoring_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_scored_edge(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge scoring line: {line}"
|
||||
)
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
|
|
@ -821,25 +827,30 @@ class GraphRag:
|
|||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_response, Exception):
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_response}"
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_response = ""
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
|
|
@ -848,29 +859,15 @@ class GraphRag:
|
|||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning response: {reasoning_response}")
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response and build explainability data
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
def parse_reasoning(obj):
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
if isinstance(reasoning_response, list):
|
||||
for obj in reasoning_response:
|
||||
parse_reasoning(obj)
|
||||
elif isinstance(reasoning_response, str):
|
||||
for line in reasoning_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_reasoning(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge reasoning line: {line}"
|
||||
)
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
|
|
@ -919,19 +916,22 @@ class GraphRag:
|
|||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
await self.prompt_client.prompt(
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.prompt(
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
resp = synthesis_result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
@ -964,5 +964,11 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
usage = {
|
||||
"in_token": total_in if total_in > 0 else None,
|
||||
"out_token": total_out if total_out > 0 else None,
|
||||
"model": last_model,
|
||||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
# Query with streaming and real-time explain
|
||||
response = await rag.query(
|
||||
response, usage = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
@ -348,7 +348,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
else:
|
||||
# Non-streaming path with real-time explain
|
||||
response = await rag.query(
|
||||
response, usage = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
@ -360,23 +360,30 @@ class Processor(FlowProcessor):
|
|||
parent_uri = v.parent_uri,
|
||||
)
|
||||
|
||||
# Send chunk with response
|
||||
# Send single response with answer and token usage
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
end_of_session=True,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
return
|
||||
|
||||
# Send final message to close session
|
||||
# Streaming: send final message to close session with token usage
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue