mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Feature/streaming llm phase 1 (#566)
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
This commit is contained in:
parent
943a9d83b0
commit
310a2deb06
44 changed files with 2684 additions and 937 deletions
|
|
@ -12,7 +12,7 @@ from . parameter_spec import ParameterSpec
|
|||
from . producer_spec import ProducerSpec
|
||||
from . subscriber_spec import SubscriberSpec
|
||||
from . request_response_spec import RequestResponseSpec
|
||||
from . llm_service import LlmService, LlmResult
|
||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
|
|
|
|||
|
|
@ -28,6 +28,19 @@ class LlmResult:
|
|||
self.model = model
|
||||
__slots__ = ["text", "in_token", "out_token", "model"]
|
||||
|
||||
class LlmChunk:
|
||||
"""Represents a streaming chunk from an LLM"""
|
||||
def __init__(
|
||||
self, text = None, in_token = None, out_token = None,
|
||||
model = None, is_final = False,
|
||||
):
|
||||
self.text = text
|
||||
self.in_token = in_token
|
||||
self.out_token = out_token
|
||||
self.model = model
|
||||
self.is_final = is_final
|
||||
__slots__ = ["text", "in_token", "out_token", "model", "is_final"]
|
||||
|
||||
class LlmService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -99,16 +112,57 @@ class LlmService(FlowProcessor):
|
|||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
model = flow("model")
|
||||
temperature = flow("temperature")
|
||||
|
||||
model = flow("model")
|
||||
temperature = flow("temperature")
|
||||
# Check if streaming is requested and supported
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
|
||||
response = await self.generate_content(
|
||||
request.system, request.prompt, model, temperature
|
||||
if streaming and self.supports_streaming():
|
||||
|
||||
# Streaming mode
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
|
||||
async for chunk in self.generate_content_stream(
|
||||
request.system, request.prompt, model, temperature
|
||||
):
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=chunk.text,
|
||||
in_token=chunk.in_token,
|
||||
out_token=chunk.out_token,
|
||||
model=chunk.model,
|
||||
end_of_stream=chunk.is_final
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Non-streaming mode (original behavior)
|
||||
with __class__.text_completion_metric.labels(
|
||||
id=self.id,
|
||||
flow=f"{flow.name}-{consumer.name}",
|
||||
).time():
|
||||
|
||||
response = await self.generate_content(
|
||||
request.system, request.prompt, model, temperature
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=response.text,
|
||||
in_token=response.in_token,
|
||||
out_token=response.out_token,
|
||||
model=response.model,
|
||||
end_of_stream=True
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
__class__.text_completion_model_metric.labels(
|
||||
|
|
@ -119,17 +173,6 @@ class LlmService(FlowProcessor):
|
|||
"temperature": str(temperature) if temperature is not None else "",
|
||||
})
|
||||
|
||||
await flow("response").send(
|
||||
TextCompletionResponse(
|
||||
error=None,
|
||||
response=response.text,
|
||||
in_token=response.in_token,
|
||||
out_token=response.out_token,
|
||||
model=response.model
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
|
|
@ -151,10 +194,26 @@ class LlmService(FlowProcessor):
|
|||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
end_of_stream=True
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
def supports_streaming(self):
|
||||
"""
|
||||
Override in subclass to indicate streaming support.
|
||||
Returns False by default.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
||||
"""
|
||||
Override in subclass to implement streaming.
|
||||
Should yield LlmChunk objects.
|
||||
The final chunk should have is_final=True.
|
||||
"""
|
||||
raise NotImplementedError("Streaming not implemented for this provider")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,30 +1,75 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600):
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
}
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
if not streaming:
|
||||
# Non-streaming path
|
||||
resp = await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text: return resp.text
|
||||
if resp.text: return resp.text
|
||||
|
||||
return json.loads(resp.object)
|
||||
return json.loads(resp.object)
|
||||
|
||||
else:
|
||||
# Streaming path - collect all chunks
|
||||
full_text = ""
|
||||
full_object = None
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_text, full_object
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text:
|
||||
full_text += resp.text
|
||||
# Call chunk callback if provided
|
||||
if chunk_callback:
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
elif resp.object:
|
||||
full_object = resp.object
|
||||
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
PromptRequest(
|
||||
id = id,
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in variables.items()
|
||||
},
|
||||
streaming = True
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if full_text: return full_text
|
||||
|
||||
return json.loads(full_object)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
@ -70,11 +115,13 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def agent_react(self, variables, timeout=600):
|
||||
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "agent-react",
|
||||
variables = variables,
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def question(self, question, timeout=600):
|
||||
|
|
|
|||
|
|
@ -3,18 +3,45 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
class TextCompletionClient(RequestResponse):
|
||||
async def text_completion(self, system, prompt, timeout=600):
|
||||
resp = await self.request(
|
||||
async def text_completion(self, system, prompt, streaming=False, timeout=600):
|
||||
# If not streaming, use original behavior
|
||||
if not streaming:
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
|
||||
# For streaming: collect all chunks and return complete response
|
||||
full_response = ""
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_response
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.response:
|
||||
full_response += resp.response
|
||||
|
||||
# Return True when end_of_stream is reached
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt
|
||||
system = system, prompt = prompt, streaming = True
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
return full_response
|
||||
|
||||
class TextCompletionClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .. schema import TextCompletionRequest, TextCompletionResponse
|
|||
from .. schema import text_completion_request_queue
|
||||
from .. schema import text_completion_response_queue
|
||||
from . base import BaseClient
|
||||
from .. exceptions import LlmError
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
|
|
@ -37,8 +38,68 @@ class LlmClient(BaseClient):
|
|||
output_schema=TextCompletionResponse,
|
||||
)
|
||||
|
||||
def request(self, system, prompt, timeout=300):
|
||||
def request(self, system, prompt, timeout=300, streaming=False):
|
||||
"""
|
||||
Non-streaming request (backward compatible).
|
||||
Returns complete response string.
|
||||
"""
|
||||
if streaming:
|
||||
raise ValueError("Use request_stream() for streaming requests")
|
||||
return self.call(
|
||||
system=system, prompt=prompt, timeout=timeout
|
||||
system=system, prompt=prompt, streaming=False, timeout=timeout
|
||||
).response
|
||||
|
||||
def request_stream(self, system, prompt, timeout=300):
|
||||
"""
|
||||
Streaming request generator.
|
||||
Yields response chunks as they arrive.
|
||||
Usage:
|
||||
for chunk in client.request_stream(system, prompt):
|
||||
print(chunk.response, end='', flush=True)
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
request = TextCompletionRequest(
|
||||
system=system, prompt=prompt, streaming=True
|
||||
)
|
||||
|
||||
end_time = time.time() + timeout
|
||||
self.producer.send(request, properties={"id": id})
|
||||
|
||||
# Collect responses until end_of_stream
|
||||
while time.time() < end_time:
|
||||
try:
|
||||
msg = self.consumer.receive(timeout_millis=2500)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
mid = msg.properties()["id"]
|
||||
|
||||
if mid == id:
|
||||
value = msg.value()
|
||||
|
||||
# Handle errors
|
||||
if value.error:
|
||||
self.consumer.acknowledge(msg)
|
||||
if value.error.type == "llm-error":
|
||||
raise LlmError(value.error.message)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{value.error.type}: {value.error.message}"
|
||||
)
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
yield value
|
||||
|
||||
# Check if this is the final chunk
|
||||
if getattr(value, 'end_of_stream', True):
|
||||
break
|
||||
else:
|
||||
# Ignore messages with wrong ID
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
if time.time() >= end_time:
|
||||
raise TimeoutError("Timed out waiting for response")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,16 +12,18 @@ class AgentRequestTranslator(MessageTranslator):
|
|||
state=data.get("state", None),
|
||||
group=data.get("group", None),
|
||||
history=data.get("history", []),
|
||||
user=data.get("user", "trustgraph")
|
||||
user=data.get("user", "trustgraph"),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"state": obj.state,
|
||||
"group": obj.group,
|
||||
"history": obj.history,
|
||||
"user": obj.user
|
||||
"user": obj.user,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -33,14 +35,36 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if obj.answer:
|
||||
result["answer"] = obj.answer
|
||||
if obj.thought:
|
||||
result["thought"] = obj.thought
|
||||
if obj.observation:
|
||||
result["observation"] = obj.observation
|
||||
|
||||
# Check if this is a streaming response (has chunk_type)
|
||||
if hasattr(obj, 'chunk_type') and obj.chunk_type:
|
||||
result["chunk_type"] = obj.chunk_type
|
||||
if obj.content:
|
||||
result["content"] = obj.content
|
||||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
else:
|
||||
# Legacy format
|
||||
if obj.answer:
|
||||
result["answer"] = obj.answer
|
||||
if obj.thought:
|
||||
result["thought"] = obj.thought
|
||||
if obj.observation:
|
||||
result["observation"] = obj.observation
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "code": obj.error.code}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), (obj.answer is not None)
|
||||
# For streaming responses, check end_of_dialog
|
||||
if hasattr(obj, 'chunk_type') and obj.chunk_type:
|
||||
is_final = getattr(obj, 'end_of_dialog', False)
|
||||
else:
|
||||
# For legacy responses, check if answer is present
|
||||
is_final = (obj.answer is not None)
|
||||
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -16,10 +16,11 @@ class PromptRequestTranslator(MessageTranslator):
|
|||
k: json.dumps(v)
|
||||
for k, v in data["variables"].items()
|
||||
}
|
||||
|
||||
|
||||
return PromptRequest(
|
||||
id=data.get("id"),
|
||||
terms=terms
|
||||
terms=terms,
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]:
|
||||
|
|
@ -51,4 +52,6 @@ class PromptResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
# Check end_of_stream field to determine if this is the final message
|
||||
is_final = getattr(obj, 'end_of_stream', True)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -5,11 +5,12 @@ from .base import MessageTranslator
|
|||
|
||||
class TextCompletionRequestTranslator(MessageTranslator):
|
||||
"""Translator for TextCompletionRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest:
|
||||
return TextCompletionRequest(
|
||||
system=data["system"],
|
||||
prompt=data["prompt"]
|
||||
prompt=data["prompt"],
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]:
|
||||
|
|
@ -39,4 +40,6 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
# Check end_of_stream field to determine if this is the final message
|
||||
is_final = getattr(obj, 'end_of_stream', True)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from pulsar.schema import Record, String, Array, Map
|
||||
from pulsar.schema import Record, String, Array, Map, Boolean
|
||||
|
||||
from ..core.topic import topic
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -21,8 +21,16 @@ class AgentRequest(Record):
|
|||
group = Array(String())
|
||||
history = Array(AgentStep())
|
||||
user = String() # User context for multi-tenancy
|
||||
streaming = Boolean() # NEW: Enable streaming response delivery (default false)
|
||||
|
||||
class AgentResponse(Record):
|
||||
# Streaming-first design
|
||||
chunk_type = String() # "thought", "action", "observation", "answer", "error"
|
||||
content = String() # The actual content (interpretation depends on chunk_type)
|
||||
end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete
|
||||
end_of_dialog = Boolean() # Entire agent dialog is complete
|
||||
|
||||
# Legacy fields (deprecated but kept for backward compatibility)
|
||||
answer = String()
|
||||
error = Error()
|
||||
thought = String()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from pulsar.schema import Record, String, Array, Double, Integer
|
||||
from pulsar.schema import Record, String, Array, Double, Integer, Boolean
|
||||
|
||||
from ..core.topic import topic
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -11,6 +11,7 @@ from ..core.primitives import Error
|
|||
class TextCompletionRequest(Record):
|
||||
system = String()
|
||||
prompt = String()
|
||||
streaming = Boolean() # Default false for backward compatibility
|
||||
|
||||
class TextCompletionResponse(Record):
|
||||
error = Error()
|
||||
|
|
@ -18,6 +19,7 @@ class TextCompletionResponse(Record):
|
|||
in_token = Integer()
|
||||
out_token = Integer()
|
||||
model = String()
|
||||
end_of_stream = Boolean() # Indicates final message in stream
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Map
|
||||
from pulsar.schema import Record, String, Map, Boolean
|
||||
|
||||
from ..core.primitives import Error
|
||||
from ..core.topic import topic
|
||||
|
|
@ -24,6 +24,9 @@ class PromptRequest(Record):
|
|||
# JSON encoded values
|
||||
terms = Map(String())
|
||||
|
||||
# Streaming support (default false for backward compatibility)
|
||||
streaming = Boolean()
|
||||
|
||||
class PromptResponse(Record):
|
||||
|
||||
# Error case
|
||||
|
|
@ -35,4 +38,7 @@ class PromptResponse(Record):
|
|||
# JSON encoded
|
||||
object = String()
|
||||
|
||||
# Indicates final message in stream
|
||||
end_of_stream = Boolean()
|
||||
|
||||
############################################################################
|
||||
Loading…
Add table
Add a link
Reference in a new issue