Add support for json based content types in Message (#480)

This commit is contained in:
Adil Hafeez 2025-05-23 00:51:53 -07:00 committed by GitHub
parent f5e77bbe65
commit 218e9c540d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 314 additions and 121 deletions

View file

@ -9,10 +9,11 @@ use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn};
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
@ -30,19 +31,23 @@ pub async fn chat_completions(
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(request) => request,
Err(err) => {
let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap();
let err_msg = format!("Failed to parse request body: {}", err);
warn!("{}", err_msg);
warn!("request body: {}", v.to_string());
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
info!(
"request body received: {}",
debug!(
"request body: {}",
shorten_string(&serde_json::to_string(&chat_completion_request).unwrap())
);

View file

@ -2,38 +2,27 @@ use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::router::llm_router::RouterService;
use bytes::Bytes;
use common::configuration::Configuration;
use common::utils::shorten_string;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use opentelemetry::global::BoxedTracer;
use opentelemetry::trace::FutureExt;
use opentelemetry::{
global,
trace::{SpanKind, Tracer},
Context,
};
use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
use tracing::info;
use tracing::{debug, info};
use tracing_subscriber::EnvFilter;
pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
fn get_tracer() -> &'static BoxedTracer {
static TRACER: OnceLock<BoxedTracer> = OnceLock::new();
TRACER.get_or_init(|| global::tracer("archgw/router"))
}
// Utility function to extract the context from the incoming request headers
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
global::get_text_map_propagator(|propagator| {
@ -83,24 +72,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
info!(
debug!(
"arch_config: {:?}",
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
info!("llm provider endpoint: {}", llm_provider_endpoint);
info!("Listening on http://{}", bind_address);
info!("listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
// if routing is null then return gpt-4o as model name
let model = arch_config.routing.as_ref().map_or_else(
|| "gpt-4o".to_string(),
|routing| routing.model.clone(),
);
//TODO: fail if routing is null
let model = arch_config
.routing
.as_ref()
.map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
@ -119,12 +108,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
info!("parent_cx: {:?}", parent_cx);
let tracer = get_tracer();
let _span = tracer
.span_builder("request")
.with_kind(SpanKind::Server)
.start_with_context(tracer, &parent_cx);
let llm_provider_endpoint = llm_provider_endpoint.clone();
async move {

View file

@ -1,14 +1,13 @@
use std::sync::Arc;
use common::{
api::open_ai::{ChatCompletionsResponse, Message},
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
configuration::LlmProvider,
consts::ARCH_PROVIDER_HINT_HEADER,
utils::shorten_string,
};
use hyper::header;
use thiserror::Error;
use tracing::{info, warn};
use tracing::{debug, info, warn};
use super::router_model::RouterModel;
@ -59,9 +58,9 @@ impl RouterService {
.collect::<Vec<String>>()
.join("\n");
info!(
debug!(
"llm_providers from config with usage: {}...",
shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n"))
llm_providers_with_usage_yaml.replace("\n", "\\n")
);
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new(
@ -83,7 +82,6 @@ impl RouterService {
messages: &[Message],
trace_parent: Option<String>,
) -> Result<Option<String>> {
if !self.llm_usage_defined {
return Ok(None);
}
@ -91,8 +89,14 @@ impl RouterService {
let router_request = self.router_model.generate_request(messages);
info!(
"router_request: {}",
shorten_string(&serde_json::to_string(&router_request).unwrap()),
"sending request to arch-router model: {}, endpoint: {}",
self.router_model.get_model_name(),
self.router_url
);
debug!(
"arch request body: {}",
&serde_json::to_string(&router_request).unwrap(),
);
let mut llm_route_request_headers = header::HeaderMap::new();
@ -113,6 +117,7 @@ impl RouterService {
);
}
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.router_url)
@ -122,6 +127,7 @@ impl RouterService {
.await?;
let body = res.text().await?;
let router_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
@ -138,14 +144,18 @@ impl RouterService {
}
};
let selected_llm = self.router_model.parse_response(
chat_completion_response.choices[0]
.message
.content
.as_ref()
.unwrap(),
)?;
Ok(selected_llm)
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
info!(
"router response: {}, response time: {}ms",
content.replace("\n", "\\n"),
router_response_time.as_millis()
);
let selected_llm = self.router_model.parse_response(content)?;
Ok(selected_llm)
} else {
Ok(None)
}
}
}

View file

@ -12,4 +12,5 @@ pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,9 +1,8 @@
use common::{
api::open_ai::{ChatCompletionsRequest, Message},
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
consts::{SYSTEM_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::info;
use super::router_model::{RouterModel, RoutingModelError};
@ -68,7 +67,7 @@ impl RouterModel for RouterModelV1 {
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(message),
content: Some(ContentType::Text(message)),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
@ -86,10 +85,6 @@ impl RouterModel for RouterModelV1 {
return Ok(None);
}
let router_resp_fixed = fix_json_response(content);
info!(
"router response (fixed): {}",
router_resp_fixed.replace("\n", "\\n")
);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selected_llm = router_response.route.unwrap_or_default().to_string();
@ -100,6 +95,10 @@ impl RouterModel for RouterModelV1 {
Ok(Some(selected_llm))
}
fn get_model_name(&self) -> String {
self.routing_model.clone()
}
}
fn fix_json_response(body: &str) -> String {
@ -172,22 +171,28 @@ user: "seattle"
let messages = vec![
Message {
role: "system".to_string(),
content: Some("You are a helpful assistant.".to_string()),
content: Some(ContentType::Text(
"You are a helpful assistant.".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("Hello, I want to book a flight.".to_string()),
content: Some(ContentType::Text(
"Hello, I want to book a flight.".to_string(),
)),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some("Sure, where would you like to go?".to_string()),
content: Some(ContentType::Text(
"Sure, where would you like to go?".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("seattle".to_string()),
content: Some(ContentType::Text("seattle".to_string())),
..Default::default()
},
];
@ -198,7 +203,7 @@ user: "seattle"
println!("Prompt: {}", prompt);
assert_eq!(expected_prompt, prompt);
assert_eq!(expected_prompt, prompt.to_string());
}
}