Expose LLM token usage across all service layers (#782)

Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
cybermaggedon 2026-04-13 14:38:34 +01:00 committed by GitHub
parent 67cfa80836
commit 14e49d83c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 1252 additions and 577 deletions

View file

@ -53,7 +53,7 @@ class MetaRouter:
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
}
async def identify_task_type(self, question, context):
async def identify_task_type(self, question, context, usage=None):
"""
Use the LLM to classify the question into one of the known task types.
@ -71,7 +71,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="task-type-classify",
variables={
"question": question,
@ -81,7 +81,9 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in self.task_types:
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
@ -100,7 +102,7 @@ class MetaRouter:
)
return DEFAULT_TASK_TYPE, framing
async def select_pattern(self, question, task_type, context):
async def select_pattern(self, question, task_type, context, usage=None):
"""
Use the LLM to select the best execution pattern for this task type.
@ -120,7 +122,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="pattern-select",
variables={
"question": question,
@ -133,7 +135,9 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in valid_patterns:
logger.info(f"MetaRouter: selected pattern '{selected}'")
@ -148,19 +152,20 @@ class MetaRouter:
logger.warning(f"MetaRouter: pattern selection failed: {e}")
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
async def route(self, question, context):
async def route(self, question, context, usage=None):
"""
Full routing pipeline: identify task type, then select pattern.
Args:
question: The user's query.
context: UserAwareContext (flow wrapper).
usage: Optional UsageTracker for token counting.
Returns:
(pattern, task_type, framing) tuple.
"""
task_type, framing = await self.identify_task_type(question, context)
pattern = await self.select_pattern(question, task_type, context)
task_type, framing = await self.identify_task_type(question, context, usage=usage)
pattern = await self.select_pattern(question, task_type, context, usage=usage)
logger.info(
f"MetaRouter: route result — "
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"

View file

@ -65,6 +65,37 @@ class UserAwareContext:
return client
class UsageTracker:
"""Accumulates token usage across multiple prompt calls."""
def __init__(self):
self.total_in = 0
self.total_out = 0
self.last_model = None
def track(self, result):
"""Track usage from a PromptResult."""
if result is not None:
if getattr(result, "in_token", None) is not None:
self.total_in += result.in_token
if getattr(result, "out_token", None) is not None:
self.total_out += result.out_token
if getattr(result, "model", None) is not None:
self.last_model = result.model
@property
def in_token(self):
return self.total_in if self.total_in > 0 else None
@property
def out_token(self):
return self.total_out if self.total_out > 0 else None
@property
def model(self):
return self.last_model
class PatternBase:
"""
Shared infrastructure for all agent patterns.
@ -571,7 +602,8 @@ class PatternBase:
# ---- Response helpers ---------------------------------------------------
async def prompt_as_answer(self, client, prompt_id, variables,
respond, streaming, message_id=""):
respond, streaming, message_id="",
usage=None):
"""Call a prompt template, forwarding chunks as answer
AgentResponse messages when streaming is enabled.
@ -591,22 +623,28 @@ class PatternBase:
message_id=message_id,
))
await client.prompt(
result = await client.prompt(
id=prompt_id,
variables=variables,
streaming=True,
chunk_callback=on_chunk,
)
if usage:
usage.track(result)
return "".join(accumulated)
else:
return await client.prompt(
result = await client.prompt(
id=prompt_id,
variables=variables,
)
if usage:
usage.track(result)
return result.text
async def send_final_response(self, respond, streaming, answer_text,
already_streamed=False, message_id=""):
already_streamed=False, message_id="",
usage=None):
"""Send the answer content and end-of-dialog marker.
Args:
@ -614,7 +652,16 @@ class PatternBase:
via streaming callbacks (e.g. ReactPattern). Only the
end-of-dialog marker is emitted.
message_id: Provenance URI for the answer entity.
usage: UsageTracker with accumulated token counts.
"""
usage_kwargs = {}
if usage:
usage_kwargs = {
"in_token": usage.in_token,
"out_token": usage.out_token,
"model": usage.model,
}
if streaming and not already_streamed:
# Answer wasn't streamed yet — send it as a chunk first
if answer_text:
@ -626,13 +673,14 @@ class PatternBase:
message_id=message_id,
))
if streaming:
# End-of-dialog marker
# End-of-dialog marker with usage
await respond(AgentResponse(
chunk_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
**usage_kwargs,
))
else:
await respond(AgentResponse(
@ -641,6 +689,7 @@ class PatternBase:
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
**usage_kwargs,
))
def build_next_request(self, request, history, session_id, collection,

View file

@ -18,7 +18,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase):
Subsequent calls execute the next pending plan step via ReACT.
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase):
await self._planning_iteration(
request, respond, next, flow,
session_id, collection, streaming, session_uri,
iteration_num,
iteration_num, usage=usage,
)
else:
await self._execution_iteration(
request, respond, next, flow,
session_id, collection, streaming, session_uri,
iteration_num, plan,
iteration_num, plan, usage=usage,
)
def _extract_plan(self, history):
@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase):
async def _planning_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Ask the LLM to produce a structured plan."""
think = self.make_think_callback(respond, streaming)
@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Use the plan-create prompt template
plan_steps = await client.prompt(
result = await client.prompt(
id="plan-create",
variables={
"question": request.question,
@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
if usage:
usage.track(result)
plan_steps = result.objects
# Validate we got a list
if not isinstance(plan_steps, list) or not plan_steps:
logger.warning("plan-create returned invalid result, falling back to single step")
@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase):
async def _execution_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan):
session_uri, iteration_num, plan,
usage=None):
"""Execute the next pending plan step via single-shot tool call."""
pending_idx = self._find_next_pending_step(plan)
@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase):
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan,
usage=usage,
)
return
@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Single-shot: ask LLM which tool + arguments to use for this goal
tool_call = await client.prompt(
result = await client.prompt(
id="plan-step-execute",
variables={
"goal": goal,
@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
if usage:
usage.track(result)
tool_call = result.object
tool_name = tool_call.get("tool", "")
tool_arguments = tool_call.get("arguments", {})
@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan):
session_uri, iteration_num, plan,
usage=None):
"""Synthesise a final answer from all completed plan step results."""
think = self.make_think_callback(respond, streaming)
@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
usage=usage,
)
# Emit synthesis provenance (links back to last step result)
@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase):
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
usage=usage,
)

View file

@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
from ..react.types import Action, Final
from ..tool_filter import get_next_state
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -37,7 +37,10 @@ class ReactPattern(PatternBase):
result is appended to history and a next-request is emitted.
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -121,6 +124,7 @@ class ReactPattern(PatternBase):
context=context,
streaming=streaming,
on_action=on_action,
usage=usage,
)
logger.debug(f"Action: {act}")
@ -144,6 +148,7 @@ class ReactPattern(PatternBase):
await self.send_final_response(
respond, streaming, f, already_streamed=streaming,
message_id=answer_msg_id,
usage=usage,
)
return

View file

@ -23,6 +23,7 @@ from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
@ -493,6 +494,8 @@ class Processor(AgentService):
async def agent_request(self, request, respond, next, flow):
usage = UsageTracker()
try:
# Intercept subagent completion messages
@ -516,7 +519,7 @@ class Processor(AgentService):
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
request.question, context,
request.question, context, usage=usage,
)
else:
pattern = "react"
@ -536,16 +539,16 @@ class Processor(AgentService):
# Dispatch to the selected pattern
if pattern == "plan-then-execute":
await self.plan_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
elif pattern == "supervisor":
await self.supervisor_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
else:
# Default to react
await self.react_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
except Exception as e:

View file

@ -22,7 +22,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase):
- "synthesise": triggered by aggregator with results in subagent_results
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase):
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num,
usage=usage,
)
else:
await self._decompose_and_fanout(
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num,
usage=usage,
)
async def _decompose_and_fanout(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Decompose the question into sub-goals and fan out subagents."""
decompose_msg_id = agent_decomposition_uri(session_id)
@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase):
client = context("prompt-request")
# Use the supervisor-decompose prompt template
goals = await client.prompt(
result = await client.prompt(
id="supervisor-decompose",
variables={
"question": request.question,
@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase):
],
},
)
if usage:
usage.track(result)
goals = result.objects
# Validate result
if not isinstance(goals, list):
goals = []
@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Synthesise final answer from subagent results."""
synthesis_msg_id = agent_synthesis_uri(session_id)
@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
usage=usage,
)
# Emit synthesis provenance (links back to all findings)
@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase):
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
usage=usage,
)

View file

@ -170,7 +170,7 @@ class AgentManager:
raise ValueError(f"Could not parse response: {text}")
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None):
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
logger.debug(f"calling reason: {question}")
@ -255,11 +255,13 @@ class AgentManager:
client = context("prompt-request")
# Get streaming response
response_text = await client.agent_react(
prompt_result = await client.agent_react(
variables=variables,
streaming=True,
chunk_callback=on_chunk
)
if usage:
usage.track(prompt_result)
# Finalize parser
parser.finalize()
@ -275,10 +277,13 @@ class AgentManager:
# Non-streaming path - get complete text and parse
client = context("prompt-request")
response_text = await client.agent_react(
prompt_result = await client.agent_react(
variables=variables,
streaming=False
)
if usage:
usage.track(prompt_result)
response_text = prompt_result.text
logger.debug(f"Response text:\n{response_text}")
@ -292,7 +297,8 @@ class AgentManager:
raise RuntimeError(f"Failed to parse agent response: {e}")
async def react(self, question, history, think, observe, context,
streaming=False, answer=None, on_action=None):
streaming=False, answer=None, on_action=None,
usage=None):
act = await self.reason(
question = question,
@ -302,6 +308,7 @@ class AgentManager:
think = think,
observe = observe,
answer = answer,
usage = usage,
)
if isinstance(act, Final):

View file

@ -78,9 +78,10 @@ class TextCompletionImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug("Prompt question...")
return await client.question(
result = await client.question(
arguments.get("question")
)
return result.text
# This tool implementation knows how to do MCP tool invocation. This uses
# the mcp-tool service.
@ -227,10 +228,11 @@ class PromptImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug(f"Prompt template invocation: {self.template_id}...")
return await client.prompt(
result = await client.prompt(
id=self.template_id,
variables=arguments
)
return result.text
# This tool implementation invokes a dynamically configured tool service

View file

@ -117,10 +117,11 @@ class Processor(FlowProcessor):
try:
defs = await flow("prompt-request").extract_definitions(
result = await flow("prompt-request").extract_definitions(
text = chunk
)
defs = result.objects
logger.debug(f"Definitions response: {defs}")
if type(defs) != list:

View file

@ -376,10 +376,11 @@ class Processor(FlowProcessor):
"""
try:
# Call prompt service with simplified format prompt
extraction_response = await flow("prompt-request").prompt(
result = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
extraction_response = result.object
logger.debug(f"Simplified extraction response: {extraction_response}")
# Parse response into structured format

View file

@ -100,10 +100,11 @@ class Processor(FlowProcessor):
try:
rels = await flow("prompt-request").extract_relationships(
result = await flow("prompt-request").extract_relationships(
text = chunk
)
rels = result.objects
logger.debug(f"Prompt response: {rels}")
if type(rels) != list:

View file

@ -148,11 +148,12 @@ class Processor(FlowProcessor):
schema_dict = row_schema_translator.encode(schema)
# Use prompt client to extract rows based on schema
objects = await flow("prompt-request").extract_objects(
result = await flow("prompt-request").extract_objects(
schema=schema_dict,
text=text
)
objects = result.objects
if not isinstance(objects, list):
return []

View file

@ -11,7 +11,6 @@ import logging
from ...schema import Definition, Relationship, Triple
from ...schema import Topic
from ...schema import PromptRequest, PromptResponse, Error
from ...schema import TextCompletionRequest, TextCompletionResponse
from ...base import FlowProcessor
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
@ -124,35 +123,26 @@ class Processor(FlowProcessor):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
# Use the text completion client with recipient handler
client = flow("text-completion-request")
async def forward_chunks(resp):
if resp.error:
raise RuntimeError(resp.error.message)
is_final = getattr(resp, 'end_of_stream', False)
# Always send a message if there's content OR if it's the final message
if resp.response or is_final:
# Forward each chunk immediately
r = PromptResponse(
text=resp.response if resp.response else "",
object=None,
error=None,
end_of_stream=is_final,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
await flow("response").send(r, properties={"id": id})
# Return True when end_of_stream
return is_final
await client.request(
TextCompletionRequest(
system=system, prompt=prompt, streaming=True
),
recipient=forward_chunks,
timeout=600
await flow("text-completion-request").text_completion_stream(
system=system, prompt=prompt,
handler=forward_chunks,
timeout=600,
)
# Return empty string since we already sent all chunks
@ -167,17 +157,21 @@ class Processor(FlowProcessor):
return
# Non-streaming path (original behavior)
usage = {}
async def llm(system, prompt):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt, streaming = False,
)
try:
return resp
result = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
usage["in_token"] = result.in_token
usage["out_token"] = result.out_token
usage["model"] = result.model
return result.text
except Exception as e:
logger.error(f"LLM Exception: {e}", exc_info=True)
return None
@ -199,6 +193,9 @@ class Processor(FlowProcessor):
object=None,
error=None,
end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})
@ -215,6 +212,9 @@ class Processor(FlowProcessor):
object=json.dumps(resp),
error=None,
end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})

View file

@ -27,24 +27,27 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -167,8 +170,23 @@ class DocumentRag:
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
str: The synthesized answer text
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
@ -191,7 +209,7 @@ class DocumentRag:
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
doc_limit=doc_limit, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
@ -228,19 +246,22 @@ class DocumentRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
resp = await self.prompt_client.document_prompt(
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.document_prompt(
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@ -273,5 +294,11 @@ class DocumentRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -200,7 +200,7 @@ class Processor(FlowProcessor):
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
await self.rag.query(
response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@ -217,12 +217,15 @@ class Processor(FlowProcessor):
response=None,
end_of_session=True,
message_type="end",
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await self.rag.query(
# Non-streaming path - single response with answer and token usage
response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@ -233,11 +236,15 @@ class Processor(FlowProcessor):
await flow("response").send(
DocumentRagResponse(
response = response,
end_of_stream = True,
error = None
response=response,
end_of_stream=True,
end_of_session=True,
error=None,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties = {"id": id}
properties={"id": id}
)
logger.info("Request processing complete")

View file

@ -121,7 +121,7 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
max_path_length=2, track_usage=None,
):
self.rag = rag
self.user = user
@ -131,17 +131,20 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -609,8 +612,24 @@ class GraphRag:
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
str: The synthesized answer text
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
@ -641,6 +660,7 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
@ -751,21 +771,22 @@ class GraphRag:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
track_usage(scoring_result)
if self.verbose:
logger.debug(f"Edge scoring response: {scoring_response}")
logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response to get edge IDs with scores
# Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
def parse_scored_edge(obj):
for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
@ -773,21 +794,6 @@ class GraphRag:
score = 0
scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
@ -821,25 +827,30 @@ class GraphRag:
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
if isinstance(reasoning_result, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
f"Edge reasoning failed: {reasoning_result}"
)
reasoning_response = ""
reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
@ -848,29 +859,15 @@ class GraphRag:
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response and build explainability data
# Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
if reasoning_result is not None:
for obj in reasoning_result.objects or []:
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
selected_edges_with_reasoning = []
for eid in selected_ids:
@ -919,19 +916,22 @@ class GraphRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt(
synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.prompt(
synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@ -964,5 +964,11 @@ class GraphRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -332,7 +332,7 @@ class Processor(FlowProcessor):
)
# Query with streaming and real-time explain
response = await rag.query(
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@ -348,7 +348,7 @@ class Processor(FlowProcessor):
else:
# Non-streaming path with real-time explain
response = await rag.query(
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@ -360,23 +360,30 @@ class Processor(FlowProcessor):
parent_uri = v.parent_uri,
)
# Send chunk with response
# Send single response with answer and token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=response,
end_of_stream=True,
error=None,
end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)
return
# Send final message to close session
# Streaming: send final message to close session with token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)