more updates

This commit is contained in:
Adil Hafeez 2024-10-24 15:32:51 -07:00
parent 03a02455e8
commit 81f50911a0
11 changed files with 269 additions and 88 deletions

View file

@ -34,7 +34,10 @@ pub struct SearchPointResult {
}
pub mod open_ai {
use std::collections::{HashMap, VecDeque};
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
@ -256,37 +259,44 @@ pub mod open_ai {
NoChunks,
}
impl TryFrom<&str> for ChatCompletionChunkResponse {
type Error = ChatCompletionChunkResponseError;
pub struct ChatCompletionChunkResponseServerEvents {
pub events: Vec<ChatCompletionChunkResponse>,
}
fn try_from(value: &str) -> Result<Self, Self::Error> {
let mut response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(|data_chunk| serde_json::from_str::<ChatCompletionChunkResponse>(data_chunk))
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
let new_contents: String = response_chunks
.iter_mut()
impl Display for ChatCompletionChunkResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
response_chunk.choices[0]
.delta
.content
.take()
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
let mut response_chunk = response_chunks
.pop_front()
.ok_or(ChatCompletionChunkResponseError::NoChunks)?;
write!(f, "{}", tokens_str)
}
}
response_chunk.choices[0].delta.content = Some(new_contents);
impl TryFrom<&str> for ChatCompletionChunkResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
Ok(response_chunk)
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionChunkResponse>)
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
Ok(ChatCompletionChunkResponseServerEvents {
events: response_chunks.into(),
})
}
}
@ -357,7 +367,7 @@ pub struct PromptGuardResponse {
#[cfg(test)]
mod test {
use crate::common_types::open_ai::Message;
use crate::common_types::open_ai::{ChatCompletionChunkResponseServerEvents, Message};
use pretty_assertions::{assert_eq, assert_ne};
use std::collections::HashMap;
@ -510,13 +520,50 @@ data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.c
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
let sever_events =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
"Hello! How can"
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
@ -538,12 +585,90 @@ data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.c
data: [DONE]
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
let sever_events: ChatCompletionChunkResponseServerEvents =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
" I assist you today?"
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionChunkResponseServerEvents =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -27,12 +27,12 @@ pub enum GatewayMode {
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: HashMap<String, Endpoint>,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Vec<PromptTarget>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
@ -246,8 +246,10 @@ mod test {
);
let prompt_targets = &config.prompt_targets;
assert_eq!(prompt_targets.len(), 2);
assert_eq!(prompt_targets.as_ref().unwrap().len(), 2);
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
@ -255,6 +257,8 @@ mod test {
assert_eq!(prompt_target.default, None);
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "information_extraction")
.unwrap();

View file

@ -1,17 +1,19 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
#[error("Unknown model: {model_name}")]
UnknownModel { model_name: String },
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
})?;
Ok(bpe.encode_ordinary(text).len())
}
@ -32,7 +34,9 @@ mod test {
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
)
}