mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-12 08:15:14 +02:00
trustgraph-base/trustgraph/base/text_completion_client.py
- New TextCompletionResult dataclass: text: Optional[str], in_token, out_token, model. - text_completion(system, prompt, timeout=600) — non-streaming only. Returns TextCompletionResult with text set and tokens/model populated from the response. - New text_completion_stream(system, prompt, handler, timeout=600) — streaming. Invokes handler(chunk) with each TextCompletionResponse as it arrives (including the final one, so the caller can see end_of_stream). Raises on resp.error. Returns a TextCompletionResult with text=None and in_token/out_token/model pulled from the final message. - Old streaming= kwarg on text_completion() is gone. trustgraph-base/trustgraph/base/__init__.py - Export TextCompletionClient and TextCompletionResult alongside the spec. trustgraph-flow/trustgraph/prompt/template/service.py - Non-streaming path now uses await ...text_completion(system=..., prompt=...) and reads result.text. The llm() callback still returns a plain string to PromptManager, preserving its contract. - Streaming path collapsed onto text_completion_stream(..., handler=forward_chunks). Removed the hand-rolled client.request(...) + TextCompletionRequest plumbing. The is_final / resp.error behavior is preserved (forward_chunks still checks end_of_stream; errors are now raised centrally by the client). - Dropped the now-unused TextCompletionRequest, TextCompletionResponse import. Token counts are now available to any TextCompletionClient caller - result.in_token / result.out_token / result.model — in both modes. Nothing in the prompt service consumes them yet; that's the next step on this branch.
This commit is contained in:
parent
ffe310af7c
commit
4930bc4d2b
3 changed files with 67 additions and 55 deletions
|
|
@ -18,7 +18,9 @@ from . librarian_client import LibrarianClient
|
|||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
from . text_completion_client import TextCompletionClientSpec
|
||||
from . text_completion_client import (
|
||||
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
|
||||
)
|
||||
from . prompt_client import PromptClientSpec
|
||||
from . triples_store_service import TriplesStoreService
|
||||
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
|
||||
|
|
|
|||
|
|
@ -1,47 +1,71 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
@dataclass
|
||||
class TextCompletionResult:
|
||||
text: Optional[str]
|
||||
in_token: int = 0
|
||||
out_token: int = 0
|
||||
model: str = ""
|
||||
|
||||
class TextCompletionClient(RequestResponse):
|
||||
async def text_completion(self, system, prompt, streaming=False, timeout=600):
|
||||
# If not streaming, use original behavior
|
||||
if not streaming:
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
async def text_completion(self, system, prompt, timeout=600):
|
||||
|
||||
return resp.response
|
||||
|
||||
# For streaming: collect all chunks and return complete response
|
||||
full_response = ""
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_response
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.response:
|
||||
full_response += resp.response
|
||||
|
||||
# Return True when end_of_stream is reached
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return full_response
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = resp.response,
|
||||
in_token = getattr(resp, "in_token", 0) or 0,
|
||||
out_token = getattr(resp, "out_token", 0) or 0,
|
||||
model = getattr(resp, "model", "") or "",
|
||||
)
|
||||
|
||||
async def text_completion_stream(
|
||||
self, system, prompt, handler, timeout=600,
|
||||
):
|
||||
"""
|
||||
Streaming text completion. `handler` is an async callable invoked
|
||||
once per chunk with the chunk's TextCompletionResponse. Returns a
|
||||
TextCompletionResult with text=None and token counts / model taken
|
||||
from the end_of_stream message.
|
||||
"""
|
||||
|
||||
async def on_chunk(resp):
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
await handler(resp)
|
||||
|
||||
return getattr(resp, "end_of_stream", False)
|
||||
|
||||
final = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
),
|
||||
recipient=on_chunk,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = None,
|
||||
in_token = getattr(final, "in_token", 0) or 0,
|
||||
out_token = getattr(final, "out_token", 0) or 0,
|
||||
model = getattr(final, "model", "") or "",
|
||||
)
|
||||
|
||||
class TextCompletionClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
|
|||
response_schema = TextCompletionResponse,
|
||||
impl = TextCompletionClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import logging
|
|||
from ...schema import Definition, Relationship, Triple
|
||||
from ...schema import Topic
|
||||
from ...schema import PromptRequest, PromptResponse, Error
|
||||
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from ...base import FlowProcessor
|
||||
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||
|
|
@ -124,13 +123,7 @@ class Processor(FlowProcessor):
|
|||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
# Use the text completion client with recipient handler
|
||||
client = flow("text-completion-request")
|
||||
|
||||
async def forward_chunks(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
is_final = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
# Always send a message if there's content OR if it's the final message
|
||||
|
|
@ -144,15 +137,10 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
# Return True when end_of_stream
|
||||
return is_final
|
||||
|
||||
await client.request(
|
||||
TextCompletionRequest(
|
||||
system=system, prompt=prompt, streaming=True
|
||||
),
|
||||
recipient=forward_chunks,
|
||||
timeout=600
|
||||
await flow("text-completion-request").text_completion_stream(
|
||||
system=system, prompt=prompt,
|
||||
handler=forward_chunks,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
# Return empty string since we already sent all chunks
|
||||
|
|
@ -172,12 +160,11 @@ class Processor(FlowProcessor):
|
|||
logger.debug(f"System prompt: {system}")
|
||||
logger.debug(f"User prompt: {prompt}")
|
||||
|
||||
resp = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt, streaming = False,
|
||||
)
|
||||
|
||||
try:
|
||||
return resp
|
||||
result = await flow("text-completion-request").text_completion(
|
||||
system = system, prompt = prompt,
|
||||
)
|
||||
return result.text
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Exception: {e}", exc_info=True)
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue