From 780c7cf7ad358633161452bc62ecb5e50246a733 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Sun, 17 Nov 2024 16:20:51 -0800 Subject: [PATCH] wip --- arch/envoy.template.yaml | 4 - arch/supervisord.conf | 21 +- crates/common/src/consts.rs | 1 + crates/common/src/tracing.rs | 15 +- crates/llm_gateway/src/filter_context.rs | 85 +++++++- crates/llm_gateway/src/stream_context.rs | 241 +++++++++++++---------- 6 files changed, 246 insertions(+), 121 deletions(-) diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 39735bc4..59c2c77d 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -614,10 +614,6 @@ static_resources: dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} load_assignment: cluster_name: opentelemetry_collector_http endpoints: diff --git a/arch/supervisord.conf b/arch/supervisord.conf index da659e65..5af78121 100644 --- a/arch/supervisord.conf +++ b/arch/supervisord.conf @@ -1,20 +1,21 @@ [supervisord] nodaemon=true -[program:trace_streamer] -command=sh -c "tail -F /var/log/envoy.log | python stream_traces.py" -autostart=true -autorestart=false -startretries=3 -priority=1 -stdout_logfile=/dev/stdout -stderr_logfile=/dev/stderr -stdout_logfile_maxbytes = 0 -stderr_logfile_maxbytes = 0 +; [program:trace_streamer] +; command=sh -c "tail -F /var/log/envoy.log | python stream_traces.py" +; autostart=true +; autorestart=false +; startretries=3 +; priority=1 +; stdout_logfile=/dev/stdout +; stderr_logfile=/dev/stderr +; stdout_logfile_maxbytes = 0 +; stderr_logfile_maxbytes = 0 [program:envoy] command=sh -c "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log" +; command=sh -c "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --log-level trace 2>&1 | tee /var/log/envoy.log" autostart=true autorestart=true startretries=3 diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 7ac5ea1c..b8348ad4 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -29,3 +29,4 @@ pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; pub const HALLUCINATION_TEMPLATE: &str = "It seems I'm missing some information. Could you provide the following details "; +pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http"; diff --git a/crates/common/src/tracing.rs b/crates/common/src/tracing.rs index 0d2bb978..709ff032 100644 --- a/crates/common/src/tracing.rs +++ b/crates/common/src/tracing.rs @@ -47,14 +47,15 @@ pub struct Span { impl Span { pub fn new( name: String, - parent_trace_id: String, + trace_id: String, + span_id: String, parent_span_id: Option, start_time_unix_nano: u128, end_time_unix_nano: u128, ) -> Self { Span { - trace_id: parent_trace_id, - span_id: get_random_span_id(), + trace_id, + span_id, parent_span_id, name, start_time_unix_nano: format!("{}", start_time_unix_nano), @@ -175,3 +176,11 @@ pub fn get_random_span_id() -> String { hex::encode(random_bytes) } + +pub fn get_random_trace_id() -> String { + let mut rng = rand::thread_rng(); + let mut random_bytes = [0u8; 16]; + rng.fill_bytes(&mut random_bytes); + + hex::encode(random_bytes) +} diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 9a34fe98..c058af5a 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,17 +1,25 @@ use crate::stream_context::StreamContext; use common::configuration::Configuration; +use common::consts::OTEL_COLLECTOR_HTTP; +use common::http::CallArgs; use common::http::Client; use common::llm_providers::LlmProviders; use common::ratelimit; use common::stats::Counter; use common::stats::Gauge; use common::stats::Histogram; +use common::tracing::TraceData; use log::debug; +use log::warn; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; use std::collections::HashMap; +use std::collections::VecDeque; use std::rc::Rc; +use std::time::Duration; + +use std::sync::{Arc, Mutex}; #[derive(Copy, Clone, Debug)] pub struct WasmMetrics { @@ -49,14 +57,31 @@ pub struct FilterContext { // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, llm_providers: Option>, + // traces: Rc>>, + traces_queue: Arc>>, + // trace_sender: Rc>, + // receiver: Receiver, } impl FilterContext { pub fn new() -> FilterContext { + // let (sender, receiver) = channel::(); + // thread::spawn(move || { + // while let Ok(trace) = receiver.recv() { + // debug!("received trace: {:?}", trace); + // } + // }); + // let queue: Arc>> = Arc::new(Mutex::new(Vec::new())); + // queue.lock().unwrap().push("foo".to_string()); + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, + // traces: Rc::new(RefCell::new(VecDeque::new())), + traces_queue: Arc::new(Mutex::new(VecDeque::new())), + // trace_sender: Rc::new(sender), + // receiver, } } } @@ -73,8 +98,6 @@ impl Client for FilterContext { } } -impl Context for FilterContext {} - // RootContext allows the Rust code to reach into the Envoy Config impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { @@ -111,10 +134,68 @@ impl RootContext for FilterContext { .as_ref() .expect("LLM Providers must exist when Streams are being created"), ), + Arc::clone(&self.traces_queue), ))) } fn get_type(&self) -> Option { Some(ContextType::HttpContext) } + + fn on_vm_start(&mut self, _vm_configuration_size: usize) -> bool { + self.set_tick_period(Duration::from_secs(1)); + true + } + + fn on_tick(&mut self) { + let _ = self.traces_queue.try_lock().map(|mut traces_queue| { + while let Some(trace) = traces_queue.pop_front() { + debug!("trace received: {:?}", trace); + + let trace_str = serde_json::to_string(&trace).unwrap(); + debug!("trace: {}", trace_str); + let call_args = CallArgs::new( + OTEL_COLLECTOR_HTTP, + "/v1/traces", + vec![ + (":method", "POST"), + (":path", "/v1/traces"), + (":authority", OTEL_COLLECTOR_HTTP), + ("content-type", "application/json"), + ], + Some(trace_str.as_bytes()), + vec![], + Duration::from_secs(60), + ); + if let Err(error) = self.http_call(call_args, CallContext {}) { + warn!("failed to schedule http call: {:?}", error); + } + } + }); + } +} + +impl Context for FilterContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + _body_size: usize, + _num_trailers: usize, + ) { + debug!( + "||| on_http_call_response called with token_id: {:?} |||", + token_id + ); + + let _callout_data = self + .callouts + .borrow_mut() + .remove(&token_id) + .expect("invalid token_id"); + + self.get_http_call_response_header(":status").map(|status| { + debug!("trace response status: {:?}", status); + }); + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index cdfdbeb2..1a9df7ff 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -12,14 +12,16 @@ use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::pii::obfuscate_auth_header; use common::ratelimit::Header; -use common::tracing::{Event, Span}; +use common::tracing::{get_random_span_id, get_random_trace_id, Event, Span, TraceData}; use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, trace, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; +use std::collections::VecDeque; use std::num::NonZero; use std::rc::Rc; +use std::sync::{Arc, Mutex}; use common::stats::{IncrementingMetric, RecordingMetric}; @@ -36,15 +38,27 @@ pub struct StreamContext { llm_providers: Rc, llm_provider: Option>, request_id: Option, - start_time: Option, + start_time: SystemTime, ttft_duration: Option, ttft_time: Option, - pub traceparent: Option, + trace_id: String, + span_id: String, + traceparent: String, + parent_span_id: Option, + traceparent_present_in_request: bool, user_message: Option, + traces_queue: Arc>>, } impl StreamContext { - pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { + pub fn new( + context_id: u32, + metrics: Rc, + llm_providers: Rc, + traces_queue: Arc>>, + ) -> Self { + let trace_id = get_random_trace_id(); + let span_id = get_random_span_id(); StreamContext { context_id, metrics, @@ -55,11 +69,16 @@ impl StreamContext { llm_providers, llm_provider: None, request_id: None, - start_time: None, + start_time: SystemTime::now(), ttft_duration: None, - traceparent: None, + traceparent: format!("00-{}-{}-01", trace_id, span_id), + trace_id, + parent_span_id: Some(span_id.clone()), + span_id, ttft_time: None, user_message: None, + traces_queue, + traceparent_present_in_request: false, } } fn llm_provider(&self) -> &LlmProvider { @@ -183,12 +202,24 @@ impl HttpContext for StreamContext { ); self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); - self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); - - //start the timing for the request using get_current_time() - let current_time: SystemTime = get_current_time().unwrap(); - self.start_time = Some(current_time); - self.ttft_duration = None; + // if traceparent is not present in the request, set it and add it to the response headers + if let Some(traceparent) = self.get_http_request_header(TRACE_PARENT_HEADER) { + debug!("traceparent set"); + self.traceparent = traceparent; + self.traceparent_present_in_request = true; + self.parent_span_id = { + let traceparent_tokens: Vec<&str> = + self.traceparent.split("-").collect::>(); + if traceparent_tokens.len() != 4 { + warn!("traceparent header is invalid: {}", self.traceparent); + None + } else { + Some(traceparent_tokens[2].to_string()) + } + }; + } else { + self.set_http_request_header(TRACE_PARENT_HEADER, Some(self.traceparent.as_str())); + } Action::Continue } @@ -294,21 +325,26 @@ impl HttpContext for StreamContext { self.context_id, _end_of_stream ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - debug!("setting user-message header: {}", prompt); - self.set_http_response_header("x-user-message", Some(&prompt)); - } - } + // if let Some(user_message) = self.user_message.as_ref() { + // if let Some(prompt) = user_message.content.as_ref() { + // debug!("setting user-message header: {}", prompt); + // self.set_http_response_header("x-user-message", Some(&prompt)); + // } + // } - let tftt_time_ms = get_current_time() - .unwrap() - .duration_since(self.start_time.unwrap()) - .unwrap() - .as_millis(); + // let tftt_time_ms = get_current_time() + // .unwrap() + // .duration_since(self.start_time.unwrap()) + // .unwrap() + // .as_millis(); - let tftt_time = tftt_time_ms.to_string(); - self.set_http_response_header("x-time-to-first-token", Some(&tftt_time)); + // let tftt_time = tftt_time_ms.to_string(); + // self.set_http_response_header("x-time-to-first-token", Some(&tftt_time)); + + self.set_property( + vec!["metadata", "filter_metadata", "llm_filter", "user_prompt"], + Some("hello world from filter".as_bytes()), + ); Action::Continue } @@ -328,29 +364,27 @@ impl HttpContext for StreamContext { if end_of_stream && body_size == 0 { // All streaming responses end with bytes=0 and end_stream=true // Record the latency for the request - if let Some(start_time) = self.start_time { - match current_time.duration_since(start_time) { - Ok(duration) => { - // Convert the duration to milliseconds - let duration_ms = duration.as_millis(); - debug!("Total latency: {} milliseconds", duration_ms); - // Record the latency to the latency histogram - self.metrics.request_latency.record(duration_ms as u64); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + // Convert the duration to milliseconds + let duration_ms = duration.as_millis(); + debug!("Total latency: {} milliseconds", duration_ms); + // Record the latency to the latency histogram + self.metrics.request_latency.record(duration_ms as u64); - // Compute the time per output token - let tpot = duration_ms as u64 / self.response_tokens as u64; + // Compute the time per output token + let tpot = duration_ms as u64 / self.response_tokens as u64; - debug!("Time per output token: {} milliseconds", tpot); - // Record the time per output token - self.metrics.time_per_output_token.record(tpot); + debug!("Time per output token: {} milliseconds", tpot); + // Record the time per output token + self.metrics.time_per_output_token.record(tpot); - debug!("Tokens per second: {}", 1000 / tpot); - // Record the tokens per second - self.metrics.tokens_per_second.record(1000 / tpot); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } + debug!("Tokens per second: {}", 1000 / tpot); + // Record the tokens per second + self.metrics.tokens_per_second.record(1000 / tpot); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); } } // Record the output sequence length @@ -358,52 +392,55 @@ impl HttpContext for StreamContext { .output_sequence_length .record(self.response_tokens as u64); - if let Some(traceparent) = self.traceparent.as_ref() { - let since_the_epoch_ns = SystemTime::now() + // if let Some(traceparent) = self.traceparent.as_ref() { + let since_the_epoch_ns = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + + let parent_span_id = { + if self.traceparent_present_in_request { + self.parent_span_id.clone() + } else { + None + } + }; + + let mut trace_data = common::tracing::TraceData::new(); + let mut llm_span = Span::new( + "upstream_llm_time".to_string(), + self.trace_id.to_string(), + self.span_id.to_string(), + parent_span_id, + self.start_time .duration_since(UNIX_EPOCH) .unwrap() - .as_nanos(); - - let traceparent_tokens = traceparent.split("-").collect::>(); - if traceparent_tokens.len() != 4 { - warn!("traceparent header is invalid: {}", traceparent); - return Action::Continue; + .as_nanos(), + since_the_epoch_ns, + ); + if let Some(user_message) = self.user_message.as_ref() { + if let Some(prompt) = user_message.content.as_ref() { + llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); } - let parent_trace_id = traceparent_tokens[1]; - let parent_span_id = traceparent_tokens[2]; - let mut trace_data = common::tracing::TraceData::new(); - let mut llm_span = Span::new( - "upstream_llm_time".to_string(), - parent_trace_id.to_string(), - Some(parent_span_id.to_string()), - self.start_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - since_the_epoch_ns, - ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); - } - } - llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); - - llm_span.add_event(Event::new( - "time_to_first_token".to_string(), - self.ttft_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - )); - trace_data.add_span(llm_span); - - let trace_data_str = serde_json::to_string(&trace_data).unwrap(); - debug!("upstream_llm trace details: {}", trace_data_str); - // send trace_data to http tracing endpoint } + llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); + + llm_span.add_event(Event::new( + "time_to_first_token".to_string(), + self.ttft_time + .unwrap() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(), + )); + trace_data.add_span(llm_span); + + // debug!("upstream_llm trace details: {:?}", trace_data); + self.traces_queue.lock().unwrap().push_back(trace_data); + + // let trace_data_str = serde_json::to_string(&trace_data).unwrap(); + // send trace_data to http tracing endpoint + // } return Action::Continue; } @@ -498,23 +535,23 @@ impl HttpContext for StreamContext { // Compute TTFT if not already recorded if self.ttft_duration.is_none() { - if let Some(start_time) = self.start_time { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time); - match current_time.duration_since(start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); - self.ttft_duration = Some(duration); - self.metrics.time_to_first_token.record(duration_ms as u64); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } + // if let Some(start_time) = self.start_time { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); + self.ttft_duration = Some(duration); + self.metrics.time_to_first_token.record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); } - } else { - warn!("Start time was not recorded"); } + // } else { + // warn!("Start time was not recorded"); + // } } } else { debug!("non streaming response");