Update arch_config and add tests for arch config file (#407)

This commit is contained in:
Adil Hafeez 2025-02-14 19:28:10 -08:00 committed by GitHub
parent d0a783cca8
commit e40b13be05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 379 additions and 212 deletions

View file

@ -9,7 +9,6 @@ use crate::api::open_ai::{
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
@ -48,32 +47,6 @@ pub struct ErrorTargetDetail {
pub endpoint: Option<EndpointDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
pub address: String,
pub port: u16,
pub message_format: MessageFormat,
// pub connect_timeout: Option<DurationString>,
}
impl Default for Listener {
fn default() -> Self {
Listener {
address: "".to_string(),
port: 0,
message_format: MessageFormat::default(),
// connect_timeout: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum MessageFormat {
#[serde(rename = "huggingface")]
#[default]
Huggingface,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PromptGuards {
pub input_guards: HashMap<GuardType, GuardOptions>,
@ -353,16 +326,6 @@ mod test {
Some("/agent/summary".to_string())
);
let error_target = config.error_target.as_ref().unwrap();
assert_eq!(
error_target.endpoint.as_ref().unwrap().name,
"error_target_1".to_string()
);
assert_eq!(
error_target.endpoint.as_ref().unwrap().path,
Some("/error".to_string())
);
let tracing = config.tracing.as_ref().unwrap();
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);

View file

@ -3,7 +3,10 @@ pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const DEFAULT_TARGET_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const API_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const MODEL_SERVER_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";

View file

@ -166,7 +166,7 @@ impl TraceData {
attributes: vec![Attribute {
key: "service.name".to_string(),
value: AttributeValue {
string_value: Some("upstream-llm".to_string()),
string_value: Some("egress_llm_traffic".to_string()),
},
}],
};

View file

@ -381,7 +381,7 @@ impl HttpContext for StreamContext {
Ok(traceparent) => {
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"upstream_llm_time".to_string(),
"egress_traffic".to_string(),
Some(traceparent.trace_id),
Some(traceparent.parent_id),
self.request_body_sent_time.unwrap(),

View file

@ -6,7 +6,8 @@ use common::{
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE,
},
errors::ServerError,
http::{CallArgs, Client},
@ -144,7 +145,10 @@ impl HttpContext for StreamContext {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string());
metadata
.as_mut()
.unwrap()
.insert("optimize_context_window".to_string(), "true".to_string());
}
}
@ -170,12 +174,15 @@ impl HttpContext for StreamContext {
debug!("sending request to model server");
trace!("request body: {}", json_data);
let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
(":method", "POST"),
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", MODEL_SERVER_NAME),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {

View file

@ -6,9 +6,9 @@ use common::api::open_ai::{
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
@ -89,7 +89,7 @@ impl StreamContext {
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
overrides: overrides,
overrides,
request_id: None,
traceparent: None,
_tracing: tracing,
@ -160,7 +160,7 @@ impl StreamContext {
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
@ -302,6 +302,8 @@ impl StreamContext {
}
};
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
let http_method_str = http_method.to_string();
let mut headers: HashMap<_, _> = [
(ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()),
@ -310,6 +312,7 @@ impl StreamContext {
(":authority", endpoint_details.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
]
.into_iter()
.collect();

View file

@ -81,10 +81,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", "model_server"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
None,
None,
None,
Some(5000),
)
.returning(Some(1))
.expect_log(Some(LogLevel::Trace), None)
@ -387,10 +388,11 @@ fn prompt_gateway_request_to_llm_gateway() {
(":authority", "api_server"),
("x-envoy-max-retries", "3"),
(":path", "/weather"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
Some(expected_body),
None,
None,
Some(5000),
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)