mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
latest
This commit is contained in:
parent
ad23fb81bc
commit
4c7a7df08c
3 changed files with 52 additions and 27 deletions
|
|
@ -34,7 +34,7 @@ pub struct SearchPointResult {
|
|||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
|
@ -246,6 +246,47 @@ pub mod open_ai {
|
|||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
Deserialization(#[from] serde_json::Error),
|
||||
#[error("empty content in data chunk")]
|
||||
EmptyContent,
|
||||
#[error("no chunks present")]
|
||||
NoChunks,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionChunkResponse {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let mut response_chunks: VecDeque<ChatCompletionChunkResponse> = value
|
||||
.split("data: ")
|
||||
.map(|data_chunk| serde_json::from_str::<ChatCompletionChunkResponse>(data_chunk))
|
||||
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
|
||||
|
||||
let new_contents: String = response_chunks
|
||||
.iter_mut()
|
||||
.map(|response_chunk| {
|
||||
response_chunk.choices[0]
|
||||
.delta
|
||||
.content
|
||||
.take()
|
||||
.ok_or(ChatCompletionChunkResponseError::EmptyContent)
|
||||
})
|
||||
.collect::<Result<Vec<String>, _>>()?
|
||||
.join(" ");
|
||||
|
||||
let mut response_chunk = response_chunks
|
||||
.pop_front()
|
||||
.ok_or(ChatCompletionChunkResponseError::NoChunks)?;
|
||||
|
||||
response_chunk.choices[0].delta.content = Some(new_contents);
|
||||
|
||||
Ok(response_chunk)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use proxy_wasm::types::Status;
|
||||
use serde_json::error;
|
||||
|
||||
use crate::ratelimit;
|
||||
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
|
|
@ -36,4 +37,6 @@ pub enum ServerError {
|
|||
ExceededRatelimit(ratelimit::Error),
|
||||
#[error("{why}")]
|
||||
BadRequest { why: String },
|
||||
#[error("error in streaming response")]
|
||||
Streaming(#[from] ChatCompletionChunkResponseError),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -290,32 +290,13 @@ impl HttpContext for StreamContext {
|
|||
if self.streaming_response.is_some() {
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
debug!("streaming response");
|
||||
let chat_completions_data = match body_str.split_once("data: ") {
|
||||
Some((_, chat_completions_data)) => chat_completions_data,
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("parsing error in streaming data")),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completions_chunk_response: ChatCompletionChunkResponse =
|
||||
match serde_json::from_str(chat_completions_data) {
|
||||
Ok(de) => de,
|
||||
Err(_) => {
|
||||
if chat_completions_data != "[NONE]" {
|
||||
debug!("received incorrect streaming data={chat_completions_data}");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from(
|
||||
"error in streaming response",
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
return Action::Continue;
|
||||
let chat_completions_chunk_response =
|
||||
match ChatCompletionChunkResponse::try_from(body_str.as_str()) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
self.send_server_error(e.into(), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue