mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
When router usage is defined ensure that router model is defined too (#481)
This commit is contained in:
parent
218e9c540d
commit
d050dfb85a
3 changed files with 13 additions and 7 deletions
|
|
@ -81,7 +81,10 @@ def validate_and_render_schema():
|
|||
|
||||
updated_llm_providers = []
|
||||
llm_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
if llm_provider.get("usage", None):
|
||||
llms_with_usage.append(llm_provider["name"])
|
||||
if llm_provider.get("name") in llm_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||
|
|
@ -137,6 +140,15 @@ def validate_and_render_schema():
|
|||
llm_provider["protocol"] = protocol
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
|
||||
if (
|
||||
len(llms_with_usage) > 0
|
||||
and config_yaml.get("routing", {}).get("model", None) == None
|
||||
):
|
||||
llms_with_usage_names = ", ".join(llms_with_usage)
|
||||
raise Exception(
|
||||
f"LLMs with usage found ({llms_with_usage_names}), please provide model in routing section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
config_yaml["llm_providers"] = updated_llm_providers
|
||||
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
|
|
|
|||
|
|
@ -84,12 +84,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
info!("listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
// if routing is null then return gpt-4o as model name
|
||||
//TODO: fail if routing is null
|
||||
let model = arch_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone());
|
||||
let model = arch_config.routing.as_ref().unwrap().model.clone();
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ use crate::{
|
|||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::open_ai::ContentType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue