mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
more changes
This commit is contained in:
parent
1d314c8cb7
commit
2179b5a162
14 changed files with 467 additions and 56 deletions
|
|
@ -160,7 +160,7 @@ pub struct LlmProvider {
|
|||
pub name: String,
|
||||
pub provider_interface: LlmProviderType,
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub model: Option<String>,
|
||||
pub default: Option<bool>,
|
||||
pub stream: Option<bool>,
|
||||
pub endpoint: Option<String>,
|
||||
|
|
@ -177,6 +177,7 @@ impl Display for LlmProvider {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
pub agent_orchestrator: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::Configuration;
|
||||
use common::configuration::Overrides;
|
||||
use common::consts::OTEL_COLLECTOR_HTTP;
|
||||
use common::consts::OTEL_POST_PATH;
|
||||
use common::http::CallArgs;
|
||||
|
|
@ -31,6 +32,7 @@ pub struct FilterContext {
|
|||
callouts: RefCell<HashMap<u32, CallContext>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
|
|
@ -40,6 +42,7 @@ impl FilterContext {
|
|||
metrics: Rc::new(Metrics::new()),
|
||||
llm_providers: None,
|
||||
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
overrides: Rc::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -69,6 +72,7 @@ impl RootContext for FilterContext {
|
|||
};
|
||||
|
||||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
|
|
@ -93,6 +97,7 @@ impl RootContext for FilterContext {
|
|||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
Arc::clone(&self.traces_queue),
|
||||
Rc::clone(&self.overrides),
|
||||
)))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use common::api::open_ai::{
|
|||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
Message, StreamOptions,
|
||||
};
|
||||
use common::configuration::LlmProvider;
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::consts::{
|
||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -42,6 +42,7 @@ pub struct StreamContext {
|
|||
request_body_sent_time: Option<u128>,
|
||||
user_message: Option<Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -50,10 +51,12 @@ impl StreamContext {
|
|||
metrics: Rc<Metrics>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
overrides,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
|
|
@ -91,7 +94,12 @@ impl StreamContext {
|
|||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or_default(),
|
||||
self.llm_provider.as_ref().unwrap().name,
|
||||
self.llm_provider.as_ref().unwrap().model
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.model
|
||||
.as_ref()
|
||||
.unwrap_or(&String::new())
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -184,24 +192,42 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.select_llm_provider();
|
||||
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
|
||||
|
||||
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
|
||||
if self.llm_provider().endpoint.is_none() {
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
if let Some(routing_header_value) = routing_header_value.as_ref() {
|
||||
debug!("routing header already set: {}", routing_header_value);
|
||||
self.llm_provider = Some(Rc::new(LlmProvider {
|
||||
name: routing_header_value.to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
access_key: None,
|
||||
endpoint: None,
|
||||
model: None,
|
||||
default: None,
|
||||
stream: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
}));
|
||||
} else {
|
||||
self.select_llm_provider();
|
||||
debug!("setting routing header to: {}", self.llm_provider().name);
|
||||
self.add_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
&self.llm_provider().provider_interface.to_string(),
|
||||
);
|
||||
} else {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
// ensure that the provider has an endpoint if the access key is missing else return a bad request
|
||||
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
// ensure that the provider has an endpoint if the access key is missing else return a bad request
|
||||
if self.llm_provider.as_ref().unwrap().endpoint.is_none() && !use_agent_orchestrator
|
||||
{
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.delete_content_length_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
|
|
@ -267,7 +293,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// remove metadata from the request body
|
||||
//TODO: move this to prompt gateway
|
||||
deserialized_body.metadata = None;
|
||||
// deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
|
|
@ -280,10 +306,24 @@ impl HttpContext for StreamContext {
|
|||
.last()
|
||||
.cloned();
|
||||
|
||||
// override model name from the llm provider
|
||||
deserialized_body
|
||||
.model
|
||||
.clone_from(&self.llm_provider.as_ref().unwrap().model);
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
||||
Some(model) => model,
|
||||
None => "--",
|
||||
},
|
||||
None => "--",
|
||||
};
|
||||
|
||||
deserialized_body.model = model_name.to_string();
|
||||
|
||||
// if use_agent_orchestrator || self.llm_provider.as_ref().unwrap().model.is_none() {
|
||||
// deserialized_body.model = "None".to_string()
|
||||
// } else {
|
||||
// // override model name from the llm provider
|
||||
// deserialized_body
|
||||
// .model
|
||||
// .clone_from(&self.llm_provider.as_ref().unwrap().model.as_ref().unwrap());
|
||||
// }
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
trace!(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::configuration::{
|
||||
Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tracing,
|
||||
};
|
||||
use common::http::Client;
|
||||
use common::stats::Gauge;
|
||||
use log::trace;
|
||||
|
|
@ -21,6 +23,7 @@ pub struct FilterContext {
|
|||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
}
|
||||
|
|
@ -34,6 +37,7 @@ impl FilterContext {
|
|||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
endpoints: Rc::new(None),
|
||||
tracing: Rc::new(None),
|
||||
}
|
||||
}
|
||||
|
|
@ -73,6 +77,7 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.endpoints = Rc::new(config.endpoints);
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
|
|
@ -94,6 +99,7 @@ impl RootContext for FilterContext {
|
|||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.endpoints),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(&self.tracing),
|
||||
)))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::{
|
|||
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE,
|
||||
|
|
@ -33,6 +33,28 @@ impl HttpContext for StreamContext {
|
|||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
// get endpoint that has agent_orchestrator set to true
|
||||
if let Some(endpoints) = self.endpoints.as_ref() {
|
||||
let agent_orchestrator = endpoints
|
||||
.iter()
|
||||
.find(|(_, endpoint)| endpoint.agent_orchestrator.unwrap_or_default())
|
||||
.map(|(name, _)| name.clone());
|
||||
if let Some(agent_orchestrator_name) = agent_orchestrator {
|
||||
debug!(
|
||||
"Setting ARCH_PROVIDER_HINT_HEADER to {}",
|
||||
agent_orchestrator_name
|
||||
);
|
||||
self.set_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
Some(&agent_orchestrator_name),
|
||||
);
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let request_path = self.get_http_request_header(":path").unwrap_or_default();
|
||||
if request_path == HEALTHZ_PATH {
|
||||
self.send_http_response(200, vec![], None);
|
||||
|
|
@ -49,6 +71,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
|
||||
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::api::open_ai::{
|
|||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
|
||||
|
|
@ -46,6 +46,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<Metrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
|
|
@ -72,6 +73,7 @@ impl StreamContext {
|
|||
metrics: Rc<Metrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
) -> Self {
|
||||
|
|
@ -80,6 +82,7 @@ impl StreamContext {
|
|||
metrics,
|
||||
system_prompt,
|
||||
prompt_targets,
|
||||
endpoints,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
chat_completions_request: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -312,6 +315,51 @@ impl StreamContext {
|
|||
callout_context.prompt_target_name =
|
||||
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
|
||||
|
||||
metadata.insert(
|
||||
"Agent-Name".to_string(),
|
||||
callout_context
|
||||
.prompt_target_name
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.optimize_context_window.unwrap_or_default() {
|
||||
metadata.insert("optimize_context_window".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let messages = self.construct_llm_messages(&callout_context);
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model.clone(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options.clone(),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending request to llm agent: {}", body_str);
|
||||
self.set_http_request_body(0, self.request_body_size, body_str.as_bytes());
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashM
|
|||
if tool_params.is_none() {
|
||||
return HashMap::new();
|
||||
}
|
||||
tool_params.as_ref().unwrap()
|
||||
tool_params
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue