diff --git a/crates/common/src/tracing.rs b/crates/common/src/tracing.rs index 709ff032..da67c032 100644 --- a/crates/common/src/tracing.rs +++ b/crates/common/src/tracing.rs @@ -47,15 +47,18 @@ pub struct Span { impl Span { pub fn new( name: String, - trace_id: String, - span_id: String, + trace_id: Option, parent_span_id: Option, start_time_unix_nano: u128, end_time_unix_nano: u128, ) -> Self { + let trace_id = match trace_id { + Some(trace_id) => trace_id, + None => Span::get_random_trace_id(), + }; Span { trace_id, - span_id, + span_id: Span::get_random_span_id(), parent_span_id, name, start_time_unix_nano: format!("{}", start_time_unix_nano), @@ -81,6 +84,22 @@ impl Span { } self.events.as_mut().unwrap().push(event); } + + fn get_random_span_id() -> String { + let mut rng = rand::thread_rng(); + let mut random_bytes = [0u8; 8]; + rng.fill_bytes(&mut random_bytes); + + hex::encode(random_bytes) + } + + fn get_random_trace_id() -> String { + let mut rng = rand::thread_rng(); + let mut random_bytes = [0u8; 16]; + rng.fill_bytes(&mut random_bytes); + + hex::encode(random_bytes) + } } #[derive(Serialize, Deserialize, Debug)] @@ -168,19 +187,3 @@ impl TraceData { self.resource_spans[0].scope_spans[0].spans.push(span); } } - -pub fn get_random_span_id() -> String { - let mut rng = rand::thread_rng(); - let mut random_bytes = [0u8; 8]; - rng.fill_bytes(&mut random_bytes); - - hex::encode(random_bytes) -} - -pub fn get_random_trace_id() -> String { - let mut rng = rand::thread_rng(); - let mut random_bytes = [0u8; 16]; - rng.fill_bytes(&mut random_bytes); - - hex::encode(random_bytes) -} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 660ac2ef..1060dddf 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -12,7 +12,7 @@ use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::pii::obfuscate_auth_header; use common::ratelimit::Header; -use common::tracing::{get_random_span_id, get_random_trace_id, Event, Span, TraceData}; +use common::tracing::{Event, Span, TraceData}; use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, trace, warn}; @@ -40,13 +40,9 @@ pub struct StreamContext { request_id: Option, start_time: SystemTime, ttft_duration: Option, - ttft_time: Option, - trace_id: String, - span_id: String, - parent_span_id: Option, - traceparent: String, - traceparent_present_in_request: bool, - request_body_sent_time: Option, + ttft_time: Option, + traceparent: Option, + request_body_sent_time: Option, user_message: Option, traces_queue: Arc>>, } @@ -58,8 +54,6 @@ impl StreamContext { llm_providers: Rc, traces_queue: Arc>>, ) -> Self { - let trace_id = get_random_trace_id(); - let span_id = get_random_span_id(); StreamContext { context_id, metrics, @@ -72,14 +66,10 @@ impl StreamContext { request_id: None, start_time: SystemTime::now(), ttft_duration: None, - traceparent: format!("00-{}-{}-01", trace_id, span_id), - trace_id, - parent_span_id: Some(span_id.clone()), - span_id, + traceparent: None, ttft_time: None, user_message: None, traces_queue, - traceparent_present_in_request: false, request_body_sent_time: None, } } @@ -204,24 +194,7 @@ impl HttpContext for StreamContext { ); self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); - // if traceparent is not present in the request, set it and add it to the response headers - if let Some(traceparent) = self.get_http_request_header(TRACE_PARENT_HEADER) { - debug!("traceparent set"); - self.traceparent = traceparent; - self.traceparent_present_in_request = true; - self.parent_span_id = { - let traceparent_tokens: Vec<&str> = - self.traceparent.split("-").collect::>(); - if traceparent_tokens.len() != 4 { - warn!("traceparent header is invalid: {}", self.traceparent); - None - } else { - Some(traceparent_tokens[2].to_string()) - } - }; - } else { - self.set_http_request_header(TRACE_PARENT_HEADER, Some(self.traceparent.as_str())); - } + self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); Action::Continue } @@ -231,7 +204,7 @@ impl HttpContext for StreamContext { // TODO: consider a streaming API. if self.request_body_sent_time.is_none() { - self.request_body_sent_time = Some(get_current_time().unwrap()); + self.request_body_sent_time = Some(current_time_ns()); } if !end_of_stream { @@ -399,56 +372,46 @@ impl HttpContext for StreamContext { .output_sequence_length .record(self.response_tokens as u64); - // if let Some(traceparent) = self.traceparent.as_ref() { - let since_the_epoch_ns = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(); + if let Some(traceparent) = self.traceparent.as_ref() { + let current_time_ns = current_time_ns(); - let parent_span_id = { - if self.traceparent_present_in_request { - self.parent_span_id.clone() - } else { - None - } - }; + let (trace_id, parent_span_id) = { + let traceparent_tokens: Vec<&str> = + traceparent.split("-").collect::>(); + if traceparent_tokens.len() != 4 { + warn!("traceparent header is invalid: {}", traceparent); + (None, None) + } else { + ( + Some(traceparent_tokens[1].to_string()), + Some(traceparent_tokens[2].to_string()), + ) + } + }; - let mut trace_data = common::tracing::TraceData::new(); - let mut llm_span = Span::new( - "upstream_llm_time".to_string(), - self.trace_id.to_string(), - self.span_id.to_string(), - parent_span_id, - self.request_body_sent_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - since_the_epoch_ns, - ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); + let mut trace_data = common::tracing::TraceData::new(); + let mut llm_span = Span::new( + "upstream_llm_time".to_string(), + trace_id, + parent_span_id, + self.request_body_sent_time.unwrap(), + current_time_ns, + ); + if let Some(user_message) = self.user_message.as_ref() { + if let Some(prompt) = user_message.content.as_ref() { + llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); + } } + llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); + + llm_span.add_event(Event::new( + "time_to_first_token".to_string(), + self.ttft_time.unwrap(), + )); + trace_data.add_span(llm_span); + + self.traces_queue.lock().unwrap().push_back(trace_data); } - llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); - - llm_span.add_event(Event::new( - "time_to_first_token".to_string(), - self.ttft_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - )); - trace_data.add_span(llm_span); - - // debug!("upstream_llm trace details: {:?}", trace_data); - self.traces_queue.lock().unwrap().push_back(trace_data); - - // let trace_data_str = serde_json::to_string(&trace_data).unwrap(); - // send trace_data to http tracing endpoint - // } return Action::Continue; } @@ -545,7 +508,7 @@ impl HttpContext for StreamContext { if self.ttft_duration.is_none() { // if let Some(start_time) = self.start_time { let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time); + self.ttft_time = Some(current_time_ns()); match current_time.duration_since(self.start_time) { Ok(duration) => { let duration_ms = duration.as_millis(); @@ -590,4 +553,11 @@ impl HttpContext for StreamContext { } } +fn current_time_ns() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() +} + impl Context for StreamContext {}