From 81f50911a0c01f64672a41eff891e6fa1ba32cc7 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 24 Oct 2024 15:32:51 -0700 Subject: [PATCH] more updates --- arch/arch_config_schema.yaml | 1 - arch/docker-compose.dev.yaml | 1 + arch/tools/cli/config_generator.py | 15 +- chatbot_ui/app/run_stream.py | 19 +- crates/common/src/common_types.rs | 187 ++++++++++++++++---- crates/common/src/configuration.rs | 10 +- crates/common/src/tokenizer.rs | 14 +- crates/llm_gateway/src/stream_context.rs | 65 ++++--- crates/prompt_gateway/src/filter_context.rs | 2 +- demos/llm_routing/arch_config.yaml | 31 ++++ demos/llm_routing/docker-compose.yaml | 12 ++ 11 files changed, 269 insertions(+), 88 deletions(-) create mode 100644 demos/llm_routing/arch_config.yaml create mode 100644 demos/llm_routing/docker-compose.yaml diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 9b63840e..142fe338 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -160,4 +160,3 @@ required: - version - listener - llm_providers - - prompt_targets diff --git a/arch/docker-compose.dev.yaml b/arch/docker-compose.dev.yaml index c2dcb332..fdf024c6 100644 --- a/arch/docker-compose.dev.yaml +++ b/arch/docker-compose.dev.yaml @@ -19,3 +19,4 @@ services: - "host.docker.internal:host-gateway" environment: - OPENAI_API_KEY=${OPENAI_API_KEY:?error} + - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 33741ee9..1e5fd4a3 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -47,13 +47,14 @@ def validate_and_render_schema(): config_schema_yaml = yaml.safe_load(arch_config_schema) inferred_clusters = {} - for prompt_target in config_yaml["prompt_targets"]: - name = prompt_target.get("endpoint", {}).get("name", "") - if name not in inferred_clusters: - inferred_clusters[name] = { - "name": name, - "port": 80, # default port - } + if "prompt_targets" in config_yaml: + for prompt_target in config_yaml["prompt_targets"]: + name = prompt_target.get("endpoint", {}).get("name", "") + if name not in inferred_clusters: + inferred_clusters[name] = { + "name": name, + "port": 80, # default port + } print(inferred_clusters) endpoints = config_yaml.get("endpoints", {}) diff --git a/chatbot_ui/app/run_stream.py b/chatbot_ui/app/run_stream.py index 00cccd23..55c505e4 100644 --- a/chatbot_ui/app/run_stream.py +++ b/chatbot_ui/app/run_stream.py @@ -19,18 +19,25 @@ def predict(message, history): history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": message}) - response = client.chat.completions.create( + stream = True + raw_response = client.chat.completions.with_raw_response.create( model="gpt-3.5-turbo", messages=history_openai_format, temperature=1.0, - stream=True, + stream=stream, ) + response = raw_response.parse() + partial_message = "" - for chunk in response: - if chunk.choices[0].delta.content is not None: - partial_message = partial_message + chunk.choices[0].delta.content - yield partial_message + if stream: + for chunk in response: + if chunk.choices[0].delta.content is not None: + partial_message = partial_message + chunk.choices[0].delta.content + yield partial_message + else: + partial_message = response.choices[0].message.content + yield partial_message gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8080) diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index 16925c66..f9bf5921 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -34,7 +34,10 @@ pub struct SearchPointResult { } pub mod open_ai { - use std::collections::{HashMap, VecDeque}; + use std::{ + collections::{HashMap, VecDeque}, + fmt::Display, + }; use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; @@ -256,37 +259,44 @@ pub mod open_ai { NoChunks, } - impl TryFrom<&str> for ChatCompletionChunkResponse { - type Error = ChatCompletionChunkResponseError; + pub struct ChatCompletionChunkResponseServerEvents { + pub events: Vec, + } - fn try_from(value: &str) -> Result { - let mut response_chunks: VecDeque = value - .lines() - .filter(|line| line.starts_with("data: ")) - .map(|line| line.get(6..).unwrap()) - .filter(|data_chunk| *data_chunk != "[DONE]") - .map(|data_chunk| serde_json::from_str::(data_chunk)) - .collect::, _>>()?; - - let new_contents: String = response_chunks - .iter_mut() + impl Display for ChatCompletionChunkResponseServerEvents { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let tokens_str = self + .events + .iter() .map(|response_chunk| { response_chunk.choices[0] .delta .content - .take() + .clone() .unwrap_or("".to_string()) }) .collect::>() .join(""); - let mut response_chunk = response_chunks - .pop_front() - .ok_or(ChatCompletionChunkResponseError::NoChunks)?; + write!(f, "{}", tokens_str) + } + } - response_chunk.choices[0].delta.content = Some(new_contents); + impl TryFrom<&str> for ChatCompletionChunkResponseServerEvents { + type Error = ChatCompletionChunkResponseError; - Ok(response_chunk) + fn try_from(value: &str) -> Result { + let response_chunks: VecDeque = value + .lines() + .filter(|line| line.starts_with("data: ")) + .map(|line| line.get(6..).unwrap()) + .filter(|data_chunk| *data_chunk != "[DONE]") + .map(serde_json::from_str::) + .collect::, _>>()?; + + Ok(ChatCompletionChunkResponseServerEvents { + events: response_chunks.into(), + }) } } @@ -357,7 +367,7 @@ pub struct PromptGuardResponse { #[cfg(test)] mod test { - use crate::common_types::open_ai::Message; + use crate::common_types::open_ai::{ChatCompletionChunkResponseServerEvents, Message}; use pretty_assertions::{assert_eq, assert_ne}; use std::collections::HashMap; @@ -510,13 +520,50 @@ data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.c "#; - let chunk_response: ChatCompletionChunkResponse = - ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap(); - assert_eq!(chunk_response.choices.len(), 1); + let sever_events = + ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 5); assert_eq!( - chunk_response.choices[0].delta.content.as_ref().unwrap(), - "Hello! How can" + sever_events.events[0].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "" ); + assert_eq!( + sever_events.events[1].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "Hello" + ); + assert_eq!( + sever_events.events[2].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "!" + ); + assert_eq!( + sever_events.events[3].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " How" + ); + assert_eq!( + sever_events.events[4].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " can" + ); + assert_eq!(sever_events.to_string(), "Hello! How can"); } #[test] @@ -538,12 +585,90 @@ data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.c data: [DONE] "#; - let chunk_response: ChatCompletionChunkResponse = - ChatCompletionChunkResponse::try_from(CHUNK_RESPONSE).unwrap(); - assert_eq!(chunk_response.choices.len(), 1); + let sever_events: ChatCompletionChunkResponseServerEvents = + ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 6); assert_eq!( - chunk_response.choices[0].delta.content.as_ref().unwrap(), - " I assist you today?" + sever_events.events[0].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " I" + ); + assert_eq!( + sever_events.events[1].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " assist" + ); + assert_eq!( + sever_events.events[2].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " you" + ); + assert_eq!( + sever_events.events[3].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " today" + ); + assert_eq!( + sever_events.events[4].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "?" + ); + assert_eq!(sever_events.events[5].choices[0].delta.content, None); + + assert_eq!(sever_events.to_string(), " I assist you today?"); + } + + #[test] + fn stream_chunk_parse_mistral() { + use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta}; + + const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}} + +data: [DONE] +"#; + + let sever_events: ChatCompletionChunkResponseServerEvents = + ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 11); + + assert_eq!( + sever_events.to_string(), + "Hello! How can I assist you today?" ); } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 293dad09..ef57845a 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -27,12 +27,12 @@ pub enum GatewayMode { pub struct Configuration { pub version: String, pub listener: Listener, - pub endpoints: HashMap, + pub endpoints: Option>, pub llm_providers: Vec, pub overrides: Option, pub system_prompt: Option, pub prompt_guards: Option, - pub prompt_targets: Vec, + pub prompt_targets: Option>, pub error_target: Option, pub ratelimits: Option>, pub tracing: Option, @@ -246,8 +246,10 @@ mod test { ); let prompt_targets = &config.prompt_targets; - assert_eq!(prompt_targets.len(), 2); + assert_eq!(prompt_targets.as_ref().unwrap().len(), 2); let prompt_target = prompt_targets + .as_ref() + .unwrap() .iter() .find(|p| p.name == "reboot_network_device") .unwrap(); @@ -255,6 +257,8 @@ mod test { assert_eq!(prompt_target.default, None); let prompt_target = prompt_targets + .as_ref() + .unwrap() .iter() .find(|p| p.name == "information_extraction") .unwrap(); diff --git a/crates/common/src/tokenizer.rs b/crates/common/src/tokenizer.rs index 25ac924e..aa0870f2 100644 --- a/crates/common/src/tokenizer.rs +++ b/crates/common/src/tokenizer.rs @@ -1,17 +1,19 @@ use log::debug; -#[derive(Debug, PartialEq, Eq)] +#[derive(thiserror::Error, Debug, PartialEq, Eq)] #[allow(dead_code)] pub enum Error { - UnknownModel, - FailedToTokenize, + #[error("Unknown model: {model_name}")] + UnknownModel { model_name: String }, } #[allow(dead_code)] pub fn token_count(model_name: &str, text: &str) -> Result { debug!("getting token count model={}", model_name); // Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton? - let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?; + let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel { + model_name: model_name.to_string(), + })?; Ok(bpe.encode_ordinary(text).len()) } @@ -32,7 +34,9 @@ mod test { #[test] fn unrecognized_model() { assert_eq!( - Error::UnknownModel, + Error::UnknownModel { + model_name: "unknown".to_string() + }, token_count("unknown", "").expect_err("unknown model") ) } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 2bbc8101..f66aae7c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,6 +1,7 @@ use crate::filter_context::WasmMetrics; use common::common_types::open_ai::{ - ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, StreamOptions, + ChatCompletionChunkResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse, + StreamOptions, }; use common::configuration::LlmProvider; use common::consts::{ @@ -258,23 +259,10 @@ impl HttpContext for StreamContext { ); if !self.is_chat_completions_request { - debug!("non-chatgpt request"); - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!( - "on_http_response_body non-chatgpt request [S={}] body_str={}", - self.context_id, body_str - ); - } + debug!("non-chatcompletion request"); return Action::Continue; } - if !end_of_stream && self.streaming_response.is_none() { - return Action::Pause; - } - let body = match self.streaming_response.take() { Some(mut streaming_response) => { if end_of_stream && body_size == 0 { @@ -320,7 +308,7 @@ impl HttpContext for StreamContext { } }; - let body_utf8 = match String::from_utf8(body.to_vec()) { + let body_utf8 = match String::from_utf8(body) { Ok(body_utf8) => body_utf8, Err(e) => { debug!("could not convert to utf8: {}", e); @@ -328,41 +316,51 @@ impl HttpContext for StreamContext { } }; - debug!("chunk data: body str: {}", body_utf8); - if self.streaming_response.is_some() { - let chat_completions_chunk_response = - match ChatCompletionChunkResponse::try_from(body_utf8.as_str()) { + let chat_completions_chunk_response_events = + match ChatCompletionChunkResponseServerEvents::try_from(body_utf8.as_str()) { Ok(response) => response, Err(e) => { debug!( "invalid streaming response: body str: {}, {:?}", body_utf8, e ); - self.send_server_error(e.into(), None); - return Action::Pause; + return Action::Continue; } }; - if let Some(content) = chat_completions_chunk_response - .choices + if chat_completions_chunk_response_events.events.is_empty() { + debug!("empty streaming response"); + return Action::Continue; + } + + let mut model = chat_completions_chunk_response_events + .events .first() .unwrap() - .delta - .content - .as_ref() - { - let model = &chat_completions_chunk_response.model; - let token_count = tokenizer::token_count(model, content).unwrap_or(0); - self.response_tokens += token_count; + .model + .clone(); + let tokens_str = chat_completions_chunk_response_events.to_string(); + //HACK: add support for tokenizing mistral and other models + //filed issue https://github.com/katanemo/arch/issues/222 + if model.starts_with("mistral") || model.starts_with("ministral") { + model = "gpt-4".to_string(); } + let token_count = match tokenizer::token_count(model.as_str(), tokens_str.as_str()) { + Ok(token_count) => token_count, + Err(e) => { + debug!("could not get token count: {:?}", e); + return Action::Continue; + } + }; + self.response_tokens += token_count; } else { debug!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { + match serde_json::from_str(body_utf8.as_str()) { Ok(de) => de, Err(_e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); + debug!("invalid response: {}", body_utf8); return Action::Continue; } }; @@ -381,7 +379,6 @@ impl HttpContext for StreamContext { self.context_id, self.response_tokens, end_of_stream ); - // TODO:: ratelimit based on response tokens. Action::Continue } } diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index 3f1d3f0d..de120369 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -243,7 +243,7 @@ impl RootContext for FilterContext { self.overrides = Rc::new(config.overrides); let mut prompt_targets = HashMap::new(); - for pt in config.prompt_targets { + for pt in config.prompt_targets.unwrap_or_default() { prompt_targets.insert(pt.name.clone(), pt.clone()); } self.system_prompt = Rc::new(config.system_prompt); diff --git a/demos/llm_routing/arch_config.yaml b/demos/llm_routing/arch_config.yaml new file mode 100644 index 00000000..c5839bf4 --- /dev/null +++ b/demos/llm_routing/arch_config.yaml @@ -0,0 +1,31 @@ +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +llm_providers: + - name: gpt-3.5 + access_key: $OPENAI_API_KEY + provider: openai + model: gpt-3.5-turbo + + - name: gpt-4o + access_key: $OPENAI_API_KEY + provider: openai + model: gpt-4o + + - name: ministral-8b + access_key: $MISTRAL_API_KEY + provider: mistral + model: ministral-8b-latest + + - name: ministral-3b + access_key: $MISTRAL_API_KEY + provider: mistral + model: ministral-3b-latest + +tracing: + random_sampling: 100 diff --git a/demos/llm_routing/docker-compose.yaml b/demos/llm_routing/docker-compose.yaml new file mode 100644 index 00000000..f8200977 --- /dev/null +++ b/demos/llm_routing/docker-compose.yaml @@ -0,0 +1,12 @@ +services: + + chatbot_ui: + build: + context: ../../chatbot_ui + dockerfile: Dockerfile + ports: + - "18080:8080" + environment: + - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:12000/v1 + extra_hosts: + - "host.docker.internal:host-gateway"