From 4c7a7df08cf6369f37939c3b7a5eaa6f6991af29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Ulises=20Ni=C3=B1o=20Rivera?= Date: Tue, 22 Oct 2024 21:47:10 -0400 Subject: [PATCH] latest --- crates/common/src/common_types.rs | 43 +++++++++++++++++++++++- crates/common/src/errors.rs | 5 ++- crates/llm_gateway/src/stream_context.rs | 31 ++++------------- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index c8f91e0f..45c7a1be 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -34,7 +34,7 @@ pub struct SearchPointResult { } pub mod open_ai { - use std::collections::HashMap; + use std::collections::{HashMap, VecDeque}; use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; @@ -246,6 +246,47 @@ pub mod open_ai { pub choices: Vec, } + #[derive(Debug, thiserror::Error)] + pub enum ChatCompletionChunkResponseError { + #[error("failed to deserialize")] + Deserialization(#[from] serde_json::Error), + #[error("empty content in data chunk")] + EmptyContent, + #[error("no chunks present")] + NoChunks, + } + + impl TryFrom<&str> for ChatCompletionChunkResponse { + type Error = ChatCompletionChunkResponseError; + + fn try_from(value: &str) -> Result { + let mut response_chunks: VecDeque = value + .split("data: ") + .map(|data_chunk| serde_json::from_str::(data_chunk)) + .collect::, _>>()?; + + let new_contents: String = response_chunks + .iter_mut() + .map(|response_chunk| { + response_chunk.choices[0] + .delta + .content + .take() + .ok_or(ChatCompletionChunkResponseError::EmptyContent) + }) + .collect::, _>>()? + .join(" "); + + let mut response_chunk = response_chunks + .pop_front() + .ok_or(ChatCompletionChunkResponseError::NoChunks)?; + + response_chunk.choices[0].delta.content = Some(new_contents); + + Ok(response_chunk) + } + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChunkChoice { pub delta: Delta, diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index fd634915..9f489663 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,6 +1,7 @@ use proxy_wasm::types::Status; +use serde_json::error; -use crate::ratelimit; +use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit}; #[derive(thiserror::Error, Debug)] pub enum ClientError { @@ -36,4 +37,6 @@ pub enum ServerError { ExceededRatelimit(ratelimit::Error), #[error("{why}")] BadRequest { why: String }, + #[error("error in streaming response")] + Streaming(#[from] ChatCompletionChunkResponseError), } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9eab8586..7ae841da 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -290,32 +290,13 @@ impl HttpContext for StreamContext { if self.streaming_response.is_some() { let body_str = String::from_utf8(body).expect("body is not utf-8"); debug!("streaming response"); - let chat_completions_data = match body_str.split_once("data: ") { - Some((_, chat_completions_data)) => chat_completions_data, - None => { - self.send_server_error( - ServerError::LogicError(String::from("parsing error in streaming data")), - None, - ); - return Action::Pause; - } - }; - let chat_completions_chunk_response: ChatCompletionChunkResponse = - match serde_json::from_str(chat_completions_data) { - Ok(de) => de, - Err(_) => { - if chat_completions_data != "[NONE]" { - debug!("received incorrect streaming data={chat_completions_data}"); - self.send_server_error( - ServerError::LogicError(String::from( - "error in streaming response", - )), - None, - ); - return Action::Continue; - } - return Action::Continue; + let chat_completions_chunk_response = + match ChatCompletionChunkResponse::try_from(body_str.as_str()) { + Ok(response) => response, + Err(e) => { + self.send_server_error(e.into(), None); + return Action::Pause; } };