use crate::metrics::Metrics; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter}; use hermesllm::providers::openai::types::{ ChatCompletionsResponse, ContentType, Message, StreamOptions, }; use hermesllm::Provider; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::collections::VecDeque; use std::num::NonZero; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub struct StreamContext { context_id: u32, metrics: Rc, ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, is_chat_completions_request: bool, llm_providers: Rc, llm_provider: Option>, request_id: Option, start_time: SystemTime, ttft_duration: Option, ttft_time: Option, traceparent: Option, request_body_sent_time: Option, user_message: Option, traces_queue: Arc>>, overrides: Rc>, } impl StreamContext { pub fn new( context_id: u32, metrics: Rc, llm_providers: Rc, traces_queue: Arc>>, overrides: Rc>, ) -> Self { StreamContext { context_id, metrics, overrides, ratelimit_selector: None, streaming_response: false, response_tokens: 0, is_chat_completions_request: false, llm_providers, llm_provider: None, request_id: None, start_time: SystemTime::now(), ttft_duration: None, traceparent: None, ttft_time: None, user_message: None, traces_queue, request_body_sent_time: None, } } fn llm_provider(&self) -> &LlmProvider { self.llm_provider .as_ref() .expect("the provider should be set when asked for it") } fn select_llm_provider(&mut self) { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) .map(|llm_name| llm_name.into()); self.llm_provider = Some(routing::get_llm_provider( &self.llm_providers, provider_hint, )); match self.llm_provider.as_ref().unwrap().provider_interface { LlmProviderType::Groq => { if let Some(path) = self.get_http_request_header(":path") { if path.starts_with("/v1/") { let new_path = format!("/openai{}", path); self.set_http_request_header(":path", Some(new_path.as_str())); } } } LlmProviderType::Gemini => { if let Some(path) = self.get_http_request_header(":path") { if path == "/v1/chat/completions" { self.set_http_request_header( ":path", Some("/v1beta/openai/chat/completions"), ); } } } _ => {} } debug!( "request received: llm provider hint: {}, selected llm: {}, model: {}", self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) .unwrap_or_default(), self.llm_provider.as_ref().unwrap().name, self.llm_provider .as_ref() .unwrap() .model .as_ref() .unwrap_or(&String::new()) ); } fn modify_auth_headers(&mut self) -> Result<(), ServerError> { let llm_provider_api_key_value = self.llm_provider() .access_key .as_ref() .ok_or(ServerError::BadRequest { why: format!( "No access key configured for selected LLM Provider \"{}\"", self.llm_provider() ), })?; let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); self.set_http_request_header("Authorization", Some(&authorization_header_value)); Ok(()) } fn delete_content_length_header(&mut self) { // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. // Server's generally throw away requests whose body length do not match the Content-Length header. // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could // manipulate the body in benign ways e.g., compression. self.set_http_request_header("content-length", None); } fn save_ratelimit_header(&mut self) { self.ratelimit_selector = self .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) .and_then(|key| { self.get_http_request_header(&key) .map(|value| Header { key, value }) }); } fn send_server_error(&self, error: ServerError, override_status_code: Option) { warn!("server error occurred: {}", error); self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) .as_u16() .into(), vec![], Some(format!("{error}").as_bytes()), ); } fn enforce_ratelimits( &mut self, model: &str, json_string: &str, ) -> Result<(), ratelimit::Error> { // Tokenize and record token count. let token_count = tokenizer::token_count(model, json_string).unwrap_or(0); debug!("Recorded input token count: {}", token_count); // Record the token count to metrics. self.metrics .input_sequence_length .record(token_count as u64); // Check if rate limiting needs to be applied. if let Some(selector) = self.ratelimit_selector.take() { log::debug!("Applying ratelimit for model: {}", model); ratelimit::ratelimits(None).read().unwrap().check_limit( model.to_owned(), selector, NonZero::new(token_count as u32).unwrap(), )?; } else { debug!("No rate limit applied for model: {}", model); } Ok(()) } } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { let request_path = self.get_http_request_header(":path").unwrap_or_default(); if request_path == HEALTHZ_PATH { self.send_http_response(200, vec![], None); return Action::Continue; } self.is_chat_completions_request = CHAT_COMPLETIONS_PATH == request_path; let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, }; let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER); if routing_header_value.is_some() && !routing_header_value.as_ref().unwrap().is_empty() { let routing_header_value = routing_header_value.as_ref().unwrap(); info!("routing header already set: {}", routing_header_value); self.llm_provider = Some(Rc::new(LlmProvider { name: routing_header_value.to_string(), provider_interface: LlmProviderType::OpenAI, ..Default::default() })); } else { self.select_llm_provider(); if self.llm_provider().endpoint.is_some() { self.add_http_request_header( ARCH_ROUTING_HEADER, &self.llm_provider().name.to_string(), ); } else { self.add_http_request_header( ARCH_ROUTING_HEADER, &self.llm_provider().provider_interface.to_string(), ); } if let Err(error) = self.modify_auth_headers() { // ensure that the provider has an endpoint if the access key is missing else return a bad request if self.llm_provider.as_ref().unwrap().endpoint.is_none() && !use_agent_orchestrator && self.llm_provider.as_ref().unwrap().provider_interface != LlmProviderType::Arch { self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); } } } self.delete_content_length_header(); self.save_ratelimit_header(); self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); Action::Continue } fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { debug!( "on_http_request_body [S={}] bytes={} end_stream={}", self.context_id, body_size, end_of_stream ); // Let the client send the gateway all the data before sending to the LLM_provider. // TODO: consider a streaming API. if self.request_body_sent_time.is_none() { self.request_body_sent_time = Some(current_time_ns()); } if !end_of_stream { return Action::Pause; } if body_size == 0 { return Action::Continue; } let body_bytes = match self.get_http_request_body(0, body_size) { Some(body_bytes) => body_bytes, None => { self.send_server_error( ServerError::LogicError(format!( "Failed to obtain body bytes even though body_size is {}", body_size )), None, ); return Action::Pause; } }; let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) { Ok(deserialized) => deserialized, Err(e) => { debug!( "on_http_request_body: request body: {}", String::from_utf8_lossy(&body_bytes) ); self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); return Action::Pause; } }; self.user_message = deserialized_body .messages .iter() .filter(|m| m.role == "user") .last() .cloned(); let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), None => None, }; let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, }; let model_requested = deserialized_body.model.clone(); deserialized_body.model = match model_name { Some(model_name) => model_name.clone(), None => { if use_agent_orchestrator { "agent_orchestrator".to_string() } else { self.send_server_error( ServerError::BadRequest { why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(), }, Some(StatusCode::BAD_REQUEST), ); return Action::Continue; } } }; info!( "on_http_request_body: provider: {}, model requested: {}, model selected: {}", self.llm_provider().name, model_requested, model_name.unwrap_or(&"None".to_string()), ); if deserialized_body.stream.unwrap_or_default() { self.streaming_response = true; } if deserialized_body.stream.unwrap_or_default() && deserialized_body.stream_options.is_none() { deserialized_body.stream_options = Some(StreamOptions { include_usage: true, }); } // only use the tokens from the messages, excluding the metadata and json tags let input_tokens_str = deserialized_body .messages .iter() .fold(String::new(), |acc, m| { acc + " " + m.content .as_ref() .unwrap_or(&ContentType::Text(String::new())) .to_string() .as_str() }); // enforce ratelimits on ingress if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str()) { self.send_server_error( ServerError::ExceededRatelimit(e), Some(StatusCode::TOO_MANY_REQUESTS), ); self.metrics.ratelimited_rq.increment(1); return Action::Continue; } let llm_provider_str = self.llm_provider().provider_interface.to_string(); let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); // convert chat completion request to llm provider specific request let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize request body: {}", e); self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); return Action::Pause; } }; self.set_http_request_body(0, body_size, &deserialized_body_bytes); Action::Continue } fn on_http_response_headers(&mut self, _num_headers: usize, end_of_stream: bool) -> Action { debug!( "on_http_response_headers [S={}] end_stream={}", self.context_id, end_of_stream ); self.set_property( vec!["metadata", "filter_metadata", "llm_filter", "user_prompt"], Some("hello world from filter".as_bytes()), ); Action::Continue } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { debug!( "on_http_response_body [S={}] bytes={} end_stream={}", self.context_id, body_size, end_of_stream ); if self.request_body_sent_time.is_none() { debug!("on_http_response_body: request body not sent, not doing any processing in llm filter"); return Action::Continue; } if !self.is_chat_completions_request { info!("on_http_response_body: non-chatcompletion request"); return Action::Continue; } let current_time = get_current_time().unwrap(); if end_of_stream && body_size == 0 { // All streaming responses end with bytes=0 and end_stream=true // Record the latency for the request match current_time.duration_since(self.start_time) { Ok(duration) => { // Convert the duration to milliseconds let duration_ms = duration.as_millis(); info!("on_http_response_body: request latency: {}ms", duration_ms); // Record the latency to the latency histogram self.metrics.request_latency.record(duration_ms as u64); if self.response_tokens > 0 { // Compute the time per output token let tpot = duration_ms as u64 / self.response_tokens as u64; // Record the time per output token self.metrics.time_per_output_token.record(tpot); debug!( "time per token: {}ms, tokens per second: {}", tpot, 1000 / tpot ); // Record the tokens per second self.metrics.tokens_per_second.record(1000 / tpot); } } Err(e) => { warn!("SystemTime error: {:?}", e); } } // Record the output sequence length self.metrics .output_sequence_length .record(self.response_tokens as u64); if let Some(traceparent) = self.traceparent.as_ref() { let current_time_ns = current_time_ns(); match Traceparent::try_from(traceparent.to_string()) { Err(e) => { warn!("traceparent header is invalid: {}", e); } Ok(traceparent) => { let mut trace_data = common::tracing::TraceData::new(); let mut llm_span = Span::new( "egress_traffic".to_string(), Some(traceparent.trace_id), Some(traceparent.parent_id), self.request_body_sent_time.unwrap(), current_time_ns, ); if let Some(user_message) = self.user_message.as_ref() { if let Some(prompt) = user_message.content.as_ref() { llm_span .add_attribute("user_prompt".to_string(), prompt.to_string()); } } llm_span.add_attribute( "model".to_string(), self.llm_provider().name.to_string(), ); if self.ttft_time.is_some() { llm_span.add_event(Event::new( "time_to_first_token".to_string(), self.ttft_time.unwrap(), )); trace_data.add_span(llm_span); } self.traces_queue.lock().unwrap().push_back(trace_data); } }; } return Action::Continue; } let body = if self.streaming_response { let chunk_start = 0; let chunk_size = body_size; debug!( "on_http_response_body: streaming response reading, {}..{}", chunk_start, chunk_size ); let streaming_chunk = match self.get_http_response_body(0, chunk_size) { Some(chunk) => chunk, None => { warn!( "response body empty, chunk_start: {}, chunk_size: {}", chunk_start, chunk_size ); return Action::Continue; } }; if streaming_chunk.len() != chunk_size { warn!( "chunk size mismatch: read: {} != requested: {}", streaming_chunk.len(), chunk_size ); } streaming_chunk } else { if body_size == 0 { return Action::Continue; } debug!("non streaming response bytes read: 0:{}", body_size); match self.get_http_response_body(0, body_size) { Some(body) => body, None => { warn!("non streaming response body empty"); return Action::Continue; } } }; if log::log_enabled!(log::Level::Debug) { debug!( "response data (converted to utf8): {}", String::from_utf8_lossy(&body) ); } let llm_provider_str = self.llm_provider().provider_interface.to_string(); let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); if self.streaming_response { let chat_completions_chunk_response_events = match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(events) => events, Err(e) => { warn!( "could not parse response: {}, body str: {}", e, String::from_utf8_lossy(&body) ); return Action::Continue; } }; for event in chat_completions_chunk_response_events { match event { Ok(event) => { if let Some(usage) = event.usage.as_ref() { self.response_tokens += usage.completion_tokens; } } Err(e) => { warn!("error in response event: {}", e); continue; } } } // Compute TTFT if not already recorded if self.ttft_duration.is_none() { // if let Some(start_time) = self.start_time { let current_time = get_current_time().unwrap(); self.ttft_time = Some(current_time_ns()); match current_time.duration_since(self.start_time) { Ok(duration) => { let duration_ms = duration.as_millis(); info!( "on_http_response_body: time to first token: {}ms", duration_ms ); self.ttft_duration = Some(duration); self.metrics.time_to_first_token.record(duration_ms as u64); } Err(e) => { warn!("SystemTime error: {:?}", e); } } } } else { debug!("non streaming response"); let chat_completions_response = match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(de) => de, Err(e) => { warn!( "could not parse response: {}, body str: {}", e, String::from_utf8_lossy(&body) ); debug!( "on_http_response_body: S[{}], response body: {}", self.context_id, String::from_utf8_lossy(&body) ); self.send_server_error( ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST), ); return Action::Continue; } }; if let Some(usage) = chat_completions_response.usage { self.response_tokens += usage.completion_tokens; } } debug!( "recv [S={}] total_tokens={} end_stream={}", self.context_id, self.response_tokens, end_of_stream ); Action::Continue } } fn current_time_ns() -> u128 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_nanos() } impl Context for StreamContext {}