diff --git a/arch/Dockerfile b/arch/Dockerfile index 0d96713c..e1c49f1d 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -13,7 +13,7 @@ FROM envoyproxy/envoy:v1.32-latest as envoy #Build config generator, so that we have a single build image for both Rust and Python FROM python:3.12-slim as arch -RUN apt-get update && apt-get install -y gettext-base curl supervisor && apt-get clean && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && rm -rf /var/lib/apt/lists/* COPY --from=builder /arch/target/wasm32-wasip1/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm COPY --from=builder /arch/target/wasm32-wasip1/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm @@ -24,10 +24,8 @@ RUN pip install -r requirements.txt COPY arch/tools/cli/config_generator.py . COPY arch/envoy.template.yaml . COPY arch/arch_config_schema.yaml . -COPY arch/supervisord.conf /etc/supervisor/conf.d/supervisord.conf -COPY arch/stream_traces.py . RUN pip install requests RUN touch /var/log/envoy.log -ENTRYPOINT ["supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] +ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log"] diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 52671f99..4ae2c00b 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -575,10 +575,6 @@ static_resources: dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} load_assignment: cluster_name: opentelemetry_collector_http endpoints: diff --git a/arch/stream_traces.py b/arch/stream_traces.py deleted file mode 100644 index 1a165a8a..00000000 --- a/arch/stream_traces.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import sys -import time -import requests -import logging - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -otel_tracing_endpoint = os.getenv( - "OTEL_TRACING_HTTP_ENDPOINT", "http://localhost:4318/v1/traces" -) -envoy_log_path = os.getenv("ENVOY_LOG_PATH", "/var/log/envoy.log") - -logging.info(f"Using otel-tracing host: {otel_tracing_endpoint}") -logging.info(f"Using envoy log path: {envoy_log_path}") - - -def process_log_line(line): - try: - response = requests.post( - url=otel_tracing_endpoint, - data=line, - headers={"Content-Type": "application/json"}, - ) - logging.info(f"Sent trace to otel-tracing: {response.status_code}") - except Exception as e: - logging.error(f"Failed to send trace to otel-tracing: {e}") - - -for line in sys.stdin: - if line: - tokens = line.split("gateway: upstream_llm trace details: ") - if len(tokens) > 1: - process_log_line(tokens[1]) diff --git a/arch/supervisord.conf b/arch/supervisord.conf deleted file mode 100644 index da659e65..00000000 --- a/arch/supervisord.conf +++ /dev/null @@ -1,25 +0,0 @@ -[supervisord] -nodaemon=true - -[program:trace_streamer] -command=sh -c "tail -F /var/log/envoy.log | python stream_traces.py" -autostart=true -autorestart=false -startretries=3 -priority=1 -stdout_logfile=/dev/stdout -stderr_logfile=/dev/stderr -stdout_logfile_maxbytes = 0 -stderr_logfile_maxbytes = 0 - - -[program:envoy] -command=sh -c "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log" -autostart=true -autorestart=true -startretries=3 -priority=2 -stdout_logfile=/dev/stdout -stderr_logfile=/dev/stderr -stdout_logfile_maxbytes = 0 -stderr_logfile_maxbytes = 0 diff --git a/archgw.code-workspace b/archgw.code-workspace index 07b23996..5e04c71a 100644 --- a/archgw.code-workspace +++ b/archgw.code-workspace @@ -40,6 +40,7 @@ "github.copilot", "eamodio.gitlens", "ms-python.black-formatter", + "tamasfe.even-better-toml", ] } } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 7ac5ea1c..28db9fb8 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -29,3 +29,5 @@ pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; pub const HALLUCINATION_TEMPLATE: &str = "It seems I'm missing some information. Could you provide the following details "; +pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http"; +pub const OTEL_POST_PATH: &str = "/v1/traces"; diff --git a/crates/common/src/tracing.rs b/crates/common/src/tracing.rs index 0d2bb978..e71ac03c 100644 --- a/crates/common/src/tracing.rs +++ b/crates/common/src/tracing.rs @@ -1,3 +1,5 @@ +use std::path::Display; + use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -47,14 +49,18 @@ pub struct Span { impl Span { pub fn new( name: String, - parent_trace_id: String, + trace_id: Option, parent_span_id: Option, start_time_unix_nano: u128, end_time_unix_nano: u128, ) -> Self { + let trace_id = match trace_id { + Some(trace_id) => trace_id, + None => Span::get_random_trace_id(), + }; Span { - trace_id: parent_trace_id, - span_id: get_random_span_id(), + trace_id, + span_id: Span::get_random_span_id(), parent_span_id, name, start_time_unix_nano: format!("{}", start_time_unix_nano), @@ -80,6 +86,22 @@ impl Span { } self.events.as_mut().unwrap().push(event); } + + fn get_random_span_id() -> String { + let mut rng = rand::thread_rng(); + let mut random_bytes = [0u8; 8]; + rng.fill_bytes(&mut random_bytes); + + hex::encode(random_bytes) + } + + fn get_random_trace_id() -> String { + let mut rng = rand::thread_rng(); + let mut random_bytes = [0u8; 16]; + rng.fill_bytes(&mut random_bytes); + + hex::encode(random_bytes) + } } #[derive(Serialize, Deserialize, Debug)] @@ -168,10 +190,42 @@ impl TraceData { } } -pub fn get_random_span_id() -> String { - let mut rng = rand::thread_rng(); - let mut random_bytes = [0u8; 8]; - rng.fill_bytes(&mut random_bytes); - - hex::encode(random_bytes) +pub struct Traceparent { + pub version: String, + pub trace_id: String, + pub parent_id: String, + pub flags: String, +} + +impl std::fmt::Display for Traceparent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}-{}-{}-{}", + self.version, self.trace_id, self.parent_id, self.flags + ) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TraceparentNewError { + #[error("Invalid traceparent: \'{0}\'")] + InvalidTraceparent(String), +} + +impl TryFrom for Traceparent { + type Error = TraceparentNewError; + + fn try_from(traceparent: String) -> Result { + let traceparent_tokens: Vec<&str> = traceparent.split("-").collect::>(); + if traceparent_tokens.len() != 4 { + return Err(TraceparentNewError::InvalidTraceparent(traceparent)); + } + Ok(Traceparent { + version: traceparent_tokens[0].to_string(), + trace_id: traceparent_tokens[1].to_string(), + parent_id: traceparent_tokens[2].to_string(), + flags: traceparent_tokens[3].to_string(), + }) + } } diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 9a34fe98..b5b279de 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,17 +1,26 @@ use crate::stream_context::StreamContext; use common::configuration::Configuration; +use common::consts::OTEL_COLLECTOR_HTTP; +use common::consts::OTEL_POST_PATH; +use common::http::CallArgs; use common::http::Client; use common::llm_providers::LlmProviders; use common::ratelimit; use common::stats::Counter; use common::stats::Gauge; use common::stats::Histogram; +use common::tracing::TraceData; use log::debug; +use log::warn; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; use std::collections::HashMap; +use std::collections::VecDeque; use std::rc::Rc; +use std::time::Duration; + +use std::sync::{Arc, Mutex}; #[derive(Copy, Clone, Debug)] pub struct WasmMetrics { @@ -49,6 +58,7 @@ pub struct FilterContext { // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, llm_providers: Option>, + traces_queue: Arc>>, } impl FilterContext { @@ -57,6 +67,7 @@ impl FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, + traces_queue: Arc::new(Mutex::new(VecDeque::new())), } } } @@ -73,8 +84,6 @@ impl Client for FilterContext { } } -impl Context for FilterContext {} - // RootContext allows the Rust code to reach into the Envoy Config impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { @@ -111,10 +120,71 @@ impl RootContext for FilterContext { .as_ref() .expect("LLM Providers must exist when Streams are being created"), ), + Arc::clone(&self.traces_queue), ))) } fn get_type(&self) -> Option { Some(ContextType::HttpContext) } + + fn on_vm_start(&mut self, _vm_configuration_size: usize) -> bool { + self.set_tick_period(Duration::from_secs(1)); + true + } + + fn on_tick(&mut self) { + let _ = self.traces_queue.try_lock().map(|mut traces_queue| { + while let Some(trace) = traces_queue.pop_front() { + debug!("trace received: {:?}", trace); + + let trace_str = serde_json::to_string(&trace).unwrap(); + debug!("trace: {}", trace_str); + let call_args = CallArgs::new( + OTEL_COLLECTOR_HTTP, + OTEL_POST_PATH, + vec![ + (":method", http::Method::POST.as_str()), + (":path", OTEL_POST_PATH), + (":authority", OTEL_COLLECTOR_HTTP), + ("content-type", "application/json"), + ], + Some(trace_str.as_bytes()), + vec![], + Duration::from_secs(60), + ); + if let Err(error) = self.http_call(call_args, CallContext {}) { + warn!( + "failed to schedule http call to otel-collector: {:?}", + error + ); + } + } + }); + } +} + +impl Context for FilterContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + _body_size: usize, + _num_trailers: usize, + ) { + debug!( + "||| on_http_call_response called with token_id: {:?} |||", + token_id + ); + + let _callout_data = self + .callouts + .borrow_mut() + .remove(&token_id) + .expect("invalid token_id"); + + if let Some(status) = self.get_http_call_response_header(":status") { + debug!("trace response status: {:?}", status); + }; + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index c7e42c82..9bb81f11 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -12,14 +12,16 @@ use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::pii::obfuscate_auth_header; use common::ratelimit::Header; -use common::tracing::{Event, Span}; +use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, trace, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; +use std::collections::VecDeque; use std::num::NonZero; use std::rc::Rc; +use std::sync::{Arc, Mutex}; use common::stats::{IncrementingMetric, RecordingMetric}; @@ -36,16 +38,22 @@ pub struct StreamContext { llm_providers: Rc, llm_provider: Option>, request_id: Option, - start_time: Option, + start_time: SystemTime, ttft_duration: Option, - ttft_time: Option, - pub traceparent: Option, - request_body_sent_time: Option, + ttft_time: Option, + traceparent: Option, + request_body_sent_time: Option, user_message: Option, + traces_queue: Arc>>, } impl StreamContext { - pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { + pub fn new( + context_id: u32, + metrics: Rc, + llm_providers: Rc, + traces_queue: Arc>>, + ) -> Self { StreamContext { context_id, metrics, @@ -56,11 +64,12 @@ impl StreamContext { llm_providers, llm_provider: None, request_id: None, - start_time: None, + start_time: SystemTime::now(), ttft_duration: None, traceparent: None, ttft_time: None, user_message: None, + traces_queue, request_body_sent_time: None, } } @@ -187,11 +196,6 @@ impl HttpContext for StreamContext { self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); - //start the timing for the request using get_current_time() - let current_time: SystemTime = get_current_time().unwrap(); - self.start_time = Some(current_time); - self.ttft_duration = None; - Action::Continue } @@ -200,7 +204,7 @@ impl HttpContext for StreamContext { // TODO: consider a streaming API. if self.request_body_sent_time.is_none() { - self.request_body_sent_time = Some(get_current_time().unwrap()); + self.request_body_sent_time = Some(current_time_ns()); } if !end_of_stream { @@ -295,6 +299,20 @@ impl HttpContext for StreamContext { Action::Continue } + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!( + "on_http_response_headers [S={}] end_stream={}", + self.context_id, _end_of_stream + ); + + self.set_property( + vec!["metadata", "filter_metadata", "llm_filter", "user_prompt"], + Some("hello world from filter".as_bytes()), + ); + + Action::Continue + } + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { debug!( "on_http_response_body [S={}] bytes={} end_stream={}", @@ -310,29 +328,27 @@ impl HttpContext for StreamContext { if end_of_stream && body_size == 0 { // All streaming responses end with bytes=0 and end_stream=true // Record the latency for the request - if let Some(start_time) = self.start_time { - match current_time.duration_since(start_time) { - Ok(duration) => { - // Convert the duration to milliseconds - let duration_ms = duration.as_millis(); - debug!("Total latency: {} milliseconds", duration_ms); - // Record the latency to the latency histogram - self.metrics.request_latency.record(duration_ms as u64); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + // Convert the duration to milliseconds + let duration_ms = duration.as_millis(); + debug!("Total latency: {} milliseconds", duration_ms); + // Record the latency to the latency histogram + self.metrics.request_latency.record(duration_ms as u64); - // Compute the time per output token - let tpot = duration_ms as u64 / self.response_tokens as u64; + // Compute the time per output token + let tpot = duration_ms as u64 / self.response_tokens as u64; - debug!("Time per output token: {} milliseconds", tpot); - // Record the time per output token - self.metrics.time_per_output_token.record(tpot); + debug!("Time per output token: {} milliseconds", tpot); + // Record the time per output token + self.metrics.time_per_output_token.record(tpot); - debug!("Tokens per second: {}", 1000 / tpot); - // Record the tokens per second - self.metrics.tokens_per_second.record(1000 / tpot); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } + debug!("Tokens per second: {}", 1000 / tpot); + // Record the tokens per second + self.metrics.tokens_per_second.record(1000 / tpot); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); } } // Record the output sequence length @@ -341,49 +357,41 @@ impl HttpContext for StreamContext { .record(self.response_tokens as u64); if let Some(traceparent) = self.traceparent.as_ref() { - let since_the_epoch_ns = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(); + let current_time_ns = current_time_ns(); - let traceparent_tokens = traceparent.split("-").collect::>(); - if traceparent_tokens.len() != 4 { - warn!("traceparent header is invalid: {}", traceparent); - return Action::Continue; - } - let parent_trace_id = traceparent_tokens[1]; - let parent_span_id = traceparent_tokens[2]; - let mut trace_data = common::tracing::TraceData::new(); - let mut llm_span = Span::new( - "upstream_llm_time".to_string(), - parent_trace_id.to_string(), - Some(parent_span_id.to_string()), - self.request_body_sent_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - since_the_epoch_ns, - ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span.add_attribute("user_prompt".to_string(), prompt.to_string()); + match Traceparent::try_from(traceparent.to_string()) { + Err(e) => { + warn!("traceparent header is invalid: {}", e); } - } - llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string()); - llm_span.add_event(Event::new( - "time_to_first_token".to_string(), - self.ttft_time - .unwrap() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(), - )); - trace_data.add_span(llm_span); + Ok(traceparent) => { + let mut trace_data = common::tracing::TraceData::new(); + let mut llm_span = Span::new( + "upstream_llm_time".to_string(), + Some(traceparent.trace_id), + Some(traceparent.parent_id), + self.request_body_sent_time.unwrap(), + current_time_ns, + ); + if let Some(user_message) = self.user_message.as_ref() { + if let Some(prompt) = user_message.content.as_ref() { + llm_span + .add_attribute("user_prompt".to_string(), prompt.to_string()); + } + } + llm_span.add_attribute( + "model".to_string(), + self.llm_provider().name.to_string(), + ); - let trace_data_str = serde_json::to_string(&trace_data).unwrap(); - debug!("upstream_llm trace details: {}", trace_data_str); - // send trace_data to http tracing endpoint + llm_span.add_event(Event::new( + "time_to_first_token".to_string(), + self.ttft_time.unwrap(), + )); + trace_data.add_span(llm_span); + + self.traces_queue.lock().unwrap().push_back(trace_data); + } + }; } return Action::Continue; @@ -479,22 +487,19 @@ impl HttpContext for StreamContext { // Compute TTFT if not already recorded if self.ttft_duration.is_none() { - if let Some(start_time) = self.start_time { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time); - match current_time.duration_since(start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); - self.ttft_duration = Some(duration); - self.metrics.time_to_first_token.record(duration_ms as u64); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } + // if let Some(start_time) = self.start_time { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time_ns()); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); + self.ttft_duration = Some(duration); + self.metrics.time_to_first_token.record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); } - } else { - warn!("Start time was not recorded"); } } } else { @@ -526,4 +531,11 @@ impl HttpContext for StreamContext { } } +fn current_time_ns() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() +} + impl Context for StreamContext {} diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index c39debd6..cbaedad3 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -53,8 +53,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) - .expect_get_current_time_nanos() - .returning(Some(0)) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -217,8 +215,6 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) - .expect_get_current_time_nanos() - .returning(Some(0)) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_metric_record("input_sequence_length", 21) @@ -281,8 +277,6 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) - .expect_get_current_time_nanos() - .returning(Some(0)) .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), @@ -341,8 +335,6 @@ fn llm_gateway_request_ratelimited() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) - .expect_get_current_time_nanos() - .returning(Some(0)) // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) @@ -409,8 +401,6 @@ fn llm_gateway_request_not_ratelimited() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) - .expect_get_current_time_nanos() - .returning(Some(0)) // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None)