move custom tracer to llm filter (#267)

This commit is contained in:
Adil Hafeez 2024-11-15 10:44:01 -08:00 committed by GitHub
parent 1d229cba8f
commit d3c17c7abd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 335 additions and 133 deletions

View file

@ -1,17 +1,18 @@
use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
StreamOptions,
Message, StreamOptions,
};
use common::configuration::LlmProvider;
use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::pii::obfuscate_auth_header;
use common::ratelimit::Header;
use common::tracing::{Event, Span};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, trace, warn};
@ -23,7 +24,7 @@ use std::rc::Rc;
use common::stats::{IncrementingMetric, RecordingMetric};
use proxy_wasm::hostcalls::get_current_time;
use std::time::{Duration, SystemTime};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct StreamContext {
context_id: u32,
@ -36,7 +37,10 @@ pub struct StreamContext {
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
start_time: Option<SystemTime>,
ttft_duration: Option<Duration>, // Store the duration directly
ttft_duration: Option<Duration>,
ttft_time: Option<SystemTime>,
pub traceparent: Option<String>,
user_message: Option<Message>,
}
impl StreamContext {
@ -53,6 +57,9 @@ impl StreamContext {
request_id: None,
start_time: None,
ttft_duration: None,
traceparent: None,
ttft_time: None,
user_message: None,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -176,9 +183,10 @@ impl HttpContext for StreamContext {
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
//start the timing for the request using get_current_time()
let current_time = get_current_time().unwrap();
let current_time: SystemTime = get_current_time().unwrap();
self.start_time = Some(current_time);
self.ttft_duration = None;
@ -229,6 +237,13 @@ impl HttpContext for StreamContext {
message.model = None;
}
self.user_message = deserialized_body
.messages
.iter()
.filter(|m| m.role == "user")
.last()
.cloned();
// override model name from the llm provider
deserialized_body
.model
@ -318,6 +333,52 @@ impl HttpContext for StreamContext {
.output_sequence_length
.record(self.response_tokens as u64);
if let Some(traceparent) = self.traceparent.as_ref() {
let since_the_epoch_ns = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let traceparent_tokens = traceparent.split("-").collect::<Vec<&str>>();
if traceparent_tokens.len() != 4 {
warn!("traceparent header is invalid: {}", traceparent);
return Action::Continue;
}
let parent_trace_id = traceparent_tokens[1];
let parent_span_id = traceparent_tokens[2];
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"upstream_llm_time".to_string(),
parent_trace_id.to_string(),
Some(parent_span_id.to_string()),
self.start_time
.unwrap()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
since_the_epoch_ns,
);
if let Some(user_message) = self.user_message.as_ref() {
if let Some(prompt) = user_message.content.as_ref() {
llm_span.add_attribute("user_prompt".to_string(), prompt.to_string());
}
}
llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string());
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time
.unwrap()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
trace_data.add_span(llm_span);
let trace_data_str = serde_json::to_string(&trace_data).unwrap();
debug!("upstream_llm trace details: {}", trace_data_str);
// send trace_data to http tracing endpoint
}
return Action::Continue;
}
@ -413,6 +474,7 @@ impl HttpContext for StreamContext {
if self.ttft_duration.is_none() {
if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time);
match current_time.duration_since(start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();