mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add streaming
This commit is contained in:
parent
4588787427
commit
08471d8adf
5 changed files with 274 additions and 59 deletions
3
.github/workflows/rust_tests.yml
vendored
3
.github/workflows/rust_tests.yml
vendored
|
|
@ -29,6 +29,3 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: cargo test --lib
|
||||
|
||||
# - name: Run integration tests
|
||||
# run: cargo test --test integration
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice};
|
||||
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
|
||||
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::apis::{Role, Usage};
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode, Uri};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
|
@ -30,7 +25,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
|||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -52,7 +47,6 @@ pub async fn agent_chat(
|
|||
|
||||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -163,9 +157,12 @@ pub async fn agent_chat(
|
|||
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
for agent_name in agent_pipeline.filter_chain {
|
||||
let filter_chain_without_terminal_agent =
|
||||
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
|
||||
|
||||
for agent_name in filter_chain_without_terminal_agent {
|
||||
debug!("Processing agent: {}", agent_name);
|
||||
let agent = agent_name_map.get(&agent_name).unwrap();
|
||||
let agent = agent_name_map.get(agent_name).unwrap();
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
|
|
@ -223,41 +220,88 @@ pub async fn agent_chat(
|
|||
.clone()
|
||||
.unwrap();
|
||||
|
||||
debug!(
|
||||
"Received response from agent {}",
|
||||
agent_name
|
||||
);
|
||||
debug!("Received response from agent {}", agent_name);
|
||||
|
||||
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
|
||||
}
|
||||
|
||||
let last_response: Option<String> = match chat_completions_history.last() {
|
||||
Some(msg) => Some(msg.content.clone().to_string()),
|
||||
None => None,
|
||||
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
|
||||
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}", terminal_agent_name);
|
||||
|
||||
let mut agent_request_headers = request_headers.clone();
|
||||
agent_request_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post("http://localhost:11000/v1/chat/completions")
|
||||
.headers(agent_request_headers)
|
||||
.body(request_str)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse =
|
||||
hermesllm::apis::openai::ChatCompletionsResponse {
|
||||
model: "arch-agent".to_string(),
|
||||
choices: vec![hermesllm::apis::openai::Choice {
|
||||
message: {
|
||||
hermesllm::apis::openai::ResponseMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: last_response,
|
||||
..Default::default()
|
||||
}
|
||||
},
|
||||
..Default::default()
|
||||
}],
|
||||
usage: hermesllm::apis::openai::Usage {
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
// copy over the headers from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let mut response = Response::builder();
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
|
||||
// channel to create async stream
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
return Ok(Response::new(full(response_body)));
|
||||
// Spawn a task to send data as it becomes available
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let item = match item {
|
||||
Ok(item) => item,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(item).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
match response.body(stream_body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
|
|
|
|||
|
|
@ -26,3 +26,11 @@ class ChatCompletionResponse(BaseModel):
|
|||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
usage: Dict[str, int]
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
import json
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import uvicorn
|
||||
import asyncio
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
@ -17,6 +23,18 @@ logger = logging.getLogger(__name__)
|
|||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
|
||||
RESPONSE_MODEL = "gpt-4o"
|
||||
|
||||
# System prompt for response generation
|
||||
SYSTEM_PROMPT = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
|
||||
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||
Your response should:
|
||||
1. Be contextually aware of the entire conversation
|
||||
2. Address the user's needs appropriately
|
||||
3. Be helpful and informative
|
||||
4. Maintain a natural conversational tone
|
||||
|
||||
Generate a complete response to assist the user."""
|
||||
|
||||
# Initialize OpenAI client for archgw
|
||||
archgw_client = AsyncOpenAI(
|
||||
base_url=LLM_GATEWAY_ENDPOINT,
|
||||
|
|
@ -27,6 +45,17 @@ archgw_client = AsyncOpenAI(
|
|||
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
|
||||
|
||||
|
||||
def prepare_response_messages(request_body: ChatCompletionRequest):
|
||||
"""Prepare messages for response generation by adding system prompt."""
|
||||
response_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
# Add conversation history
|
||||
for msg in request_body.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
return response_messages
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
|
||||
"""Chat completions endpoint that generates a coherent response based on all context."""
|
||||
|
|
@ -41,24 +70,121 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Prepare the system prompt for response generation
|
||||
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
# Check if streaming is requested
|
||||
if request_body.stream:
|
||||
return StreamingResponse(
|
||||
stream_chat_completions(request_body, traceparent_header),
|
||||
media_type="text/plain",
|
||||
)
|
||||
else:
|
||||
return await non_streaming_chat_completions(request_body, traceparent_header)
|
||||
|
||||
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||
Your response should:
|
||||
1. Be contextually aware of the entire conversation
|
||||
2. Address the user's needs appropriately
|
||||
3. Be helpful and informative
|
||||
4. Maintain a natural conversational tone
|
||||
|
||||
Generate a complete response to assist the user."""
|
||||
|
||||
async def stream_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
):
|
||||
"""Generate streaming chat completions."""
|
||||
# Prepare messages for response generation
|
||||
response_messages = [{"role": "system", "content": system_prompt}]
|
||||
response_messages = prepare_response_messages(request_body)
|
||||
|
||||
# Add conversation history
|
||||
for msg in request_body.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
try:
|
||||
# Call archgw using OpenAI client for streaming
|
||||
logger.info(
|
||||
f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate streaming response"
|
||||
)
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
|
||||
response_stream = await archgw_client.chat.completions.create(
|
||||
model=RESPONSE_MODEL,
|
||||
messages=response_messages,
|
||||
temperature=request_body.temperature or 0.7,
|
||||
max_tokens=request_body.max_tokens or 1000,
|
||||
stream=True,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
created_time = int(time.time())
|
||||
collected_content = []
|
||||
|
||||
async for chunk in response_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
collected_content.append(content)
|
||||
|
||||
# Create streaming response chunk
|
||||
stream_chunk = ChatCompletionStreamResponse(
|
||||
id=completion_id,
|
||||
created=created_time,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": content},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {stream_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Send final chunk with complete response in expected format
|
||||
full_response = "".join(collected_content)
|
||||
updated_history = [{"role": "assistant", "content": full_response}]
|
||||
|
||||
final_chunk = ChatCompletionStreamResponse(
|
||||
id=completion_id,
|
||||
created=created_time,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": json.dumps(updated_history),
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating streaming response: {e}")
|
||||
|
||||
# Send error as streaming response
|
||||
error_chunk = ChatCompletionStreamResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "I apologize, but I'm having trouble generating a response right now. Please try again."
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {error_chunk.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def non_streaming_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
):
|
||||
"""Generate non-streaming chat completions."""
|
||||
# Prepare messages for response generation
|
||||
response_messages = prepare_response_messages(request_body)
|
||||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
|
|
@ -116,7 +242,7 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
|
|
@ -126,11 +252,11 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
),
|
||||
"completion_tokens": len(fallback_message.split()),
|
||||
"total_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
)
|
||||
+ len(fallback_message.split()),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
@baseUrl = http://0.0.0.0:10502
|
||||
@model = gpt-4o
|
||||
|
||||
###
|
||||
|
||||
# Health Check
|
||||
GET {{baseUrl}}/health
|
||||
|
||||
###
|
||||
|
||||
# Test 1: Simple Non-Streaming Chat Completion
|
||||
POST {{baseUrl}}/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "{{model}}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! Can you help me understand what machine learning is?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
|
||||
# Test 2: Simple Streaming Chat Completion
|
||||
POST {{baseUrl}}/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "{{model}}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Explain the concept of artificial intelligence in simple terms."
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue