diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 24b6c1f0..9511c44d 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -18,7 +18,9 @@ from . librarian_client import LibrarianClient from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec -from . text_completion_client import TextCompletionClientSpec +from . text_completion_client import ( + TextCompletionClientSpec, TextCompletionClient, TextCompletionResult, +) from . prompt_client import PromptClientSpec from . triples_store_service import TriplesStoreService from . graph_embeddings_store_service import GraphEmbeddingsStoreService diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py index ae93e22e..0a1358dc 100644 --- a/trustgraph-base/trustgraph/base/text_completion_client.py +++ b/trustgraph-base/trustgraph/base/text_completion_client.py @@ -1,47 +1,71 @@ +from dataclasses import dataclass +from typing import Optional + from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import TextCompletionRequest, TextCompletionResponse +@dataclass +class TextCompletionResult: + text: Optional[str] + in_token: int = 0 + out_token: int = 0 + model: str = "" + class TextCompletionClient(RequestResponse): - async def text_completion(self, system, prompt, streaming=False, timeout=600): - # If not streaming, use original behavior - if not streaming: - resp = await self.request( - TextCompletionRequest( - system = system, prompt = prompt, streaming = False - ), - timeout=timeout - ) - if resp.error: - raise RuntimeError(resp.error.message) + async def text_completion(self, system, prompt, timeout=600): - return resp.response - - # For streaming: collect all chunks and return complete response - full_response = "" - - async def collect_chunks(resp): - nonlocal full_response - - if resp.error: - raise RuntimeError(resp.error.message) - - if resp.response: - full_response += resp.response - - # Return True when end_of_stream is reached - return getattr(resp, 'end_of_stream', False) - - await self.request( + resp = await self.request( TextCompletionRequest( - system = system, prompt = prompt, streaming = True + system = system, prompt = prompt, streaming = False ), - recipient=collect_chunks, timeout=timeout ) - return full_response + if resp.error: + raise RuntimeError(resp.error.message) + + return TextCompletionResult( + text = resp.response, + in_token = getattr(resp, "in_token", 0) or 0, + out_token = getattr(resp, "out_token", 0) or 0, + model = getattr(resp, "model", "") or "", + ) + + async def text_completion_stream( + self, system, prompt, handler, timeout=600, + ): + """ + Streaming text completion. `handler` is an async callable invoked + once per chunk with the chunk's TextCompletionResponse. Returns a + TextCompletionResult with text=None and token counts / model taken + from the end_of_stream message. + """ + + async def on_chunk(resp): + + if resp.error: + raise RuntimeError(resp.error.message) + + await handler(resp) + + return getattr(resp, "end_of_stream", False) + + final = await self.request( + TextCompletionRequest( + system = system, prompt = prompt, streaming = True + ), + recipient=on_chunk, + timeout=timeout, + ) + + return TextCompletionResult( + text = None, + in_token = getattr(final, "in_token", 0) or 0, + out_token = getattr(final, "out_token", 0) or 0, + model = getattr(final, "model", "") or "", + ) class TextCompletionClientSpec(RequestResponseSpec): def __init__( @@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec): response_schema = TextCompletionResponse, impl = TextCompletionClient, ) - diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index 97298e13..c9c0c87b 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -11,7 +11,6 @@ import logging from ...schema import Definition, Relationship, Triple from ...schema import Topic from ...schema import PromptRequest, PromptResponse, Error -from ...schema import TextCompletionRequest, TextCompletionResponse from ...base import FlowProcessor from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec @@ -124,13 +123,7 @@ class Processor(FlowProcessor): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - # Use the text completion client with recipient handler - client = flow("text-completion-request") - async def forward_chunks(resp): - if resp.error: - raise RuntimeError(resp.error.message) - is_final = getattr(resp, 'end_of_stream', False) # Always send a message if there's content OR if it's the final message @@ -144,15 +137,10 @@ class Processor(FlowProcessor): ) await flow("response").send(r, properties={"id": id}) - # Return True when end_of_stream - return is_final - - await client.request( - TextCompletionRequest( - system=system, prompt=prompt, streaming=True - ), - recipient=forward_chunks, - timeout=600 + await flow("text-completion-request").text_completion_stream( + system=system, prompt=prompt, + handler=forward_chunks, + timeout=600, ) # Return empty string since we already sent all chunks @@ -172,12 +160,11 @@ class Processor(FlowProcessor): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - resp = await flow("text-completion-request").text_completion( - system = system, prompt = prompt, streaming = False, - ) - try: - return resp + result = await flow("text-completion-request").text_completion( + system = system, prompt = prompt, + ) + return result.text except Exception as e: logger.error(f"LLM Exception: {e}", exc_info=True) return None