more changes

This commit is contained in:
Adil Hafeez 2025-01-23 17:36:56 -08:00
parent 4ab7665c30
commit e2d49fb3f2
3 changed files with 47 additions and 41 deletions

View file

@ -9,7 +9,7 @@ use common::llm_providers::LlmProviders;
use common::ratelimit; use common::ratelimit;
use common::stats::Gauge; use common::stats::Gauge;
use common::tracing::TraceData; use common::tracing::TraceData;
use log::debug; use log::trace;
use log::warn; use log::warn;
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use proxy_wasm::types::*; use proxy_wasm::types::*;
@ -103,10 +103,8 @@ impl RootContext for FilterContext {
fn on_tick(&mut self) { fn on_tick(&mut self) {
let _ = self.traces_queue.try_lock().map(|mut traces_queue| { let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
while let Some(trace) = traces_queue.pop_front() { while let Some(trace) = traces_queue.pop_front() {
debug!("trace received: {:?}", trace);
let trace_str = serde_json::to_string(&trace).unwrap(); let trace_str = serde_json::to_string(&trace).unwrap();
debug!("trace: {}", trace_str); trace!("trace details: {}", trace_str);
let call_args = CallArgs::new( let call_args = CallArgs::new(
OTEL_COLLECTOR_HTTP, OTEL_COLLECTOR_HTTP,
OTEL_POST_PATH, OTEL_POST_PATH,
@ -139,7 +137,7 @@ impl Context for FilterContext {
_body_size: usize, _body_size: usize,
_num_trailers: usize, _num_trailers: usize,
) { ) {
debug!( trace!(
"||| on_http_call_response called with token_id: {:?} |||", "||| on_http_call_response called with token_id: {:?} |||",
token_id token_id
); );
@ -151,7 +149,7 @@ impl Context for FilterContext {
.expect("invalid token_id"); .expect("invalid token_id");
if let Some(status) = self.get_http_call_response_header(":status") { if let Some(status) = self.get_http_call_response_header(":status") {
debug!("trace response status: {:?}", status); trace!("trace response status: {:?}", status);
}; };
} }
} }

View file

@ -153,7 +153,7 @@ impl StreamContext {
self.metrics self.metrics
.input_sequence_length .input_sequence_length
.record(token_count as u64); .record(token_count as u64);
log::debug!("Recorded input token count: {}", token_count); trace!("Recorded input token count: {}", token_count);
// Check if rate limiting needs to be applied. // Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() { if let Some(selector) = self.ratelimit_selector.take() {
@ -164,7 +164,7 @@ impl StreamContext {
NonZero::new(token_count as u32).unwrap(), NonZero::new(token_count as u32).unwrap(),
)?; )?;
} else { } else {
log::debug!("No rate limit applied for model: {}", model); trace!("No rate limit applied for model: {}", model);
} }
Ok(()) Ok(())
@ -331,7 +331,7 @@ impl HttpContext for StreamContext {
Ok(duration) => { Ok(duration) => {
// Convert the duration to milliseconds // Convert the duration to milliseconds
let duration_ms = duration.as_millis(); let duration_ms = duration.as_millis();
debug!("Total latency: {} milliseconds", duration_ms); debug!("request latency: {}ms", duration_ms);
// Record the latency to the latency histogram // Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64); self.metrics.request_latency.record(duration_ms as u64);
@ -339,11 +339,14 @@ impl HttpContext for StreamContext {
// Compute the time per output token // Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64; let tpot = duration_ms as u64 / self.response_tokens as u64;
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token // Record the time per output token
self.metrics.time_per_output_token.record(tpot); self.metrics.time_per_output_token.record(tpot);
debug!("Tokens per second: {}", 1000 / tpot); trace!(
"time per token: {}ms, tokens per second: {}",
tpot,
1000 / tpot
);
// Record the tokens per second // Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot); self.metrics.tokens_per_second.record(1000 / tpot);
} }
@ -499,7 +502,7 @@ impl HttpContext for StreamContext {
match current_time.duration_since(self.start_time) { match current_time.duration_since(self.start_time) {
Ok(duration) => { Ok(duration) => {
let duration_ms = duration.as_millis(); let duration_ms = duration.as_millis();
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); debug!("time to first token: {}ms", duration_ms);
self.ttft_duration = Some(duration); self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64); self.metrics.time_to_first_token.record(duration_ms as u64);
} }

View file

@ -263,7 +263,7 @@ impl StreamContext {
); );
} }
// update prompt target name from the tool call // update prompt target name from the tool call response
callout_context.prompt_target_name = callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone()); Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
@ -338,7 +338,8 @@ impl StreamContext {
ARCH_INTERNAL_CLUSTER_NAME, ARCH_INTERNAL_CLUSTER_NAME,
&path, &path,
headers, headers,
Some(tool_params_json_str.as_bytes()), // Some(tool_params_json_str.as_bytes()),
None,
vec![], vec![],
Duration::from_secs(5), Duration::from_secs(5),
); );
@ -363,7 +364,6 @@ impl StreamContext {
let http_status = self let http_status = self
.get_http_call_response_header(":status") .get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string()); .unwrap_or(StatusCode::OK.as_str().to_string());
debug!("api_call_response_handler: http_status: {}", http_status);
if http_status != StatusCode::OK.as_str() { if http_status != StatusCode::OK.as_str() {
warn!( warn!(
"api server responded with non 2xx status code: {}", "api server responded with non 2xx status code: {}",
@ -385,7 +385,7 @@ impl StreamContext {
self.tool_call_response.as_ref().unwrap() self.tool_call_response.as_ref().unwrap()
); );
let mut messages = self.filter_out_arch_messages(&callout_context); let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() { let user_message = match messages.pop() {
Some(user_message) => user_message, Some(user_message) => user_message,
@ -442,25 +442,39 @@ impl StreamContext {
self.resume_http_request(); self.resume_http_request();
} }
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> { fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
let mut messages: Vec<Message> = Vec::new(); match prompt_target {
// add system prompt None => self.system_prompt.as_ref().clone(),
Some(prompt_target) => match prompt_target.system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
},
}
}
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
messages
.iter()
.filter(|m| {
!(m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() { let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(), None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => { Some(prompt_target_name) => {
let prompt_system_prompt = self self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
} }
}; };
if system_prompt.is_some() { if system_prompt.is_some() {
let system_prompt_message = Message { let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(), role: SYSTEM_ROLE.to_string(),
@ -472,18 +486,9 @@ impl StreamContext {
messages.push(system_prompt_message); messages.push(system_prompt_message);
} }
// don't send tools message and api response to chat gpt messages.append(
for m in callout_context.request_body.messages.iter() { &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
// don't send api response and tool calls to upstream LLMs );
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages messages
} }