mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-09 15:22:38 +02:00
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
parent
67cfa80836
commit
14e49d83c7
60 changed files with 1252 additions and 577 deletions
|
|
@ -53,7 +53,7 @@ class MetaRouter:
|
|||
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
|
||||
}
|
||||
|
||||
async def identify_task_type(self, question, context):
|
||||
async def identify_task_type(self, question, context, usage=None):
|
||||
"""
|
||||
Use the LLM to classify the question into one of the known task types.
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="task-type-classify",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -81,7 +81,9 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
if usage:
|
||||
usage.track(result)
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in self.task_types:
|
||||
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
|
||||
|
|
@ -100,7 +102,7 @@ class MetaRouter:
|
|||
)
|
||||
return DEFAULT_TASK_TYPE, framing
|
||||
|
||||
async def select_pattern(self, question, task_type, context):
|
||||
async def select_pattern(self, question, task_type, context, usage=None):
|
||||
"""
|
||||
Use the LLM to select the best execution pattern for this task type.
|
||||
|
||||
|
|
@ -120,7 +122,7 @@ class MetaRouter:
|
|||
|
||||
try:
|
||||
client = context("prompt-request")
|
||||
response = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="pattern-select",
|
||||
variables={
|
||||
"question": question,
|
||||
|
|
@ -133,7 +135,9 @@ class MetaRouter:
|
|||
],
|
||||
},
|
||||
)
|
||||
selected = response.strip().lower().replace('"', '').replace("'", "")
|
||||
if usage:
|
||||
usage.track(result)
|
||||
selected = result.text.strip().lower().replace('"', '').replace("'", "")
|
||||
|
||||
if selected in valid_patterns:
|
||||
logger.info(f"MetaRouter: selected pattern '{selected}'")
|
||||
|
|
@ -148,19 +152,20 @@ class MetaRouter:
|
|||
logger.warning(f"MetaRouter: pattern selection failed: {e}")
|
||||
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
|
||||
|
||||
async def route(self, question, context):
|
||||
async def route(self, question, context, usage=None):
|
||||
"""
|
||||
Full routing pipeline: identify task type, then select pattern.
|
||||
|
||||
Args:
|
||||
question: The user's query.
|
||||
context: UserAwareContext (flow wrapper).
|
||||
usage: Optional UsageTracker for token counting.
|
||||
|
||||
Returns:
|
||||
(pattern, task_type, framing) tuple.
|
||||
"""
|
||||
task_type, framing = await self.identify_task_type(question, context)
|
||||
pattern = await self.select_pattern(question, task_type, context)
|
||||
task_type, framing = await self.identify_task_type(question, context, usage=usage)
|
||||
pattern = await self.select_pattern(question, task_type, context, usage=usage)
|
||||
logger.info(
|
||||
f"MetaRouter: route result — "
|
||||
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"
|
||||
|
|
|
|||
|
|
@ -65,6 +65,37 @@ class UserAwareContext:
|
|||
return client
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
"""Accumulates token usage across multiple prompt calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_in = 0
|
||||
self.total_out = 0
|
||||
self.last_model = None
|
||||
|
||||
def track(self, result):
|
||||
"""Track usage from a PromptResult."""
|
||||
if result is not None:
|
||||
if getattr(result, "in_token", None) is not None:
|
||||
self.total_in += result.in_token
|
||||
if getattr(result, "out_token", None) is not None:
|
||||
self.total_out += result.out_token
|
||||
if getattr(result, "model", None) is not None:
|
||||
self.last_model = result.model
|
||||
|
||||
@property
|
||||
def in_token(self):
|
||||
return self.total_in if self.total_in > 0 else None
|
||||
|
||||
@property
|
||||
def out_token(self):
|
||||
return self.total_out if self.total_out > 0 else None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.last_model
|
||||
|
||||
|
||||
class PatternBase:
|
||||
"""
|
||||
Shared infrastructure for all agent patterns.
|
||||
|
|
@ -571,7 +602,8 @@ class PatternBase:
|
|||
# ---- Response helpers ---------------------------------------------------
|
||||
|
||||
async def prompt_as_answer(self, client, prompt_id, variables,
|
||||
respond, streaming, message_id=""):
|
||||
respond, streaming, message_id="",
|
||||
usage=None):
|
||||
"""Call a prompt template, forwarding chunks as answer
|
||||
AgentResponse messages when streaming is enabled.
|
||||
|
||||
|
|
@ -591,22 +623,28 @@ class PatternBase:
|
|||
message_id=message_id,
|
||||
))
|
||||
|
||||
await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk,
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
return "".join(accumulated)
|
||||
else:
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=prompt_id,
|
||||
variables=variables,
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
return result.text
|
||||
|
||||
async def send_final_response(self, respond, streaming, answer_text,
|
||||
already_streamed=False, message_id=""):
|
||||
already_streamed=False, message_id="",
|
||||
usage=None):
|
||||
"""Send the answer content and end-of-dialog marker.
|
||||
|
||||
Args:
|
||||
|
|
@ -614,7 +652,16 @@ class PatternBase:
|
|||
via streaming callbacks (e.g. ReactPattern). Only the
|
||||
end-of-dialog marker is emitted.
|
||||
message_id: Provenance URI for the answer entity.
|
||||
usage: UsageTracker with accumulated token counts.
|
||||
"""
|
||||
usage_kwargs = {}
|
||||
if usage:
|
||||
usage_kwargs = {
|
||||
"in_token": usage.in_token,
|
||||
"out_token": usage.out_token,
|
||||
"model": usage.model,
|
||||
}
|
||||
|
||||
if streaming and not already_streamed:
|
||||
# Answer wasn't streamed yet — send it as a chunk first
|
||||
if answer_text:
|
||||
|
|
@ -626,13 +673,14 @@ class PatternBase:
|
|||
message_id=message_id,
|
||||
))
|
||||
if streaming:
|
||||
# End-of-dialog marker
|
||||
# End-of-dialog marker with usage
|
||||
await respond(AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
message_id=message_id,
|
||||
**usage_kwargs,
|
||||
))
|
||||
else:
|
||||
await respond(AgentResponse(
|
||||
|
|
@ -641,6 +689,7 @@ class PatternBase:
|
|||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
message_id=message_id,
|
||||
**usage_kwargs,
|
||||
))
|
||||
|
||||
def build_next_request(self, request, history, session_id, collection,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from trustgraph.provenance import (
|
|||
agent_synthesis_uri,
|
||||
)
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
Subsequent calls execute the next pending plan step via ReACT.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase):
|
|||
await self._planning_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num,
|
||||
iteration_num, usage=usage,
|
||||
)
|
||||
else:
|
||||
await self._execution_iteration(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming, session_uri,
|
||||
iteration_num, plan,
|
||||
iteration_num, plan, usage=usage,
|
||||
)
|
||||
|
||||
def _extract_plan(self, history):
|
||||
|
|
@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _planning_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Ask the LLM to produce a structured plan."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
|
|
@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the plan-create prompt template
|
||||
plan_steps = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-create",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
plan_steps = result.objects
|
||||
# Validate we got a list
|
||||
if not isinstance(plan_steps, list) or not plan_steps:
|
||||
logger.warning("plan-create returned invalid result, falling back to single step")
|
||||
|
|
@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _execution_iteration(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
session_uri, iteration_num, plan,
|
||||
usage=None):
|
||||
"""Execute the next pending plan step via single-shot tool call."""
|
||||
|
||||
pending_idx = self._find_next_pending_step(plan)
|
||||
|
|
@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan,
|
||||
usage=usage,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Single-shot: ask LLM which tool + arguments to use for this goal
|
||||
tool_call = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="plan-step-execute",
|
||||
variables={
|
||||
"goal": goal,
|
||||
|
|
@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
tool_call = result.object
|
||||
tool_name = tool_call.get("tool", "")
|
||||
tool_arguments = tool_call.get("arguments", {})
|
||||
|
||||
|
|
@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase):
|
|||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num, plan):
|
||||
session_uri, iteration_num, plan,
|
||||
usage=None):
|
||||
"""Synthesise a final answer from all completed plan step results."""
|
||||
|
||||
think = self.make_think_callback(respond, streaming)
|
||||
|
|
@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase):
|
|||
respond=respond,
|
||||
streaming=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Emit synthesis provenance (links back to last step result)
|
||||
|
|
@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
|
|||
from ..react.types import Action, Final
|
||||
from ..tool_filter import get_next_state
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,7 +37,10 @@ class ReactPattern(PatternBase):
|
|||
result is appended to history and a next-request is emitted.
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -121,6 +124,7 @@ class ReactPattern(PatternBase):
|
|||
context=context,
|
||||
streaming=streaming,
|
||||
on_action=on_action,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
logger.debug(f"Action: {act}")
|
||||
|
|
@ -144,6 +148,7 @@ class ReactPattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, f, already_streamed=streaming,
|
||||
message_id=answer_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from ... base import Consumer, Producer
|
|||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ..orchestrator.pattern_base import UsageTracker
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
|
|
@ -493,6 +494,8 @@ class Processor(AgentService):
|
|||
|
||||
async def agent_request(self, request, respond, next, flow):
|
||||
|
||||
usage = UsageTracker()
|
||||
|
||||
try:
|
||||
|
||||
# Intercept subagent completion messages
|
||||
|
|
@ -516,7 +519,7 @@ class Processor(AgentService):
|
|||
|
||||
if self.meta_router:
|
||||
pattern, task_type, framing = await self.meta_router.route(
|
||||
request.question, context,
|
||||
request.question, context, usage=usage,
|
||||
)
|
||||
else:
|
||||
pattern = "react"
|
||||
|
|
@ -536,16 +539,16 @@ class Processor(AgentService):
|
|||
# Dispatch to the selected pattern
|
||||
if pattern == "plan-then-execute":
|
||||
await self.plan_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
elif pattern == "supervisor":
|
||||
await self.supervisor_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
else:
|
||||
# Default to react
|
||||
await self.react_pattern.iterate(
|
||||
request, respond, next, flow,
|
||||
request, respond, next, flow, usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from trustgraph.provenance import (
|
|||
agent_synthesis_uri,
|
||||
)
|
||||
|
||||
from . pattern_base import PatternBase
|
||||
from . pattern_base import PatternBase, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase):
|
|||
- "synthesise": triggered by aggregator with results in subagent_results
|
||||
"""
|
||||
|
||||
async def iterate(self, request, respond, next, flow):
|
||||
async def iterate(self, request, respond, next, flow, usage=None):
|
||||
|
||||
if usage is None:
|
||||
usage = UsageTracker()
|
||||
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
|
|
@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase):
|
|||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
usage=usage,
|
||||
)
|
||||
else:
|
||||
await self._decompose_and_fanout(
|
||||
request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _decompose_and_fanout(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Decompose the question into sub-goals and fan out subagents."""
|
||||
|
||||
decompose_msg_id = agent_decomposition_uri(session_id)
|
||||
|
|
@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase):
|
|||
client = context("prompt-request")
|
||||
|
||||
# Use the supervisor-decompose prompt template
|
||||
goals = await client.prompt(
|
||||
result = await client.prompt(
|
||||
id="supervisor-decompose",
|
||||
variables={
|
||||
"question": request.question,
|
||||
|
|
@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase):
|
|||
],
|
||||
},
|
||||
)
|
||||
if usage:
|
||||
usage.track(result)
|
||||
|
||||
goals = result.objects
|
||||
# Validate result
|
||||
if not isinstance(goals, list):
|
||||
goals = []
|
||||
|
|
@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase):
|
|||
|
||||
async def _synthesise(self, request, respond, next, flow,
|
||||
session_id, collection, streaming,
|
||||
session_uri, iteration_num):
|
||||
session_uri, iteration_num, usage=None):
|
||||
"""Synthesise final answer from subagent results."""
|
||||
|
||||
synthesis_msg_id = agent_synthesis_uri(session_id)
|
||||
|
|
@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase):
|
|||
respond=respond,
|
||||
streaming=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Emit synthesis provenance (links back to all findings)
|
||||
|
|
@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase):
|
|||
await self.send_final_response(
|
||||
respond, streaming, response_text, already_streamed=streaming,
|
||||
message_id=synthesis_msg_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class AgentManager:
|
|||
|
||||
raise ValueError(f"Could not parse response: {text}")
|
||||
|
||||
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None):
|
||||
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
|
||||
|
||||
logger.debug(f"calling reason: {question}")
|
||||
|
||||
|
|
@ -255,11 +255,13 @@ class AgentManager:
|
|||
client = context("prompt-request")
|
||||
|
||||
# Get streaming response
|
||||
response_text = await client.agent_react(
|
||||
prompt_result = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=True,
|
||||
chunk_callback=on_chunk
|
||||
)
|
||||
if usage:
|
||||
usage.track(prompt_result)
|
||||
|
||||
# Finalize parser
|
||||
parser.finalize()
|
||||
|
|
@ -275,10 +277,13 @@ class AgentManager:
|
|||
# Non-streaming path - get complete text and parse
|
||||
client = context("prompt-request")
|
||||
|
||||
response_text = await client.agent_react(
|
||||
prompt_result = await client.agent_react(
|
||||
variables=variables,
|
||||
streaming=False
|
||||
)
|
||||
if usage:
|
||||
usage.track(prompt_result)
|
||||
response_text = prompt_result.text
|
||||
|
||||
logger.debug(f"Response text:\n{response_text}")
|
||||
|
||||
|
|
@ -292,7 +297,8 @@ class AgentManager:
|
|||
raise RuntimeError(f"Failed to parse agent response: {e}")
|
||||
|
||||
async def react(self, question, history, think, observe, context,
|
||||
streaming=False, answer=None, on_action=None):
|
||||
streaming=False, answer=None, on_action=None,
|
||||
usage=None):
|
||||
|
||||
act = await self.reason(
|
||||
question = question,
|
||||
|
|
@ -302,6 +308,7 @@ class AgentManager:
|
|||
think = think,
|
||||
observe = observe,
|
||||
answer = answer,
|
||||
usage = usage,
|
||||
)
|
||||
|
||||
if isinstance(act, Final):
|
||||
|
|
|
|||
|
|
@ -78,9 +78,10 @@ class TextCompletionImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug("Prompt question...")
|
||||
return await client.question(
|
||||
result = await client.question(
|
||||
arguments.get("question")
|
||||
)
|
||||
return result.text
|
||||
|
||||
# This tool implementation knows how to do MCP tool invocation. This uses
|
||||
# the mcp-tool service.
|
||||
|
|
@ -227,10 +228,11 @@ class PromptImpl:
|
|||
async def invoke(self, **arguments):
|
||||
client = self.context("prompt-request")
|
||||
logger.debug(f"Prompt template invocation: {self.template_id}...")
|
||||
return await client.prompt(
|
||||
result = await client.prompt(
|
||||
id=self.template_id,
|
||||
variables=arguments
|
||||
)
|
||||
return result.text
|
||||
|
||||
|
||||
# This tool implementation invokes a dynamically configured tool service
|
||||
|
|
|
|||
|
|
@ -117,10 +117,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
defs = await flow("prompt-request").extract_definitions(
|
||||
result = await flow("prompt-request").extract_definitions(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
defs = result.objects
|
||||
logger.debug(f"Definitions response: {defs}")
|
||||
|
||||
if type(defs) != list:
|
||||
|
|
|
|||
|
|
@ -376,10 +376,11 @@ class Processor(FlowProcessor):
|
|||
"""
|
||||
try:
|
||||
# Call prompt service with simplified format prompt
|
||||
extraction_response = await flow("prompt-request").prompt(
|
||||
result = await flow("prompt-request").prompt(
|
||||
id="extract-with-ontologies",
|
||||
variables=prompt_variables
|
||||
)
|
||||
extraction_response = result.object
|
||||
logger.debug(f"Simplified extraction response: {extraction_response}")
|
||||
|
||||
# Parse response into structured format
|
||||
|
|
|
|||
|
|
@ -100,10 +100,11 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
rels = await flow("prompt-request").extract_relationships(
|
||||
result = await flow("prompt-request").extract_relationships(
|
||||
text = chunk
|
||||
)
|
||||
|
||||
rels = result.objects
|
||||
logger.debug(f"Prompt response: {rels}")
|
||||
|
||||
if type(rels) != list:
|
||||
|
|
|
|||
|
|
@ -148,11 +148,12 @@ class Processor(FlowProcessor):
|
|||
schema_dict = row_schema_translator.encode(schema)
|
||||
|
||||
# Use prompt client to extract rows based on schema
|
||||
objects = await flow("prompt-request").extract_objects(
|
||||
result = await flow("prompt-request").extract_objects(
|
||||
schema=schema_dict,
|
||||
text=text
|
||||
)
|
||||
|
||||
|
||||
objects = result.objects
|
||||
if not isinstance(objects, list):
|
||||
return []
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import logging
|
|||
from ...schema import Definition, Relationship, Triple
|
||||
from ...schema import Topic
|
||||
from ...schema import PromptRequest, PromptResponse, Error
|
||||
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from ...base import FlowProcessor
|
||||
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
|
|
@ -124,35 +123,26 @@ class Processor(FlowProcessor):
|
|||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
# Use the text completion client with recipient handler
|
||||
client = flow("text-completion-request")
|
||||
|
||||
async def forward_chunks(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
is_final = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
# Always send a message if there's content OR if it's the final message
|
||||
if resp.response or is_final:
|
||||
# Forward each chunk immediately
|
||||
r = PromptResponse(
|
||||
text=resp.response if resp.response else "",
|
||||
object=None,
|
||||
error=None,
|
||||
end_of_stream=is_final,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
# Return True when end_of_stream
|
||||
return is_final
|
||||
|
||||
await client.request(
|
||||
TextCompletionRequest(
|
||||
system=system, prompt=prompt, streaming=True
|
||||
),
|
||||
recipient=forward_chunks,
|
||||
timeout=600
|
||||
await flow("text-completion-request").text_completion_stream(
|
||||
system=system, prompt=prompt,
|
||||
handler=forward_chunks,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
# Return empty string since we already sent all chunks
|
||||
|
|
@ -167,17 +157,21 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
|
||||
# Non-streaming path (original behavior)
|
||||
usage = {}
|
||||
|
||||
async def llm(system, prompt):
|
||||
|
||||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
resp = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt, streaming = False,
|
||||
)
|
||||
|
||||
try:
|
||||
return resp
|
||||
result = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt,
|
||||
)
|
||||
usage["in_token"] = result.in_token
|
||||
usage["out_token"] = result.out_token
|
||||
usage["model"] = result.model
|
||||
return result.text
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Exception: {e}", exc_info=True)
|
||||
return None
|
||||
|
|
@ -199,6 +193,9 @@ class Processor(FlowProcessor):
|
|||
object=None,
|
||||
error=None,
|
||||
end_of_stream=True,
|
||||
in_token=usage.get("in_token", 0),
|
||||
out_token=usage.get("out_token", 0),
|
||||
model=usage.get("model", ""),
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
@ -215,6 +212,9 @@ class Processor(FlowProcessor):
|
|||
object=json.dumps(resp),
|
||||
error=None,
|
||||
end_of_stream=True,
|
||||
in_token=usage.get("in_token", 0),
|
||||
out_token=usage.get("out_token", 0),
|
||||
model=usage.get("model", ""),
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -27,24 +27,27 @@ class Query:
|
|||
|
||||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
doc_limit=20
|
||||
doc_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
if self.track_usage:
|
||||
self.track_usage(result)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -167,8 +170,23 @@ class DocumentRag:
|
|||
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
last_model = None
|
||||
|
||||
def track_usage(result):
|
||||
nonlocal total_in, total_out, last_model
|
||||
if result is not None:
|
||||
if result.in_token is not None:
|
||||
total_in += result.in_token
|
||||
if result.out_token is not None:
|
||||
total_out += result.out_token
|
||||
if result.model is not None:
|
||||
last_model = result.model
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
|
|
@ -191,7 +209,7 @@ class DocumentRag:
|
|||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -228,19 +246,22 @@ class DocumentRag:
|
|||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
synthesis_result = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
synthesis_result = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
resp = synthesis_result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
@ -273,5 +294,11 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
usage = {
|
||||
"in_token": total_in if total_in > 0 else None,
|
||||
"out_token": total_out if total_out > 0 else None,
|
||||
"model": last_model,
|
||||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Query with streaming enabled
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await self.rag.query(
|
||||
response, usage = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
|
|
@ -217,12 +217,15 @@ class Processor(FlowProcessor):
|
|||
response=None,
|
||||
end_of_session=True,
|
||||
message_type="end",
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await self.rag.query(
|
||||
# Non-streaming path - single response with answer and token usage
|
||||
response, usage = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
|
|
@ -233,11 +236,15 @@ class Processor(FlowProcessor):
|
|||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties = {"id": id}
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, user, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
max_path_length=2, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
|
|
@ -131,17 +131,20 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
result = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
if self.track_usage:
|
||||
self.track_usage(result)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
if result.text:
|
||||
for line in result.text.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
|
@ -609,8 +612,24 @@ class GraphRag:
|
|||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
last_model = None
|
||||
|
||||
def track_usage(result):
|
||||
nonlocal total_in, total_out, last_model
|
||||
if result is not None:
|
||||
if result.in_token is not None:
|
||||
total_in += result.in_token
|
||||
if result.out_token is not None:
|
||||
total_out += result.out_token
|
||||
if result.model is not None:
|
||||
last_model = result.model
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
|
|
@ -641,6 +660,7 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
|
|
@ -751,21 +771,22 @@ class GraphRag:
|
|||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response to get edge IDs with scores
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
def parse_scored_edge(obj):
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
|
|
@ -773,21 +794,6 @@ class GraphRag:
|
|||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
if isinstance(scoring_response, list):
|
||||
for obj in scoring_response:
|
||||
parse_scored_edge(obj)
|
||||
elif isinstance(scoring_response, str):
|
||||
for line in scoring_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_scored_edge(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge scoring line: {line}"
|
||||
)
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
|
|
@ -821,25 +827,30 @@ class GraphRag:
|
|||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_response, Exception):
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_response}"
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_response = ""
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
|
|
@ -848,29 +859,15 @@ class GraphRag:
|
|||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning response: {reasoning_response}")
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response and build explainability data
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
def parse_reasoning(obj):
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
if isinstance(reasoning_response, list):
|
||||
for obj in reasoning_response:
|
||||
parse_reasoning(obj)
|
||||
elif isinstance(reasoning_response, str):
|
||||
for line in reasoning_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_reasoning(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge reasoning line: {line}"
|
||||
)
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
|
|
@ -919,19 +916,22 @@ class GraphRag:
|
|||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
await self.prompt_client.prompt(
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.prompt(
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
resp = synthesis_result.text
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
@ -964,5 +964,11 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
usage = {
|
||||
"in_token": total_in if total_in > 0 else None,
|
||||
"out_token": total_out if total_out > 0 else None,
|
||||
"model": last_model,
|
||||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
# Query with streaming and real-time explain
|
||||
response = await rag.query(
|
||||
response, usage = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
@ -348,7 +348,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
else:
|
||||
# Non-streaming path with real-time explain
|
||||
response = await rag.query(
|
||||
response, usage = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
@ -360,23 +360,30 @@ class Processor(FlowProcessor):
|
|||
parent_uri = v.parent_uri,
|
||||
)
|
||||
|
||||
# Send chunk with response
|
||||
# Send single response with answer and token usage
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
end_of_session=True,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
return
|
||||
|
||||
# Send final message to close session
|
||||
# Streaming: send final message to close session with token usage
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
in_token=usage.get("in_token"),
|
||||
out_token=usage.get("out_token"),
|
||||
model=usage.get("model"),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue