mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-10 15:52:36 +02:00
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
parent
67cfa80836
commit
14e49d83c7
60 changed files with 1252 additions and 577 deletions
|
|
@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
|
|||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
|
||||
)
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
text_completion_client.question.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Machine learning involves algorithms that improve through experience."
|
||||
)
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
|
|
@ -147,8 +154,11 @@ Args: {
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -284,11 +300,14 @@ Args: {{
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -321,11 +340,14 @@ Args: {
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -372,9 +394,12 @@ Args: {
|
|||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -507,7 +532,10 @@ Args: {
|
|||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=test_case["response"]
|
||||
)
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
|
|
@ -527,13 +555,16 @@ Args: {
|
|||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -541,15 +572,18 @@ Args: {
|
|||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -560,7 +594,9 @@ Args: {
|
|||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
|
|
@ -568,6 +604,7 @@ Action: knowledge_query
|
|||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -575,13 +612,16 @@ Args: {
|
|||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -593,13 +633,16 @@ This covers all aspects of the question."""
|
|||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -608,7 +651,9 @@ Args: {
|
|||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
|
|
@ -621,6 +666,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -632,7 +678,9 @@ Args: {
|
|||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
|
|
@ -642,6 +690,7 @@ Final Answer: {
|
|||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -792,11 +841,14 @@ Final Answer: {
|
|||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||
from trustgraph.agent.react.types import Tool, Argument
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_agent_streaming_chunks,
|
||||
assert_streaming_chunks_valid,
|
||||
|
|
@ -51,10 +52,10 @@ Args: {
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.agent_react.side_effect = agent_react_streaming
|
||||
return client
|
||||
|
|
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
return PromptResult(response_type="text", text=response)
|
||||
return PromptResult(response_type="text", text=response)
|
||||
|
||||
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -173,11 +177,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -250,11 +257,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -339,11 +349,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -447,11 +460,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
|
|||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
)
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Verify final response
|
||||
result, usage = result
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
|
|
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
|
|||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="I couldn't find any relevant documents for your query."
|
||||
)
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
result_text, usage = result
|
||||
assert result_text == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
|
|||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
result_text, usage = result
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert result_text == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert len(result) > 0
|
||||
assert len(result_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||
|
|
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
non_streaming_text, _ = non_streaming_result
|
||||
streaming_text, _ = streaming_result
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||
|
|
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
result_text, usage = result
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert result_text is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
result_text, usage = result
|
||||
assert isinstance(result_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||
|
|
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
# Verify doc_limit was passed correctly
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
|
|||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
|
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
|
|||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
|
|||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
response, usage = response
|
||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
|
|
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_response == non_streaming_response
|
||||
non_streaming_text, _ = non_streaming_response
|
||||
streaming_text, _ = streaming_response
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
response_text, usage = response
|
||||
assert isinstance(response_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
|
|||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
prompt_client.extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
prompt_client.extract_relationships.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
|
|
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="invalid format"
|
||||
) # Should be jsonl with objects list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
|
|||
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
|
||||
if schema_name == "customer_records":
|
||||
if "john" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "jane" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
elif schema_name == "product_catalog":
|
||||
if "laptop" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "book" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
prompt_client.extract_objects.side_effect = mock_extract_objects
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.prompt.template.service import Processor
|
||||
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||
from trustgraph.base.text_completion_client import TextCompletionResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -27,34 +28,52 @@ class TestPromptStreaming:
|
|||
# Mock text completion client with streaming
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def streaming_request(request, recipient=None, timeout=600):
|
||||
"""Simulate streaming text completion"""
|
||||
if request.streaming and recipient:
|
||||
# Simulate streaming chunks
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
# Streaming chunks to send
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=is_final
|
||||
)
|
||||
final = await recipient(response)
|
||||
if final:
|
||||
break
|
||||
|
||||
# Final empty chunk
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
|
||||
"""Simulate streaming text completion via text_completion_stream"""
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
end_of_stream=False
|
||||
)
|
||||
await handler(response)
|
||||
|
||||
text_completion_client.request = streaming_request
|
||||
# Send final empty chunk with end_of_stream
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
return TextCompletionResult(
|
||||
text=None,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, timeout=600):
|
||||
"""Simulate non-streaming text completion"""
|
||||
full_text = "Machine learning is a field of artificial intelligence."
|
||||
return TextCompletionResult(
|
||||
text=full_text,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=streaming_text_completion_stream
|
||||
)
|
||||
text_completion_client.text_completion = AsyncMock(
|
||||
side_effect=non_streaming_text_completion
|
||||
)
|
||||
|
||||
# Mock response producer
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -156,14 +175,6 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock non-streaming text completion
|
||||
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||
return "AI is the simulation of human intelligence in machines."
|
||||
|
||||
text_completion_client.text_completion = non_streaming_text_completion
|
||||
|
||||
# Act
|
||||
await prompt_processor_streaming.on_request(
|
||||
message, consumer, mock_flow_context_streaming
|
||||
|
|
@ -218,17 +229,12 @@ class TestPromptStreaming:
|
|||
# Mock text completion client that raises an error
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def failing_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Send error response with proper Error schema
|
||||
error_response = TextCompletionResponse(
|
||||
response="",
|
||||
error=Error(message="Text completion error", type="processing_error"),
|
||||
end_of_stream=True
|
||||
)
|
||||
await recipient(error_response)
|
||||
async def failing_stream(system, prompt, handler, timeout=600):
|
||||
raise RuntimeError("Text completion error")
|
||||
|
||||
text_completion_client.request = failing_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=failing_stream
|
||||
)
|
||||
|
||||
# Mock response producer to capture error response
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -255,22 +261,15 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Act - The service catches errors and sends error responses, doesn't raise
|
||||
# Act - The service catches errors and sends an error PromptResponse
|
||||
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||
|
||||
# Assert - Verify error response was sent
|
||||
assert response_producer.send.call_count > 0
|
||||
|
||||
# Check that at least one response contains an error
|
||||
error_sent = False
|
||||
for call in response_producer.send.call_args_list:
|
||||
response = call.args[0]
|
||||
if hasattr(response, 'error') and response.error:
|
||||
error_sent = True
|
||||
assert "Text completion error" in response.error.message
|
||||
break
|
||||
|
||||
assert error_sent, "Expected error response to be sent"
|
||||
# Assert - error response was sent
|
||||
calls = response_producer.send.call_args_list
|
||||
assert len(calls) > 0
|
||||
error_response = calls[-1].args[0]
|
||||
assert error_response.error is not None
|
||||
assert "Text completion error" in error_response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||
|
|
@ -315,21 +314,22 @@ class TestPromptStreaming:
|
|||
# Mock text completion that sends empty chunks
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||
if request.streaming and recipient:
|
||||
# Send empty chunk followed by final marker
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
async def empty_streaming(system, prompt, handler, timeout=600):
|
||||
# Send empty chunk followed by final marker
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
text_completion_client.request = empty_streaming_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=empty_streaming
|
||||
)
|
||||
response_producer = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
|
|
@ -401,4 +401,4 @@ class TestPromptStreaming:
|
|||
|
||||
# Verify chunks concatenate to expected result
|
||||
full_text = "".join(chunk_texts)
|
||||
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||
assert full_text == "Machine learning is a field of artificial intelligence."
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
|||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
|
|||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
|
|
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
|
|||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="The answer is here.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "Document summary."
|
||||
return PromptResult(response_type="text", text="Document summary.")
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="textmore")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue