This commit is contained in:
Adil Hafeez 2025-01-10 12:52:29 -08:00
parent e55127d325
commit 42ab061971
6 changed files with 301 additions and 1608 deletions

View file

@ -1,4 +1,4 @@
use log::debug;
use log::trace;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
@ -9,7 +9,7 @@ pub enum Error {
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
trace!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),

View file

@ -312,9 +312,11 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
trace!(
"on_http_response_body [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
self.context_id,
body_size,
end_of_stream
);
if !self.is_chat_completions_request {
@ -334,16 +336,18 @@ impl HttpContext for StreamContext {
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
if self.response_tokens > 0 {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
debug!("Tokens per second: {}", 1000 / tpot);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
@ -398,9 +402,10 @@ impl HttpContext for StreamContext {
let body = if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
debug!(
trace!(
"streaming response reading, {}..{}",
chunk_start, chunk_size
chunk_start,
chunk_size
);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
@ -520,9 +525,11 @@ impl HttpContext for StreamContext {
}
}
debug!(
trace!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
self.context_id,
self.response_tokens,
end_of_stream
);
Action::Continue

View file

@ -14,7 +14,7 @@ use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{info, debug, warn};
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
@ -465,6 +465,7 @@ impl StreamContext {
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
info!("prompt target: {:?}", callout_context.prompt_target_name);
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
@ -473,7 +474,7 @@ impl StreamContext {
}
};
info!("messages 1: {:?}", callout_context.request_body.messages);
info!("system_prompt: {:?}", system_prompt);
if system_prompt.is_some() {
let system_prompt_message = Message {
@ -486,12 +487,9 @@ impl StreamContext {
messages.push(system_prompt_message);
}
info!("messages 2: {:?}", messages);
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
info!("messages 3: {:?}", messages);
messages
}