trustgraph-base/trustgraph/base/text_completion_client.py

- New TextCompletionResult dataclass: text: Optional[str],
  in_token, out_token, model.
- text_completion(system, prompt, timeout=600) — non-streaming
  only. Returns TextCompletionResult with text set and
  tokens/model populated from the response.
- New text_completion_stream(system, prompt, handler,
  timeout=600) — streaming. Invokes handler(chunk) with each
  TextCompletionResponse as it arrives (including the final one,
  so the caller can see end_of_stream). Raises on
  resp.error. Returns a TextCompletionResult with text=None and
  in_token/out_token/model pulled from the final message.
- Old streaming= kwarg on text_completion() is gone.

trustgraph-base/trustgraph/base/__init__.py
- Export TextCompletionClient and TextCompletionResult alongside
  the spec.

trustgraph-flow/trustgraph/prompt/template/service.py
- Non-streaming path now uses await
  ...text_completion(system=..., prompt=...) and reads
  result.text. The llm() callback still returns a plain string to
  PromptManager, preserving its contract.
- Streaming path collapsed onto text_completion_stream(...,
  handler=forward_chunks). Removed the hand-rolled
  client.request(...) + TextCompletionRequest plumbing. The
  is_final / resp.error behavior is preserved (forward_chunks
  still checks end_of_stream; errors are now raised centrally by
  the client).
- Dropped the now-unused TextCompletionRequest,
  TextCompletionResponse import.

Token counts are now available to any TextCompletionClient caller
- result.in_token / result.out_token / result.model — in both
modes. Nothing in the prompt service consumes them yet; that's
the next step on this branch.
This commit is contained in:
Cyber MacGeddon 2026-04-11 20:22:35 +01:00
parent ffe310af7c
commit 4930bc4d2b
3 changed files with 67 additions and 55 deletions

View file

@ -18,7 +18,9 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . text_completion_client import (
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService

View file

@ -1,47 +1,71 @@
from dataclasses import dataclass
from typing import Optional
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
@dataclass
class TextCompletionResult:
text: Optional[str]
in_token: int = 0
out_token: int = 0
model: str = ""
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
async def text_completion(self, system, prompt, timeout=600):
return resp.response
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
system = system, prompt = prompt, streaming = False
),
recipient=collect_chunks,
timeout=timeout
)
return full_response
if resp.error:
raise RuntimeError(resp.error.message)
return TextCompletionResult(
text = resp.response,
in_token = getattr(resp, "in_token", 0) or 0,
out_token = getattr(resp, "out_token", 0) or 0,
model = getattr(resp, "model", "") or "",
)
async def text_completion_stream(
self, system, prompt, handler, timeout=600,
):
"""
Streaming text completion. `handler` is an async callable invoked
once per chunk with the chunk's TextCompletionResponse. Returns a
TextCompletionResult with text=None and token counts / model taken
from the end_of_stream message.
"""
async def on_chunk(resp):
if resp.error:
raise RuntimeError(resp.error.message)
await handler(resp)
return getattr(resp, "end_of_stream", False)
final = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
),
recipient=on_chunk,
timeout=timeout,
)
return TextCompletionResult(
text = None,
in_token = getattr(final, "in_token", 0) or 0,
out_token = getattr(final, "out_token", 0) or 0,
model = getattr(final, "model", "") or "",
)
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)