more updates

This commit is contained in:
Adil Hafeez 2024-10-24 15:32:51 -07:00
parent 03a02455e8
commit 81f50911a0
11 changed files with 269 additions and 88 deletions

View file

@ -160,4 +160,3 @@ required:
- version
- listener
- llm_providers
- prompt_targets

View file

@ -19,3 +19,4 @@ services:
- "host.docker.internal:host-gateway"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}

View file

@ -47,13 +47,14 @@ def validate_and_render_schema():
config_schema_yaml = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", "")
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,
"port": 80, # default port
}
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", "")
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,
"port": 80, # default port
}
print(inferred_clusters)
endpoints = config_yaml.get("endpoints", {})

View file

@ -19,18 +19,25 @@ def predict(message, history):
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(
stream = True
raw_response = client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages=history_openai_format,
temperature=1.0,
stream=True,
stream=stream,
)
response = raw_response.parse()
partial_message = ""
for chunk in response:
if chunk.choices[0].delta.content is not None:
partial_message = partial_message + chunk.choices[0].delta.content
yield partial_message
if stream:
for chunk in response:
if chunk.choices[0].delta.content is not None:
partial_message = partial_message + chunk.choices[0].delta.content
yield partial_message
else:
partial_message = response.choices[0].message.content
yield partial_message
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8080)

View file

@ -34,7 +34,10 @@ pub struct SearchPointResult {
}
pub mod open_ai {
use std::collections::{HashMap, VecDeque};
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
@ -256,37 +259,44 @@ pub mod open_ai {
NoChunks,
}
impl TryFrom<&str> for ChatCompletionChunkResponse {
type Error = ChatCompletionChunkResponseError;
pub struct ChatCompletionChunkResponseServerEvents {
pub events: Vec<ChatCompletionChunkResponse>,
}
fn try_from(value: &str) -> Result<Self, Self::Error> {
let mut response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(|data_chunk| serde_json::from_str::<ChatCompletionChunkResponse>(data_chunk))
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
let new_contents: String = response_chunks
.iter_mut()
impl Display for ChatCompletionChunkResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
response_chunk.choices[0]
.delta
.content
.take()
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
let mut response_chunk = response_chunks
.pop_front()
.ok_or(ChatCompletionChunkResponseError::NoChunks)?;
write!(f, "{}", tokens_str)
}
}
response_chunk.choices[0].delta.content = Some(new_contents);
impl TryFrom<&str> for ChatCompletionChunkResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
Ok(response_chunk)
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionChunkResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionChunkResponse>)
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
Ok(ChatCompletionChunkResponseServerEvents {
events: response_chunks.into(),
})
}
}
@ -357,7 +367,7 @@ pub struct PromptGuardResponse {
#[cfg(test)]
mod test {
use crate::common_types::open_ai::Message;
use crate::common_types::open_ai::{ChatCompletionChunkResponseServerEvents, Message};
use pretty_assertions::{assert_eq, assert_ne};
use std::collections::HashMap;
@ -510,13 +520,50 @@ data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.c
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
let sever_events =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
"Hello! How can"
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
@ -538,12 +585,90 @@ data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.c
data: [DONE]
"#;
let chunk_response: ChatCompletionChunkResponse =
ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(chunk_response.choices.len(), 1);
let sever_events: ChatCompletionChunkResponseServerEvents =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
chunk_response.choices[0].delta.content.as_ref().unwrap(),
" I assist you today?"
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionChunkResponseServerEvents =
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -27,12 +27,12 @@ pub enum GatewayMode {
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: HashMap<String, Endpoint>,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Vec<PromptTarget>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
@ -246,8 +246,10 @@ mod test {
);
let prompt_targets = &config.prompt_targets;
assert_eq!(prompt_targets.len(), 2);
assert_eq!(prompt_targets.as_ref().unwrap().len(), 2);
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
@ -255,6 +257,8 @@ mod test {
assert_eq!(prompt_target.default, None);
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "information_extraction")
.unwrap();

View file

@ -1,17 +1,19 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
#[error("Unknown model: {model_name}")]
UnknownModel { model_name: String },
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
})?;
Ok(bpe.encode_ordinary(text).len())
}
@ -32,7 +34,9 @@ mod test {
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
)
}

View file

@ -1,6 +1,7 @@
use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, StreamOptions,
ChatCompletionChunkResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
StreamOptions,
};
use common::configuration::LlmProvider;
use common::consts::{
@ -258,23 +259,10 @@ impl HttpContext for StreamContext {
);
if !self.is_chat_completions_request {
debug!("non-chatgpt request");
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!(
"on_http_response_body non-chatgpt request [S={}] body_str={}",
self.context_id, body_str
);
}
debug!("non-chatcompletion request");
return Action::Continue;
}
if !end_of_stream && self.streaming_response.is_none() {
return Action::Pause;
}
let body = match self.streaming_response.take() {
Some(mut streaming_response) => {
if end_of_stream && body_size == 0 {
@ -320,7 +308,7 @@ impl HttpContext for StreamContext {
}
};
let body_utf8 = match String::from_utf8(body.to_vec()) {
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
debug!("could not convert to utf8: {}", e);
@ -328,41 +316,51 @@ impl HttpContext for StreamContext {
}
};
debug!("chunk data: body str: {}", body_utf8);
if self.streaming_response.is_some() {
let chat_completions_chunk_response =
match ChatCompletionChunkResponse::try_from(body_utf8.as_str()) {
let chat_completions_chunk_response_events =
match ChatCompletionChunkResponseServerEvents::try_from(body_utf8.as_str()) {
Ok(response) => response,
Err(e) => {
debug!(
"invalid streaming response: body str: {}, {:?}",
body_utf8, e
);
self.send_server_error(e.into(), None);
return Action::Pause;
return Action::Continue;
}
};
if let Some(content) = chat_completions_chunk_response
.choices
if chat_completions_chunk_response_events.events.is_empty() {
debug!("empty streaming response");
return Action::Continue;
}
let mut model = chat_completions_chunk_response_events
.events
.first()
.unwrap()
.delta
.content
.as_ref()
{
let model = &chat_completions_chunk_response.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
.model
.clone();
let tokens_str = chat_completions_chunk_response_events.to_string();
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
if model.starts_with("mistral") || model.starts_with("ministral") {
model = "gpt-4".to_string();
}
let token_count = match tokenizer::token_count(model.as_str(), tokens_str.as_str()) {
Ok(token_count) => token_count,
Err(e) => {
debug!("could not get token count: {:?}", e);
return Action::Continue;
}
};
self.response_tokens += token_count;
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
match serde_json::from_str(body_utf8.as_str()) {
Ok(de) => de,
Err(_e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
debug!("invalid response: {}", body_utf8);
return Action::Continue;
}
};
@ -381,7 +379,6 @@ impl HttpContext for StreamContext {
self.context_id, self.response_tokens, end_of_stream
);
// TODO:: ratelimit based on response tokens.
Action::Continue
}
}

View file

@ -243,7 +243,7 @@ impl RootContext for FilterContext {
self.overrides = Rc::new(config.overrides);
let mut prompt_targets = HashMap::new();
for pt in config.prompt_targets {
for pt in config.prompt_targets.unwrap_or_default() {
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.system_prompt = Rc::new(config.system_prompt);

View file

@ -0,0 +1,31 @@
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
llm_providers:
- name: gpt-3.5
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-3.5-turbo
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
- name: ministral-8b
access_key: $MISTRAL_API_KEY
provider: mistral
model: ministral-8b-latest
- name: ministral-3b
access_key: $MISTRAL_API_KEY
provider: mistral
model: ministral-3b-latest
tracing:
random_sampling: 100

View file

@ -0,0 +1,12 @@
services:
chatbot_ui:
build:
context: ../../chatbot_ui
dockerfile: Dockerfile
ports:
- "18080:8080"
environment:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:12000/v1
extra_hosts:
- "host.docker.internal:host-gateway"