mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/streaming llm phase 1 (#566)
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
This commit is contained in:
parent
943a9d83b0
commit
310a2deb06
44 changed files with 2684 additions and 937 deletions
|
|
@ -11,7 +11,7 @@ import enum
|
|||
import logging
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... base import LlmService, LlmResult
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -52,6 +52,8 @@ class ModelHandler:
|
|||
raise RuntimeError("format_request not implemented")
|
||||
def decode_response(self, response):
|
||||
raise RuntimeError("format_request not implemented")
|
||||
def decode_stream_chunk(self, chunk):
|
||||
raise RuntimeError("decode_stream_chunk not implemented")
|
||||
|
||||
class Mistral(ModelHandler):
|
||||
def __init__(self):
|
||||
|
|
@ -68,6 +70,11 @@ class Mistral(ModelHandler):
|
|||
def decode_response(self, response):
|
||||
response_body = json.loads(response.get("body").read())
|
||||
return response_body['outputs'][0]['text']
|
||||
def decode_stream_chunk(self, chunk):
|
||||
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
||||
if 'outputs' in chunk_obj and len(chunk_obj['outputs']) > 0:
|
||||
return chunk_obj['outputs'][0].get('text', '')
|
||||
return ''
|
||||
|
||||
# Llama 3
|
||||
class Meta(ModelHandler):
|
||||
|
|
@ -83,6 +90,9 @@ class Meta(ModelHandler):
|
|||
def decode_response(self, response):
|
||||
model_response = json.loads(response["body"].read())
|
||||
return model_response["generation"]
|
||||
def decode_stream_chunk(self, chunk):
|
||||
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
||||
return chunk_obj.get('generation', '')
|
||||
|
||||
class Anthropic(ModelHandler):
|
||||
def __init__(self):
|
||||
|
|
@ -108,6 +118,12 @@ class Anthropic(ModelHandler):
|
|||
def decode_response(self, response):
|
||||
model_response = json.loads(response["body"].read())
|
||||
return model_response['content'][0]['text']
|
||||
def decode_stream_chunk(self, chunk):
|
||||
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
||||
if chunk_obj.get('type') == 'content_block_delta':
|
||||
if 'delta' in chunk_obj and 'text' in chunk_obj['delta']:
|
||||
return chunk_obj['delta']['text']
|
||||
return ''
|
||||
|
||||
class Ai21(ModelHandler):
|
||||
def __init__(self):
|
||||
|
|
@ -129,6 +145,12 @@ class Ai21(ModelHandler):
|
|||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
return content_json['choices'][0]['message']['content']
|
||||
def decode_stream_chunk(self, chunk):
|
||||
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
||||
if 'choices' in chunk_obj and len(chunk_obj['choices']) > 0:
|
||||
delta = chunk_obj['choices'][0].get('delta', {})
|
||||
return delta.get('content', '')
|
||||
return ''
|
||||
|
||||
class Cohere(ModelHandler):
|
||||
def encode_request(self, system, prompt):
|
||||
|
|
@ -142,6 +164,9 @@ class Cohere(ModelHandler):
|
|||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
return content_json['text']
|
||||
def decode_stream_chunk(self, chunk):
|
||||
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
||||
return chunk_obj.get('text', '')
|
||||
|
||||
Default=Mistral
|
||||
|
||||
|
|
@ -309,6 +334,78 @@ class Processor(LlmService):
|
|||
logger.error(f"Bedrock LLM exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
def supports_streaming(self):
|
||||
"""Bedrock supports streaming"""
|
||||
return True
|
||||
|
||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||
"""Stream content generation from Bedrock"""
|
||||
model_name = model or self.default_model
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
variant = self._get_or_create_variant(model_name, effective_temperature)
|
||||
promptbody = variant.encode_request(system, prompt)
|
||||
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
|
||||
response = self.bedrock.invoke_model_with_response_stream(
|
||||
body=promptbody,
|
||||
modelId=model_name,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
stream = response.get('body')
|
||||
if stream:
|
||||
for event in stream:
|
||||
chunk = event.get('chunk')
|
||||
if chunk:
|
||||
# Decode the chunk text
|
||||
text = variant.decode_stream_chunk(event)
|
||||
if text:
|
||||
yield LlmChunk(
|
||||
text=text,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=model_name,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Try to extract metadata from the event
|
||||
metadata = event.get('metadata')
|
||||
if metadata:
|
||||
usage = metadata.get('usage')
|
||||
if usage:
|
||||
total_input_tokens = usage.get('inputTokens', 0)
|
||||
total_output_tokens = usage.get('outputTokens', 0)
|
||||
|
||||
# Send final chunk with token counts
|
||||
yield LlmChunk(
|
||||
text="",
|
||||
in_token=total_input_tokens,
|
||||
out_token=total_output_tokens,
|
||||
model=model_name,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
logger.debug("Streaming complete")
|
||||
|
||||
except self.bedrock.exceptions.ThrottlingException as e:
|
||||
logger.warning(f"Hit rate limit during streaming: {e}")
|
||||
raise TooManyRequests()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue