diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 561dbae3..630ed6d5 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -4,6 +4,9 @@ pub const USER_ROLE: &str = "user"; pub const TOOL_ROLE: &str = "tool"; pub const ASSISTANT_ROLE: &str = "assistant"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes +pub const DEFAULT_TARGET_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes +pub const API_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes +pub const MODEL_SERVER_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 1ff7f91d..53a2d25b 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -6,7 +6,8 @@ use common::{ consts::{ ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, - MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE, + TRACE_PARENT_HEADER, USER_ROLE, }, errors::ServerError, http::{CallArgs, Client}, @@ -144,7 +145,10 @@ impl HttpContext for StreamContext { if metadata.is_none() { metadata = Some(HashMap::new()); } - metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string()); + metadata + .as_mut() + .unwrap() + .insert("optimize_context_window".to_string(), "true".to_string()); } } @@ -170,12 +174,15 @@ impl HttpContext for StreamContext { debug!("sending request to model server"); trace!("request body: {}", json_data); + let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string(); + let mut headers = vec![ (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), (":method", "POST"), (":path", "/function_calling"), ("content-type", "application/json"), (":authority", MODEL_SERVER_NAME), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), ]; if self.request_id.is_some() { diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index e6db7f59..d197b3e0 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -6,9 +6,9 @@ use common::api::open_ai::{ }; use common::configuration::{Overrides, PromptTarget, Tracing}; use common::consts::{ - ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, - ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, - TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY, + REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; @@ -89,7 +89,7 @@ impl StreamContext { streaming_response: false, user_prompt: None, is_chat_completions_request: false, - overrides: overrides, + overrides, request_id: None, traceparent: None, _tracing: tracing, @@ -160,7 +160,7 @@ impl StreamContext { callout_context.request_body.messages.clone(), ); let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); + let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string(); let mut headers = vec![ (":method", "POST"), @@ -302,6 +302,8 @@ impl StreamContext { } }; + let timeout_str = API_REQUEST_TIMEOUT_MS.to_string(); + let http_method_str = http_method.to_string(); let mut headers: HashMap<_, _> = [ (ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()), @@ -310,6 +312,7 @@ impl StreamContext { (":authority", endpoint_details.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), ] .into_iter() .collect();