This commit is contained in:
Adil Hafeez 2024-11-17 16:20:51 -08:00
parent 1508743eeb
commit 780c7cf7ad
6 changed files with 246 additions and 121 deletions

View file

@ -614,10 +614,6 @@ static_resources:
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: opentelemetry_collector_http
endpoints:

View file

@ -1,20 +1,21 @@
[supervisord]
nodaemon=true
[program:trace_streamer]
command=sh -c "tail -F /var/log/envoy.log | python stream_traces.py"
autostart=true
autorestart=false
startretries=3
priority=1
stdout_logfile=/dev/stdout
stderr_logfile=/dev/stderr
stdout_logfile_maxbytes = 0
stderr_logfile_maxbytes = 0
; [program:trace_streamer]
; command=sh -c "tail -F /var/log/envoy.log | python stream_traces.py"
; autostart=true
; autorestart=false
; startretries=3
; priority=1
; stdout_logfile=/dev/stdout
; stderr_logfile=/dev/stderr
; stdout_logfile_maxbytes = 0
; stderr_logfile_maxbytes = 0
[program:envoy]
command=sh -c "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log"
; command=sh -c "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --log-level trace 2>&1 | tee /var/log/envoy.log"
autostart=true
autorestart=true
startretries=3

View file

@ -29,3 +29,4 @@ pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str =
"It seems I'm missing some information. Could you provide the following details ";
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";

View file

@ -47,14 +47,15 @@ pub struct Span {
impl Span {
pub fn new(
name: String,
parent_trace_id: String,
trace_id: String,
span_id: String,
parent_span_id: Option<String>,
start_time_unix_nano: u128,
end_time_unix_nano: u128,
) -> Self {
Span {
trace_id: parent_trace_id,
span_id: get_random_span_id(),
trace_id,
span_id,
parent_span_id,
name,
start_time_unix_nano: format!("{}", start_time_unix_nano),
@ -175,3 +176,11 @@ pub fn get_random_span_id() -> String {
hex::encode(random_bytes)
}
pub fn get_random_trace_id() -> String {
let mut rng = rand::thread_rng();
let mut random_bytes = [0u8; 16];
rng.fill_bytes(&mut random_bytes);
hex::encode(random_bytes)
}

View file

@ -1,17 +1,25 @@
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::consts::OTEL_COLLECTOR_HTTP;
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::Histogram;
use common::tracing::TraceData;
use log::debug;
use log::warn;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::rc::Rc;
use std::time::Duration;
use std::sync::{Arc, Mutex};
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
@ -49,14 +57,31 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
// traces: Rc<RefCell<VecDeque<TraceData>>>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
// trace_sender: Rc<Sender<TraceData>>,
// receiver: Receiver<TraceData>,
}
impl FilterContext {
pub fn new() -> FilterContext {
// let (sender, receiver) = channel::<TraceData>();
// thread::spawn(move || {
// while let Ok(trace) = receiver.recv() {
// debug!("received trace: {:?}", trace);
// }
// });
// let queue: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
// queue.lock().unwrap().push("foo".to_string());
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
llm_providers: None,
// traces: Rc::new(RefCell::new(VecDeque::new())),
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
// trace_sender: Rc::new(sender),
// receiver,
}
}
}
@ -73,8 +98,6 @@ impl Client for FilterContext {
}
}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
@ -111,10 +134,68 @@ impl RootContext for FilterContext {
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
Arc::clone(&self.traces_queue),
)))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _vm_configuration_size: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
while let Some(trace) = traces_queue.pop_front() {
debug!("trace received: {:?}", trace);
let trace_str = serde_json::to_string(&trace).unwrap();
debug!("trace: {}", trace_str);
let call_args = CallArgs::new(
OTEL_COLLECTOR_HTTP,
"/v1/traces",
vec![
(":method", "POST"),
(":path", "/v1/traces"),
(":authority", OTEL_COLLECTOR_HTTP),
("content-type", "application/json"),
],
Some(trace_str.as_bytes()),
vec![],
Duration::from_secs(60),
);
if let Err(error) = self.http_call(call_args, CallContext {}) {
warn!("failed to schedule http call: {:?}", error);
}
}
});
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
_body_size: usize,
_num_trailers: usize,
) {
debug!(
"||| on_http_call_response called with token_id: {:?} |||",
token_id
);
let _callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
self.get_http_call_response_header(":status").map(|status| {
debug!("trace response status: {:?}", status);
});
}
}

View file

@ -12,14 +12,16 @@ use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::pii::obfuscate_auth_header;
use common::ratelimit::Header;
use common::tracing::{Event, Span};
use common::tracing::{get_random_span_id, get_random_trace_id, Event, Span, TraceData};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, trace, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use common::stats::{IncrementingMetric, RecordingMetric};
@ -36,15 +38,27 @@ pub struct StreamContext {
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
start_time: Option<SystemTime>,
start_time: SystemTime,
ttft_duration: Option<Duration>,
ttft_time: Option<SystemTime>,
pub traceparent: Option<String>,
trace_id: String,
span_id: String,
traceparent: String,
parent_span_id: Option<String>,
traceparent_present_in_request: bool,
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
}
impl StreamContext {
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
llm_providers: Rc<LlmProviders>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
) -> Self {
let trace_id = get_random_trace_id();
let span_id = get_random_span_id();
StreamContext {
context_id,
metrics,
@ -55,11 +69,16 @@ impl StreamContext {
llm_providers,
llm_provider: None,
request_id: None,
start_time: None,
start_time: SystemTime::now(),
ttft_duration: None,
traceparent: None,
traceparent: format!("00-{}-{}-01", trace_id, span_id),
trace_id,
parent_span_id: Some(span_id.clone()),
span_id,
ttft_time: None,
user_message: None,
traces_queue,
traceparent_present_in_request: false,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -183,12 +202,24 @@ impl HttpContext for StreamContext {
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
//start the timing for the request using get_current_time()
let current_time: SystemTime = get_current_time().unwrap();
self.start_time = Some(current_time);
self.ttft_duration = None;
// if traceparent is not present in the request, set it and add it to the response headers
if let Some(traceparent) = self.get_http_request_header(TRACE_PARENT_HEADER) {
debug!("traceparent set");
self.traceparent = traceparent;
self.traceparent_present_in_request = true;
self.parent_span_id = {
let traceparent_tokens: Vec<&str> =
self.traceparent.split("-").collect::<Vec<&str>>();
if traceparent_tokens.len() != 4 {
warn!("traceparent header is invalid: {}", self.traceparent);
None
} else {
Some(traceparent_tokens[2].to_string())
}
};
} else {
self.set_http_request_header(TRACE_PARENT_HEADER, Some(self.traceparent.as_str()));
}
Action::Continue
}
@ -294,21 +325,26 @@ impl HttpContext for StreamContext {
self.context_id, _end_of_stream
);
if let Some(user_message) = self.user_message.as_ref() {
if let Some(prompt) = user_message.content.as_ref() {
debug!("setting user-message header: {}", prompt);
self.set_http_response_header("x-user-message", Some(&prompt));
}
}
// if let Some(user_message) = self.user_message.as_ref() {
// if let Some(prompt) = user_message.content.as_ref() {
// debug!("setting user-message header: {}", prompt);
// self.set_http_response_header("x-user-message", Some(&prompt));
// }
// }
let tftt_time_ms = get_current_time()
.unwrap()
.duration_since(self.start_time.unwrap())
.unwrap()
.as_millis();
// let tftt_time_ms = get_current_time()
// .unwrap()
// .duration_since(self.start_time.unwrap())
// .unwrap()
// .as_millis();
let tftt_time = tftt_time_ms.to_string();
self.set_http_response_header("x-time-to-first-token", Some(&tftt_time));
// let tftt_time = tftt_time_ms.to_string();
// self.set_http_response_header("x-time-to-first-token", Some(&tftt_time));
self.set_property(
vec!["metadata", "filter_metadata", "llm_filter", "user_prompt"],
Some("hello world from filter".as_bytes()),
);
Action::Continue
}
@ -328,29 +364,27 @@ impl HttpContext for StreamContext {
if end_of_stream && body_size == 0 {
// All streaming responses end with bytes=0 and end_stream=true
// Record the latency for the request
if let Some(start_time) = self.start_time {
match current_time.duration_since(start_time) {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
debug!("Total latency: {} milliseconds", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
match current_time.duration_since(self.start_time) {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
debug!("Total latency: {} milliseconds", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
// Record the output sequence length
@ -358,52 +392,55 @@ impl HttpContext for StreamContext {
.output_sequence_length
.record(self.response_tokens as u64);
if let Some(traceparent) = self.traceparent.as_ref() {
let since_the_epoch_ns = SystemTime::now()
// if let Some(traceparent) = self.traceparent.as_ref() {
let since_the_epoch_ns = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let parent_span_id = {
if self.traceparent_present_in_request {
self.parent_span_id.clone()
} else {
None
}
};
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"upstream_llm_time".to_string(),
self.trace_id.to_string(),
self.span_id.to_string(),
parent_span_id,
self.start_time
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let traceparent_tokens = traceparent.split("-").collect::<Vec<&str>>();
if traceparent_tokens.len() != 4 {
warn!("traceparent header is invalid: {}", traceparent);
return Action::Continue;
.as_nanos(),
since_the_epoch_ns,
);
if let Some(user_message) = self.user_message.as_ref() {
if let Some(prompt) = user_message.content.as_ref() {
llm_span.add_attribute("user_prompt".to_string(), prompt.to_string());
}
let parent_trace_id = traceparent_tokens[1];
let parent_span_id = traceparent_tokens[2];
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"upstream_llm_time".to_string(),
parent_trace_id.to_string(),
Some(parent_span_id.to_string()),
self.start_time
.unwrap()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
since_the_epoch_ns,
);
if let Some(user_message) = self.user_message.as_ref() {
if let Some(prompt) = user_message.content.as_ref() {
llm_span.add_attribute("user_prompt".to_string(), prompt.to_string());
}
}
llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string());
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time
.unwrap()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
trace_data.add_span(llm_span);
let trace_data_str = serde_json::to_string(&trace_data).unwrap();
debug!("upstream_llm trace details: {}", trace_data_str);
// send trace_data to http tracing endpoint
}
llm_span.add_attribute("model".to_string(), self.llm_provider().name.to_string());
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time
.unwrap()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
trace_data.add_span(llm_span);
// debug!("upstream_llm trace details: {:?}", trace_data);
self.traces_queue.lock().unwrap().push_back(trace_data);
// let trace_data_str = serde_json::to_string(&trace_data).unwrap();
// send trace_data to http tracing endpoint
// }
return Action::Continue;
}
@ -498,23 +535,23 @@ impl HttpContext for StreamContext {
// Compute TTFT if not already recorded
if self.ttft_duration.is_none() {
if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time);
match current_time.duration_since(start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
// if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time);
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
} else {
warn!("Start time was not recorded");
}
// } else {
// warn!("Start time was not recorded");
// }
}
} else {
debug!("non streaming response");