use provider/model to identify models

This commit is contained in:
Adil Hafeez 2025-07-08 13:47:56 -07:00
parent 5f18fee089
commit bcd7f9be45
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 33 additions and 68 deletions

View file

@ -103,7 +103,8 @@ properties:
type: string type: string
additionalProperties: false additionalProperties: false
required: required:
- name - model
- provider_interface
overrides: overrides:
type: object type: object
properties: properties:

View file

@ -9,6 +9,9 @@ ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml" "ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
) )
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml") ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
ARCH_CONFIG_FILE_RENDERED = os.getenv(
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
)
ENVOY_CONFIG_FILE_RENDERED = os.getenv( ENVOY_CONFIG_FILE_RENDERED = os.getenv(
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml" "ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
) )
@ -90,9 +93,9 @@ def validate_and_render_schema():
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
) )
if llm_provider.get("name") is None: if llm_provider.get("name") is None:
raise Exception( provider_interface = llm_provider.get("provider_interface", "unknown")
f"llm_provider name is required, please provide name for llm_provider" model_name = llm_provider.get("model", "unknown")
) llm_provider["name"] = f"{provider_interface}/{model_name}"
llm_provider_name_set.add(llm_provider.get("name")) llm_provider_name_set.add(llm_provider.get("name"))
provider = None provider = None
if llm_provider.get("provider") and llm_provider.get("provider_interface"): if llm_provider.get("provider") and llm_provider.get("provider_interface"):
@ -216,6 +219,9 @@ def validate_and_render_schema():
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file: with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered) file.write(rendered)
with open(ARCH_CONFIG_FILE_RENDERED, "w") as file:
file.write(arch_config_string)
def validate_prompt_config(arch_config_file, arch_config_schema_file): def validate_prompt_config(arch_config_file, arch_config_schema_file):
with open(arch_config_file, "r") as file: with open(arch_config_file, "r") as file:

View file

@ -104,7 +104,7 @@ pub async fn chat_completions(
debug!("usage preferences: {:?}", usage_preferences); debug!("usage preferences: {:?}", usage_preferences);
let mut selected_llm = match router_service let mut determined_route = match router_service
.determine_route( .determine_route(
&chat_completion_request.messages, &chat_completion_request.messages,
trace_parent.clone(), trace_parent.clone(),
@ -121,14 +121,14 @@ pub async fn chat_completions(
} }
}; };
if selected_llm.is_none() { if determined_route.is_none() {
debug!("No LLM model selected, using default from request"); debug!("No LLM model selected, using default from request");
selected_llm = Some(chat_completion_request.model.clone()); determined_route = Some(chat_completion_request.model.clone());
} }
info!( info!(
"sending request to llm provider: {} with llm model: {:?}", "sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, selected_llm llm_provider_endpoint, determined_route
); );
if let Some(trace_parent) = trace_parent { if let Some(trace_parent) = trace_parent {
@ -138,10 +138,10 @@ pub async fn chat_completions(
); );
} }
if let Some(selected_llm) = selected_llm { if let Some(selected_route) = determined_route {
request_headers.insert( request_headers.insert(
ARCH_PROVIDER_HINT_HEADER, ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_llm).unwrap(), header::HeaderValue::from_str(&selected_route).unwrap(),
); );
} }

View file

@ -45,8 +45,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string()); let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
// loading arch_config.yaml file // loading arch_config.yaml file
let arch_config_path = let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string()); .unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path); info!("Loading arch_config.yaml from {}", arch_config_path);
let config_contents = let config_contents =

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc}; use std::sync::Arc;
use common::{ use common::{
configuration::{LlmProvider, LlmRoute, ModelUsagePreference}, configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
@ -19,7 +19,6 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>, router_model: Arc<dyn RouterModel>,
routing_provider_name: String, routing_provider_name: String,
llm_usage_defined: bool, llm_usage_defined: bool,
llm_provider_map: HashMap<String, LlmProvider>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -57,18 +56,12 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN, router_model_v1::MAX_TOKEN_LEN,
)); ));
let llm_provider_map: HashMap<String, LlmProvider> = providers
.into_iter()
.map(|provider| (provider.name.clone(), provider))
.collect();
RouterService { RouterService {
router_url, router_url,
client: reqwest::Client::new(), client: reqwest::Client::new(),
router_model, router_model,
routing_provider_name, routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(), llm_usage_defined: !providers_with_usage.is_empty(),
llm_provider_map,
} }
} }
@ -155,40 +148,15 @@ impl RouterService {
if let Some(ContentType::Text(content)) = if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content &chat_completion_response.choices[0].message.content
{ {
let mut selected_model: Option<String> = None; let route_name = self.router_model.parse_response(content)?;
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
if selected_llm_name != "other" {
if let Some(usage_preferences) = usage_preferences {
for usage in usage_preferences {
if usage.name == selected_llm_name {
selected_model = Some(usage.model);
break;
}
}
if selected_model.is_none() {
warn!(
"Selected LLM model not found in usage preferences: {}",
selected_llm_name
);
}
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
selected_model = provider.model.clone();
} else {
warn!(
"Selected LLM model not found in provider map: {}",
selected_llm_name
);
}
}
}
info!( info!(
"router response: {}, selected_model: {:?}, response time: {}ms", "router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"), content.replace("\n", "\\n"),
selected_model, route_name,
router_response_time.as_millis() router_response_time.as_millis()
); );
Ok(selected_model) Ok(route_name)
} else { } else {
Ok(None) Ok(None)
} }

View file

@ -9,44 +9,36 @@ listeners:
llm_providers: llm_providers:
- name: openai/gpt-4o-mini - access_key: $OPENAI_API_KEY
access_key: $OPENAI_API_KEY
provider_interface: openai provider_interface: openai
model: gpt-4o-mini model: gpt-4o-mini
- name: openai/gpt-4o - access_key: $OPENAI_API_KEY
access_key: $OPENAI_API_KEY
provider_interface: openai provider_interface: openai
model: gpt-4o model: gpt-4o
default: true default: true
- name: mistral/ministral-3b - access_key: $MISTRAL_API_KEY
access_key: $MISTRAL_API_KEY
provider_interface: mistral provider_interface: mistral
model: ministral-3b-latest model: ministral-3b-latest
- name: claude/claude-sonnet - access_key: $ANTHROPIC_API_KEY
access_key: $ANTHROPIC_API_KEY
provider_interface: claude provider_interface: claude
model: claude-3-7-sonnet-latest model: claude-3-7-sonnet-latest
- name: claude/claude-sonnet-4 - access_key: $ANTHROPIC_API_KEY
access_key: $ANTHROPIC_API_KEY
provider_interface: claude provider_interface: claude
model: claude-sonnet-4-0 model: claude-sonnet-4-0
- name: deepseek/deepseek-reasoner - access_key: $DEEPSEEK_API_KEY
access_key: $DEEPSEEK_API_KEY
provider_interface: deepseek provider_interface: deepseek
model: deepseek-reasoner model: deepseek-reasoner
- name: groq/llama-3.1-8b-instant - access_key: $GROQ_API_KEY
access_key: $GROQ_API_KEY
provider_interface: groq provider_interface: groq
model: llama-3.1-8b-instant model: llama-3.1-8b-instant
- name: gemini/gemini-1.5-pro-latest - access_key: $GEMINI_API_KEY
access_key: $GEMINI_API_KEY
provider_interface: gemini provider_interface: gemini
model: gemini-1.5-pro-latest model: gemini-1.5-pro-latest

View file

@ -9,13 +9,11 @@ listeners:
llm_providers: llm_providers:
- name: openai/gpt-4o-mini - provider_interface: openai
provider_interface: openai
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
model: gpt-4o-mini model: gpt-4o-mini
- name: openai/gpt-4.1 - provider_interface: openai
provider_interface: openai
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
model: gpt-4.1 model: gpt-4.1
default: true default: true