diff --git a/crates/common/src/stats.rs b/crates/common/src/stats.rs index 527713f3..9479fadf 100644 --- a/crates/common/src/stats.rs +++ b/crates/common/src/stats.rs @@ -80,7 +80,7 @@ impl RecordingMetric for Gauge {} /// For offset deltas impl IncrementingMetric for Gauge {} -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Histogram { id: u32, } diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index be80c390..b603103a 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -5,6 +5,7 @@ use common::llm_providers::LlmProviders; use common::ratelimit; use common::stats::Counter; use common::stats::Gauge; +use common::stats::Histogram; use log::debug; use proxy_wasm::traits::*; use proxy_wasm::types::*; @@ -16,6 +17,7 @@ use std::rc::Rc; pub struct WasmMetrics { pub active_http_calls: Gauge, pub ratelimited_rq: Counter, + pub time_to_first_token: Histogram, } impl WasmMetrics { @@ -23,6 +25,7 @@ impl WasmMetrics { WasmMetrics { active_http_calls: Gauge::new(String::from("active_http_calls")), ratelimited_rq: Counter::new(String::from("ratelimited_rq")), + time_to_first_token: Histogram::new(String::from("time_to_first_token")), } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9c3db01d..43d9f0aa 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -20,7 +20,10 @@ use proxy_wasm::types::*; use std::num::NonZero; use std::rc::Rc; -use common::stats::IncrementingMetric; +use common::stats::{IncrementingMetric, RecordingMetric}; + +use proxy_wasm::hostcalls::get_current_time; +use std::time::{Duration, SystemTime}; pub struct StreamContext { context_id: u32, @@ -32,6 +35,9 @@ pub struct StreamContext { llm_providers: Rc, llm_provider: Option>, request_id: Option, + start_time: Option, + ttft_recorded: bool, + ttft_duration: Option, // Store the duration directly } impl StreamContext { @@ -46,6 +52,9 @@ impl StreamContext { llm_providers, llm_provider: None, request_id: None, + start_time: None, + ttft_recorded: false, + ttft_duration: None, } } fn llm_provider(&self) -> &LlmProvider { @@ -344,6 +353,31 @@ impl HttpContext for StreamContext { } }; self.response_tokens += token_count; + + // Compute TFT if not already recorded + if !self.ttft_recorded { + if let Some(start_time) = self.start_time { + match get_current_time() { + Ok(current_time) => match current_time.duration_since(start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + debug!("Time to First Token (TFT): {} milliseconds", duration_ms); + self.ttft_duration = Some(duration); + self.metrics.time_to_first_token.record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + }, + Err(e) => { + warn!("Failed to get current time: {:?}", e); + } + } + self.ttft_recorded = true; + } else { + warn!("Start time was not recorded"); + } + } } else { debug!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = @@ -362,6 +396,31 @@ impl HttpContext for StreamContext { .unwrap() .completion_tokens; } + + // Compute TFT if not already recorded + if !self.ttft_recorded { + if let Some(start_time) = self.start_time { + match get_current_time() { + Ok(current_time) => match current_time.duration_since(start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + debug!("Time to First Token (TFT): {} milliseconds", duration_ms); + self.ttft_duration = Some(duration); + self.metrics.time_to_first_token.record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + }, + Err(e) => { + warn!("Failed to get current time: {:?}", e); + } + } + self.ttft_recorded = true; + } else { + warn!("Start time was not recorded"); + } + } } debug!(