Feature/streaming llm phase 1 (#566)

* Tidy up duplicate tech specs in doc directory

* Streaming LLM text-completion service tech spec.

* text-completion and prompt interfaces

* streaming change applied to all LLMs, so far tested with VertexAI

* Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again

* Added agent streaming, not working and has broken tests
This commit is contained in:
cybermaggedon 2025-11-26 09:59:10 +00:00 committed by GitHub
parent 943a9d83b0
commit 310a2deb06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 2684 additions and 937 deletions

View file

@ -11,7 +11,7 @@ import enum
import logging
from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult
from .... base import LlmService, LlmResult, LlmChunk
# Module logger
logger = logging.getLogger(__name__)
@ -52,6 +52,8 @@ class ModelHandler:
raise RuntimeError("format_request not implemented")
def decode_response(self, response):
raise RuntimeError("format_request not implemented")
def decode_stream_chunk(self, chunk):
raise RuntimeError("decode_stream_chunk not implemented")
class Mistral(ModelHandler):
def __init__(self):
@ -68,6 +70,11 @@ class Mistral(ModelHandler):
def decode_response(self, response):
response_body = json.loads(response.get("body").read())
return response_body['outputs'][0]['text']
def decode_stream_chunk(self, chunk):
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
if 'outputs' in chunk_obj and len(chunk_obj['outputs']) > 0:
return chunk_obj['outputs'][0].get('text', '')
return ''
# Llama 3
class Meta(ModelHandler):
@ -83,6 +90,9 @@ class Meta(ModelHandler):
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response["generation"]
def decode_stream_chunk(self, chunk):
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
return chunk_obj.get('generation', '')
class Anthropic(ModelHandler):
def __init__(self):
@ -108,6 +118,12 @@ class Anthropic(ModelHandler):
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response['content'][0]['text']
def decode_stream_chunk(self, chunk):
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
if chunk_obj.get('type') == 'content_block_delta':
if 'delta' in chunk_obj and 'text' in chunk_obj['delta']:
return chunk_obj['delta']['text']
return ''
class Ai21(ModelHandler):
def __init__(self):
@ -129,6 +145,12 @@ class Ai21(ModelHandler):
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['choices'][0]['message']['content']
def decode_stream_chunk(self, chunk):
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
if 'choices' in chunk_obj and len(chunk_obj['choices']) > 0:
delta = chunk_obj['choices'][0].get('delta', {})
return delta.get('content', '')
return ''
class Cohere(ModelHandler):
def encode_request(self, system, prompt):
@ -142,6 +164,9 @@ class Cohere(ModelHandler):
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['text']
def decode_stream_chunk(self, chunk):
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
return chunk_obj.get('text', '')
Default=Mistral
@ -309,6 +334,78 @@ class Processor(LlmService):
logger.error(f"Bedrock LLM exception ({type(e).__name__}): {e}", exc_info=True)
raise e
def supports_streaming(self):
"""Bedrock supports streaming"""
return True
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
"""Stream content generation from Bedrock"""
model_name = model or self.default_model
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}")
logger.debug(f"Using temperature: {effective_temperature}")
try:
variant = self._get_or_create_variant(model_name, effective_temperature)
promptbody = variant.encode_request(system, prompt)
accept = 'application/json'
contentType = 'application/json'
response = self.bedrock.invoke_model_with_response_stream(
body=promptbody,
modelId=model_name,
accept=accept,
contentType=contentType
)
total_input_tokens = 0
total_output_tokens = 0
stream = response.get('body')
if stream:
for event in stream:
chunk = event.get('chunk')
if chunk:
# Decode the chunk text
text = variant.decode_stream_chunk(event)
if text:
yield LlmChunk(
text=text,
in_token=None,
out_token=None,
model=model_name,
is_final=False
)
# Try to extract metadata from the event
metadata = event.get('metadata')
if metadata:
usage = metadata.get('usage')
if usage:
total_input_tokens = usage.get('inputTokens', 0)
total_output_tokens = usage.get('outputTokens', 0)
# Send final chunk with token counts
yield LlmChunk(
text="",
in_token=total_input_tokens,
out_token=total_output_tokens,
model=model_name,
is_final=True
)
logger.debug("Streaming complete")
except self.bedrock.exceptions.ThrottlingException as e:
logger.warning(f"Hit rate limit during streaming: {e}")
raise TooManyRequests()
except Exception as e:
logger.error(f"Bedrock streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):