mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
use overrides for custom routing and orchestration model
This commit is contained in:
parent
98038690b0
commit
6143b7ad54
9 changed files with 93 additions and 114 deletions
|
|
@ -8,13 +8,13 @@ from urllib.parse import urlparse
|
|||
from copy import deepcopy
|
||||
from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
||||
"azure_openai",
|
||||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
"plano",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
|
|
@ -368,29 +368,25 @@ def validate_and_render_schema():
|
|||
llms_with_endpoint.append(model_provider)
|
||||
llms_with_endpoint_cluster_names.add(cluster_name)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_model_provider = config_yaml.get("routing", {}).get(
|
||||
"model_provider", None
|
||||
overrides_config = config_yaml.get("overrides", {})
|
||||
# Build lookup of model names (already prefix-stripped by config processing)
|
||||
model_name_set = {mp.get("model") for mp in updated_model_providers}
|
||||
|
||||
# Auto-add arch-router provider if routing preferences exist and no provider matches the router model
|
||||
router_model = overrides_config.get("router_model", "Arch-Router")
|
||||
# Strip provider prefix for comparison since config processing strips prefixes from model names
|
||||
router_model_id = (
|
||||
router_model.split("/", 1)[1] if "/" in router_model else router_model
|
||||
)
|
||||
if len(model_usage_name_keys) > 0 and router_model_id not in model_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": router_model_id,
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
if (
|
||||
routing_model_provider
|
||||
and routing_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise Exception(
|
||||
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
|
||||
)
|
||||
if (
|
||||
routing_model_provider is None
|
||||
and "arch-router" not in model_provider_name_set
|
||||
):
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": config_yaml.get("routing", {}).get("model", "Arch-Router"),
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Always add arch-function model provider if not already defined
|
||||
if "arch-function" not in model_provider_name_set:
|
||||
|
|
@ -403,26 +399,21 @@ def validate_and_render_schema():
|
|||
}
|
||||
)
|
||||
|
||||
orchestration_config = config_yaml.get("orchestration", {})
|
||||
orchestration_model_provider = orchestration_config.get("llm_provider", None)
|
||||
|
||||
if (
|
||||
orchestration_model_provider
|
||||
and orchestration_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise Exception(
|
||||
f"Orchestration llm_provider {orchestration_model_provider} is not defined in model_providers"
|
||||
)
|
||||
|
||||
if (
|
||||
orchestration_model_provider is None
|
||||
and "plano-orchestrator" not in model_provider_name_set
|
||||
):
|
||||
# Auto-add plano-orchestrator provider if no provider matches the orchestrator model
|
||||
orchestrator_model = overrides_config.get(
|
||||
"orchestrator_model", "Plano-Orchestrator"
|
||||
)
|
||||
orchestrator_model_id = (
|
||||
orchestrator_model.split("/", 1)[1]
|
||||
if "/" in orchestrator_model
|
||||
else orchestrator_model
|
||||
)
|
||||
if orchestrator_model_id not in model_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "arch",
|
||||
"model": orchestration_config.get("model", "Plano-Orchestrator"),
|
||||
"model": orchestrator_model_id,
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ properties:
|
|||
type: string
|
||||
enum:
|
||||
- arch
|
||||
- plano
|
||||
- claude
|
||||
- deepseek
|
||||
- groq
|
||||
|
|
@ -221,6 +222,7 @@ properties:
|
|||
type: string
|
||||
enum:
|
||||
- arch
|
||||
- plano
|
||||
- claude
|
||||
- deepseek
|
||||
- groq
|
||||
|
|
@ -271,6 +273,12 @@ properties:
|
|||
upstream_tls_ca_path:
|
||||
type: string
|
||||
description: "Path to the trusted CA bundle for upstream TLS verification. Default is '/etc/ssl/certs/ca-certificates.crt'."
|
||||
router_model:
|
||||
type: string
|
||||
description: "Model name for the LLM router (e.g., 'Arch-Router'). Must match a model in model_providers."
|
||||
orchestrator_model:
|
||||
type: string
|
||||
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
|
||||
system_prompt:
|
||||
type: string
|
||||
prompt_targets:
|
||||
|
|
@ -408,22 +416,6 @@ properties:
|
|||
enum:
|
||||
- llm
|
||||
- prompt
|
||||
routing:
|
||||
type: object
|
||||
properties:
|
||||
llm_provider:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
orchestration:
|
||||
type: object
|
||||
properties:
|
||||
llm_provider:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
state_storage:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let routing_model_name: String = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
|
||||
let overrides = plano_config.overrides.clone().unwrap_or_default();
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let routing_model_name: String = overrides
|
||||
.router_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model_provider.clone())
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
|
|
@ -109,16 +114,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_llm_provider,
|
||||
));
|
||||
|
||||
let orchestrator_model_name: String = plano_config
|
||||
.orchestration
|
||||
.as_ref()
|
||||
.and_then(|o| o.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_MODEL_NAME.to_string());
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let orchestrator_model_name: String = overrides
|
||||
.orchestrator_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let orchestrator_llm_provider: String = plano_config
|
||||
.orchestration
|
||||
.as_ref()
|
||||
.and_then(|o| o.model_provider.clone())
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
|
|
|
|||
|
|
@ -7,18 +7,6 @@ use crate::api::open_ai::{
|
|||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
pub model_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Orchestration {
|
||||
pub model_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelAlias {
|
||||
pub target: String,
|
||||
|
|
@ -78,8 +66,6 @@ pub struct Configuration {
|
|||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
pub orchestration: Option<Orchestration>,
|
||||
pub agents: Option<Vec<Agent>>,
|
||||
pub filters: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
|
|
@ -91,6 +77,8 @@ pub struct Overrides {
|
|||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
pub optimize_context_window: Option<bool>,
|
||||
pub use_agent_orchestrator: Option<bool>,
|
||||
pub router_model: Option<String>,
|
||||
pub orchestrator_model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -244,6 +232,8 @@ pub enum LlmProviderType {
|
|||
Qwen,
|
||||
#[serde(rename = "amazon_bedrock")]
|
||||
AmazonBedrock,
|
||||
#[serde(rename = "plano")]
|
||||
Plano,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -264,6 +254,7 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::Zhipu => write!(f, "zhipu"),
|
||||
LlmProviderType::Qwen => write!(f, "qwen"),
|
||||
LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
LlmProviderType::Plano => write!(f, "plano"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -272,7 +263,15 @@ impl LlmProviderType {
|
|||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::try_from(self.to_string().as_str())
|
||||
// Plano provider uses the same interface as Arch
|
||||
let provider_str = match self {
|
||||
LlmProviderType::Plano => "arch",
|
||||
other => {
|
||||
return hermesllm::ProviderId::try_from(other.to_string().as_str())
|
||||
.expect("LlmProviderType should always map to a valid ProviderId")
|
||||
}
|
||||
};
|
||||
hermesllm::ProviderId::try_from(provider_str)
|
||||
.expect("LlmProviderType should always map to a valid ProviderId")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
version: v0.3.0
|
||||
|
||||
orchestration:
|
||||
model: Plano-Orchestrator
|
||||
llm_provider: plano-orchestrator
|
||||
overrides:
|
||||
orchestrator_model: arch/Plano-Orchestrator
|
||||
|
||||
agents:
|
||||
- id: weather_agent
|
||||
|
|
@ -11,8 +10,7 @@ agents:
|
|||
url: http://localhost:10520
|
||||
|
||||
model_providers:
|
||||
- name: plano-orchestrator
|
||||
model: Plano-Orchestrator
|
||||
- model: arch/Plano-Orchestrator
|
||||
base_url: http://localhost:8000
|
||||
|
||||
- model: openai/gpt-5.2
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
version: v0.1.0
|
||||
|
||||
routing:
|
||||
model: Arch-Router
|
||||
llm_provider: arch-router
|
||||
overrides:
|
||||
router_model: Arch-Router
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
version: v0.3.0
|
||||
|
||||
routing:
|
||||
model: Arch-Router
|
||||
llm_provider: arch-router
|
||||
overrides:
|
||||
router_model: plano/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
|
|
@ -11,8 +10,7 @@ listeners:
|
|||
|
||||
model_providers:
|
||||
|
||||
- name: arch-router
|
||||
model: arch/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
- model: plano/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
base_url: http://localhost:11434
|
||||
|
||||
- model: openai/gpt-4o-mini
|
||||
|
|
|
|||
|
|
@ -253,13 +253,11 @@ Using Ollama (recommended for local development)
|
|||
|
||||
.. code-block:: yaml
|
||||
|
||||
routing:
|
||||
model: Arch-Router
|
||||
llm_provider: arch-router
|
||||
overrides:
|
||||
router_model: arch/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
|
||||
model_providers:
|
||||
- name: arch-router
|
||||
model: arch/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
- model: arch/hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M
|
||||
base_url: http://localhost:11434
|
||||
|
||||
- model: openai/gpt-5.2
|
||||
|
|
@ -324,13 +322,11 @@ vLLM provides higher throughput and GPU optimizations suitable for production de
|
|||
|
||||
.. code-block:: yaml
|
||||
|
||||
routing:
|
||||
model: Arch-Router
|
||||
llm_provider: arch-router
|
||||
overrides:
|
||||
router_model: Arch-Router
|
||||
|
||||
model_providers:
|
||||
- name: arch-router
|
||||
model: Arch-Router
|
||||
- model: Arch-Router
|
||||
base_url: http://<your-server-ip>:10000
|
||||
|
||||
- model: openai/gpt-5.2
|
||||
|
|
|
|||
|
|
@ -401,13 +401,11 @@ Using vLLM
|
|||
|
||||
.. code-block:: yaml
|
||||
|
||||
orchestration:
|
||||
model: Plano-Orchestrator
|
||||
llm_provider: plano-orchestrator
|
||||
overrides:
|
||||
orchestrator_model: arch/Plano-Orchestrator
|
||||
|
||||
model_providers:
|
||||
- name: plano-orchestrator
|
||||
model: Plano-Orchestrator
|
||||
- model: arch/Plano-Orchestrator
|
||||
base_url: http://<your-server-ip>:8000
|
||||
|
||||
5. **Verify the server is running**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue