mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-18 11:55:12 +02:00
Expose LLM token usage (in_token, out_token, model) across all
service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
parent
ffe310af7c
commit
56d700f301
60 changed files with 1252 additions and 577 deletions
|
|
@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
|
|||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
|
||||
)
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
text_completion_client.question.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Machine learning involves algorithms that improve through experience."
|
||||
)
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
|
|
@ -147,8 +154,11 @@ Args: {
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -284,11 +300,14 @@ Args: {{
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -321,11 +340,14 @@ Args: {
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -372,9 +394,12 @@ Args: {
|
|||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -507,7 +532,10 @@ Args: {
|
|||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=test_case["response"]
|
||||
)
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
|
|
@ -527,13 +555,16 @@ Args: {
|
|||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -541,15 +572,18 @@ Args: {
|
|||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -560,7 +594,9 @@ Args: {
|
|||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
|
|
@ -568,6 +604,7 @@ Action: knowledge_query
|
|||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -575,13 +612,16 @@ Args: {
|
|||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -593,13 +633,16 @@ This covers all aspects of the question."""
|
|||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -608,7 +651,9 @@ Args: {
|
|||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
|
|
@ -621,6 +666,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -632,7 +678,9 @@ Args: {
|
|||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
|
|
@ -642,6 +690,7 @@ Final Answer: {
|
|||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -792,11 +841,14 @@ Final Answer: {
|
|||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||
from trustgraph.agent.react.types import Tool, Argument
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_agent_streaming_chunks,
|
||||
assert_streaming_chunks_valid,
|
||||
|
|
@ -51,10 +52,10 @@ Args: {
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.agent_react.side_effect = agent_react_streaming
|
||||
return client
|
||||
|
|
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
return PromptResult(response_type="text", text=response)
|
||||
return PromptResult(response_type="text", text=response)
|
||||
|
||||
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -173,11 +177,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -250,11 +257,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -339,11 +349,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -447,11 +460,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
|
|||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
)
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Verify final response
|
||||
result, usage = result
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
|
|
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
|
|||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="I couldn't find any relevant documents for your query."
|
||||
)
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
result_text, usage = result
|
||||
assert result_text == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
|
|||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
result_text, usage = result
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert result_text == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert len(result) > 0
|
||||
assert len(result_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||
|
|
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
non_streaming_text, _ = non_streaming_result
|
||||
streaming_text, _ = streaming_result
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||
|
|
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
result_text, usage = result
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert result_text is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
result_text, usage = result
|
||||
assert isinstance(result_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||
|
|
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
# Verify doc_limit was passed correctly
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
|
|||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
|
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
|
|||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
|
|||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
response, usage = response
|
||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
|
|
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_response == non_streaming_response
|
||||
non_streaming_text, _ = non_streaming_response
|
||||
streaming_text, _ = streaming_response
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
response_text, usage = response
|
||||
assert isinstance(response_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
|
|||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
prompt_client.extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
prompt_client.extract_relationships.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
|
|
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="invalid format"
|
||||
) # Should be jsonl with objects list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
|
|||
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
|
||||
if schema_name == "customer_records":
|
||||
if "john" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "jane" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
elif schema_name == "product_catalog":
|
||||
if "laptop" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "book" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
prompt_client.extract_objects.side_effect = mock_extract_objects
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.prompt.template.service import Processor
|
||||
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||
from trustgraph.base.text_completion_client import TextCompletionResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -27,34 +28,52 @@ class TestPromptStreaming:
|
|||
# Mock text completion client with streaming
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def streaming_request(request, recipient=None, timeout=600):
|
||||
"""Simulate streaming text completion"""
|
||||
if request.streaming and recipient:
|
||||
# Simulate streaming chunks
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
# Streaming chunks to send
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=is_final
|
||||
)
|
||||
final = await recipient(response)
|
||||
if final:
|
||||
break
|
||||
|
||||
# Final empty chunk
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
|
||||
"""Simulate streaming text completion via text_completion_stream"""
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
end_of_stream=False
|
||||
)
|
||||
await handler(response)
|
||||
|
||||
text_completion_client.request = streaming_request
|
||||
# Send final empty chunk with end_of_stream
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
return TextCompletionResult(
|
||||
text=None,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, timeout=600):
|
||||
"""Simulate non-streaming text completion"""
|
||||
full_text = "Machine learning is a field of artificial intelligence."
|
||||
return TextCompletionResult(
|
||||
text=full_text,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=streaming_text_completion_stream
|
||||
)
|
||||
text_completion_client.text_completion = AsyncMock(
|
||||
side_effect=non_streaming_text_completion
|
||||
)
|
||||
|
||||
# Mock response producer
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -156,14 +175,6 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock non-streaming text completion
|
||||
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||
return "AI is the simulation of human intelligence in machines."
|
||||
|
||||
text_completion_client.text_completion = non_streaming_text_completion
|
||||
|
||||
# Act
|
||||
await prompt_processor_streaming.on_request(
|
||||
message, consumer, mock_flow_context_streaming
|
||||
|
|
@ -218,17 +229,12 @@ class TestPromptStreaming:
|
|||
# Mock text completion client that raises an error
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def failing_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Send error response with proper Error schema
|
||||
error_response = TextCompletionResponse(
|
||||
response="",
|
||||
error=Error(message="Text completion error", type="processing_error"),
|
||||
end_of_stream=True
|
||||
)
|
||||
await recipient(error_response)
|
||||
async def failing_stream(system, prompt, handler, timeout=600):
|
||||
raise RuntimeError("Text completion error")
|
||||
|
||||
text_completion_client.request = failing_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=failing_stream
|
||||
)
|
||||
|
||||
# Mock response producer to capture error response
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -255,22 +261,15 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Act - The service catches errors and sends error responses, doesn't raise
|
||||
# Act - The service catches errors and sends an error PromptResponse
|
||||
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||
|
||||
# Assert - Verify error response was sent
|
||||
assert response_producer.send.call_count > 0
|
||||
|
||||
# Check that at least one response contains an error
|
||||
error_sent = False
|
||||
for call in response_producer.send.call_args_list:
|
||||
response = call.args[0]
|
||||
if hasattr(response, 'error') and response.error:
|
||||
error_sent = True
|
||||
assert "Text completion error" in response.error.message
|
||||
break
|
||||
|
||||
assert error_sent, "Expected error response to be sent"
|
||||
# Assert - error response was sent
|
||||
calls = response_producer.send.call_args_list
|
||||
assert len(calls) > 0
|
||||
error_response = calls[-1].args[0]
|
||||
assert error_response.error is not None
|
||||
assert "Text completion error" in error_response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||
|
|
@ -315,21 +314,22 @@ class TestPromptStreaming:
|
|||
# Mock text completion that sends empty chunks
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||
if request.streaming and recipient:
|
||||
# Send empty chunk followed by final marker
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
async def empty_streaming(system, prompt, handler, timeout=600):
|
||||
# Send empty chunk followed by final marker
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
text_completion_client.request = empty_streaming_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=empty_streaming
|
||||
)
|
||||
response_producer = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
|
|
@ -401,4 +401,4 @@ class TestPromptStreaming:
|
|||
|
||||
# Verify chunks concatenate to expected result
|
||||
full_text = "".join(chunk_texts)
|
||||
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||
assert full_text == "Machine learning is a field of artificial intelligence."
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
|||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
|
|||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
|
|
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
|
|||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="The answer is here.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "Document summary."
|
||||
return PromptResult(response_type="text", text="Document summary.")
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="textmore")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
|
|
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
|
|||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
client.prompt = AsyncMock(
|
||||
return_value=PromptResult(response_type="text", text=prompt_response)
|
||||
)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
|
@ -274,8 +277,8 @@ class TestRoute:
|
|||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
return PromptResult(response_type="text", text="research")
|
||||
return PromptResult(response_type="text", text="plan-then-execute")
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
|||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
|
|||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
|
|
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
|
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
|
|
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="json",
|
||||
object={
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
},
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
|
|||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B", "Goal C"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
|
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
if isinstance(prompt_result, list):
|
||||
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
|
||||
else:
|
||||
wrapped = PromptResult(response_type="text", text=prompt_result)
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=wrapped
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
|
|||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=prompt_result,
|
||||
)
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content mapping for tests
|
||||
|
|
@ -132,7 +133,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -157,7 +158,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -258,7 +259,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -273,7 +274,7 @@ class TestQuery:
|
|||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -315,7 +316,8 @@ class TestQuery:
|
|||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
assert result == expected_response
|
||||
result_text, usage = result
|
||||
assert result_text == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
|
|
@ -325,7 +327,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
|
|
@ -333,7 +335,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c5"
|
||||
mock_match.score = 0.9
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -352,7 +354,8 @@ class TestQuery:
|
|||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
|
|
@ -401,7 +404,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
|
|
@ -409,7 +412,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c7"
|
||||
mock_match.score = 0.92
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -428,7 +431,8 @@ class TestQuery:
|
|||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
|
|
@ -469,11 +473,11 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -490,7 +494,8 @@ class TestQuery:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
result_text, usage = result
|
||||
assert result_text == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
|
|
@ -525,7 +530,7 @@ class TestQuery:
|
|||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
|
|
@ -541,7 +546,7 @@ class TestQuery:
|
|||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -584,7 +589,8 @@ class TestQuery:
|
|||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
assert result == final_response
|
||||
result_text, usage = result
|
||||
assert result_text == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
|
|||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -89,8 +90,8 @@ def build_mock_clients():
|
|||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -113,8 +114,9 @@ def build_mock_clients():
|
|||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
|
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
|
|
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
result_text, usage = await rag.query(query="What is the return policy?")
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "test response"
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
msg = MagicMock()
|
||||
|
|
@ -97,7 +97,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -130,4 +130,5 @@ class TestDocumentRagService:
|
|||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.end_of_session is True
|
||||
assert sent_response.error is None
|
||||
|
|
@ -7,6 +7,7 @@ import unittest.mock
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
|
|
@ -172,7 +173,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -196,7 +197,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -220,7 +221,7 @@ class TestQuery:
|
|||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -565,14 +566,14 @@ class TestQuery:
|
|||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
|
|
@ -607,7 +608,8 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
assert response == expected_response
|
||||
response_text, usage = response
|
||||
assert response_text == expected_response
|
||||
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -136,24 +137,36 @@ def build_mock_clients():
|
|||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=synthesis_answer,
|
||||
)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
|
|||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
|
|
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestGraphRagService:
|
|||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -79,8 +79,8 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 5
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
|
|
@ -88,17 +88,12 @@ class TestGraphRagService:
|
|||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
# 5th message is chunk with response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
|
@ -187,7 +182,7 @@ class TestGraphRagService:
|
|||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
return "Response text", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -218,17 +213,12 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
# Verify: 1 combined message (chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 1
|
||||
|
||||
# First is the response chunk
|
||||
# Single message has response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue