Expose LLM token usage across all service layers (#782)

Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
cybermaggedon 2026-04-13 14:38:34 +01:00 committed by GitHub
parent 67cfa80836
commit 14e49d83c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 1252 additions and 577 deletions

View file

@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . prompt_client import PromptClientSpec
from . text_completion_client import (
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService

View file

@ -1,10 +1,22 @@
import json
import asyncio
from dataclasses import dataclass
from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
@dataclass
class PromptResult:
response_type: str # "text", "json", or "jsonl"
text: Optional[str] = None # populated for "text"
object: Any = None # populated for "json"
objects: Optional[list] = None # populated for "jsonl"
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
if resp.text:
return PromptResult(
response_type="text",
text=resp.text,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return json.loads(resp.object)
parsed = json.loads(resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
else:
last_text = ""
last_object = None
last_resp = None
async def forward_chunks(resp):
nonlocal last_text, last_object
nonlocal last_resp
if resp.error:
raise RuntimeError(resp.error.message)
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None:
last_text = resp.text
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
elif resp.object:
last_object = resp.object
last_resp = resp
return end_stream
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout
)
if last_text:
return last_text
if last_resp is None:
return PromptResult(response_type="text")
return json.loads(last_object) if last_object else None
if last_resp.object:
parsed = json.loads(last_resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="text",
text=last_resp.text,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -1,47 +1,71 @@
from dataclasses import dataclass
from typing import Optional
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
@dataclass
class TextCompletionResult:
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
async def text_completion(self, system, prompt, timeout=600):
return resp.response
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
system = system, prompt = prompt, streaming = False
),
recipient=collect_chunks,
timeout=timeout
)
return full_response
if resp.error:
raise RuntimeError(resp.error.message)
return TextCompletionResult(
text = resp.response,
in_token = resp.in_token,
out_token = resp.out_token,
model = resp.model,
)
async def text_completion_stream(
self, system, prompt, handler, timeout=600,
):
"""
Streaming text completion. `handler` is an async callable invoked
once per chunk with the chunk's TextCompletionResponse. Returns a
TextCompletionResult with text=None and token counts / model taken
from the end_of_stream message.
"""
async def on_chunk(resp):
if resp.error:
raise RuntimeError(resp.error.message)
await handler(resp)
return getattr(resp, "end_of_stream", False)
final = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
),
recipient=on_chunk,
timeout=timeout,
)
return TextCompletionResult(
text = None,
in_token = final.in_token,
out_token = final.out_token,
model = final.model,
)
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)