From 93ea6e1a3d150cdd17fd3ee449a15f3d557668d9 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 16 Oct 2024 10:42:54 -0700 Subject: [PATCH] rename public_types => common and move common code there --- crates/{public_types => common}/Cargo.lock | 318 ++++++++++++- crates/{public_types => common}/Cargo.toml | 4 +- .../src/common_types.rs | 0 .../src/configuration.rs | 22 +- crates/{public_types => common}/src/consts.rs | 0 .../embeddings/create_embedding_request.rs | 0 .../create_embedding_request_input.rs | 0 .../embeddings/create_embedding_response.rs | 0 .../create_embedding_response_usage.rs | 0 .../src/embeddings/embedding.rs | 0 .../src/embeddings/mod.rs | 0 crates/{public_types => common}/src/http.rs | 0 crates/{public_types => common}/src/lib.rs | 6 +- .../src/llm_providers.rs | 2 +- .../src/ratelimit.rs | 7 +- .../{prompt_gateway => common}/src/routing.rs | 4 +- crates/{public_types => common}/src/stats.rs | 0 .../{llm_gateway => common}/src/tokenizer.rs | 0 crates/llm_gateway/Cargo.lock | 45 +- crates/llm_gateway/Cargo.toml | 3 +- crates/llm_gateway/src/filter_context.rs | 34 +- crates/llm_gateway/src/lib.rs | 4 - crates/llm_gateway/src/ratelimit.rs | 450 ------------------ crates/llm_gateway/src/routing.rs | 50 -- crates/llm_gateway/src/stream_context.rs | 34 +- crates/llm_gateway/tests/integration.rs | 18 +- crates/prompt_gateway/Cargo.lock | 45 +- crates/prompt_gateway/Cargo.toml | 3 +- crates/prompt_gateway/src/filter_context.rs | 34 +- crates/prompt_gateway/src/lib.rs | 4 - crates/prompt_gateway/src/llm_providers.rs | 69 --- crates/prompt_gateway/src/stream_context.rs | 34 +- crates/prompt_gateway/src/tokenizer.rs | 39 -- crates/prompt_gateway/tests/integration.rs | 16 +- gateway.code-workspace | 4 +- 35 files changed, 458 insertions(+), 791 deletions(-) rename crates/{public_types => common}/Cargo.lock (56%) rename crates/{public_types => common}/Cargo.toml (88%) rename crates/{public_types => common}/src/common_types.rs (100%) rename crates/{public_types => common}/src/configuration.rs (95%) rename crates/{public_types => common}/src/consts.rs (100%) rename crates/{public_types => common}/src/embeddings/create_embedding_request.rs (100%) rename crates/{public_types => common}/src/embeddings/create_embedding_request_input.rs (100%) rename crates/{public_types => common}/src/embeddings/create_embedding_response.rs (100%) rename crates/{public_types => common}/src/embeddings/create_embedding_response_usage.rs (100%) rename crates/{public_types => common}/src/embeddings/embedding.rs (100%) rename crates/{public_types => common}/src/embeddings/mod.rs (100%) rename crates/{public_types => common}/src/http.rs (100%) rename crates/{public_types => common}/src/lib.rs (63%) rename crates/{llm_gateway => common}/src/llm_providers.rs (97%) rename crates/{prompt_gateway => common}/src/ratelimit.rs (99%) rename crates/{prompt_gateway => common}/src/routing.rs (93%) rename crates/{public_types => common}/src/stats.rs (100%) rename crates/{llm_gateway => common}/src/tokenizer.rs (100%) delete mode 100644 crates/llm_gateway/src/ratelimit.rs delete mode 100644 crates/llm_gateway/src/routing.rs delete mode 100644 crates/prompt_gateway/src/llm_providers.rs delete mode 100644 crates/prompt_gateway/src/tokenizer.rs diff --git a/crates/public_types/Cargo.lock b/crates/common/Cargo.lock similarity index 56% rename from crates/public_types/Cargo.lock rename to crates/common/Cargo.lock index 7073cb20..8bdd2dec 100644 --- a/crates/public_types/Cargo.lock +++ b/crates/common/Cargo.lock @@ -20,24 +20,101 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "anyhow" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" + [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "common" +version = "0.1.0" +dependencies = [ + "derivative", + "duration-string", + "governor", + "log", + "pretty_assertions", + "proxy-wasm", + "rand", + "serde", + "serde_json", + "serde_yaml", + "thiserror", + "tiktoken-rs", +] + [[package]] name = "derivative" version = "2.2.0" @@ -70,6 +147,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fancy-regex" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" +dependencies = [ + "bit-set", + "regex", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "governor" version = "0.6.3" @@ -120,6 +218,18 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" + [[package]] name = "lock_api" version = "0.4.12" @@ -163,12 +273,44 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "portable-atomic" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + [[package]] name = "pretty_assertions" version = "1.4.1" @@ -198,22 +340,6 @@ dependencies = [ "log", ] -[[package]] -name = "public_types" -version = "0.1.0" -dependencies = [ - "derivative", - "duration-string", - "governor", - "log", - "pretty_assertions", - "proxy-wasm", - "serde", - "serde_json", - "serde_yaml", - "thiserror", -] - [[package]] name = "quote" version = "1.0.37" @@ -223,6 +349,80 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "ryu" version = "1.0.18" @@ -337,6 +537,21 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" +dependencies = [ + "anyhow", + "base64", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot", + "rustc-hash", +] + [[package]] name = "unicode-ident" version = "1.0.13" @@ -355,6 +570,76 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "yansi" version = "1.0.1" @@ -367,6 +652,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] diff --git a/crates/public_types/Cargo.toml b/crates/common/Cargo.toml similarity index 88% rename from crates/public_types/Cargo.toml rename to crates/common/Cargo.toml index d4251614..a362da9c 100644 --- a/crates/public_types/Cargo.toml +++ b/crates/common/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "public_types" +name = "common" version = "0.1.0" edition = "2021" @@ -12,6 +12,8 @@ governor = { version = "0.6.3", default-features = false, features = ["no_std"]} log = "0.4" derivative = "2.2.0" thiserror = "1.0.64" +tiktoken-rs = "0.5.9" +rand = "0.8.5" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/public_types/src/common_types.rs b/crates/common/src/common_types.rs similarity index 100% rename from crates/public_types/src/common_types.rs rename to crates/common/src/common_types.rs diff --git a/crates/public_types/src/configuration.rs b/crates/common/src/configuration.rs similarity index 95% rename from crates/public_types/src/configuration.rs rename to crates/common/src/configuration.rs index c8bb72de..63ab156c 100644 --- a/crates/public_types/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,5 +1,6 @@ use duration_string::DurationString; use serde::{Deserialize, Deserializer, Serialize}; +use std::default; use std::fmt::Display; use std::{collections::HashMap, time::Duration}; @@ -13,20 +14,15 @@ pub struct Tracing { pub sampling_rate: Option, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] pub enum GatewayMode { #[serde(rename = "llm")] Llm, + #[default] #[serde(rename = "prompt")] Prompt, } -impl Default for GatewayMode { - fn default() -> Self { - GatewayMode::Prompt - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -225,9 +221,10 @@ mod test { #[test] fn test_deserialize_configuration() { - let ref_config = - fs::read_to_string("../../docs/source/resources/includes/arch_config_full_reference.yaml") - .expect("reference config file not found"); + let ref_config = fs::read_to_string( + "../../docs/source/resources/includes/arch_config_full_reference.yaml", + ) + .expect("reference config file not found"); let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); assert_eq!(config.version, "v0.1"); @@ -299,10 +296,7 @@ mod test { let tracing = config.tracing.as_ref().unwrap(); assert_eq!(tracing.sampling_rate.unwrap(), 0.1); - let mode = config - .mode - .as_ref() - .unwrap_or(&super::GatewayMode::Prompt); + let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt); assert_eq!(*mode, super::GatewayMode::Prompt); } } diff --git a/crates/public_types/src/consts.rs b/crates/common/src/consts.rs similarity index 100% rename from crates/public_types/src/consts.rs rename to crates/common/src/consts.rs diff --git a/crates/public_types/src/embeddings/create_embedding_request.rs b/crates/common/src/embeddings/create_embedding_request.rs similarity index 100% rename from crates/public_types/src/embeddings/create_embedding_request.rs rename to crates/common/src/embeddings/create_embedding_request.rs diff --git a/crates/public_types/src/embeddings/create_embedding_request_input.rs b/crates/common/src/embeddings/create_embedding_request_input.rs similarity index 100% rename from crates/public_types/src/embeddings/create_embedding_request_input.rs rename to crates/common/src/embeddings/create_embedding_request_input.rs diff --git a/crates/public_types/src/embeddings/create_embedding_response.rs b/crates/common/src/embeddings/create_embedding_response.rs similarity index 100% rename from crates/public_types/src/embeddings/create_embedding_response.rs rename to crates/common/src/embeddings/create_embedding_response.rs diff --git a/crates/public_types/src/embeddings/create_embedding_response_usage.rs b/crates/common/src/embeddings/create_embedding_response_usage.rs similarity index 100% rename from crates/public_types/src/embeddings/create_embedding_response_usage.rs rename to crates/common/src/embeddings/create_embedding_response_usage.rs diff --git a/crates/public_types/src/embeddings/embedding.rs b/crates/common/src/embeddings/embedding.rs similarity index 100% rename from crates/public_types/src/embeddings/embedding.rs rename to crates/common/src/embeddings/embedding.rs diff --git a/crates/public_types/src/embeddings/mod.rs b/crates/common/src/embeddings/mod.rs similarity index 100% rename from crates/public_types/src/embeddings/mod.rs rename to crates/common/src/embeddings/mod.rs diff --git a/crates/public_types/src/http.rs b/crates/common/src/http.rs similarity index 100% rename from crates/public_types/src/http.rs rename to crates/common/src/http.rs diff --git a/crates/public_types/src/lib.rs b/crates/common/src/lib.rs similarity index 63% rename from crates/public_types/src/lib.rs rename to crates/common/src/lib.rs index a1d38925..27a51803 100644 --- a/crates/public_types/src/lib.rs +++ b/crates/common/src/lib.rs @@ -2,7 +2,11 @@ pub mod common_types; pub mod configuration; -pub mod embeddings; pub mod consts; +pub mod embeddings; pub mod http; +pub mod llm_providers; +pub mod ratelimit; +pub mod routing; pub mod stats; +pub mod tokenizer; diff --git a/crates/llm_gateway/src/llm_providers.rs b/crates/common/src/llm_providers.rs similarity index 97% rename from crates/llm_gateway/src/llm_providers.rs rename to crates/common/src/llm_providers.rs index 65cd0d04..8214f148 100644 --- a/crates/llm_gateway/src/llm_providers.rs +++ b/crates/common/src/llm_providers.rs @@ -1,4 +1,4 @@ -use public_types::configuration::LlmProvider; +use crate::configuration::LlmProvider; use std::collections::HashMap; use std::rc::Rc; diff --git a/crates/prompt_gateway/src/ratelimit.rs b/crates/common/src/ratelimit.rs similarity index 99% rename from crates/prompt_gateway/src/ratelimit.rs rename to crates/common/src/ratelimit.rs index 83a85e6c..66c3facd 100644 --- a/crates/prompt_gateway/src/ratelimit.rs +++ b/crates/common/src/ratelimit.rs @@ -1,7 +1,7 @@ +use crate::configuration; +use configuration::{Limit, Ratelimit, TimeUnit}; use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; use log::debug; -use public_types::configuration; -use public_types::configuration::{Limit, Ratelimit, TimeUnit}; use std::fmt::Display; use std::num::{NonZero, NonZeroU32}; use std::sync::RwLock; @@ -398,9 +398,10 @@ fn different_provider_can_have_different_limits_with_the_same_keys() { // If more tests are written here, move the initial call out of the test. #[cfg(test)] mod test { + use crate::configuration; + use super::ratelimits; use configuration::{Limit, Ratelimit, TimeUnit}; - use public_types::configuration; use std::num::NonZero; use std::thread; diff --git a/crates/prompt_gateway/src/routing.rs b/crates/common/src/routing.rs similarity index 93% rename from crates/prompt_gateway/src/routing.rs rename to crates/common/src/routing.rs index a372537e..1a440ee9 100644 --- a/crates/prompt_gateway/src/routing.rs +++ b/crates/common/src/routing.rs @@ -1,8 +1,8 @@ use std::rc::Rc; -use crate::llm_providers::LlmProviders; +use crate::{configuration, llm_providers::LlmProviders}; +use configuration::LlmProvider; use log::debug; -use public_types::configuration::LlmProvider; use rand::{seq::IteratorRandom, thread_rng}; #[derive(Debug)] diff --git a/crates/public_types/src/stats.rs b/crates/common/src/stats.rs similarity index 100% rename from crates/public_types/src/stats.rs rename to crates/common/src/stats.rs diff --git a/crates/llm_gateway/src/tokenizer.rs b/crates/common/src/tokenizer.rs similarity index 100% rename from crates/llm_gateway/src/tokenizer.rs rename to crates/common/src/tokenizer.rs diff --git a/crates/llm_gateway/Cargo.lock b/crates/llm_gateway/Cargo.lock index 3e9c51c9..35182863 100644 --- a/crates/llm_gateway/Cargo.lock +++ b/crates/llm_gateway/Cargo.lock @@ -217,6 +217,22 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" +[[package]] +name = "common" +version = "0.1.0" +dependencies = [ + "derivative", + "duration-string", + "governor", + "log", + "proxy-wasm", + "rand", + "serde", + "serde_yaml", + "thiserror", + "tiktoken-rs", +] + [[package]] name = "cpp_demangle" version = "0.4.4" @@ -842,6 +858,7 @@ name = "llm_gateway" version = "0.1.0" dependencies = [ "acap", + "common", "derivative", "governor", "http", @@ -849,7 +866,6 @@ dependencies = [ "md5", "proxy-wasm", "proxy-wasm-test-framework", - "public_types", "rand", "serde", "serde_json", @@ -857,7 +873,6 @@ dependencies = [ "serial_test", "sha2", "thiserror", - "tiktoken-rs", ] [[package]] @@ -1094,20 +1109,6 @@ dependencies = [ "cc", ] -[[package]] -name = "public_types" -version = "0.1.0" -dependencies = [ - "derivative", - "duration-string", - "governor", - "log", - "proxy-wasm", - "serde", - "serde_yaml", - "thiserror", -] - [[package]] name = "quote" version = "1.0.37" @@ -1202,9 +1203,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", @@ -1214,9 +1215,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -1225,9 +1226,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rustc-demangle" diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index b752b888..73d62c3d 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -14,10 +14,9 @@ serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9.34" serde_json = "1.0" md5 = "0.7.0" -public_types = { path = "../public_types" } +common = { path = "../common" } http = "1.1.0" governor = { version = "0.6.3", default-features = false, features = ["no_std"]} -tiktoken-rs = "0.5.9" acap = "0.3.0" rand = "0.8.5" thiserror = "1.0.64" diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 36c7cd57..5d0090a7 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,25 +1,23 @@ -use crate::llm_providers::LlmProviders; -use crate::ratelimit; use crate::stream_context::StreamContext; +use common::common_types::EmbeddingType; +use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; +use common::consts::ARCH_INTERNAL_CLUSTER_NAME; +use common::consts::ARCH_UPSTREAM_HOST_HEADER; +use common::consts::DEFAULT_EMBEDDING_MODEL; +use common::consts::MODEL_SERVER_NAME; +use common::embeddings::{ + CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, +}; +use common::http::CallArgs; +use common::http::Client; +use common::llm_providers::LlmProviders; +use common::ratelimit; +use common::stats::Counter; +use common::stats::Gauge; +use common::stats::IncrementingMetric; use log::debug; use proxy_wasm::traits::*; use proxy_wasm::types::*; -use public_types::common_types::EmbeddingType; -use public_types::configuration::{ - Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget, -}; -use public_types::consts::ARCH_INTERNAL_CLUSTER_NAME; -use public_types::consts::ARCH_UPSTREAM_HOST_HEADER; -use public_types::consts::DEFAULT_EMBEDDING_MODEL; -use public_types::consts::MODEL_SERVER_NAME; -use public_types::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; -use public_types::http::CallArgs; -use public_types::http::Client; -use public_types::stats::Counter; -use public_types::stats::Gauge; -use public_types::stats::IncrementingMetric; use std::cell::RefCell; use std::collections::hash_map::Entry; use std::collections::HashMap; diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index ae2f2545..e2ad9025 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -3,11 +3,7 @@ use proxy_wasm::traits::*; use proxy_wasm::types::*; mod filter_context; -mod llm_providers; -mod ratelimit; -mod routing; mod stream_context; -mod tokenizer; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/llm_gateway/src/ratelimit.rs b/crates/llm_gateway/src/ratelimit.rs deleted file mode 100644 index 83a85e6c..00000000 --- a/crates/llm_gateway/src/ratelimit.rs +++ /dev/null @@ -1,450 +0,0 @@ -use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; -use log::debug; -use public_types::configuration; -use public_types::configuration::{Limit, Ratelimit, TimeUnit}; -use std::fmt::Display; -use std::num::{NonZero, NonZeroU32}; -use std::sync::RwLock; -use std::{collections::HashMap, sync::OnceLock}; - -pub type RatelimitData = RwLock; - -pub fn ratelimits(ratelimits_config: Option>) -> &'static RatelimitData { - static RATELIMIT_DATA: OnceLock = OnceLock::new(); - RATELIMIT_DATA.get_or_init(|| { - RwLock::new(RatelimitMap::new( - ratelimits_config.expect("The initialization call has to have passed a config"), - )) - }) -} - -// The Data Structure is laid out in the following way: -// Provider -> Hash { Header -> Limit }. -// If the Header used to configure the given Limit: -// a) Has None value, then there will be N Limit keyed by the Header value. -// b) Has Some() value, then there will be 1 Limit keyed by the empty string. -// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare. -pub struct RatelimitMap { - datastore: HashMap>>, -} - -// This version of Header demands that the user passes a header value to match on. -#[derive(Debug, Clone)] -pub struct Header { - pub key: String, - pub value: String, -} - -impl Display for Header { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl From
for configuration::Header { - fn from(header: Header) -> Self { - Self { - key: header.key, - value: Some(header.value), - } - } -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")] - ExceededLimit { - provider: String, - selector: Header, - tokens_used: NonZeroU32, - }, -} - -impl RatelimitMap { - // n.b new is private so that the only access to the Ratelimits can be done via the static - // reference inside a RwLock via ratelimit::ratelimits(). - fn new(ratelimits_config: Vec) -> Self { - let mut new_ratelimit_map = RatelimitMap { - datastore: HashMap::new(), - }; - for ratelimit_config in ratelimits_config { - let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit)); - - match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) { - Some(limits) => match limits.get_mut(&ratelimit_config.selector) { - Some(_) => { - panic!("repeated selector. Selectors per provider must be unique") - } - None => { - limits.insert(ratelimit_config.selector, limit); - } - }, - None => { - // The provider has not been seen before. - // Insert the provider and a new HashMap with the specified limit - let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]); - new_ratelimit_map - .datastore - .insert(ratelimit_config.model, new_hash_map); - } - } - } - new_ratelimit_map - } - - #[allow(unused)] - pub fn check_limit( - &self, - provider: String, - selector: Header, - tokens_used: NonZeroU32, - ) -> Result<(), Error> { - debug!( - "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", - provider, selector, tokens_used - ); - - let provider_limits = match self.datastore.get(&provider) { - None => { - // No limit configured for this provider, hence ok. - return Ok(()); - } - Some(limit) => limit, - }; - - let mut config_selector = configuration::Header::from(selector.clone()); - - let (limit, limit_key) = match provider_limits.get(&config_selector) { - // This is a specific limit, i.e one that was configured with both key, and value. - // Therefore, the key for the internal limit does not matter, and hence the empty string is always returned. - Some(limit) => (limit, String::from("")), - None => { - // Unwrap is ok here because we _know_ the value exists. - let header_key = config_selector.value.take().unwrap(); - // Search for less specific limit, i.e, one that was configured without a value, therefore every Header - // value has its own key in the internal limit. - match provider_limits.get(&config_selector) { - Some(limit) => (limit, header_key), - // No limit for that header key, value pair exists within that provider limits. - None => { - return Ok(()); - } - } - } - }; - - match limit.check_key_n(&limit_key, tokens_used) { - Ok(Ok(())) => Ok(()), - Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit { - provider, - selector, - tokens_used, - }), - } - } -} - -fn get_quota(limit: Limit) -> Quota { - let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive"); - match limit.unit { - TimeUnit::Second => Quota::per_second(tokens), - TimeUnit::Minute => Quota::per_minute(tokens), - TimeUnit::Hour => Quota::per_hour(tokens), - } -} - -// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide -// different configuration values per test. -#[test] -fn non_existent_provider_is_ok() { - let ratelimits_config = vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("only-key"), - value: None, - }, - limit: Limit { - tokens: 100, - unit: TimeUnit::Minute, - }, - }]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - assert!(ratelimits - .check_limit( - String::from("non-existent-provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(5000).unwrap(), - ) - .is_ok()) -} - -#[test] -fn non_existent_key_is_ok() { - let ratelimits_config = vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("only-key"), - value: None, - }, - limit: Limit { - tokens: 100, - unit: TimeUnit::Minute, - }, - }]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(5000).unwrap(), - ) - .is_ok()) -} - -#[test] -fn specific_limit_does_not_catch_non_specific_value() { - let ratelimits_config = vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("key"), - value: Some(String::from("value")), - }, - limit: Limit { - tokens: 200, - unit: TimeUnit::Second, - }, - }]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("key"), - value: String::from("not-the-correct-value"), - }, - NonZero::new(5000).unwrap(), - ) - .is_ok()) -} - -#[test] -fn specific_limit_is_hit() { - let ratelimits_config = vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("key"), - value: Some(String::from("value")), - }, - limit: Limit { - tokens: 200, - unit: TimeUnit::Hour, - }, - }]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(5000).unwrap(), - ) - .is_err()) -} - -#[test] -fn non_specific_key_has_different_limits_for_different_values() { - let ratelimits_config = vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("only-key"), - value: None, - }, - limit: Limit { - tokens: 100, - unit: TimeUnit::Hour, - }, - }]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - // Value1 takes 50. - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("only-key"), - value: String::from("value1"), - }, - NonZero::new(50).unwrap(), - ) - .is_ok()); - - // value2 takes 60 because it has its own 100 limit - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("only-key"), - value: String::from("value2"), - }, - NonZero::new(60).unwrap(), - ) - .is_ok()); - - // However value1 cannot take more than 100 per hour which 50+70 = 120 - assert!(ratelimits - .check_limit( - String::from("provider"), - Header { - key: String::from("only-key"), - value: String::from("value1"), - }, - NonZero::new(70).unwrap(), - ) - .is_err()) -} - -#[test] -fn different_provider_can_have_different_limits_with_the_same_keys() { - let ratelimits_config = vec![ - Ratelimit { - model: String::from("first_provider"), - selector: configuration::Header { - key: String::from("key"), - value: Some(String::from("value")), - }, - limit: Limit { - tokens: 100, - unit: TimeUnit::Hour, - }, - }, - Ratelimit { - model: String::from("second_provider"), - selector: configuration::Header { - key: String::from("key"), - value: Some(String::from("value")), - }, - limit: Limit { - tokens: 200, - unit: TimeUnit::Hour, - }, - }, - ]; - - let ratelimits = RatelimitMap::new(ratelimits_config); - - assert!(ratelimits - .check_limit( - String::from("first_provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(100).unwrap(), - ) - .is_ok()); - - assert!(ratelimits - .check_limit( - String::from("second_provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(200).unwrap(), - ) - .is_ok()); - - assert!(ratelimits - .check_limit( - String::from("first_provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(1).unwrap(), - ) - .is_err()); - - assert!(ratelimits - .check_limit( - String::from("second_provider"), - Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(1).unwrap(), - ) - .is_err()); -} - -// These tests use the publicly exposed static singleton, thus the same configuration is used in every test. -// If more tests are written here, move the initial call out of the test. -#[cfg(test)] -mod test { - use super::ratelimits; - use configuration::{Limit, Ratelimit, TimeUnit}; - use public_types::configuration; - use std::num::NonZero; - use std::thread; - - #[test] - fn make_ratelimits_optional() { - let ratelimits_config = Vec::new(); - - // Initialize in the main thread. - ratelimits(Some(ratelimits_config)); - } - - #[test] - fn different_threads_have_same_ratelimit_data_structure() { - let ratelimits_config = Some(vec![Ratelimit { - model: String::from("provider"), - selector: configuration::Header { - key: String::from("key"), - value: Some(String::from("value")), - }, - limit: Limit { - tokens: 200, - unit: TimeUnit::Hour, - }, - }]); - - // Initialize in the main thread. - ratelimits(ratelimits_config); - - // Use the singleton in a different thread. - thread::spawn(|| { - let ratelimits = ratelimits(None); - - assert!(ratelimits - .read() - .unwrap() - .check_limit( - String::from("provider"), - super::Header { - key: String::from("key"), - value: String::from("value"), - }, - NonZero::new(5000).unwrap(), - ) - .is_err()) - }); - } -} diff --git a/crates/llm_gateway/src/routing.rs b/crates/llm_gateway/src/routing.rs deleted file mode 100644 index a372537e..00000000 --- a/crates/llm_gateway/src/routing.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::rc::Rc; - -use crate::llm_providers::LlmProviders; -use log::debug; -use public_types::configuration::LlmProvider; -use rand::{seq::IteratorRandom, thread_rng}; - -#[derive(Debug)] -pub enum ProviderHint { - Default, - Name(String), -} - -impl From for ProviderHint { - fn from(value: String) -> Self { - match value.as_str() { - "default" => ProviderHint::Default, - _ => ProviderHint::Name(value), - } - } -} - -pub fn get_llm_provider( - llm_providers: &LlmProviders, - provider_hint: Option, -) -> Rc { - let maybe_provider = provider_hint.and_then(|hint| match hint { - ProviderHint::Default => llm_providers.default(), - // FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST? - ProviderHint::Name(name) => llm_providers.get(&name), - }); - - if let Some(provider) = maybe_provider { - return provider; - } - - if llm_providers.default().is_some() { - debug!("no llm provider found for hint, using default llm provider"); - return llm_providers.default().unwrap(); - } - - debug!("no default llm found, using random llm provider"); - let mut rng = thread_rng(); - llm_providers - .iter() - .choose(&mut rng) - .expect("There should always be at least one llm provider") - .1 - .clone() -} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9d485dd7..5e4e6149 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,25 +1,18 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics}; -use crate::llm_providers::LlmProviders; -use crate::ratelimit::Header; -use crate::{ratelimit, routing, tokenizer}; use acap::cos; -use http::StatusCode; -use log::{debug, info, warn}; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use public_types::common_types::open_ai::{ +use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, }; -use public_types::common_types::{ +use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use public_types::configuration::{GatewayMode, LlmProvider}; -use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; -use public_types::consts::{ +use common::configuration::{GatewayMode, LlmProvider}; +use common::configuration::{Overrides, PromptGuards, PromptTarget}; +use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, @@ -27,11 +20,18 @@ use public_types::consts::{ DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, }; -use public_types::embeddings::{ +use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use public_types::http::{CallArgs, Client, ClientError}; -use public_types::stats::Gauge; +use common::http::{CallArgs, Client, ClientError}; +use common::llm_providers::LlmProviders; +use common::ratelimit::Header; +use common::stats::Gauge; +use common::{ratelimit, routing, tokenizer}; +use http::StatusCode; +use log::{debug, info, warn}; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; use serde_json::Value; use sha2::{Digest, Sha256}; use std::cell::RefCell; @@ -40,7 +40,7 @@ use std::num::NonZero; use std::rc::Rc; use std::time::Duration; -use public_types::stats::IncrementingMetric; +use common::stats::IncrementingMetric; #[derive(Debug, Clone)] enum ResponseHandlerType { @@ -1280,7 +1280,7 @@ impl HttpContext for StreamContext { let prompt_guard_jailbreak_task = self .prompt_guards .input_guards - .contains_key(&public_types::configuration::GuardType::Jailbreak); + .contains_key(&common::configuration::GuardType::Jailbreak); self.chat_completions_request = Some(deserialized_body); diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 5bc76c66..2e9e984e 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -1,23 +1,23 @@ +use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; +use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; +use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; +use common::embeddings::{ + create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, + Embedding, +}; +use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; -use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; -use public_types::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; -use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serde_yaml::Value; use serial_test::serial; use std::collections::HashMap; use std::path::Path; fn wasm_module() -> String { - let wasm_file = Path::new("target/wasm32-wasi/release/llm_gateway.wasm"); + let wasm_file = Path::new("target/wasm32-wasi/release/prompt_gateway.wasm"); assert!( wasm_file.exists(), "Run `cargo build --release --target=wasm32-wasi` first" diff --git a/crates/prompt_gateway/Cargo.lock b/crates/prompt_gateway/Cargo.lock index 3ed218b0..63de3b3f 100644 --- a/crates/prompt_gateway/Cargo.lock +++ b/crates/prompt_gateway/Cargo.lock @@ -217,6 +217,22 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" +[[package]] +name = "common" +version = "0.1.0" +dependencies = [ + "derivative", + "duration-string", + "governor", + "log", + "proxy-wasm", + "rand", + "serde", + "serde_yaml", + "thiserror", + "tiktoken-rs", +] + [[package]] name = "cpp_demangle" version = "0.4.4" @@ -1043,6 +1059,7 @@ name = "prompt_gateway" version = "0.1.0" dependencies = [ "acap", + "common", "derivative", "governor", "http", @@ -1050,7 +1067,6 @@ dependencies = [ "md5", "proxy-wasm", "proxy-wasm-test-framework", - "public_types", "rand", "serde", "serde_json", @@ -1058,7 +1074,6 @@ dependencies = [ "serial_test", "sha2", "thiserror", - "tiktoken-rs", ] [[package]] @@ -1094,20 +1109,6 @@ dependencies = [ "cc", ] -[[package]] -name = "public_types" -version = "0.1.0" -dependencies = [ - "derivative", - "duration-string", - "governor", - "log", - "proxy-wasm", - "serde", - "serde_yaml", - "thiserror", -] - [[package]] name = "quote" version = "1.0.37" @@ -1202,9 +1203,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", @@ -1214,9 +1215,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -1225,9 +1226,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rustc-demangle" diff --git a/crates/prompt_gateway/Cargo.toml b/crates/prompt_gateway/Cargo.toml index 8d37387a..29d385b7 100644 --- a/crates/prompt_gateway/Cargo.toml +++ b/crates/prompt_gateway/Cargo.toml @@ -14,10 +14,9 @@ serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9.34" serde_json = "1.0" md5 = "0.7.0" -public_types = { path = "../public_types" } +common = { path = "../common" } http = "1.1.0" governor = { version = "0.6.3", default-features = false, features = ["no_std"]} -tiktoken-rs = "0.5.9" acap = "0.3.0" rand = "0.8.5" thiserror = "1.0.64" diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index 36c7cd57..5d0090a7 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,25 +1,23 @@ -use crate::llm_providers::LlmProviders; -use crate::ratelimit; use crate::stream_context::StreamContext; +use common::common_types::EmbeddingType; +use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; +use common::consts::ARCH_INTERNAL_CLUSTER_NAME; +use common::consts::ARCH_UPSTREAM_HOST_HEADER; +use common::consts::DEFAULT_EMBEDDING_MODEL; +use common::consts::MODEL_SERVER_NAME; +use common::embeddings::{ + CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, +}; +use common::http::CallArgs; +use common::http::Client; +use common::llm_providers::LlmProviders; +use common::ratelimit; +use common::stats::Counter; +use common::stats::Gauge; +use common::stats::IncrementingMetric; use log::debug; use proxy_wasm::traits::*; use proxy_wasm::types::*; -use public_types::common_types::EmbeddingType; -use public_types::configuration::{ - Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget, -}; -use public_types::consts::ARCH_INTERNAL_CLUSTER_NAME; -use public_types::consts::ARCH_UPSTREAM_HOST_HEADER; -use public_types::consts::DEFAULT_EMBEDDING_MODEL; -use public_types::consts::MODEL_SERVER_NAME; -use public_types::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; -use public_types::http::CallArgs; -use public_types::http::Client; -use public_types::stats::Counter; -use public_types::stats::Gauge; -use public_types::stats::IncrementingMetric; use std::cell::RefCell; use std::collections::hash_map::Entry; use std::collections::HashMap; diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index ae2f2545..e2ad9025 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -3,11 +3,7 @@ use proxy_wasm::traits::*; use proxy_wasm::types::*; mod filter_context; -mod llm_providers; -mod ratelimit; -mod routing; mod stream_context; -mod tokenizer; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/prompt_gateway/src/llm_providers.rs b/crates/prompt_gateway/src/llm_providers.rs deleted file mode 100644 index 65cd0d04..00000000 --- a/crates/prompt_gateway/src/llm_providers.rs +++ /dev/null @@ -1,69 +0,0 @@ -use public_types::configuration::LlmProvider; -use std::collections::HashMap; -use std::rc::Rc; - -#[derive(Debug)] -pub struct LlmProviders { - providers: HashMap>, - default: Option>, -} - -impl LlmProviders { - pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc> { - self.providers.iter() - } - - pub fn default(&self) -> Option> { - self.default.as_ref().map(|rc| rc.clone()) - } - - pub fn get(&self, name: &str) -> Option> { - self.providers.get(name).cloned() - } -} - -#[derive(thiserror::Error, Debug)] -pub enum LlmProvidersNewError { - #[error("There must be at least one LLM Provider")] - EmptySource, - #[error("There must be at most one default LLM Provider")] - MoreThanOneDefault, - #[error("\'{0}\' is not a unique name")] - DuplicateName(String), -} - -impl TryFrom> for LlmProviders { - type Error = LlmProvidersNewError; - - fn try_from(llm_providers_config: Vec) -> Result { - if llm_providers_config.is_empty() { - return Err(LlmProvidersNewError::EmptySource); - } - - let mut llm_providers = LlmProviders { - providers: HashMap::new(), - default: None, - }; - - for llm_provider in llm_providers_config { - let llm_provider: Rc = Rc::new(llm_provider); - if llm_provider.default.unwrap_or_default() { - match llm_providers.default { - Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault), - None => llm_providers.default = Some(Rc::clone(&llm_provider)), - } - } - - // Insert and check that there is no other provider with the same name. - let name = llm_provider.name.clone(); - if llm_providers - .providers - .insert(name.clone(), llm_provider) - .is_some() - { - return Err(LlmProvidersNewError::DuplicateName(name)); - } - } - Ok(llm_providers) - } -} diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9d485dd7..5e4e6149 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,25 +1,18 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics}; -use crate::llm_providers::LlmProviders; -use crate::ratelimit::Header; -use crate::{ratelimit, routing, tokenizer}; use acap::cos; -use http::StatusCode; -use log::{debug, info, warn}; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use public_types::common_types::open_ai::{ +use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, }; -use public_types::common_types::{ +use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use public_types::configuration::{GatewayMode, LlmProvider}; -use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; -use public_types::consts::{ +use common::configuration::{GatewayMode, LlmProvider}; +use common::configuration::{Overrides, PromptGuards, PromptTarget}; +use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, @@ -27,11 +20,18 @@ use public_types::consts::{ DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, }; -use public_types::embeddings::{ +use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use public_types::http::{CallArgs, Client, ClientError}; -use public_types::stats::Gauge; +use common::http::{CallArgs, Client, ClientError}; +use common::llm_providers::LlmProviders; +use common::ratelimit::Header; +use common::stats::Gauge; +use common::{ratelimit, routing, tokenizer}; +use http::StatusCode; +use log::{debug, info, warn}; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; use serde_json::Value; use sha2::{Digest, Sha256}; use std::cell::RefCell; @@ -40,7 +40,7 @@ use std::num::NonZero; use std::rc::Rc; use std::time::Duration; -use public_types::stats::IncrementingMetric; +use common::stats::IncrementingMetric; #[derive(Debug, Clone)] enum ResponseHandlerType { @@ -1280,7 +1280,7 @@ impl HttpContext for StreamContext { let prompt_guard_jailbreak_task = self .prompt_guards .input_guards - .contains_key(&public_types::configuration::GuardType::Jailbreak); + .contains_key(&common::configuration::GuardType::Jailbreak); self.chat_completions_request = Some(deserialized_body); diff --git a/crates/prompt_gateway/src/tokenizer.rs b/crates/prompt_gateway/src/tokenizer.rs deleted file mode 100644 index 25ac924e..00000000 --- a/crates/prompt_gateway/src/tokenizer.rs +++ /dev/null @@ -1,39 +0,0 @@ -use log::debug; - -#[derive(Debug, PartialEq, Eq)] -#[allow(dead_code)] -pub enum Error { - UnknownModel, - FailedToTokenize, -} - -#[allow(dead_code)] -pub fn token_count(model_name: &str, text: &str) -> Result { - debug!("getting token count model={}", model_name); - // Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton? - let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?; - Ok(bpe.encode_ordinary(text).len()) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn encode_ordinary() { - let model_name = "gpt-3.5-turbo"; - let text = "How many tokens does this sentence have?"; - assert_eq!( - 8, - token_count(model_name, text).expect("correct tokenization") - ); - } - - #[test] - fn unrecognized_model() { - assert_eq!( - Error::UnknownModel, - token_count("unknown", "").expect_err("unknown model") - ) - } -} diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 49f8576b..2e9e984e 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,16 +1,16 @@ +use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; +use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; +use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; +use common::embeddings::{ + create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, + Embedding, +}; +use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; -use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; -use public_types::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; -use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serde_yaml::Value; use serial_test::serial; use std::collections::HashMap; diff --git a/gateway.code-workspace b/gateway.code-workspace index f78dca23..ed15406b 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -5,8 +5,8 @@ "path": "." }, { - "name": "public_types", - "path": "crates/public_types" + "name": "common", + "path": "crates/common" }, { "name": "prompt_gateway",