Merge branch 'adil/fix_prompt_target_name' into adil/add_acm_demo

This commit is contained in:
Adil Hafeez 2025-01-17 17:50:40 -08:00
commit c532a5f4c7
47 changed files with 1133 additions and 1419 deletions

View file

@ -162,15 +162,34 @@ pub struct EmbeddingProviver {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LlmProviderType {
#[serde(rename = "openai")]
OpenAI,
#[serde(rename = "mistral")]
Mistral,
}
impl Display for LlmProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmProviderType::OpenAI => write!(f, "openai"),
LlmProviderType::Mistral => write!(f, "mistral"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
pub provider: String,
pub provider_interface: LlmProviderType,
pub access_key: Option<String>,
pub model: String,
pub default: Option<bool>,
pub stream: Option<bool>,
pub endpoint: Option<String>,
pub port: Option<u16>,
pub rate_limits: Option<LlmRatelimit>,
}

View file

@ -80,7 +80,7 @@ impl StreamContext {
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
.map(|llm_name| llm_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
@ -174,10 +174,22 @@ impl HttpContext for StreamContext {
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.select_llm_provider();
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
if self.llm_provider().endpoint.is_none() {
self.add_http_request_header(
ARCH_ROUTING_HEADER,
&self.llm_provider().provider_interface.to_string(),
);
} else {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
}
self.delete_content_length_header();
self.save_ratelimit_header();
@ -385,11 +397,13 @@ impl HttpContext for StreamContext {
self.llm_provider().name.to_string(),
);
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
}
self.traces_queue.lock().unwrap().push_back(trace_data);
}

View file

@ -23,11 +23,15 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(
Some(LogLevel::Debug),
Some("llm provider hint: Some(Default)"),
)
.expect_log(Some(LogLevel::Debug), Some("selected llm: open-ai-gpt-4"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
Some("open-ai-gpt-4"),
Some("openai"),
)
.expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders),
@ -46,8 +50,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
@ -110,12 +112,12 @@ endpoints:
llm_providers:
- name: open-ai-gpt-4
provider: openai
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
- name: open-ai-gpt-4o
provider: openai
provider_interface: openai
access_key: secret_key
model: gpt-4o

View file

@ -263,6 +263,10 @@ impl StreamContext {
);
}
// update prompt target name from the tool call
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
self.schedule_api_call_request(callout_context);
}
@ -359,8 +363,8 @@ impl StreamContext {
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
debug!("api_call_response_handler: http_status: {}", http_status);
if http_status != StatusCode::OK.as_str() {
debug!("api_call_response_handler: http_status: {}", http_status);
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
http_status

View file

@ -131,7 +131,7 @@ endpoints:
llm_providers:
- name: open-ai-gpt-4
provider: openai
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true