mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
add support for openwebui (#487)
This commit is contained in:
parent
4899117876
commit
9c4733590f
8 changed files with 150 additions and 24 deletions
|
|
@ -52,6 +52,7 @@ def docker_start_archgw_detached(
|
|||
port_mappings = [
|
||||
f"{prompt_gateway_port}:{prompt_gateway_port}",
|
||||
f"{llm_gateway_port}:{llm_gateway_port}",
|
||||
f"{llm_gateway_port+1}:{llm_gateway_port+1}",
|
||||
"19901:9901",
|
||||
]
|
||||
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use std::sync::Arc;
|
|||
use bytes::Bytes;
|
||||
use common::api::open_ai::ChatCompletionsRequest;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use common::utils::shorten_string;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -39,7 +38,7 @@ pub async fn chat_completions(
|
|||
let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap();
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
warn!("{}", err_msg);
|
||||
warn!("request body: {}", v.to_string());
|
||||
warn!("arch-router request body: {}", v.to_string());
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
|
|
@ -47,8 +46,8 @@ pub async fn chat_completions(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"request body: {}",
|
||||
shorten_string(&serde_json::to_string(&chat_completion_request).unwrap())
|
||||
"arch-router request body: {}",
|
||||
&serde_json::to_string(&chat_completion_request).unwrap()
|
||||
);
|
||||
|
||||
let trace_parent = request_headers
|
||||
|
|
@ -56,7 +55,7 @@ pub async fn chat_completions(
|
|||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let selected_llm = match router_service
|
||||
let mut selected_llm = match router_service
|
||||
.determine_route(&chat_completion_request.messages, trace_parent.clone())
|
||||
.await
|
||||
{
|
||||
|
|
@ -69,6 +68,11 @@ pub async fn chat_completions(
|
|||
}
|
||||
};
|
||||
|
||||
if selected_llm.is_none() {
|
||||
debug!("No LLM model selected, using default from request");
|
||||
selected_llm = Some(chat_completion_request.model.clone());
|
||||
}
|
||||
|
||||
info!(
|
||||
"sending request to llm provider: {} with llm model: {:?}",
|
||||
llm_provider_endpoint, selected_llm
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
|
|
|
|||
40
crates/brightstaff/src/handlers/models.rs
Normal file
40
crates/brightstaff/src/handlers/models.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
use bytes::Bytes;
|
||||
use common::api::open_ai::Models;
|
||||
use common::configuration::LlmProvider;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<Vec<LlmProvider>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.clone();
|
||||
let providers = (*prov).clone();
|
||||
let openai_models = Models::from(providers);
|
||||
|
||||
match serde_json::to_string(&openai_models) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Err(_) => {
|
||||
let body = Full::new(Bytes::from_static(
|
||||
b"{\"error\":\"Failed to serialize models\"}",
|
||||
))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
use brightstaff::handlers::chat_completions::chat_completions;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
|
|
@ -52,6 +53,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
let llm_providers = Arc::new(arch_config.llm_providers.clone());
|
||||
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
|
|
@ -84,10 +87,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let router_service = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
|
|
@ -96,6 +101,35 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models") => {
|
||||
Ok(list_models(llm_providers).await)
|
||||
}
|
||||
(&Method::OPTIONS, "/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
response.headers_mut().insert(
|
||||
"Allow",
|
||||
"GET, OPTIONS".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
"*".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Authorization, Content-Type".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, OPTIONS".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Content-Type",
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
_ => {
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use crate::{
|
||||
configuration::LlmProvider,
|
||||
consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE},
|
||||
};
|
||||
use core::{panic, str};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use core::panic;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
|
|
@ -420,6 +423,45 @@ pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
|||
response_str
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelDetail {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: usize,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelObject {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Models {
|
||||
pub object: ModelObject,
|
||||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
impl From<Vec<LlmProvider>> for Models {
|
||||
fn from(llm_providers: Vec<LlmProvider>) -> Self {
|
||||
let data = llm_providers
|
||||
.iter()
|
||||
.map(|provider| ModelDetail {
|
||||
id: provider.model.as_ref().unwrap().clone(),
|
||||
object: "model".to_string(),
|
||||
created: 1721172741,
|
||||
owned_by: "system".to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Models {
|
||||
object: ModelObject::List,
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType};
|
||||
|
|
@ -775,7 +817,10 @@ data: [DONE]
|
|||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
|
|
@ -815,22 +860,24 @@ data: [DONE]
|
|||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[1].text,
|
||||
Some("hello world".to_string())
|
||||
multi_part_content[1].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_claude() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ listeners:
|
|||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- name: gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
|
|
@ -17,6 +18,7 @@ llm_providers:
|
|||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
- name: ministral-3b
|
||||
access_key: $MISTRAL_API_KEY
|
||||
|
|
@ -27,7 +29,6 @@ llm_providers:
|
|||
access_key: $ANTHROPY_API_KEY
|
||||
provider_interface: claude
|
||||
model: claude-3-7-sonnet-latest
|
||||
default: true
|
||||
|
||||
- name: claude-sonnet-4
|
||||
access_key: $ANTHROPY_API_KEY
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
services:
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../../shared/chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
|
||||
open-web-ui:
|
||||
image: ghcr.io/open-webui/open-webui:main
|
||||
restart: always
|
||||
ports:
|
||||
- "18080:8080"
|
||||
- "8080:8080"
|
||||
environment:
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:12000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
- DEFAULT_MODEL=gpt-4o-mini
|
||||
- ENABLE_OPENAI_API=true
|
||||
- OPENAI_API_BASE_URL=http://host.docker.internal:12000/v1
|
||||
|
||||
jaeger:
|
||||
build:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue