mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
refactor more
This commit is contained in:
parent
8c71acfc76
commit
efa677683a
5 changed files with 61 additions and 41 deletions
|
|
@ -172,7 +172,7 @@ impl AgentSelector {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{AgentFilterChain, Listener};
|
||||
use common::configuration::{AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
|
|
@ -192,12 +192,12 @@ mod tests {
|
|||
|
||||
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
||||
Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
filter_chain: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
filter_agents: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use hyper::StatusCode;
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
|
|
@ -72,12 +72,12 @@ mod tests {
|
|||
};
|
||||
|
||||
let listener = Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
filter_chain: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
filter_agents: None,
|
||||
};
|
||||
|
||||
let listeners = vec![listener];
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes};
|
||||
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
|
|
@ -45,7 +45,8 @@ pub async fn llm_chat(
|
|||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
|
@ -85,7 +86,8 @@ pub async fn llm_chat(
|
|||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
listeners,
|
||||
filter_chain,
|
||||
filter_agents,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
|
|
@ -103,7 +105,8 @@ async fn llm_chat_inner(
|
|||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -257,20 +260,9 @@ async fn llm_chat_inner(
|
|||
debug!("removed plano_preference_config from metadata");
|
||||
}
|
||||
|
||||
// === Filter chain processing for model listeners ===
|
||||
// Check if any model listener (no agents) has a filter_chain configured
|
||||
// === Filter chain processing for model listener ===
|
||||
{
|
||||
let listeners_guard = listeners.read().await;
|
||||
let model_listener = listeners_guard
|
||||
.iter()
|
||||
.find(|l| l.agents.is_none() && l.filter_chain.is_some());
|
||||
|
||||
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
|
||||
let agent_map = model_listener
|
||||
.and_then(|l| l.filter_agents.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(ref fc) = filter_chain {
|
||||
if let Some(ref fc) = *filter_chain {
|
||||
if !fc.is_empty() {
|
||||
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
||||
|
||||
|
|
@ -288,7 +280,7 @@ async fn llm_chat_inner(
|
|||
.process_filter_chain(
|
||||
&messages,
|
||||
&temp_filter_chain,
|
||||
&agent_map,
|
||||
&filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
|||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::configuration::{Agent, Configuration, ListenerType};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
|
|
@ -87,26 +87,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.map(|a| (a.id.clone(), a.clone()))
|
||||
.collect();
|
||||
|
||||
// Resolve filter_agents on each listener at startup
|
||||
let mut listeners_resolved = plano_config.listeners.clone();
|
||||
for listener in &mut listeners_resolved {
|
||||
if let Some(ref fc) = listener.filter_chain {
|
||||
let filter_agents: HashMap<String, Agent> = fc
|
||||
.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect();
|
||||
if !filter_agents.is_empty() {
|
||||
listener.filter_agents = Some(filter_agents);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(listeners_resolved));
|
||||
|
||||
// Resolve model listener filter chain and agents at startup
|
||||
let model_listener_count = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
assert!(
|
||||
model_listener_count <= 1,
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
);
|
||||
let model_listener = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
let model_filter_chain: Arc<Option<Vec<String>>> =
|
||||
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
|
||||
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
|
||||
model_filter_chain
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
.map(|fc| {
|
||||
fc.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
|
|
@ -200,6 +215,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -211,6 +228,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -269,7 +288,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
llm_providers,
|
||||
span_attributes,
|
||||
state_storage,
|
||||
listeners,
|
||||
model_filter_chain,
|
||||
model_filter_agents,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -36,15 +36,23 @@ pub struct AgentFilterChain {
|
|||
pub filter_chain: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ListenerType {
|
||||
Model,
|
||||
Agent,
|
||||
Prompt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
#[serde(rename = "type")]
|
||||
pub listener_type: ListenerType,
|
||||
pub name: String,
|
||||
pub router: Option<String>,
|
||||
pub agents: Option<Vec<AgentFilterChain>>,
|
||||
pub filter_chain: Option<Vec<String>>,
|
||||
pub port: u16,
|
||||
#[serde(skip)]
|
||||
pub filter_agents: Option<HashMap<String, Agent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue