This commit is contained in:
Adil Hafeez 2024-10-24 01:32:21 -07:00
parent 6982d0a575
commit 03a02455e8
11 changed files with 175 additions and 34 deletions

View file

@ -261,7 +261,10 @@ pub mod open_ai {
fn try_from(value: &str) -> Result<Self, Self::Error> {
let mut response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.split("data: ")
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(|data_chunk| serde_json::from_str::<ChatCompletionChunkResponse>(data_chunk))
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
@ -272,10 +275,10 @@ pub mod open_ai {
.delta
.content
.take()
.ok_or(ChatCompletionChunkResponseError::EmptyContent)
.unwrap_or("".to_string())
})
.collect::<Result<Vec<String>, _>>()?
.join(" ");
.collect::<Vec<String>>()
.join("");
let mut response_chunk = response_chunks
.pop_front()
@ -489,4 +492,58 @@ mod test {
ParameterType::String
);
}
#[test]
fn stream_chunk_parse() {
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
"Hello! How can"
);
}
#[test]
fn stream_chunk_parse_done() {
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
" I assist you today?"
);
}
}

View file

@ -12,7 +12,7 @@ use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::debug;
use log::{debug, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::num::NonZero;
@ -32,6 +32,7 @@ pub struct StreamContext {
request_id: Option<String>,
}
#[derive(Debug)]
struct StreamingResponse {
bytes_read: usize,
}
@ -252,16 +253,20 @@ impl HttpContext for StreamContext {
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"recv [S={}] bytes={} end_stream={}",
"on_http_response_body [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.is_chat_completions_request {
debug!("non-chatgpt request");
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!("recv [S={}] body_str={}", self.context_id, body_str);
debug!(
"on_http_response_body non-chatgpt request [S={}] body_str={}",
self.context_id, body_str
);
}
return Action::Continue;
}
@ -272,29 +277,68 @@ impl HttpContext for StreamContext {
let body = match self.streaming_response.take() {
Some(mut streaming_response) => {
let streaming_chunk = self
.get_http_response_body(streaming_response.bytes_read, body_size)
.expect("cant get response body");
streaming_response.bytes_read += body_size;
if end_of_stream && body_size == 0 {
return Action::Continue;
}
let chunk_start = 0;
let chunk_size = body_size;
debug!("streaming respose reading, {}..{}", chunk_start, chunk_size);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empy, chunk_start: {}, chunk_size: {}",
chunk_start, chunk_size
);
return Action::Continue;
}
};
if streaming_chunk.len() != chunk_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
chunk_size
);
}
streaming_response.bytes_read += chunk_size;
// n.b: this funky take and replace of the streaming_response struct is done to appease the borrow
// checker which wouldn't let us take a mut ref of streaming_response, and then a ref for
// `get_http_response_body`
self.streaming_response = Some(streaming_response);
streaming_chunk
}
None => self
.get_http_response_body(0, body_size)
.expect("cant get response body"),
None => {
debug!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
warn!("non streaming response body empty");
return Action::Continue;
}
}
}
};
if self.streaming_response.is_some() {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let body_utf8 = match String::from_utf8(body.to_vec()) {
Ok(body_utf8) => body_utf8,
Err(e) => {
debug!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
debug!("chunk data: body str: {}", body_utf8);
if self.streaming_response.is_some() {
let chat_completions_chunk_response =
match ChatCompletionChunkResponse::try_from(body_str.as_str()) {
match ChatCompletionChunkResponse::try_from(body_utf8.as_str()) {
Ok(response) => response,
Err(e) => {
debug!(
"invalid streaming response: body str: {}, {:?}",
body_utf8, e
);
self.send_server_error(e.into(), None);
return Action::Pause;
}

View file

@ -40,8 +40,8 @@ pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use common::common_types::open_ai::Message;
use pretty_assertions::assert_eq;
use super::extract_messages_for_hallucination;
@ -158,7 +158,9 @@ mod test {
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
println!("{:?}", messages_for_halluncination);
assert_eq!(messages_for_halluncination.len(), 3);
assert_eq!(["tell me about the weather", "Seattle", "7 days"], messages_for_halluncination.as_slice());
assert_eq!(
["tell me about the weather", "Seattle", "7 days"],
messages_for_halluncination.as_slice()
);
}
}

View file

@ -80,7 +80,10 @@ impl HttpContext for StreamContext {
}
};
debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes));
debug!(
"developer => archgw: {}",
String::from_utf8_lossy(&body_bytes)
);
// Deserialize body into spec.
// Currently OpenAI API.