mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
more changes
This commit is contained in:
parent
4ab7665c30
commit
e2d49fb3f2
3 changed files with 47 additions and 41 deletions
|
|
@ -9,7 +9,7 @@ use common::llm_providers::LlmProviders;
|
||||||
use common::ratelimit;
|
use common::ratelimit;
|
||||||
use common::stats::Gauge;
|
use common::stats::Gauge;
|
||||||
use common::tracing::TraceData;
|
use common::tracing::TraceData;
|
||||||
use log::debug;
|
use log::trace;
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use proxy_wasm::traits::*;
|
use proxy_wasm::traits::*;
|
||||||
use proxy_wasm::types::*;
|
use proxy_wasm::types::*;
|
||||||
|
|
@ -103,10 +103,8 @@ impl RootContext for FilterContext {
|
||||||
fn on_tick(&mut self) {
|
fn on_tick(&mut self) {
|
||||||
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
|
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
|
||||||
while let Some(trace) = traces_queue.pop_front() {
|
while let Some(trace) = traces_queue.pop_front() {
|
||||||
debug!("trace received: {:?}", trace);
|
|
||||||
|
|
||||||
let trace_str = serde_json::to_string(&trace).unwrap();
|
let trace_str = serde_json::to_string(&trace).unwrap();
|
||||||
debug!("trace: {}", trace_str);
|
trace!("trace details: {}", trace_str);
|
||||||
let call_args = CallArgs::new(
|
let call_args = CallArgs::new(
|
||||||
OTEL_COLLECTOR_HTTP,
|
OTEL_COLLECTOR_HTTP,
|
||||||
OTEL_POST_PATH,
|
OTEL_POST_PATH,
|
||||||
|
|
@ -139,7 +137,7 @@ impl Context for FilterContext {
|
||||||
_body_size: usize,
|
_body_size: usize,
|
||||||
_num_trailers: usize,
|
_num_trailers: usize,
|
||||||
) {
|
) {
|
||||||
debug!(
|
trace!(
|
||||||
"||| on_http_call_response called with token_id: {:?} |||",
|
"||| on_http_call_response called with token_id: {:?} |||",
|
||||||
token_id
|
token_id
|
||||||
);
|
);
|
||||||
|
|
@ -151,7 +149,7 @@ impl Context for FilterContext {
|
||||||
.expect("invalid token_id");
|
.expect("invalid token_id");
|
||||||
|
|
||||||
if let Some(status) = self.get_http_call_response_header(":status") {
|
if let Some(status) = self.get_http_call_response_header(":status") {
|
||||||
debug!("trace response status: {:?}", status);
|
trace!("trace response status: {:?}", status);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ impl StreamContext {
|
||||||
self.metrics
|
self.metrics
|
||||||
.input_sequence_length
|
.input_sequence_length
|
||||||
.record(token_count as u64);
|
.record(token_count as u64);
|
||||||
log::debug!("Recorded input token count: {}", token_count);
|
trace!("Recorded input token count: {}", token_count);
|
||||||
|
|
||||||
// Check if rate limiting needs to be applied.
|
// Check if rate limiting needs to be applied.
|
||||||
if let Some(selector) = self.ratelimit_selector.take() {
|
if let Some(selector) = self.ratelimit_selector.take() {
|
||||||
|
|
@ -164,7 +164,7 @@ impl StreamContext {
|
||||||
NonZero::new(token_count as u32).unwrap(),
|
NonZero::new(token_count as u32).unwrap(),
|
||||||
)?;
|
)?;
|
||||||
} else {
|
} else {
|
||||||
log::debug!("No rate limit applied for model: {}", model);
|
trace!("No rate limit applied for model: {}", model);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -331,7 +331,7 @@ impl HttpContext for StreamContext {
|
||||||
Ok(duration) => {
|
Ok(duration) => {
|
||||||
// Convert the duration to milliseconds
|
// Convert the duration to milliseconds
|
||||||
let duration_ms = duration.as_millis();
|
let duration_ms = duration.as_millis();
|
||||||
debug!("Total latency: {} milliseconds", duration_ms);
|
debug!("request latency: {}ms", duration_ms);
|
||||||
// Record the latency to the latency histogram
|
// Record the latency to the latency histogram
|
||||||
self.metrics.request_latency.record(duration_ms as u64);
|
self.metrics.request_latency.record(duration_ms as u64);
|
||||||
|
|
||||||
|
|
@ -339,11 +339,14 @@ impl HttpContext for StreamContext {
|
||||||
// Compute the time per output token
|
// Compute the time per output token
|
||||||
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
||||||
|
|
||||||
debug!("Time per output token: {} milliseconds", tpot);
|
|
||||||
// Record the time per output token
|
// Record the time per output token
|
||||||
self.metrics.time_per_output_token.record(tpot);
|
self.metrics.time_per_output_token.record(tpot);
|
||||||
|
|
||||||
debug!("Tokens per second: {}", 1000 / tpot);
|
trace!(
|
||||||
|
"time per token: {}ms, tokens per second: {}",
|
||||||
|
tpot,
|
||||||
|
1000 / tpot
|
||||||
|
);
|
||||||
// Record the tokens per second
|
// Record the tokens per second
|
||||||
self.metrics.tokens_per_second.record(1000 / tpot);
|
self.metrics.tokens_per_second.record(1000 / tpot);
|
||||||
}
|
}
|
||||||
|
|
@ -499,7 +502,7 @@ impl HttpContext for StreamContext {
|
||||||
match current_time.duration_since(self.start_time) {
|
match current_time.duration_since(self.start_time) {
|
||||||
Ok(duration) => {
|
Ok(duration) => {
|
||||||
let duration_ms = duration.as_millis();
|
let duration_ms = duration.as_millis();
|
||||||
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
|
debug!("time to first token: {}ms", duration_ms);
|
||||||
self.ttft_duration = Some(duration);
|
self.ttft_duration = Some(duration);
|
||||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -263,7 +263,7 @@ impl StreamContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update prompt target name from the tool call
|
// update prompt target name from the tool call response
|
||||||
callout_context.prompt_target_name =
|
callout_context.prompt_target_name =
|
||||||
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
||||||
|
|
||||||
|
|
@ -338,7 +338,8 @@ impl StreamContext {
|
||||||
ARCH_INTERNAL_CLUSTER_NAME,
|
ARCH_INTERNAL_CLUSTER_NAME,
|
||||||
&path,
|
&path,
|
||||||
headers,
|
headers,
|
||||||
Some(tool_params_json_str.as_bytes()),
|
// Some(tool_params_json_str.as_bytes()),
|
||||||
|
None,
|
||||||
vec![],
|
vec![],
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
);
|
);
|
||||||
|
|
@ -363,7 +364,6 @@ impl StreamContext {
|
||||||
let http_status = self
|
let http_status = self
|
||||||
.get_http_call_response_header(":status")
|
.get_http_call_response_header(":status")
|
||||||
.unwrap_or(StatusCode::OK.as_str().to_string());
|
.unwrap_or(StatusCode::OK.as_str().to_string());
|
||||||
debug!("api_call_response_handler: http_status: {}", http_status);
|
|
||||||
if http_status != StatusCode::OK.as_str() {
|
if http_status != StatusCode::OK.as_str() {
|
||||||
warn!(
|
warn!(
|
||||||
"api server responded with non 2xx status code: {}",
|
"api server responded with non 2xx status code: {}",
|
||||||
|
|
@ -385,7 +385,7 @@ impl StreamContext {
|
||||||
self.tool_call_response.as_ref().unwrap()
|
self.tool_call_response.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut messages = self.filter_out_arch_messages(&callout_context);
|
let mut messages = self.construct_llm_messages(&callout_context);
|
||||||
|
|
||||||
let user_message = match messages.pop() {
|
let user_message = match messages.pop() {
|
||||||
Some(user_message) => user_message,
|
Some(user_message) => user_message,
|
||||||
|
|
@ -442,25 +442,39 @@ impl StreamContext {
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
|
||||||
let mut messages: Vec<Message> = Vec::new();
|
match prompt_target {
|
||||||
// add system prompt
|
None => self.system_prompt.as_ref().clone(),
|
||||||
|
Some(prompt_target) => match prompt_target.system_prompt {
|
||||||
|
None => self.system_prompt.as_ref().clone(),
|
||||||
|
Some(system_prompt) => Some(system_prompt),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.filter(|m| {
|
||||||
|
!(m.role == TOOL_ROLE
|
||||||
|
|| m.content.is_none()
|
||||||
|
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||||
|
let mut messages: Vec<Message> = Vec::new();
|
||||||
|
|
||||||
|
// add system prompt
|
||||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||||
None => self.system_prompt.as_ref().clone(),
|
None => self.system_prompt.as_ref().clone(),
|
||||||
Some(prompt_target_name) => {
|
Some(prompt_target_name) => {
|
||||||
let prompt_system_prompt = self
|
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
|
||||||
.prompt_targets
|
|
||||||
.get(prompt_target_name)
|
|
||||||
.unwrap()
|
|
||||||
.clone()
|
|
||||||
.system_prompt;
|
|
||||||
match prompt_system_prompt {
|
|
||||||
None => self.system_prompt.as_ref().clone(),
|
|
||||||
Some(system_prompt) => Some(system_prompt),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if system_prompt.is_some() {
|
if system_prompt.is_some() {
|
||||||
let system_prompt_message = Message {
|
let system_prompt_message = Message {
|
||||||
role: SYSTEM_ROLE.to_string(),
|
role: SYSTEM_ROLE.to_string(),
|
||||||
|
|
@ -472,18 +486,9 @@ impl StreamContext {
|
||||||
messages.push(system_prompt_message);
|
messages.push(system_prompt_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't send tools message and api response to chat gpt
|
messages.append(
|
||||||
for m in callout_context.request_body.messages.iter() {
|
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
||||||
// don't send api response and tool calls to upstream LLMs
|
);
|
||||||
if m.role == TOOL_ROLE
|
|
||||||
|| m.content.is_none()
|
|
||||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
messages.push(m.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
messages
|
messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue