mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Feature/streaming llm phase 1 (#566)
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
This commit is contained in:
parent
943a9d83b0
commit
310a2deb06
44 changed files with 2684 additions and 937 deletions
|
|
@ -12,7 +12,7 @@ from . parameter_spec import ParameterSpec
|
|||
from . producer_spec import ProducerSpec
|
||||
from . subscriber_spec import SubscriberSpec
|
||||
from . request_response_spec import RequestResponseSpec
|
||||
from . llm_service import LlmService, LlmResult
|
||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
|
|
|
|||
|
|
@ -28,6 +28,19 @@ class LlmResult:
|
|||
self.model = model
|
||||
__slots__ = ["text", "in_token", "out_token", "model"]
|
||||
|
||||
class LlmChunk:
|
||||
"""Represents a streaming chunk from an LLM"""
|
||||
def __init__(
|
||||
self, text = None, in_token = None, out_token = None,
|
||||
model = None, is_final = False,
|
||||
):
|
||||
self.text = text
|
||||
self.in_token = in_token
|
||||
self.out_token = out_token
|
||||
self.model = model
|
||||
self.is_final = is_final
|
||||
__slots__ = ["text", "in_token", "out_token", "model", "is_final"]
|
||||
|
||||
class LlmService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -99,16 +112,57 @@ class LlmService(FlowProcessor):
|
|||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
model = flow("model")
|
||||
temperature = flow("temperature")
|
||||
|
||||
model = flow("model")
|
||||
temperature = flow("temperature")
|
||||
# Check if streaming is requested and supported
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
|
||||
response = await self.generate_content(
|
||||
request.system, request.prompt, model, temperature
|
||||
if streaming and self.supports_streaming():
|
||||
|
||||
# Streaming mode
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
|
||||
async for chunk in self.generate_content_stream(
|
||||
request.system, request.prompt, model, temperature
|
||||
):
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=chunk.text,
|
||||
in_token=chunk.in_token,
|
||||
out_token=chunk.out_token,
|
||||
model=chunk.model,
|
||||
end_of_stream=chunk.is_final
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Non-streaming mode (original behavior)
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
|
||||
response = await self.generate_content(
|
||||
request.system, request.prompt, model, temperature
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=response.text,
|
||||
in_token=response.in_token,
|
||||
out_token=response.out_token,
|
||||
model=response.model,
|
||||
end_of_stream=True
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
__class__.text_completion_model_metric.labels(
|
||||
|
|
@ -119,17 +173,6 @@ class LlmService(FlowProcessor):
|
|||
"temperature": str(temperature) if temperature is not None else "",
|
||||
})
|
||||
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=response.text,
|
||||
in_token=response.in_token,
|
||||
out_token=response.out_token,
|
||||
model=response.model
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
|
|
@ -151,10 +194,26 @@ class LlmService(FlowProcessor):
|
|||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
end_of_stream=True
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
def supports_streaming(self):
|
||||
"""
|
||||
Override in subclass to indicate streaming support.
|
||||
Returns False by default.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||
"""
|
||||
Override in subclass to implement streaming.
|
||||
Should yield LlmChunk objects.
|
||||
The final chunk should have is_final=True.
|
||||
"""
|
||||
raise NotImplementedError("Streaming not implemented for this provider")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,30 +1,75 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600):
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
}
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
if not streaming:
|
||||
# Non-streaming path
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text: return resp.text
|
||||
if resp.text: return resp.text
|
||||
|
||||
return json.loads(resp.object)
|
||||
return json.loads(resp.object)
|
||||
|
||||
else:
|
||||
# Streaming path - collect all chunks
|
||||
full_text = ""
|
||||
full_object = None
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_text, full_object
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text:
|
||||
full_text += resp.text
|
||||
# Call chunk callback if provided
|
||||
if chunk_callback:
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
elif resp.object:
|
||||
full_object = resp.object
|
||||
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if full_text: return full_text
|
||||
|
||||
return json.loads(full_object)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
@ -70,11 +115,13 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def agent_react(self, variables, timeout=600):
|
||||
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "agent-react",
|
||||
variables = variables,
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def question(self, question, timeout=600):
|
||||
|
|
|
|||
|
|
@ -3,18 +3,45 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
class TextCompletionClient(RequestResponse):
|
||||
async def text_completion(self, system, prompt, timeout=600):
|
||||
resp = await self.request(
|
||||
async def text_completion(self, system, prompt, streaming=False, timeout=600):
|
||||
# If not streaming, use original behavior
|
||||
if not streaming:
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
|
||||
# For streaming: collect all chunks and return complete response
|
||||
full_response = ""
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_response
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.response:
|
||||
full_response += resp.response
|
||||
|
||||
# Return True when end_of_stream is reached
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt
|
||||
system = system, prompt = prompt, streaming = True
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
return full_response
|
||||
|
||||
class TextCompletionClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue