From 21a5bdd561d8c6640d38fb79a88d2c7d88ca1f5e Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 24 Mar 2025 16:53:36 -0700 Subject: [PATCH] omni llm --- arch/tools/cli/docker_cli.py | 2 +- crates/Cargo.lock | 56 +++++---- crates/Cargo.toml | 2 +- crates/common/src/api/open_ai.rs | 2 +- crates/llm_gateway/Cargo.toml | 1 + crates/llm_gateway/src/stream_context.rs | 61 +++++---- crates/omnillm/Cargo.toml | 9 ++ crates/omnillm/src/lib.rs | 153 +++++++++++++++++++++++ 8 files changed, 239 insertions(+), 47 deletions(-) create mode 100644 crates/omnillm/Cargo.toml create mode 100644 crates/omnillm/src/lib.rs diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index 6edfb8dc..0bc4ee54 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -56,7 +56,7 @@ def docker_start_archgw_detached( volume_mappings = [ f"{logs_path_abs}:/var/log:rw", f"{arch_config_file}:/app/arch_config.yaml:ro", - # "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", + "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", ] volume_mappings_args = [ item for volume in volume_mappings for item in ("-v", volume) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index b585ef6e..ae66c599 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -64,9 +64,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.90" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" [[package]] name = "arbitrary" @@ -82,7 +82,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -487,7 +487,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -919,7 +919,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -1061,6 +1061,7 @@ dependencies = [ "http", "log", "md5", + "omnillm", "proxy-wasm", "proxy-wasm-test-framework", "rand", @@ -1160,6 +1161,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "omnillm" +version = "0.1.0" +dependencies = [ + "anyhow", + "serde", + "serde_json", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1529,29 +1539,29 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] name = "serde_json" -version = "1.0.130" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -1603,7 +1613,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -1711,9 +1721,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1728,7 +1738,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -1772,7 +1782,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -2071,7 +2081,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", "wasmtime-component-util", "wasmtime-wit-bindgen", "wit-parser", @@ -2201,7 +2211,7 @@ checksum = "a2bde986038b819bc43a21fef0610aeb47aabfe3ea09ca3533a7b81023b84ec6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -2450,7 +2460,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", "synstructure", ] @@ -2472,7 +2482,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] @@ -2492,7 +2502,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", "synstructure", ] @@ -2515,7 +2525,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.87", ] [[package]] diff --git a/crates/Cargo.toml b/crates/Cargo.toml index 3ba99280..2a62c18d 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -1,3 +1,3 @@ [workspace] resolver = "2" -members = ["llm_gateway", "prompt_gateway", "common"] +members = ["llm_gateway", "prompt_gateway", "common", "omnillm"] diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index d71b0d58..e7b330bf 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -51,7 +51,7 @@ impl Serialize for FunctionParameters { where S: serde::Serializer, { - // select all requried parameters + // select all required parameters let required: Vec<&String> = self .properties .iter() diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index 73d62c3d..c7e4df38 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -22,6 +22,7 @@ rand = "0.8.5" thiserror = "1.0.64" derivative = "2.2.0" sha2 = "0.10.8" +omnillm = { version = "0.1.0", path = "../omnillm" } [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 20ca9d62..c83cd9ea 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -16,12 +16,14 @@ use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, trace, warn}; +use omnillm::{ChatRequest, LlmProvider as LlmProviderTrait, LlmProviders as OmniLlmProviders}; use proxy_wasm::hostcalls::get_current_time; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::collections::VecDeque; use std::num::NonZero; use std::rc::Rc; +use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -40,9 +42,10 @@ pub struct StreamContext { ttft_time: Option, traceparent: Option, request_body_sent_time: Option, - user_message: Option, + user_message: Option, traces_queue: Arc>>, overrides: Rc>, + omni_llm_providers: Rc, } impl StreamContext { @@ -71,6 +74,7 @@ impl StreamContext { user_message: None, traces_queue, request_body_sent_time: None, + omni_llm_providers: Rc::new(OmniLlmProviders::new()), } } fn llm_provider(&self) -> &LlmProvider { @@ -271,18 +275,17 @@ impl HttpContext for StreamContext { // Deserialize body into spec. // Currently OpenAI API. - let mut deserialized_body: ChatCompletionsRequest = - match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!("body str: {}", String::from_utf8_lossy(&body_bytes)); - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let mut deserialized_body: ChatRequest = match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!("body str: {}", String::from_utf8_lossy(&body_bytes)); + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; // remove metadata from the request body //TODO: move this to prompt gateway @@ -295,7 +298,7 @@ impl HttpContext for StreamContext { self.user_message = deserialized_body .messages .iter() - .filter(|m| m.role == "user") + .filter(|m| m.role == omnillm::Role::User) .last() .cloned(); @@ -336,16 +339,32 @@ impl HttpContext for StreamContext { model_name, ); - let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); + let omn_provider_type = + omnillm::Provider::from_str(&self.llm_provider().provider_interface.to_string()) + .unwrap(); - trace!("request body: {}", chat_completion_request_str); + let chat_request_bytes = self + .omni_llm_providers + .as_ref() + .providers + .get(&omn_provider_type) + .unwrap() + .translate_request(&deserialized_body) + .unwrap(); - if deserialized_body.stream { + trace!( + "request body str: {}", + String::from_utf8_lossy(&chat_request_bytes) + ); + + if deserialized_body.stream.unwrap_or_default() { self.streaming_response = true; } - if deserialized_body.stream && deserialized_body.stream_options.is_none() { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, + if deserialized_body.stream.unwrap_or_default() + && deserialized_body.stream_options.is_none() + { + deserialized_body.stream_options = Some(omnillm::StreamOptions { + include_usage: Some(true), }); } @@ -367,7 +386,7 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); + self.set_http_request_body(0, body_size, &chat_request_bytes); Action::Continue } diff --git a/crates/omnillm/Cargo.toml b/crates/omnillm/Cargo.toml new file mode 100644 index 00000000..edb2aeca --- /dev/null +++ b/crates/omnillm/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "omnillm" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.97" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" diff --git a/crates/omnillm/src/lib.rs b/crates/omnillm/src/lib.rs new file mode 100644 index 00000000..d04bd3da --- /dev/null +++ b/crates/omnillm/src/lib.rs @@ -0,0 +1,153 @@ +use std::{collections::HashMap, str::FromStr}; + +use anyhow::Result; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Hash, PartialEq, Eq)] +pub enum Provider { + OpenAI, + Mistral, +} + +impl FromStr for Provider { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "openai" => Ok(Provider::OpenAI), + "mistral" => Ok(Provider::Mistral), + _ => Err(anyhow::anyhow!(format!("Invalid provider: {}", s))), + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +#[derive(Serialize, Deserialize, PartialEq, Clone)] +pub enum Role { + #[serde(rename = "system")] + System, + #[serde(rename = "user")] + User, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct Message { + pub role: Role, + pub content: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: i64, + pub completion_tokens: i64, + pub total: i64, +} + +#[derive(Serialize, Deserialize)] +pub struct ChatResponse { + pub messages: Vec, + pub usage: Option, +} + +pub trait LlmProvider { + fn translate_request(&self, request: &ChatRequest) -> Result>; + fn translate_response(&self, response: &Vec) -> Result; +} + +pub struct LlmProviders { + pub providers: HashMap>, +} + +impl LlmProviders { + pub fn new() -> LlmProviders { + LlmProviders { + providers: HashMap::from([ + (Provider::OpenAI, Box::new(OpenAI) as Box), + (Provider::Mistral, Box::new(Mistral) as Box), + ]), + } + } +} + +pub struct OpenAI; +impl LlmProvider for OpenAI { + fn translate_request(&self, request: &ChatRequest) -> Result> { + serde_json::to_string(request) + .map(|s| s.into_bytes()) + .map_err(Into::into) + } + + fn translate_response(&self, response: &Vec) -> Result { + serde_json::from_slice(response).map_err(Into::into) + } +} + +pub struct Mistral; +impl LlmProvider for Mistral { + fn translate_request(&self, request: &ChatRequest) -> Result> { + serde_json::to_string(request) + .map(|s| s.into_bytes()) + .map_err(Into::into) + } + + fn translate_response(&self, response: &Vec) -> Result { + serde_json::from_slice(response).map_err(Into::into) + } +} + +#[cfg(test)] +mod test { + #[test] + fn test_translate_request_openai() { + use super::{ChatRequest, LlmProvider, Message, OpenAI, Role}; + let request = ChatRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![ + Message { + role: Role::System, + content: Some("You are a helpful assistant.".to_string()), + model: None, + }, + Message { + role: Role::User, + content: Some("I need help with my computer.".to_string()), + model: None, + }, + ], + temperature: None, + max_tokens: None, + stream: None, + stream_options: None, + }; + let openai = OpenAI; + let result = openai.translate_request(&request).unwrap(); + let expected = r#"{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"I need help with my computer."}]}"#; + assert_eq!(String::from_utf8(result).unwrap(), expected); + } +}