mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
When router usage is defined ensure that router model is defined too (#481)
This commit is contained in:
parent
218e9c540d
commit
d050dfb85a
3 changed files with 13 additions and 7 deletions
|
|
@ -81,7 +81,10 @@ def validate_and_render_schema():
|
||||||
|
|
||||||
updated_llm_providers = []
|
updated_llm_providers = []
|
||||||
llm_provider_name_set = set()
|
llm_provider_name_set = set()
|
||||||
|
llms_with_usage = []
|
||||||
for llm_provider in config_yaml["llm_providers"]:
|
for llm_provider in config_yaml["llm_providers"]:
|
||||||
|
if llm_provider.get("usage", None):
|
||||||
|
llms_with_usage.append(llm_provider["name"])
|
||||||
if llm_provider.get("name") in llm_provider_name_set:
|
if llm_provider.get("name") in llm_provider_name_set:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider"
|
||||||
|
|
@ -137,6 +140,15 @@ def validate_and_render_schema():
|
||||||
llm_provider["protocol"] = protocol
|
llm_provider["protocol"] = protocol
|
||||||
llms_with_endpoint.append(llm_provider)
|
llms_with_endpoint.append(llm_provider)
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(llms_with_usage) > 0
|
||||||
|
and config_yaml.get("routing", {}).get("model", None) == None
|
||||||
|
):
|
||||||
|
llms_with_usage_names = ", ".join(llms_with_usage)
|
||||||
|
raise Exception(
|
||||||
|
f"LLMs with usage found ({llms_with_usage_names}), please provide model in routing section in your arch_config.yaml file"
|
||||||
|
)
|
||||||
|
|
||||||
config_yaml["llm_providers"] = updated_llm_providers
|
config_yaml["llm_providers"] = updated_llm_providers
|
||||||
|
|
||||||
arch_config_string = yaml.dump(config_yaml)
|
arch_config_string = yaml.dump(config_yaml)
|
||||||
|
|
|
||||||
|
|
@ -84,12 +84,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("listening on http://{}", bind_address);
|
info!("listening on http://{}", bind_address);
|
||||||
let listener = TcpListener::bind(bind_address).await?;
|
let listener = TcpListener::bind(bind_address).await?;
|
||||||
|
|
||||||
// if routing is null then return gpt-4o as model name
|
let model = arch_config.routing.as_ref().unwrap().model.clone();
|
||||||
//TODO: fail if routing is null
|
|
||||||
let model = arch_config
|
|
||||||
.routing
|
|
||||||
.as_ref()
|
|
||||||
.map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone());
|
|
||||||
|
|
||||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||||
arch_config.llm_providers.clone(),
|
arch_config.llm_providers.clone(),
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ use crate::{
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::open_ai::ContentType;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct HallucinationClassificationRequest {
|
pub struct HallucinationClassificationRequest {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue