add preliminary support for llm agents (#432)

This commit is contained in:
Adil Hafeez 2025-03-19 15:21:34 -07:00 committed by GitHub
parent 8d66fefded
commit 84cd1df7bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 1388 additions and 121 deletions

View file

@ -3,7 +3,7 @@ use common::api::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
Message, StreamOptions,
};
use common::configuration::LlmProvider;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
@ -42,6 +42,7 @@ pub struct StreamContext {
request_body_sent_time: Option<u128>,
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
}
impl StreamContext {
@ -50,10 +51,12 @@ impl StreamContext {
metrics: Rc<Metrics>,
llm_providers: Rc<LlmProviders>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
context_id,
metrics,
overrides,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
@ -91,7 +94,12 @@ impl StreamContext {
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
self.llm_provider.as_ref().unwrap().name,
self.llm_provider.as_ref().unwrap().model
self.llm_provider
.as_ref()
.unwrap()
.model
.as_ref()
.unwrap_or(&String::new())
);
}
@ -151,11 +159,11 @@ impl StreamContext {
// Tokenize and record token count.
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
trace!("Recorded input token count: {}", token_count);
// Record the token count to metrics.
self.metrics
.input_sequence_length
.record(token_count as u64);
trace!("Recorded input token count: {}", token_count);
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
@ -184,24 +192,41 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.select_llm_provider();
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
if self.llm_provider().endpoint.is_none() {
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
if let Some(routing_header_value) = routing_header_value.as_ref() {
debug!("routing header already set: {}", routing_header_value);
self.llm_provider = Some(Rc::new(LlmProvider {
name: routing_header_value.to_string(),
provider_interface: LlmProviderType::OpenAI,
access_key: None,
endpoint: None,
model: None,
default: None,
stream: None,
port: None,
rate_limits: None,
}));
} else {
self.select_llm_provider();
self.add_http_request_header(
ARCH_ROUTING_HEADER,
&self.llm_provider().provider_interface.to_string(),
);
} else {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
if let Err(error) = self.modify_auth_headers() {
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
if let Err(error) = self.modify_auth_headers() {
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() && !use_agent_orchestrator
{
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
}
}
self.delete_content_length_header();
self.save_ratelimit_header();
@ -230,34 +255,38 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
None => {
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
// remove metadata from the request body
deserialized_body.metadata = None;
//TODO: move this to prompt gateway
// deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
@ -270,10 +299,16 @@ impl HttpContext for StreamContext {
.last()
.cloned();
// override model name from the llm provider
deserialized_body
.model
.clone_from(&self.llm_provider.as_ref().unwrap().model);
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => model,
None => "--",
},
None => "--",
};
deserialized_body.model = model_name.to_string();
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!(
@ -469,6 +504,10 @@ impl HttpContext for StreamContext {
};
if self.streaming_response {
if body_utf8 == "data: [DONE]\n" {
return Action::Continue;
}
let chat_completions_chunk_response_events =
match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) {
Ok(response) => response,
@ -482,7 +521,10 @@ impl HttpContext for StreamContext {
};
if chat_completions_chunk_response_events.events.is_empty() {
debug!("empty streaming response");
debug!(
"cound't parse any streaming events: body str: {}",
body_utf8
);
return Action::Continue;
}