mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Feature/streaming llm phase 1 (#566)
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
This commit is contained in:
parent
943a9d83b0
commit
310a2deb06
44 changed files with 2684 additions and 937 deletions
|
|
@ -32,7 +32,7 @@ from vertexai.generative_models import (
|
|||
from anthropic import AnthropicVertex, RateLimitError
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... base import LlmService, LlmResult
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -239,6 +239,123 @@ class Processor(LlmService):
|
|||
logger.error(f"VertexAI LLM exception: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
def supports_streaming(self):
|
||||
"""VertexAI supports streaming for both Gemini and Claude models"""
|
||||
return True
|
||||
|
||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||
"""
|
||||
Stream content generation from VertexAI (Gemini or Claude).
|
||||
Yields LlmChunk objects with is_final=True on the last chunk.
|
||||
"""
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
if 'claude' in model_name.lower():
|
||||
# Claude/Anthropic streaming
|
||||
logger.debug(f"Streaming request to Anthropic model '{model_name}'...")
|
||||
client = self._get_anthropic_client()
|
||||
|
||||
total_in_tokens = 0
|
||||
total_out_tokens = 0
|
||||
|
||||
with client.messages.stream(
|
||||
model=model_name,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=self.api_params['max_output_tokens'],
|
||||
temperature=effective_temperature,
|
||||
top_p=self.api_params['top_p'],
|
||||
top_k=self.api_params['top_k'],
|
||||
) as stream:
|
||||
# Stream text chunks
|
||||
for text in stream.text_stream:
|
||||
yield LlmChunk(
|
||||
text=text,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Get final message with token counts
|
||||
final_message = stream.get_final_message()
|
||||
total_in_tokens = final_message.usage.input_tokens
|
||||
total_out_tokens = final_message.usage.output_tokens
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=total_in_tokens,
|
||||
out_token=total_out_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
logger.info(f"Input Tokens: {total_in_tokens}")
|
||||
logger.info(f"Output Tokens: {total_out_tokens}")
|
||||
|
||||
else:
|
||||
# Gemini streaming
|
||||
logger.debug(f"Streaming request to Gemini model '{model_name}'...")
|
||||
full_prompt = system + "\n\n" + prompt
|
||||
|
||||
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
|
||||
|
||||
response = llm.generate_content(
|
||||
full_prompt,
|
||||
generation_config=generation_config,
|
||||
safety_settings=self.safety_settings,
|
||||
stream=True # Enable streaming
|
||||
)
|
||||
|
||||
total_in_tokens = 0
|
||||
total_out_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
if chunk.text:
|
||||
yield LlmChunk(
|
||||
text=chunk.text,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Accumulate token counts if available
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
if hasattr(chunk.usage_metadata, 'prompt_token_count'):
|
||||
total_in_tokens = chunk.usage_metadata.prompt_token_count
|
||||
if hasattr(chunk.usage_metadata, 'candidates_token_count'):
|
||||
total_out_tokens = chunk.usage_metadata.candidates_token_count
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=total_in_tokens,
|
||||
out_token=total_out_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
logger.info(f"Input Tokens: {total_in_tokens}")
|
||||
logger.info(f"Output Tokens: {total_out_tokens}")
|
||||
|
||||
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
|
||||
logger.warning(f"Hit rate limit during streaming: {e}")
|
||||
raise TooManyRequests()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VertexAI streaming exception: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue