mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
parent
67cfa80836
commit
14e49d83c7
60 changed files with 1252 additions and 577 deletions
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
|
|
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
|
|||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
client.prompt = AsyncMock(
|
||||
return_value=PromptResult(response_type="text", text=prompt_response)
|
||||
)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
|
@ -274,8 +277,8 @@ class TestRoute:
|
|||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
return PromptResult(response_type="text", text="research")
|
||||
return PromptResult(response_type="text", text="plan-then-execute")
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
|||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
|
|||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
|
|
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
|
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
|
|
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="json",
|
||||
object={
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
},
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
|
|||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B", "Goal C"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
|
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
if isinstance(prompt_result, list):
|
||||
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
|
||||
else:
|
||||
wrapped = PromptResult(response_type="text", text=prompt_result)
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=wrapped
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
|
|||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=prompt_result,
|
||||
)
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content mapping for tests
|
||||
|
|
@ -132,7 +133,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -157,7 +158,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -258,7 +259,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -273,7 +274,7 @@ class TestQuery:
|
|||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -315,7 +316,8 @@ class TestQuery:
|
|||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
assert result == expected_response
|
||||
result_text, usage = result
|
||||
assert result_text == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
|
|
@ -325,7 +327,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
|
|
@ -333,7 +335,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c5"
|
||||
mock_match.score = 0.9
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -352,7 +354,8 @@ class TestQuery:
|
|||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
|
|
@ -401,7 +404,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
|
|
@ -409,7 +412,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c7"
|
||||
mock_match.score = 0.92
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -428,7 +431,8 @@ class TestQuery:
|
|||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
|
|
@ -469,11 +473,11 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -490,7 +494,8 @@ class TestQuery:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
result_text, usage = result
|
||||
assert result_text == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
|
|
@ -525,7 +530,7 @@ class TestQuery:
|
|||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
|
|
@ -541,7 +546,7 @@ class TestQuery:
|
|||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -584,7 +589,8 @@ class TestQuery:
|
|||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
assert result == final_response
|
||||
result_text, usage = result
|
||||
assert result_text == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
|
|||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -89,8 +90,8 @@ def build_mock_clients():
|
|||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -113,8 +114,9 @@ def build_mock_clients():
|
|||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
|
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
|
|
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
result_text, usage = await rag.query(query="What is the return policy?")
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "test response"
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
msg = MagicMock()
|
||||
|
|
@ -97,7 +97,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -130,4 +130,5 @@ class TestDocumentRagService:
|
|||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.end_of_session is True
|
||||
assert sent_response.error is None
|
||||
|
|
@ -7,6 +7,7 @@ import unittest.mock
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
|
|
@ -172,7 +173,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -196,7 +197,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -220,7 +221,7 @@ class TestQuery:
|
|||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -565,14 +566,14 @@ class TestQuery:
|
|||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
|
|
@ -607,7 +608,8 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
assert response == expected_response
|
||||
response_text, usage = response
|
||||
assert response_text == expected_response
|
||||
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -136,24 +137,36 @@ def build_mock_clients():
|
|||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=synthesis_answer,
|
||||
)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
|
|||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
|
|
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestGraphRagService:
|
|||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -79,8 +79,8 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 5
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
|
|
@ -88,17 +88,12 @@ class TestGraphRagService:
|
|||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
# 5th message is chunk with response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
|
@ -187,7 +182,7 @@ class TestGraphRagService:
|
|||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
return "Response text", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -218,17 +213,12 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
# Verify: 1 combined message (chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 1
|
||||
|
||||
# First is the response chunk
|
||||
# Single message has response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue