This commit is contained in:
José Ulises Niño Rivera 2024-10-22 21:47:10 -04:00
parent ad23fb81bc
commit 4c7a7df08c
3 changed files with 52 additions and 27 deletions

View file

@ -34,7 +34,7 @@ pub struct SearchPointResult {
}
pub mod open_ai {
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
@ -246,6 +246,47 @@ pub mod open_ai {
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]
Deserialization(#[from] serde_json::Error),
#[error("empty content in data chunk")]
EmptyContent,
#[error("no chunks present")]
NoChunks,
}
impl TryFrom<&str> for ChatCompletionChunkResponse {
type Error = ChatCompletionChunkResponseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let mut response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.split("data: ")
.map(|data_chunk| serde_json::from_str::<ChatCompletionChunkResponse>(data_chunk))
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
let new_contents: String = response_chunks
.iter_mut()
.map(|response_chunk| {
response_chunk.choices[0]
.delta
.content
.take()
.ok_or(ChatCompletionChunkResponseError::EmptyContent)
})
.collect::<Result<Vec<String>, _>>()?
.join(" ");
let mut response_chunk = response_chunks
.pop_front()
.ok_or(ChatCompletionChunkResponseError::NoChunks)?;
response_chunk.choices[0].delta.content = Some(new_contents);
Ok(response_chunk)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,

View file

@ -1,6 +1,7 @@
use proxy_wasm::types::Status;
use serde_json::error;
use crate::ratelimit;
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
@ -36,4 +37,6 @@ pub enum ServerError {
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
#[error("error in streaming response")]
Streaming(#[from] ChatCompletionChunkResponseError),
}

View file

@ -290,32 +290,13 @@ impl HttpContext for StreamContext {
if self.streaming_response.is_some() {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(
ServerError::LogicError(String::from("parsing error in streaming data")),
None,
);
return Action::Pause;
}
};
let chat_completions_chunk_response: ChatCompletionChunkResponse =
match serde_json::from_str(chat_completions_data) {
Ok(de) => de,
Err(_) => {
if chat_completions_data != "[NONE]" {
debug!("received incorrect streaming data={chat_completions_data}");
self.send_server_error(
ServerError::LogicError(String::from(
"error in streaming response",
)),
None,
);
return Action::Continue;
}
return Action::Continue;
let chat_completions_chunk_response =
match ChatCompletionChunkResponse::try_from(body_str.as_str()) {
Ok(response) => response,
Err(e) => {
self.send_server_error(e.into(), None);
return Action::Pause;
}
};