add streaming

This commit is contained in:
Adil Hafeez 2025-09-17 09:39:10 -07:00
parent 4588787427
commit 08471d8adf
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 274 additions and 59 deletions

View file

@ -29,6 +29,3 @@ jobs:
- name: Run unit tests
run: cargo test --lib
# - name: Run integration tests
# run: cargo test --test integration

View file

@ -1,19 +1,14 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsResponse, Choice};
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::apis::{Role, Usage};
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode, Uri};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -30,7 +25,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
@ -52,7 +47,6 @@ pub async fn agent_chat(
info!("Handling request for listener: {}", listener.name);
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
@ -163,9 +157,12 @@ pub async fn agent_chat(
request_headers.remove(header::CONTENT_LENGTH);
for agent_name in agent_pipeline.filter_chain {
let filter_chain_without_terminal_agent =
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
for agent_name in filter_chain_without_terminal_agent {
debug!("Processing agent: {}", agent_name);
let agent = agent_name_map.get(&agent_name).unwrap();
let agent = agent_name_map.get(agent_name).unwrap();
debug!("Agent details: {:?}", agent);
let mut request = chat_completions_request.clone();
@ -223,41 +220,88 @@ pub async fn agent_chat(
.clone()
.unwrap();
debug!(
"Received response from agent {}",
agent_name
);
debug!("Received response from agent {}", agent_name);
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
}
let last_response: Option<String> = match chat_completions_history.last() {
Some(msg) => Some(msg.content.clone().to_string()),
None => None,
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}", terminal_agent_name);
let mut agent_request_headers = request_headers.clone();
agent_request_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
);
let llm_response = match reqwest::Client::new()
.post("http://localhost:11000/v1/chat/completions")
.headers(agent_request_headers)
.body(request_str)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
hermesllm::apis::openai::ChatCompletionsResponse {
model: "arch-agent".to_string(),
choices: vec![hermesllm::apis::openai::Choice {
message: {
hermesllm::apis::openai::ResponseMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: last_response,
..Default::default()
}
},
..Default::default()
}],
usage: hermesllm::apis::openai::Usage {
..Default::default()
},
..Default::default()
};
// copy over the headers from the original response
let response_headers = llm_response.headers().clone();
let mut response = Response::builder();
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
return Ok(Response::new(full(response_body)));
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
fn convert_agent_description_to_routing_preferences(

View file

@ -26,3 +26,11 @@ class ChatCompletionResponse(BaseModel):
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int
model: str
choices: List[Dict[str, Any]]

View file

@ -1,13 +1,19 @@
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from openai import AsyncOpenAI
import os
import logging
import time
import uuid
import uvicorn
import asyncio
from .api import ChatCompletionRequest, ChatCompletionResponse
from .api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
)
# Set up logging
logging.basicConfig(level=logging.INFO)
@ -17,6 +23,18 @@ logger = logging.getLogger(__name__)
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
RESPONSE_MODEL = "gpt-4o"
# System prompt for response generation
SYSTEM_PROMPT = """You are a helpful assistant that generates coherent, contextual responses.
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
Your response should:
1. Be contextually aware of the entire conversation
2. Address the user's needs appropriately
3. Be helpful and informative
4. Maintain a natural conversational tone
Generate a complete response to assist the user."""
# Initialize OpenAI client for archgw
archgw_client = AsyncOpenAI(
base_url=LLM_GATEWAY_ENDPOINT,
@ -27,6 +45,17 @@ archgw_client = AsyncOpenAI(
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
def prepare_response_messages(request_body: ChatCompletionRequest):
"""Prepare messages for response generation by adding system prompt."""
response_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history
for msg in request_body.messages:
response_messages.append({"role": msg.role, "content": msg.content})
return response_messages
@app.post("/v1/chat/completions")
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
"""Chat completions endpoint that generates a coherent response based on all context."""
@ -41,24 +70,121 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
else:
logger.info("No traceparent header found")
# Prepare the system prompt for response generation
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
# Check if streaming is requested
if request_body.stream:
return StreamingResponse(
stream_chat_completions(request_body, traceparent_header),
media_type="text/plain",
)
else:
return await non_streaming_chat_completions(request_body, traceparent_header)
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
Your response should:
1. Be contextually aware of the entire conversation
2. Address the user's needs appropriately
3. Be helpful and informative
4. Maintain a natural conversational tone
Generate a complete response to assist the user."""
async def stream_chat_completions(
request_body: ChatCompletionRequest, traceparent_header: str = None
):
"""Generate streaming chat completions."""
# Prepare messages for response generation
response_messages = [{"role": "system", "content": system_prompt}]
response_messages = prepare_response_messages(request_body)
# Add conversation history
for msg in request_body.messages:
response_messages.append({"role": msg.role, "content": msg.content})
try:
# Call archgw using OpenAI client for streaming
logger.info(
f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate streaming response"
)
# Prepare extra headers if traceparent is provided
extra_headers = {}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
response_stream = await archgw_client.chat.completions.create(
model=RESPONSE_MODEL,
messages=response_messages,
temperature=request_body.temperature or 0.7,
max_tokens=request_body.max_tokens or 1000,
stream=True,
extra_headers=extra_headers,
)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
created_time = int(time.time())
collected_content = []
async for chunk in response_stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
collected_content.append(content)
# Create streaming response chunk
stream_chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created_time,
model=request_body.model,
choices=[
{
"index": 0,
"delta": {"content": content},
"finish_reason": None,
}
],
)
yield f"data: {stream_chunk.model_dump_json()}\n\n"
# Send final chunk with complete response in expected format
full_response = "".join(collected_content)
updated_history = [{"role": "assistant", "content": full_response}]
final_chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created_time,
model=request_body.model,
choices=[
{
"index": 0,
"delta": {},
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": json.dumps(updated_history),
},
}
],
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Error generating streaming response: {e}")
# Send error as streaming response
error_chunk = ChatCompletionStreamResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request_body.model,
choices=[
{
"index": 0,
"delta": {
"content": "I apologize, but I'm having trouble generating a response right now. Please try again."
},
"finish_reason": "stop",
}
],
)
yield f"data: {error_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
async def non_streaming_chat_completions(
request_body: ChatCompletionRequest, traceparent_header: str = None
):
"""Generate non-streaming chat completions."""
# Prepare messages for response generation
response_messages = prepare_response_messages(request_body)
try:
# Call archgw using OpenAI client
@ -116,7 +242,7 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
return ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
model=request_body.model,
choices=[
{
"index": 0,
@ -126,11 +252,11 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
],
usage={
"prompt_tokens": sum(
len(msg.content.split()) for msg in request.messages
len(msg.content.split()) for msg in request_body.messages
),
"completion_tokens": len(fallback_message.split()),
"total_tokens": sum(
len(msg.content.split()) for msg in request.messages
len(msg.content.split()) for msg in request_body.messages
)
+ len(fallback_message.split()),
},

View file

@ -0,0 +1,40 @@
@baseUrl = http://0.0.0.0:10502
@model = gpt-4o
###
# Health Check
GET {{baseUrl}}/health
###
# Test 1: Simple Non-Streaming Chat Completion
POST {{baseUrl}}/v1/chat/completions
Content-Type: application/json
{
"model": "{{model}}",
"messages": [
{
"role": "user",
"content": "Hello! Can you help me understand what machine learning is?"
}
]
}
###
# Test 2: Simple Streaming Chat Completion
POST {{baseUrl}}/v1/chat/completions
Content-Type: application/json
{
"model": "{{model}}",
"messages": [
{
"role": "user",
"content": "Explain the concept of artificial intelligence in simple terms."
}
],
"stream": true
}