mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Add support for json based content types in Message (#480)
This commit is contained in:
parent
f5e77bbe65
commit
218e9c540d
16 changed files with 314 additions and 121 deletions
|
|
@ -80,7 +80,17 @@ def validate_and_render_schema():
|
|||
llms_with_endpoint = []
|
||||
|
||||
updated_llm_providers = []
|
||||
llm_provider_name_set = set()
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
if llm_provider.get("name") in llm_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
)
|
||||
if llm_provider.get("name") is None:
|
||||
raise Exception(
|
||||
f"llm_provider name is required, please provide name for llm_provider"
|
||||
)
|
||||
llm_provider_name_set.add(llm_provider.get("name"))
|
||||
provider = None
|
||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
||||
raise Exception(
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ def docker_start_archgw_detached(
|
|||
|
||||
volume_mappings = [
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ use http_body_util::{BodyExt, Full, StreamBody};
|
|||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
|
|
@ -30,19 +31,23 @@ pub async fn chat_completions(
|
|||
let mut request_headers = request.headers().clone();
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&chat_request_bytes) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap();
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
warn!("{}", err_msg);
|
||||
warn!("request body: {}", v.to_string());
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"request body received: {}",
|
||||
debug!(
|
||||
"request body: {}",
|
||||
shorten_string(&serde_json::to_string(&chat_completion_request).unwrap())
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,38 +2,27 @@ use brightstaff::handlers::chat_completions::chat_completions;
|
|||
use brightstaff::router::llm_router::RouterService;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::utils::shorten_string;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use opentelemetry::global::BoxedTracer;
|
||||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{
|
||||
global,
|
||||
trace::{SpanKind, Tracer},
|
||||
Context,
|
||||
};
|
||||
use opentelemetry::{global, Context};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
|
||||
use opentelemetry_stdout::SpanExporter;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
use tracing::{debug, info};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub mod router;
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
|
||||
fn get_tracer() -> &'static BoxedTracer {
|
||||
static TRACER: OnceLock<BoxedTracer> = OnceLock::new();
|
||||
TRACER.get_or_init(|| global::tracer("archgw/router"))
|
||||
}
|
||||
|
||||
// Utility function to extract the context from the incoming request headers
|
||||
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
|
|
@ -83,24 +72,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
info!(
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
|
||||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
|
||||
|
||||
info!("llm provider endpoint: {}", llm_provider_endpoint);
|
||||
info!("Listening on http://{}", bind_address);
|
||||
info!("listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
|
||||
// if routing is null then return gpt-4o as model name
|
||||
let model = arch_config.routing.as_ref().map_or_else(
|
||||
|| "gpt-4o".to_string(),
|
||||
|routing| routing.model.clone(),
|
||||
);
|
||||
//TODO: fail if routing is null
|
||||
let model = arch_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
|
|
@ -119,12 +108,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
info!("parent_cx: {:?}", parent_cx);
|
||||
let tracer = get_tracer();
|
||||
let _span = tracer
|
||||
.span_builder("request")
|
||||
.with_kind(SpanKind::Server)
|
||||
.start_with_context(tracer, &parent_cx);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
async move {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
api::open_ai::{ChatCompletionsResponse, Message},
|
||||
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
|
||||
configuration::LlmProvider,
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
utils::shorten_string,
|
||||
};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
|
|
@ -59,9 +58,9 @@ impl RouterService {
|
|||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
info!(
|
||||
debug!(
|
||||
"llm_providers from config with usage: {}...",
|
||||
shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n"))
|
||||
llm_providers_with_usage_yaml.replace("\n", "\\n")
|
||||
);
|
||||
|
||||
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new(
|
||||
|
|
@ -83,7 +82,6 @@ impl RouterService {
|
|||
messages: &[Message],
|
||||
trace_parent: Option<String>,
|
||||
) -> Result<Option<String>> {
|
||||
|
||||
if !self.llm_usage_defined {
|
||||
return Ok(None);
|
||||
}
|
||||
|
|
@ -91,8 +89,14 @@ impl RouterService {
|
|||
let router_request = self.router_model.generate_request(messages);
|
||||
|
||||
info!(
|
||||
"router_request: {}",
|
||||
shorten_string(&serde_json::to_string(&router_request).unwrap()),
|
||||
"sending request to arch-router model: {}, endpoint: {}",
|
||||
self.router_model.get_model_name(),
|
||||
self.router_url
|
||||
);
|
||||
|
||||
debug!(
|
||||
"arch request body: {}",
|
||||
&serde_json::to_string(&router_request).unwrap(),
|
||||
);
|
||||
|
||||
let mut llm_route_request_headers = header::HeaderMap::new();
|
||||
|
|
@ -113,6 +117,7 @@ impl RouterService {
|
|||
);
|
||||
}
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.router_url)
|
||||
|
|
@ -122,6 +127,7 @@ impl RouterService {
|
|||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let router_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
|
|
@ -138,14 +144,18 @@ impl RouterService {
|
|||
}
|
||||
};
|
||||
|
||||
let selected_llm = self.router_model.parse_response(
|
||||
chat_completion_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
)?;
|
||||
|
||||
Ok(selected_llm)
|
||||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
info!(
|
||||
"router response: {}, response time: {}ms",
|
||||
content.replace("\n", "\\n"),
|
||||
router_response_time.as_millis()
|
||||
);
|
||||
let selected_llm = self.router_model.parse_response(content)?;
|
||||
Ok(selected_llm)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,4 +12,5 @@ pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
|||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
|
||||
fn parse_response(&self, content: &str) -> Result<Option<String>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, Message},
|
||||
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
|
||||
consts::{SYSTEM_ROLE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use super::router_model::{RouterModel, RoutingModelError};
|
||||
|
||||
|
|
@ -68,7 +67,7 @@ impl RouterModel for RouterModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(message),
|
||||
content: Some(ContentType::Text(message)),
|
||||
role: USER_ROLE.to_string(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -86,10 +85,6 @@ impl RouterModel for RouterModelV1 {
|
|||
return Ok(None);
|
||||
}
|
||||
let router_resp_fixed = fix_json_response(content);
|
||||
info!(
|
||||
"router response (fixed): {}",
|
||||
router_resp_fixed.replace("\n", "\\n")
|
||||
);
|
||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||
|
||||
let selected_llm = router_response.route.unwrap_or_default().to_string();
|
||||
|
|
@ -100,6 +95,10 @@ impl RouterModel for RouterModelV1 {
|
|||
|
||||
Ok(Some(selected_llm))
|
||||
}
|
||||
|
||||
fn get_model_name(&self) -> String {
|
||||
self.routing_model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
|
|
@ -172,22 +171,28 @@ user: "seattle"
|
|||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some("You are a helpful assistant.".to_string()),
|
||||
content: Some(ContentType::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hello, I want to book a flight.".to_string()),
|
||||
content: Some(ContentType::Text(
|
||||
"Hello, I want to book a flight.".to_string(),
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Sure, where would you like to go?".to_string()),
|
||||
content: Some(ContentType::Text(
|
||||
"Sure, where would you like to go?".to_string(),
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("seattle".to_string()),
|
||||
content: Some(ContentType::Text("seattle".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
];
|
||||
|
|
@ -198,7 +203,7 @@ user: "seattle"
|
|||
|
||||
println!("Prompt: {}", prompt);
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ use crate::{
|
|||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::open_ai::ContentType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
|
|
@ -21,7 +23,7 @@ pub struct HallucinationClassificationResponse {
|
|||
|
||||
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
||||
let mut arch_assistant = false;
|
||||
let mut user_messages = Vec::new();
|
||||
let mut user_messages: Vec<String> = Vec::new();
|
||||
if messages.len() >= 2 {
|
||||
let latest_assistant_message = &messages[messages.len() - 2];
|
||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||
|
|
@ -35,7 +37,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
|||
if let Some(model) = message.model.as_ref() {
|
||||
if !model.starts_with(ARCH_MODEL_PREFIX) {
|
||||
if let Some(content) = &message.content {
|
||||
if !content.starts_with(HALLUCINATION_TEMPLATE) {
|
||||
if !content.to_string().starts_with(HALLUCINATION_TEMPLATE) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -43,13 +45,13 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
|||
}
|
||||
if message.role == USER_ROLE {
|
||||
if let Some(content) = &message.content {
|
||||
user_messages.push(content.clone());
|
||||
user_messages.push(content.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(message) = messages.last() {
|
||||
if let Some(content) = &message.content {
|
||||
user_messages.push(content.clone());
|
||||
user_messages.push(content.to_string());
|
||||
}
|
||||
}
|
||||
user_messages.reverse(); // Reverse to maintain the original order
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use core::panic;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
|
|
@ -154,12 +155,54 @@ pub struct StreamOptions {
|
|||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MultiPartContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MultiPartContent {
|
||||
pub text: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: MultiPartContentType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentType {
|
||||
Text(String),
|
||||
MultiPart(Vec<MultiPartContent>),
|
||||
}
|
||||
|
||||
impl Display for ContentType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentType::Text(text) => write!(f, "{}", text),
|
||||
ContentType::MultiPart(multi_part) => {
|
||||
let text_parts: Vec<String> = multi_part
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
if part.content_type == MultiPartContentType::Text {
|
||||
part.text.clone()
|
||||
} else {
|
||||
panic!("Unsupported content type: {:?}", part.content_type);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
pub content: Option<ContentType>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
|
@ -237,7 +280,7 @@ impl ChatCompletionsResponse {
|
|||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
content: Some(ContentType::Text(message)),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -379,6 +422,8 @@ pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType};
|
||||
|
||||
use super::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -448,7 +493,9 @@ mod test {
|
|||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
content: Some(ContentType::Text(
|
||||
"What city do you want to know the weather for?".to_string(),
|
||||
)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -679,6 +726,111 @@ data: [DONE]
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(
|
||||
chat_completions_request.messages[0].content,
|
||||
Some(ContentType::Text(
|
||||
"What city do you want to know the weather for?".to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type_array() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hello world"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[1].text,
|
||||
Some("hello world".to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_claude() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::metrics::Metrics;
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
Message, StreamOptions,
|
||||
ContentType, Message, StreamOptions,
|
||||
};
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::consts::{
|
||||
|
|
@ -369,7 +369,12 @@ impl HttpContext for StreamContext {
|
|||
.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " " + m.content.as_ref().unwrap_or(&String::new())
|
||||
acc + " "
|
||||
+ m.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text(String::new()))
|
||||
.to_string()
|
||||
.as_str()
|
||||
});
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
|
||||
|
|
|
|||
|
|
@ -237,22 +237,32 @@ impl HttpContext for StreamContext {
|
|||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
|
||||
upstream_cluster_path: Some("/function_calling".to_string()),
|
||||
};
|
||||
if let Some(content) =
|
||||
self.user_prompt.as_ref().unwrap().content.as_ref()
|
||||
{
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: Some(content.to_string()),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
|
||||
upstream_cluster_path: Some("/function_calling".to_string()),
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
warn!("http_call failed: {:?}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
warn!("http_call failed: {:?}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
} else {
|
||||
warn!("No content in the last user prompt");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError("No content in the last user prompt".to_string()),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
Action::Pause
|
||||
|
||||
}
|
||||
|
||||
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use crate::metrics::Metrics;
|
|||
use crate::tools::compute_request_path_body;
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
ChatCompletionsResponse, ContentType, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
|
|
@ -215,7 +215,7 @@ impl StreamContext {
|
|||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
content: Some(ContentType::Text(system_prompt.clone())),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -279,6 +279,13 @@ impl StreamContext {
|
|||
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
let content = model_server_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
self.arch_fc_response.clone(),
|
||||
|
|
@ -287,14 +294,7 @@ impl StreamContext {
|
|||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
),
|
||||
Some(content.to_string()),
|
||||
None,
|
||||
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
|
||||
None,
|
||||
|
|
@ -542,7 +542,7 @@ impl StreamContext {
|
|||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(final_prompt),
|
||||
content: Some(ContentType::Text(final_prompt)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -612,7 +612,7 @@ impl StreamContext {
|
|||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
content: Some(ContentType::Text(system_prompt.unwrap())),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -639,7 +639,9 @@ impl StreamContext {
|
|||
} else {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: self.arch_fc_response.as_ref().cloned(),
|
||||
content: Some(ContentType::Text(
|
||||
self.arch_fc_response.as_ref().unwrap().clone(),
|
||||
)),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -650,7 +652,9 @@ impl StreamContext {
|
|||
pub fn generate_api_response_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
content: Some(ContentType::Text(
|
||||
self.tool_call_response.as_ref().unwrap().clone(),
|
||||
)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
|
|
@ -688,7 +692,14 @@ impl StreamContext {
|
|||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
chat_completion_response.choices[0].message.content.clone(),
|
||||
Some(
|
||||
chat_completion_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
),
|
||||
None,
|
||||
Some(chat_completion_response.model.clone()),
|
||||
None,
|
||||
|
|
@ -727,7 +738,7 @@ impl StreamContext {
|
|||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
content: Some(ContentType::Text(system_prompt.clone())),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -748,7 +759,7 @@ impl StreamContext {
|
|||
let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp);
|
||||
messages.push(Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
content: Some(ContentType::Text(message)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -781,7 +792,7 @@ fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool
|
|||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref());
|
||||
|
||||
let content_has_value = content.is_some() && !content.unwrap().is_empty();
|
||||
let content_has_value = content.is_some() && !content.unwrap().to_string().is_empty();
|
||||
|
||||
let tool_calls = model_server_response
|
||||
.choices
|
||||
|
|
@ -807,7 +818,7 @@ impl Client for StreamContext {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice, ContentType, Message, ToolCall};
|
||||
|
||||
use crate::stream_context::check_intent_matched;
|
||||
|
||||
|
|
@ -816,7 +827,7 @@ mod test {
|
|||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
content: Some(ContentType::Text("".to_string())),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
|
|
@ -835,7 +846,7 @@ mod test {
|
|||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("hello".to_string()),
|
||||
content: Some(ContentType::Text("hello".to_string())),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
|
|
@ -854,7 +865,7 @@ mod test {
|
|||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
content: Some(ContentType::Text("".to_string())),
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "1".to_string(),
|
||||
function: common::api::open_ai::FunctionCallDetail {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
|
||||
ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, Usage
|
||||
};
|
||||
use common::configuration::Configuration;
|
||||
use http::StatusCode;
|
||||
|
|
@ -431,7 +431,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
index: Some(0),
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("hello from fake llm gateway".to_string()),
|
||||
content: Some(ContentType::Text("hello from fake llm gateway".to_string())),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -14,26 +14,26 @@ llm_providers:
|
|||
|
||||
- name: archgw-v1-router-model
|
||||
provider_interface: openai
|
||||
model: cotran2/llama-1b-4-26
|
||||
base_url: http://35.192.87.187:8000/v1
|
||||
|
||||
- name: gpt-4o-mini
|
||||
provider_interface: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
model: cotran2/llama-4-epoch
|
||||
base_url: http://34.46.85.85:8000/v1
|
||||
|
||||
- name: gpt-4o
|
||||
provider_interface: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
usage: Generating original content such as scripts, articles, or creative materials.
|
||||
default: true
|
||||
|
||||
- name: o4-mini
|
||||
- name: code_generation
|
||||
provider_interface: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: o4-mini
|
||||
usage: Requesting topic ideas specifically related to personal finance and budgeting.
|
||||
model: gpt-4o
|
||||
usage: Generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
- name: code_understanding
|
||||
provider_interface: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4.1
|
||||
usage: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ Content-Type: application/json
|
|||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" == "application/json"
|
||||
jsonpath "$.model" matches /^o4-mini/
|
||||
jsonpath "$.model" matches /^gpt-4o/
|
||||
jsonpath "$.usage" != null
|
||||
jsonpath "$.choices[0].message.content" != null
|
||||
jsonpath "$.choices[0].message.role" == "assistant"
|
||||
|
|
|
|||
|
|
@ -13,4 +13,4 @@ Content-Type: application/json
|
|||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" matches /text\/event-stream/
|
||||
body matches /^data: .*?o4-mini.*?\n/
|
||||
body matches /^data: .*?gpt-4o.*?\n/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue