diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 4e44a9ff..56af01b5 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -9,7 +9,7 @@ use common::llm_providers::LlmProviders; use common::ratelimit; use common::stats::Gauge; use common::tracing::TraceData; -use log::debug; +use log::trace; use log::warn; use proxy_wasm::traits::*; use proxy_wasm::types::*; @@ -103,10 +103,8 @@ impl RootContext for FilterContext { fn on_tick(&mut self) { let _ = self.traces_queue.try_lock().map(|mut traces_queue| { while let Some(trace) = traces_queue.pop_front() { - debug!("trace received: {:?}", trace); - let trace_str = serde_json::to_string(&trace).unwrap(); - debug!("trace: {}", trace_str); + trace!("trace details: {}", trace_str); let call_args = CallArgs::new( OTEL_COLLECTOR_HTTP, OTEL_POST_PATH, @@ -139,7 +137,7 @@ impl Context for FilterContext { _body_size: usize, _num_trailers: usize, ) { - debug!( + trace!( "||| on_http_call_response called with token_id: {:?} |||", token_id ); @@ -151,7 +149,7 @@ impl Context for FilterContext { .expect("invalid token_id"); if let Some(status) = self.get_http_call_response_header(":status") { - debug!("trace response status: {:?}", status); + trace!("trace response status: {:?}", status); }; } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index b512147d..3c15366d 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -153,7 +153,7 @@ impl StreamContext { self.metrics .input_sequence_length .record(token_count as u64); - log::debug!("Recorded input token count: {}", token_count); + trace!("Recorded input token count: {}", token_count); // Check if rate limiting needs to be applied. if let Some(selector) = self.ratelimit_selector.take() { @@ -164,7 +164,7 @@ impl StreamContext { NonZero::new(token_count as u32).unwrap(), )?; } else { - log::debug!("No rate limit applied for model: {}", model); + trace!("No rate limit applied for model: {}", model); } Ok(()) @@ -331,7 +331,7 @@ impl HttpContext for StreamContext { Ok(duration) => { // Convert the duration to milliseconds let duration_ms = duration.as_millis(); - debug!("Total latency: {} milliseconds", duration_ms); + debug!("request latency: {}ms", duration_ms); // Record the latency to the latency histogram self.metrics.request_latency.record(duration_ms as u64); @@ -339,11 +339,14 @@ impl HttpContext for StreamContext { // Compute the time per output token let tpot = duration_ms as u64 / self.response_tokens as u64; - debug!("Time per output token: {} milliseconds", tpot); // Record the time per output token self.metrics.time_per_output_token.record(tpot); - debug!("Tokens per second: {}", 1000 / tpot); + trace!( + "time per token: {}ms, tokens per second: {}", + tpot, + 1000 / tpot + ); // Record the tokens per second self.metrics.tokens_per_second.record(1000 / tpot); } @@ -499,7 +502,7 @@ impl HttpContext for StreamContext { match current_time.duration_since(self.start_time) { Ok(duration) => { let duration_ms = duration.as_millis(); - debug!("Time to First Token (TTFT): {} milliseconds", duration_ms); + debug!("time to first token: {}ms", duration_ms); self.ttft_duration = Some(duration); self.metrics.time_to_first_token.record(duration_ms as u64); } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9782698e..fc35877f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -263,7 +263,7 @@ impl StreamContext { ); } - // update prompt target name from the tool call + // update prompt target name from the tool call response callout_context.prompt_target_name = Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone()); @@ -338,7 +338,8 @@ impl StreamContext { ARCH_INTERNAL_CLUSTER_NAME, &path, headers, - Some(tool_params_json_str.as_bytes()), + // Some(tool_params_json_str.as_bytes()), + None, vec![], Duration::from_secs(5), ); @@ -363,7 +364,6 @@ impl StreamContext { let http_status = self .get_http_call_response_header(":status") .unwrap_or(StatusCode::OK.as_str().to_string()); - debug!("api_call_response_handler: http_status: {}", http_status); if http_status != StatusCode::OK.as_str() { warn!( "api server responded with non 2xx status code: {}", @@ -385,7 +385,7 @@ impl StreamContext { self.tool_call_response.as_ref().unwrap() ); - let mut messages = self.filter_out_arch_messages(&callout_context); + let mut messages = self.construct_llm_messages(&callout_context); let user_message = match messages.pop() { Some(user_message) => user_message, @@ -442,25 +442,39 @@ impl StreamContext { self.resume_http_request(); } - fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec { - let mut messages: Vec = Vec::new(); - // add system prompt + fn get_system_prompt(&self, prompt_target: Option) -> Option { + match prompt_target { + None => self.system_prompt.as_ref().clone(), + Some(prompt_target) => match prompt_target.system_prompt { + None => self.system_prompt.as_ref().clone(), + Some(system_prompt) => Some(system_prompt), + }, + } + } + fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec { + messages + .iter() + .filter(|m| { + !(m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())) + }) + .cloned() + .collect() + } + + fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec { + let mut messages: Vec = Vec::new(); + + // add system prompt let system_prompt = match callout_context.prompt_target_name.as_ref() { None => self.system_prompt.as_ref().clone(), Some(prompt_target_name) => { - let prompt_system_prompt = self - .prompt_targets - .get(prompt_target_name) - .unwrap() - .clone() - .system_prompt; - match prompt_system_prompt { - None => self.system_prompt.as_ref().clone(), - Some(system_prompt) => Some(system_prompt), - } + self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned()) } }; + if system_prompt.is_some() { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), @@ -472,18 +486,9 @@ impl StreamContext { messages.push(system_prompt_message); } - // don't send tools message and api response to chat gpt - for m in callout_context.request_body.messages.iter() { - // don't send api response and tool calls to upstream LLMs - if m.role == TOOL_ROLE - || m.content.is_none() - || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) - { - continue; - } - messages.push(m.clone()); - } - + messages.append( + &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()), + ); messages }