Update arch stats (#250)

This commit is contained in:
Aayush 2024-11-12 15:03:26 -08:00 committed by GitHub
parent 30647fd508
commit 5993e36f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1443 additions and 17 deletions

View file

@ -20,7 +20,10 @@ use proxy_wasm::types::*;
use std::num::NonZero;
use std::rc::Rc;
use common::stats::IncrementingMetric;
use common::stats::{IncrementingMetric, RecordingMetric};
use proxy_wasm::hostcalls::get_current_time;
use std::time::{Duration, SystemTime};
pub struct StreamContext {
context_id: u32,
@ -32,6 +35,8 @@ pub struct StreamContext {
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
start_time: Option<SystemTime>,
ttft_duration: Option<Duration>, // Store the duration directly
}
impl StreamContext {
@ -46,6 +51,8 @@ impl StreamContext {
llm_providers,
llm_provider: None,
request_id: None,
start_time: None,
ttft_duration: None,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -120,16 +127,27 @@ impl StreamContext {
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
// Tokenize and record token count.
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
// Record the token count to metrics.
self.metrics
.input_sequence_length
.record(token_count as u64);
log::debug!("Recorded input token count: {}", token_count);
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
// Tokenize and Ratelimit.
if let Ok(token_count) = tokenizer::token_count(model, json_string) {
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
NonZero::new(token_count as u32).unwrap(),
)?;
}
log::debug!("Applying ratelimit for model: {}", model);
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
NonZero::new(token_count as u32).unwrap(),
)?;
} else {
log::debug!("No rate limit applied for model: {}", model);
}
Ok(())
}
}
@ -158,6 +176,12 @@ impl HttpContext for StreamContext {
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
//start the timing for the request using get_current_time()
let current_time = get_current_time().unwrap();
self.start_time = Some(current_time);
self.ttft_duration = None;
Action::Continue
}
@ -226,9 +250,15 @@ impl HttpContext for StreamContext {
});
}
// only use the tokens from the messages, excluding the metadata and json tags
let input_tokens_str = deserialized_body
.messages
.iter()
.fold(String::new(), |acc, m| {
acc + " " + m.content.as_ref().unwrap_or(&String::new())
});
// enforce ratelimits on ingress
if let Err(e) =
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
{
self.send_server_error(
ServerError::ExceededRatelimit(e),
@ -254,10 +284,33 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let body = if self.streaming_response {
if end_of_stream && body_size == 0 {
return Action::Continue;
let current_time = get_current_time().unwrap();
if end_of_stream && body_size == 0 {
// All streaming responses end with bytes=0 and end_stream=true
// Record the latency for the request
if let Some(start_time) = self.start_time {
match current_time.duration_since(start_time) {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
debug!("Total latency: {} milliseconds", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
}
// Record the output sequence length
self.metrics
.output_sequence_length
.record(self.response_tokens as u64);
return Action::Continue;
}
let body = if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
debug!(
@ -344,6 +397,26 @@ impl HttpContext for StreamContext {
}
};
self.response_tokens += token_count;
// Compute TTFT if not already recorded
if self.ttft_duration.is_none() {
if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
match current_time.duration_since(start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
} else {
warn!("Start time was not recorded");
}
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =