mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
refactor more
This commit is contained in:
parent
8c71acfc76
commit
efa677683a
5 changed files with 61 additions and 41 deletions
|
|
@ -172,7 +172,7 @@ impl AgentSelector {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use common::configuration::{AgentFilterChain, Listener};
|
use common::configuration::{AgentFilterChain, Listener, ListenerType};
|
||||||
|
|
||||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||||
Arc::new(OrchestratorService::new(
|
Arc::new(OrchestratorService::new(
|
||||||
|
|
@ -192,12 +192,12 @@ mod tests {
|
||||||
|
|
||||||
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
||||||
Listener {
|
Listener {
|
||||||
|
listener_type: ListenerType::Agent,
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
agents: Some(agents),
|
agents: Some(agents),
|
||||||
filter_chain: None,
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
filter_agents: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ use hyper::StatusCode;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
|
||||||
|
|
||||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||||
Arc::new(OrchestratorService::new(
|
Arc::new(OrchestratorService::new(
|
||||||
|
|
@ -72,12 +72,12 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let listener = Listener {
|
let listener = Listener {
|
||||||
|
listener_type: ListenerType::Agent,
|
||||||
name: "test-listener".to_string(),
|
name: "test-listener".to_string(),
|
||||||
agents: Some(vec![agent_pipeline.clone()]),
|
agents: Some(vec![agent_pipeline.clone()]),
|
||||||
filter_chain: None,
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
filter_agents: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let listeners = vec![listener];
|
let listeners = vec![listener];
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes};
|
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||||
};
|
};
|
||||||
|
|
@ -45,7 +45,8 @@ pub async fn llm_chat(
|
||||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||||
span_attributes: Arc<Option<SpanAttributes>>,
|
span_attributes: Arc<Option<SpanAttributes>>,
|
||||||
state_storage: Option<Arc<dyn StateStorage>>,
|
state_storage: Option<Arc<dyn StateStorage>>,
|
||||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
filter_chain: Arc<Option<Vec<String>>>,
|
||||||
|
filter_agents: Arc<HashMap<String, Agent>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
|
@ -85,7 +86,8 @@ pub async fn llm_chat(
|
||||||
request_id,
|
request_id,
|
||||||
request_path,
|
request_path,
|
||||||
request_headers,
|
request_headers,
|
||||||
listeners,
|
filter_chain,
|
||||||
|
filter_agents,
|
||||||
)
|
)
|
||||||
.instrument(request_span)
|
.instrument(request_span)
|
||||||
.await
|
.await
|
||||||
|
|
@ -103,7 +105,8 @@ async fn llm_chat_inner(
|
||||||
request_id: String,
|
request_id: String,
|
||||||
request_path: String,
|
request_path: String,
|
||||||
mut request_headers: hyper::HeaderMap,
|
mut request_headers: hyper::HeaderMap,
|
||||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
filter_chain: Arc<Option<Vec<String>>>,
|
||||||
|
filter_agents: Arc<HashMap<String, Agent>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
// Set service name for LLM operations
|
// Set service name for LLM operations
|
||||||
set_service_name(operation_component::LLM);
|
set_service_name(operation_component::LLM);
|
||||||
|
|
@ -257,20 +260,9 @@ async fn llm_chat_inner(
|
||||||
debug!("removed plano_preference_config from metadata");
|
debug!("removed plano_preference_config from metadata");
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Filter chain processing for model listeners ===
|
// === Filter chain processing for model listener ===
|
||||||
// Check if any model listener (no agents) has a filter_chain configured
|
|
||||||
{
|
{
|
||||||
let listeners_guard = listeners.read().await;
|
if let Some(ref fc) = *filter_chain {
|
||||||
let model_listener = listeners_guard
|
|
||||||
.iter()
|
|
||||||
.find(|l| l.agents.is_none() && l.filter_chain.is_some());
|
|
||||||
|
|
||||||
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
|
|
||||||
let agent_map = model_listener
|
|
||||||
.and_then(|l| l.filter_agents.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
if let Some(ref fc) = filter_chain {
|
|
||||||
if !fc.is_empty() {
|
if !fc.is_empty() {
|
||||||
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
||||||
|
|
||||||
|
|
@ -288,7 +280,7 @@ async fn llm_chat_inner(
|
||||||
.process_filter_chain(
|
.process_filter_chain(
|
||||||
&messages,
|
&messages,
|
||||||
&temp_filter_chain,
|
&temp_filter_chain,
|
||||||
&agent_map,
|
&filter_agents,
|
||||||
&request_headers,
|
&request_headers,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||||
use brightstaff::state::StateStorage;
|
use brightstaff::state::StateStorage;
|
||||||
use brightstaff::utils::tracing::init_tracer;
|
use brightstaff::utils::tracing::init_tracer;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, Configuration};
|
use common::configuration::{Agent, Configuration, ListenerType};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||||
};
|
};
|
||||||
|
|
@ -87,26 +87,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
.map(|a| (a.id.clone(), a.clone()))
|
.map(|a| (a.id.clone(), a.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Resolve filter_agents on each listener at startup
|
|
||||||
let mut listeners_resolved = plano_config.listeners.clone();
|
|
||||||
for listener in &mut listeners_resolved {
|
|
||||||
if let Some(ref fc) = listener.filter_chain {
|
|
||||||
let filter_agents: HashMap<String, Agent> = fc
|
|
||||||
.iter()
|
|
||||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
|
||||||
.collect();
|
|
||||||
if !filter_agents.is_empty() {
|
|
||||||
listener.filter_agents = Some(filter_agents);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create expanded provider list for /v1/models endpoint
|
// Create expanded provider list for /v1/models endpoint
|
||||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||||
.expect("Failed to create LlmProviders");
|
.expect("Failed to create LlmProviders");
|
||||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||||
let listeners = Arc::new(RwLock::new(listeners_resolved));
|
|
||||||
|
// Resolve model listener filter chain and agents at startup
|
||||||
|
let model_listener_count = plano_config
|
||||||
|
.listeners
|
||||||
|
.iter()
|
||||||
|
.filter(|l| l.listener_type == ListenerType::Model)
|
||||||
|
.count();
|
||||||
|
assert!(
|
||||||
|
model_listener_count <= 1,
|
||||||
|
"only one model listener is allowed, found {}",
|
||||||
|
model_listener_count
|
||||||
|
);
|
||||||
|
let model_listener = plano_config
|
||||||
|
.listeners
|
||||||
|
.iter()
|
||||||
|
.find(|l| l.listener_type == ListenerType::Model);
|
||||||
|
let model_filter_chain: Arc<Option<Vec<String>>> =
|
||||||
|
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
|
||||||
|
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
|
||||||
|
model_filter_chain
|
||||||
|
.as_ref()
|
||||||
|
.as_ref()
|
||||||
|
.map(|fc| {
|
||||||
|
fc.iter()
|
||||||
|
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_default(),
|
||||||
|
);
|
||||||
|
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||||
let llm_provider_url =
|
let llm_provider_url =
|
||||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||||
|
|
||||||
|
|
@ -200,6 +215,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
|
||||||
let llm_providers = llm_providers.clone();
|
let llm_providers = llm_providers.clone();
|
||||||
let agents_list = combined_agents_filters_list.clone();
|
let agents_list = combined_agents_filters_list.clone();
|
||||||
|
let model_filter_chain = model_filter_chain.clone();
|
||||||
|
let model_filter_agents = model_filter_agents.clone();
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let span_attributes = span_attributes.clone();
|
let span_attributes = span_attributes.clone();
|
||||||
let state_storage = state_storage.clone();
|
let state_storage = state_storage.clone();
|
||||||
|
|
@ -211,6 +228,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let llm_providers = llm_providers.clone();
|
let llm_providers = llm_providers.clone();
|
||||||
let model_aliases = Arc::clone(&model_aliases);
|
let model_aliases = Arc::clone(&model_aliases);
|
||||||
let agents_list = agents_list.clone();
|
let agents_list = agents_list.clone();
|
||||||
|
let model_filter_chain = model_filter_chain.clone();
|
||||||
|
let model_filter_agents = model_filter_agents.clone();
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let span_attributes = span_attributes.clone();
|
let span_attributes = span_attributes.clone();
|
||||||
let state_storage = state_storage.clone();
|
let state_storage = state_storage.clone();
|
||||||
|
|
@ -269,7 +288,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
llm_providers,
|
llm_providers,
|
||||||
span_attributes,
|
span_attributes,
|
||||||
state_storage,
|
state_storage,
|
||||||
listeners,
|
model_filter_chain,
|
||||||
|
model_filter_agents,
|
||||||
)
|
)
|
||||||
.with_context(parent_cx)
|
.with_context(parent_cx)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -36,15 +36,23 @@ pub struct AgentFilterChain {
|
||||||
pub filter_chain: Option<Vec<String>>,
|
pub filter_chain: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum ListenerType {
|
||||||
|
Model,
|
||||||
|
Agent,
|
||||||
|
Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Listener {
|
pub struct Listener {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub listener_type: ListenerType,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub router: Option<String>,
|
pub router: Option<String>,
|
||||||
pub agents: Option<Vec<AgentFilterChain>>,
|
pub agents: Option<Vec<AgentFilterChain>>,
|
||||||
pub filter_chain: Option<Vec<String>>,
|
pub filter_chain: Option<Vec<String>>,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(skip)]
|
|
||||||
pub filter_agents: Option<HashMap<String, Agent>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue