more changes

This commit is contained in:
Adil Hafeez 2025-03-13 18:05:58 -07:00
parent 1d314c8cb7
commit 2179b5a162
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
14 changed files with 467 additions and 56 deletions

View file

@ -160,7 +160,7 @@ pub struct LlmProvider {
pub name: String,
pub provider_interface: LlmProviderType,
pub access_key: Option<String>,
pub model: String,
pub model: Option<String>,
pub default: Option<bool>,
pub stream: Option<bool>,
pub endpoint: Option<String>,
@ -177,6 +177,7 @@ impl Display for LlmProvider {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
pub agent_orchestrator: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,6 +1,7 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::configuration::Overrides;
use common::consts::OTEL_COLLECTOR_HTTP;
use common::consts::OTEL_POST_PATH;
use common::http::CallArgs;
@ -31,6 +32,7 @@ pub struct FilterContext {
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
}
impl FilterContext {
@ -40,6 +42,7 @@ impl FilterContext {
metrics: Rc::new(Metrics::new()),
llm_providers: None,
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
overrides: Rc::new(None),
}
}
}
@ -69,6 +72,7 @@ impl RootContext for FilterContext {
};
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
self.overrides = Rc::new(config.overrides);
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
@ -93,6 +97,7 @@ impl RootContext for FilterContext {
.expect("LLM Providers must exist when Streams are being created"),
),
Arc::clone(&self.traces_queue),
Rc::clone(&self.overrides),
)))
}

View file

@ -3,7 +3,7 @@ use common::api::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
Message, StreamOptions,
};
use common::configuration::LlmProvider;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
@ -42,6 +42,7 @@ pub struct StreamContext {
request_body_sent_time: Option<u128>,
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
}
impl StreamContext {
@ -50,10 +51,12 @@ impl StreamContext {
metrics: Rc<Metrics>,
llm_providers: Rc<LlmProviders>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
context_id,
metrics,
overrides,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
@ -91,7 +94,12 @@ impl StreamContext {
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
self.llm_provider.as_ref().unwrap().name,
self.llm_provider.as_ref().unwrap().model
self.llm_provider
.as_ref()
.unwrap()
.model
.as_ref()
.unwrap_or(&String::new())
);
}
@ -184,24 +192,42 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.select_llm_provider();
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
if self.llm_provider().endpoint.is_none() {
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
if let Some(routing_header_value) = routing_header_value.as_ref() {
debug!("routing header already set: {}", routing_header_value);
self.llm_provider = Some(Rc::new(LlmProvider {
name: routing_header_value.to_string(),
provider_interface: LlmProviderType::OpenAI,
access_key: None,
endpoint: None,
model: None,
default: None,
stream: None,
port: None,
rate_limits: None,
}));
} else {
self.select_llm_provider();
debug!("setting routing header to: {}", self.llm_provider().name);
self.add_http_request_header(
ARCH_ROUTING_HEADER,
&self.llm_provider().provider_interface.to_string(),
);
} else {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
if let Err(error) = self.modify_auth_headers() {
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
if let Err(error) = self.modify_auth_headers() {
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() && !use_agent_orchestrator
{
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
}
}
self.delete_content_length_header();
self.save_ratelimit_header();
@ -267,7 +293,7 @@ impl HttpContext for StreamContext {
// remove metadata from the request body
//TODO: move this to prompt gateway
deserialized_body.metadata = None;
// deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
@ -280,10 +306,24 @@ impl HttpContext for StreamContext {
.last()
.cloned();
// override model name from the llm provider
deserialized_body
.model
.clone_from(&self.llm_provider.as_ref().unwrap().model);
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => model,
None => "--",
},
None => "--",
};
deserialized_body.model = model_name.to_string();
// if use_agent_orchestrator || self.llm_provider.as_ref().unwrap().model.is_none() {
// deserialized_body.model = "None".to_string()
// } else {
// // override model name from the llm provider
// deserialized_body
// .model
// .clone_from(&self.llm_provider.as_ref().unwrap().model.as_ref().unwrap());
// }
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!(

View file

@ -1,6 +1,8 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::configuration::{
Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tracing,
};
use common::http::Client;
use common::stats::Gauge;
use log::trace;
@ -21,6 +23,7 @@ pub struct FilterContext {
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
prompt_guards: Rc<PromptGuards>,
tracing: Rc<Option<Tracing>>,
}
@ -34,6 +37,7 @@ impl FilterContext {
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
endpoints: Rc::new(None),
tracing: Rc::new(None),
}
}
@ -73,6 +77,7 @@ impl RootContext for FilterContext {
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.endpoints = Rc::new(config.endpoints);
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
@ -94,6 +99,7 @@ impl RootContext for FilterContext {
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.endpoints),
Rc::clone(&self.overrides),
Rc::clone(&self.tracing),
)))

View file

@ -4,7 +4,7 @@ use common::{
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE,
@ -33,6 +33,28 @@ impl HttpContext for StreamContext {
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
// get endpoint that has agent_orchestrator set to true
if let Some(endpoints) = self.endpoints.as_ref() {
let agent_orchestrator = endpoints
.iter()
.find(|(_, endpoint)| endpoint.agent_orchestrator.unwrap_or_default())
.map(|(name, _)| name.clone());
if let Some(agent_orchestrator_name) = agent_orchestrator {
debug!(
"Setting ARCH_PROVIDER_HINT_HEADER to {}",
agent_orchestrator_name
);
self.set_http_request_header(
ARCH_ROUTING_HEADER,
Some(&agent_orchestrator_name),
);
};
}
}
}
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
self.send_http_response(200, vec![], None);
@ -49,6 +71,7 @@ impl HttpContext for StreamContext {
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
Action::Continue
}

View file

@ -4,7 +4,7 @@ use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing};
use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
@ -46,6 +46,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
pub overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
@ -72,6 +73,7 @@ impl StreamContext {
metrics: Rc<Metrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
overrides: Rc<Option<Overrides>>,
tracing: Rc<Option<Tracing>>,
) -> Self {
@ -80,6 +82,7 @@ impl StreamContext {
metrics,
system_prompt,
prompt_targets,
endpoints,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
tool_calls: None,
@ -312,6 +315,51 @@ impl StreamContext {
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
let mut metadata = HashMap::new();
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
metadata.insert(
"Agent-Name".to_string(),
callout_context
.prompt_target_name
.as_ref()
.unwrap()
.to_string(),
);
if let Some(overrides) = self.overrides.as_ref() {
if overrides.optimize_context_window.unwrap_or_default() {
metadata.insert("optimize_context_window".to_string(), "true".to_string());
}
}
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
}
}
let messages = self.construct_llm_messages(&callout_context);
let chat_completion_request = ChatCompletionsRequest {
model: callout_context.request_body.model.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options.clone(),
metadata: Some(metadata),
};
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending request to llm agent: {}", body_str);
self.set_http_request_body(0, self.request_body_size, body_str.as_bytes());
self.resume_http_request();
return;
}
}
self.schedule_api_call_request(callout_context);
}

View file

@ -8,7 +8,9 @@ pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashM
if tool_params.is_none() {
return HashMap::new();
}
tool_params.as_ref().unwrap()
tool_params
.as_ref()
.unwrap()
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {