This commit is contained in:
Adil Hafeez 2024-10-24 01:32:21 -07:00
parent 6982d0a575
commit 03a02455e8
11 changed files with 175 additions and 34 deletions

View file

@ -12,7 +12,7 @@ use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::debug;
use log::{debug, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::num::NonZero;
@ -32,6 +32,7 @@ pub struct StreamContext {
request_id: Option<String>,
}
#[derive(Debug)]
struct StreamingResponse {
bytes_read: usize,
}
@ -252,16 +253,20 @@ impl HttpContext for StreamContext {
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"recv [S={}] bytes={} end_stream={}",
"on_http_response_body [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.is_chat_completions_request {
debug!("non-chatgpt request");
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!("recv [S={}] body_str={}", self.context_id, body_str);
debug!(
"on_http_response_body non-chatgpt request [S={}] body_str={}",
self.context_id, body_str
);
}
return Action::Continue;
}
@ -272,29 +277,68 @@ impl HttpContext for StreamContext {
let body = match self.streaming_response.take() {
Some(mut streaming_response) => {
let streaming_chunk = self
.get_http_response_body(streaming_response.bytes_read, body_size)
.expect("cant get response body");
streaming_response.bytes_read += body_size;
if end_of_stream && body_size == 0 {
return Action::Continue;
}
let chunk_start = 0;
let chunk_size = body_size;
debug!("streaming respose reading, {}..{}", chunk_start, chunk_size);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empy, chunk_start: {}, chunk_size: {}",
chunk_start, chunk_size
);
return Action::Continue;
}
};
if streaming_chunk.len() != chunk_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
chunk_size
);
}
streaming_response.bytes_read += chunk_size;
// n.b: this funky take and replace of the streaming_response struct is done to appease the borrow
// checker which wouldn't let us take a mut ref of streaming_response, and then a ref for
// `get_http_response_body`
self.streaming_response = Some(streaming_response);
streaming_chunk
}
None => self
.get_http_response_body(0, body_size)
.expect("cant get response body"),
None => {
debug!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
warn!("non streaming response body empty");
return Action::Continue;
}
}
}
};
if self.streaming_response.is_some() {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let body_utf8 = match String::from_utf8(body.to_vec()) {
Ok(body_utf8) => body_utf8,
Err(e) => {
debug!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
debug!("chunk data: body str: {}", body_utf8);
if self.streaming_response.is_some() {
let chat_completions_chunk_response =
match ChatCompletionChunkResponse::try_from(body_str.as_str()) {
match ChatCompletionChunkResponse::try_from(body_utf8.as_str()) {
Ok(response) => response,
Err(e) => {
debug!(
"invalid streaming response: body str: {}, {:?}",
body_utf8, e
);
self.send_server_error(e.into(), None);
return Action::Pause;
}