Feature/streaming llm phase 1 (#566)

* Tidy up duplicate tech specs in doc directory

* Streaming LLM text-completion service tech spec.

* text-completion and prompt interfaces

* streaming change applied to all LLMs, so far tested with VertexAI

* Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again

* Added agent streaming, not working and has broken tests
This commit is contained in:
cybermaggedon 2025-11-26 09:59:10 +00:00 committed by GitHub
parent 943a9d83b0
commit 310a2deb06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 2684 additions and 937 deletions

View file

@ -12,7 +12,7 @@ from . parameter_spec import ParameterSpec
from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult
from . llm_service import LlmService, LlmResult, LlmChunk
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec

View file

@ -28,6 +28,19 @@ class LlmResult:
self.model = model
__slots__ = ["text", "in_token", "out_token", "model"]
class LlmChunk:
"""Represents a streaming chunk from an LLM"""
def __init__(
self, text = None, in_token = None, out_token = None,
model = None, is_final = False,
):
self.text = text
self.in_token = in_token
self.out_token = out_token
self.model = model
self.is_final = is_final
__slots__ = ["text", "in_token", "out_token", "model", "is_final"]
class LlmService(FlowProcessor):
def __init__(self, **params):
@ -99,16 +112,57 @@ class LlmService(FlowProcessor):
id = msg.properties()["id"]
with __class__.text_completion_metric.labels(
id=self.id,
flow=f"{flow.name}-{consumer.name}",
).time():
model = flow("model")
temperature = flow("temperature")
model = flow("model")
temperature = flow("temperature")
# Check if streaming is requested and supported
streaming = getattr(request, 'streaming', False)
response = await self.generate_content(
request.system, request.prompt, model, temperature
if streaming and self.supports_streaming():
# Streaming mode
with __class__.text_completion_metric.labels(
id=self.id,
flow=f"{flow.name}-{consumer.name}",
).time():
async for chunk in self.generate_content_stream(
request.system, request.prompt, model, temperature
):
await flow("response").send(
TextCompletionResponse(
error=None,
response=chunk.text,
in_token=chunk.in_token,
out_token=chunk.out_token,
model=chunk.model,
end_of_stream=chunk.is_final
),
properties={"id": id}
)
else:
# Non-streaming mode (original behavior)
with __class__.text_completion_metric.labels(
id=self.id,
flow=f"{flow.name}-{consumer.name}",
).time():
response = await self.generate_content(
request.system, request.prompt, model, temperature
)
await flow("response").send(
TextCompletionResponse(
error=None,
response=response.text,
in_token=response.in_token,
out_token=response.out_token,
model=response.model,
end_of_stream=True
),
properties={"id": id}
)
__class__.text_completion_model_metric.labels(
@ -119,17 +173,6 @@ class LlmService(FlowProcessor):
"temperature": str(temperature) if temperature is not None else "",
})
await flow("response").send(
TextCompletionResponse(
error=None,
response=response.text,
in_token=response.in_token,
out_token=response.out_token,
model=response.model
),
properties={"id": id}
)
except TooManyRequests as e:
raise e
@ -151,10 +194,26 @@ class LlmService(FlowProcessor):
in_token=None,
out_token=None,
model=None,
end_of_stream=True
),
properties={"id": id}
)
def supports_streaming(self):
"""
Override in subclass to indicate streaming support.
Returns False by default.
"""
return False
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
"""
Override in subclass to implement streaming.
Should yield LlmChunk objects.
The final chunk should have is_final=True.
"""
raise NotImplementedError("Streaming not implemented for this provider")
@staticmethod
def add_args(parser):

View file

@ -1,30 +1,75 @@
import json
import asyncio
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
resp = await self.request(
PromptRequest(
id = id,
terms = {
k: json.dumps(v)
for k, v in variables.items()
}
),
timeout=timeout
)
if not streaming:
# Non-streaming path
resp = await self.request(
PromptRequest(
id = id,
terms = {
k: json.dumps(v)
for k, v in variables.items()
},
streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
if resp.text: return resp.text
return json.loads(resp.object)
return json.loads(resp.object)
else:
# Streaming path - collect all chunks
full_text = ""
full_object = None
async def collect_chunks(resp):
nonlocal full_text, full_object
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text:
full_text += resp.text
# Call chunk callback if provided
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text)
else:
chunk_callback(resp.text)
elif resp.object:
full_object = resp.object
return getattr(resp, 'end_of_stream', False)
await self.request(
PromptRequest(
id = id,
terms = {
k: json.dumps(v)
for k, v in variables.items()
},
streaming = True
),
recipient=collect_chunks,
timeout=timeout
)
if full_text: return full_text
return json.loads(full_object)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@ -70,11 +115,13 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def agent_react(self, variables, timeout=600):
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "agent-react",
variables = variables,
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def question(self, question, timeout=600):

View file

@ -3,18 +3,45 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, timeout=600):
resp = await self.request(
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
TextCompletionRequest(
system = system, prompt = prompt
system = system, prompt = prompt, streaming = True
),
recipient=collect_chunks,
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.response
return full_response
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(

View file

@ -5,6 +5,7 @@ from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from . base import BaseClient
from .. exceptions import LlmError
# Ugly
ERROR=_pulsar.LoggerLevel.Error
@ -37,8 +38,68 @@ class LlmClient(BaseClient):
output_schema=TextCompletionResponse,
)
def request(self, system, prompt, timeout=300):
def request(self, system, prompt, timeout=300, streaming=False):
"""
Non-streaming request (backward compatible).
Returns complete response string.
"""
if streaming:
raise ValueError("Use request_stream() for streaming requests")
return self.call(
system=system, prompt=prompt, timeout=timeout
system=system, prompt=prompt, streaming=False, timeout=timeout
).response
def request_stream(self, system, prompt, timeout=300):
"""
Streaming request generator.
Yields response chunks as they arrive.
Usage:
for chunk in client.request_stream(system, prompt):
print(chunk.response, end='', flush=True)
"""
import time
import uuid
id = str(uuid.uuid4())
request = TextCompletionRequest(
system=system, prompt=prompt, streaming=True
)
end_time = time.time() + timeout
self.producer.send(request, properties={"id": id})
# Collect responses until end_of_stream
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=2500)
except Exception:
continue
mid = msg.properties()["id"]
if mid == id:
value = msg.value()
# Handle errors
if value.error:
self.consumer.acknowledge(msg)
if value.error.type == "llm-error":
raise LlmError(value.error.message)
else:
raise RuntimeError(
f"{value.error.type}: {value.error.message}"
)
self.consumer.acknowledge(msg)
yield value
# Check if this is the final chunk
if getattr(value, 'end_of_stream', True):
break
else:
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
if time.time() >= end_time:
raise TimeoutError("Timed out waiting for response")

View file

@ -12,16 +12,18 @@ class AgentRequestTranslator(MessageTranslator):
state=data.get("state", None),
group=data.get("group", None),
history=data.get("history", []),
user=data.get("user", "trustgraph")
user=data.get("user", "trustgraph"),
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
return {
"question": obj.question,
"state": obj.state,
"group": obj.group,
"history": obj.history,
"user": obj.user
"user": obj.user,
"streaming": getattr(obj, "streaming", False)
}
@ -33,14 +35,36 @@ class AgentResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]:
result = {}
if obj.answer:
result["answer"] = obj.answer
if obj.thought:
result["thought"] = obj.thought
if obj.observation:
result["observation"] = obj.observation
# Check if this is a streaming response (has chunk_type)
if hasattr(obj, 'chunk_type') and obj.chunk_type:
result["chunk_type"] = obj.chunk_type
if obj.content:
result["content"] = obj.content
result["end_of_message"] = getattr(obj, "end_of_message", False)
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
else:
# Legacy format
if obj.answer:
result["answer"] = obj.answer
if obj.thought:
result["thought"] = obj.thought
if obj.observation:
result["observation"] = obj.observation
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}
return result
def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), (obj.answer is not None)
# For streaming responses, check end_of_dialog
if hasattr(obj, 'chunk_type') and obj.chunk_type:
is_final = getattr(obj, 'end_of_dialog', False)
else:
# For legacy responses, check if answer is present
is_final = (obj.answer is not None)
return self.from_pulsar(obj), is_final

View file

@ -16,10 +16,11 @@ class PromptRequestTranslator(MessageTranslator):
k: json.dumps(v)
for k, v in data["variables"].items()
}
return PromptRequest(
id=data.get("id"),
terms=terms
terms=terms,
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]:
@ -51,4 +52,6 @@ class PromptResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# Check end_of_stream field to determine if this is the final message
is_final = getattr(obj, 'end_of_stream', True)
return self.from_pulsar(obj), is_final

View file

@ -5,11 +5,12 @@ from .base import MessageTranslator
class TextCompletionRequestTranslator(MessageTranslator):
"""Translator for TextCompletionRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest:
return TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
prompt=data["prompt"],
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]:
@ -39,4 +40,6 @@ class TextCompletionResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# Check end_of_stream field to determine if this is the final message
is_final = getattr(obj, 'end_of_stream', True)
return self.from_pulsar(obj), is_final

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, String, Array, Map
from pulsar.schema import Record, String, Array, Map, Boolean
from ..core.topic import topic
from ..core.primitives import Error
@ -21,8 +21,16 @@ class AgentRequest(Record):
group = Array(String())
history = Array(AgentStep())
user = String() # User context for multi-tenancy
streaming = Boolean() # NEW: Enable streaming response delivery (default false)
class AgentResponse(Record):
# Streaming-first design
chunk_type = String() # "thought", "action", "observation", "answer", "error"
content = String() # The actual content (interpretation depends on chunk_type)
end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete
end_of_dialog = Boolean() # Entire agent dialog is complete
# Legacy fields (deprecated but kept for backward compatibility)
answer = String()
error = Error()
thought = String()

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, String, Array, Double, Integer
from pulsar.schema import Record, String, Array, Double, Integer, Boolean
from ..core.topic import topic
from ..core.primitives import Error
@ -11,6 +11,7 @@ from ..core.primitives import Error
class TextCompletionRequest(Record):
system = String()
prompt = String()
streaming = Boolean() # Default false for backward compatibility
class TextCompletionResponse(Record):
error = Error()
@ -18,6 +19,7 @@ class TextCompletionResponse(Record):
in_token = Integer()
out_token = Integer()
model = String()
end_of_stream = Boolean() # Indicates final message in stream
############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Map
from pulsar.schema import Record, String, Map, Boolean
from ..core.primitives import Error
from ..core.topic import topic
@ -24,6 +24,9 @@ class PromptRequest(Record):
# JSON encoded values
terms = Map(String())
# Streaming support (default false for backward compatibility)
streaming = Boolean()
class PromptResponse(Record):
# Error case
@ -35,4 +38,7 @@ class PromptResponse(Record):
# JSON encoded
object = String()
# Indicates final message in stream
end_of_stream = Boolean()
############################################################################