mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
adding support for model aliases in archgw
This commit is contained in:
parent
1e8c81d8f6
commit
f13b420146
13 changed files with 1438 additions and 8 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::configuration::{Configuration, ModelUsagePreference};
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
|
|
@ -28,6 +28,7 @@ pub async fn chat(
|
|||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
config: Arc<Configuration>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
|
|
@ -35,6 +36,7 @@ pub async fn chat(
|
|||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
|
|
@ -46,6 +48,24 @@ pub async fn chat(
|
|||
}
|
||||
};
|
||||
|
||||
// === Model alias resolution: update model field in client_request immediately ===
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let original_model = client_request.model().to_string();
|
||||
let resolved_model = if let Some(model_aliases) = &config.model_aliases {
|
||||
if let Some(alias) = model_aliases.get(&original_model) {
|
||||
debug!(
|
||||
"[BRIGHTSTAFF] Model Alias: 'From {}' -> 'To{}'",
|
||||
original_model, alias.target
|
||||
);
|
||||
alias.target.clone()
|
||||
} else {
|
||||
original_model.clone()
|
||||
}
|
||||
} else {
|
||||
original_model.clone()
|
||||
};
|
||||
client_request.set_model(resolved_model.clone());
|
||||
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
|
|
@ -132,11 +152,12 @@ pub async fn chat(
|
|||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
debug!(
|
||||
"No route determined, using default model from request: {}",
|
||||
debug!(
|
||||
"[BRIGHTSTAFF] No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
@ -148,7 +169,7 @@ pub async fn chat(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Final Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let arch_config = Arc::clone(&arch_config);
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
|
@ -109,12 +110,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let arch_config = Arc::clone(&arch_config);
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url)
|
||||
chat(req, router_service, fully_qualified_url, arch_config)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,11 +13,17 @@ pub struct Routing {
|
|||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelAlias {
|
||||
pub target: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
|
|
|
|||
|
|
@ -104,6 +104,20 @@ pub struct ChatCompletionsRequest {
|
|||
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
/// Suppress max_tokens if the model is o3, o3-*, openrouter/o3, or openrouter/o3-*
|
||||
pub fn suppress_max_tokens_if_o3(&mut self) {
|
||||
let model = self.model.as_str();
|
||||
let is_o3 = model == "o3"
|
||||
|| model.starts_with("o3-")
|
||||
|| model == "openrouter/o3"
|
||||
|| model.starts_with("openrouter/o3-");
|
||||
if is_o3 {
|
||||
self.max_tokens = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CHAT COMPLETIONS API TYPES
|
||||
// ============================================================================
|
||||
|
|
@ -530,7 +544,10 @@ impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
|||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
// Use the centralized suppression logic
|
||||
req.suppress_max_tokens_if_o3();
|
||||
Ok(req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
Ok(ChatCompletionsRequest {
|
||||
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages: openai_messages,
|
||||
temperature: req.temperature,
|
||||
|
|
@ -109,7 +109,9 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
tool_choice: openai_tool_choice,
|
||||
parallel_tool_calls,
|
||||
..Default::default()
|
||||
})
|
||||
};
|
||||
_chat_completions_req.suppress_max_tokens_if_o3();
|
||||
Ok(_chat_completions_req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue