diff --git a/crates/common/src/tracing.rs b/crates/common/src/tracing.rs index da67c032..e71ac03c 100644 --- a/crates/common/src/tracing.rs +++ b/crates/common/src/tracing.rs @@ -1,3 +1,5 @@ +use std::path::Display; + use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -187,3 +189,43 @@ impl TraceData { self.resource_spans[0].scope_spans[0].spans.push(span); } } + +pub struct Traceparent { + pub version: String, + pub trace_id: String, + pub parent_id: String, + pub flags: String, +} + +impl std::fmt::Display for Traceparent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}-{}-{}-{}", + self.version, self.trace_id, self.parent_id, self.flags + ) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TraceparentNewError { + #[error("Invalid traceparent: \'{0}\'")] + InvalidTraceparent(String), +} + +impl TryFrom for Traceparent { + type Error = TraceparentNewError; + + fn try_from(traceparent: String) -> Result { + let traceparent_tokens: Vec<&str> = traceparent.split("-").collect::>(); + if traceparent_tokens.len() != 4 { + return Err(TraceparentNewError::InvalidTraceparent(traceparent)); + } + Ok(Traceparent { + version: traceparent_tokens[0].to_string(), + trace_id: traceparent_tokens[1].to_string(), + parent_id: traceparent_tokens[2].to_string(), + flags: traceparent_tokens[3].to_string(), + }) + } +} diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 78945fbb..b5b279de 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -183,8 +183,8 @@ impl Context for FilterContext { .remove(&token_id) .expect("invalid token_id"); - self.get_http_call_response_header(":status").map(|status| { + if let Some(status) = self.get_http_call_response_header(":status") { debug!("trace response status: {:?}", status); - }); + }; } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 33072683..9bb81f11 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -12,7 +12,7 @@ use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::pii::obfuscate_auth_header; use common::ratelimit::Header; -use common::tracing::{Event, Span, TraceData}; +use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, trace, warn}; @@ -359,42 +359,39 @@ impl HttpContext for StreamContext { if let Some(traceparent) = self.traceparent.as_ref() { let current_time_ns = current_time_ns(); - let (trace_id, parent_span_id) = { - let traceparent_tokens: Vec<&str> = - traceparent.split("-").collect::>(); - if traceparent_tokens.len() != 4 { - warn!("traceparent header is invalid: {}", traceparent); - (None, None) - } else { - ( - Some(traceparent_tokens[1].to_string()), - Some(traceparent_tokens[2].to_string()), - ) + match Traceparent::try_from(traceparent.to_string()) { + Err(e) => { + warn!("traceparent header is invalid: {}", e); + } + Ok(traceparent) => { + let mut trace_data = common::tracing::TraceData::new(); + let mut llm_span = Span::new( + "upstream_llm_time".to_string(), + Some(traceparent.trace_id), + Some(traceparent.parent_id), + self.request_body_sent_time.unwrap(), + current_time_ns, + ); + if let Some(user_message) = self.user_message.as_ref() { + if let Some(prompt) = user_message.content.as_ref() { + llm_span + .add_attribute("user_prompt".to_string(), prompt.to_string()); + } + } + llm_span.add_attribute( + "model".to_string(), + self.llm_provider().name.to_string(), + ); + + llm_span.add_event(Event::new( + "time_to_first_token".to_string(), + self.ttft_time.unwrap(), + )); + trace_data.add_span(llm_span); + + self.traces_queue.lock().unwrap().push_back(trace_data); } }; - - let mut trace_data = common::tracing::TraceData::new(); - let mut llm_span = Span::new( - "upstream_llm_time".to_string(), - trace_id, - parent_span_id, - self.request_body_sent_time.unwrap(), - current_time_ns, - ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); - } - } - llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); - - llm_span.add_event(Event::new( - "time_to_first_token".to_string(), - self.ttft_time.unwrap(), - )); - trace_data.add_span(llm_span); - - self.traces_queue.lock().unwrap().push_back(trace_data); } return Action::Continue;