rename public_types => common and move common code there

This commit is contained in:
Adil Hafeez 2024-10-16 10:42:54 -07:00
parent db202578d3
commit 93ea6e1a3d
35 changed files with 458 additions and 791 deletions

668
crates/common/Cargo.lock generated Normal file
View file

@ -0,0 +1,668 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "ahash"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
[[package]]
name = "ahash"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
[[package]]
name = "anyhow"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6"
[[package]]
name = "autocfg"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bstr"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c"
dependencies = [
"memchr",
"regex-automata",
"serde",
]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "common"
version = "0.1.0"
dependencies = [
"derivative",
"duration-string",
"governor",
"log",
"pretty_assertions",
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",
]
[[package]]
name = "derivative"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
[[package]]
name = "duration-string"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
dependencies = [
"serde",
]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "fancy-regex"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05"
dependencies = [
"bit-set",
"regex",
]
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"no-std-compat",
"nonzero_ext",
"portable-atomic",
"smallvec",
"spinning_top",
]
[[package]]
name = "hashbrown"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
dependencies = [
"ahash 0.3.8",
"autocfg",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash 0.8.11",
"allocator-api2",
]
[[package]]
name = "indexmap"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5"
dependencies = [
"equivalent",
"hashbrown 0.14.5",
]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
dependencies = [
"hashbrown 0.8.2",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "once_cell"
version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "parking_lot"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets",
]
[[package]]
name = "portable-atomic"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [
"zerocopy",
]
[[package]]
name = "pretty_assertions"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
dependencies = [
"diff",
"yansi",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
[[package]]
name = "proxy-wasm"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14a5a4df5a1ab77235e36a0a0f638687ee1586d21ee9774037693001e94d4e11"
dependencies = [
"hashbrown 0.14.5",
"log",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "redox_syscall"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "serde_json"
version = "1.0.128"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "smallvec"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "tiktoken-rs"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234"
dependencies = [
"anyhow",
"base64",
"bstr",
"fancy-regex",
"lazy_static",
"parking_lot",
"rustc-hash",
]
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "zerocopy"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
"byteorder",
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]

20
crates/common/Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "common"
version = "0.1.0"
edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
duration-string = { version = "0.3.0", features = ["serde"] }
proxy-wasm = "0.2.1"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
log = "0.4"
derivative = "2.2.0"
thiserror = "1.0.64"
tiktoken-rs = "0.5.9"
rand = "0.8.5"
[dev-dependencies]
pretty_assertions = "1.4.1"
serde_json = "1.0.64"

View file

@ -0,0 +1,448 @@
use crate::configuration::PromptTarget;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
pub mod open_ai {
use std::collections::HashMap;
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"integer" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"boolean" => ParameterType::Bool,
"str" => ParameterType::String,
"string" => ParameterType::String,
"list" => ParameterType::List,
"array" => ParameterType::List,
"dict" => ParameterType::Dict,
"dictionary" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunkResponse {
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
pub content: Option<String>,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,
pub parameters: HashMap<String, String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationResponse {
pub params_scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PromptGuardTask {
#[serde(rename = "jailbreak")]
Jailbreak,
#[serde(rename = "toxicity")]
Toxicity,
#[serde(rename = "both")]
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardRequest {
pub input: String,
pub task: PromptGuardTask,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardResponse {
pub toxic_prob: Option<f64>,
pub jailbreak_prob: Option<f64>,
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}
#[cfg(test)]
mod test {
use crate::common_types::open_ai::Message;
use pretty_assertions::{assert_eq, assert_ne};
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: super::open_ai::FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
}

View file

@ -0,0 +1,302 @@
use duration_string::DurationString;
use serde::{Deserialize, Deserializer, Serialize};
use std::default;
use std::fmt::Display;
use std::{collections::HashMap, time::Duration};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Tracing {
pub sampling_rate: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum GatewayMode {
#[serde(rename = "llm")]
Llm,
#[default]
#[serde(rename = "prompt")]
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: HashMap<String, Endpoint>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Vec<PromptTarget>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorTargetDetail {
pub endpoint: Option<EndpointDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
pub address: String,
pub port: u16,
pub message_format: MessageFormat,
// pub connect_timeout: Option<DurationString>,
}
impl Default for Listener {
fn default() -> Self {
Listener {
address: "".to_string(),
port: 0,
message_format: MessageFormat::default(),
// connect_timeout: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum MessageFormat {
#[serde(rename = "huggingface")]
#[default]
Huggingface,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PromptGuards {
pub input_guards: HashMap<GuardType, GuardOptions>,
}
impl PromptGuards {
pub fn jailbreak_on_exception_message(&self) -> Option<&str> {
self.input_guards
.get(&GuardType::Jailbreak)?
.on_exception
.as_ref()?
.message
.as_ref()?
.as_str()
.into()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum GuardType {
#[serde(rename = "jailbreak")]
Jailbreak,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuardOptions {
pub on_exception: Option<OnExceptionDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnExceptionDetails {
pub forward_to_error_target: Option<bool>,
pub error_handler: Option<String>,
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRatelimit {
pub selector: LlmRatelimitSelector,
pub limit: Limit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRatelimitSelector {
pub http_header: Option<RatelimitHeader>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Header {
pub key: String,
pub value: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ratelimit {
pub model: String,
pub selector: Header,
pub limit: Limit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Limit {
pub tokens: u32,
pub unit: TimeUnit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeUnit {
#[serde(rename = "second")]
Second,
#[serde(rename = "minute")]
Minute,
#[serde(rename = "hour")]
Hour,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct RatelimitHeader {
pub name: String,
pub value: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct EmbeddingProviver {
pub name: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
pub provider: String,
pub access_key: Option<String>,
pub model: String,
pub default: Option<bool>,
pub stream: Option<bool>,
pub rate_limits: Option<LlmRatelimit>,
}
impl Display for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
// pub connect_timeout: Option<DurationString>,
// pub timeout: Option<DurationString>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Parameter {
pub name: String,
#[serde(rename = "type")]
pub parameter_type: Option<String>,
pub description: String,
pub required: Option<bool>,
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointDetails {
pub name: String,
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTarget {
pub name: String,
pub default: Option<bool>,
pub description: String,
pub endpoint: Option<EndpointDetails>,
pub parameters: Option<Vec<Parameter>>,
pub system_prompt: Option<String>,
pub auto_llm_dispatch_on_response: Option<bool>,
}
#[cfg(test)]
mod test {
use std::fs;
use crate::configuration::GuardType;
#[test]
fn test_deserialize_configuration() {
let ref_config = fs::read_to_string(
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
)
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
assert_eq!(config.version, "v0.1");
let open_ai_provider = config
.llm_providers
.iter()
.find(|p| p.name.to_lowercase() == "openai")
.unwrap();
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
assert_eq!(
open_ai_provider.access_key,
Some("OPENAI_API_KEY".to_string())
);
assert_eq!(open_ai_provider.model, "gpt-4o");
assert_eq!(open_ai_provider.default, Some(true));
assert_eq!(open_ai_provider.stream, Some(true));
let prompt_guards = config.prompt_guards.as_ref().unwrap();
let input_guards = &prompt_guards.input_guards;
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
assert_eq!(
jailbreak_guard
.on_exception
.as_ref()
.unwrap()
.forward_to_error_target,
None
);
assert_eq!(
jailbreak_guard.on_exception.as_ref().unwrap().error_handler,
None
);
let prompt_targets = &config.prompt_targets;
assert_eq!(prompt_targets.len(), 2);
let prompt_target = prompt_targets
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
assert_eq!(prompt_target.name, "reboot_network_device");
assert_eq!(prompt_target.default, None);
let prompt_target = prompt_targets
.iter()
.find(|p| p.name == "information_extraction")
.unwrap();
assert_eq!(prompt_target.name, "information_extraction");
assert_eq!(prompt_target.default, Some(true));
assert_eq!(
prompt_target.endpoint.as_ref().unwrap().name,
"app_server".to_string()
);
assert_eq!(
prompt_target.endpoint.as_ref().unwrap().path,
Some("/agent/summary".to_string())
);
let error_target = config.error_target.as_ref().unwrap();
assert_eq!(
error_target.endpoint.as_ref().unwrap().name,
"error_target_1".to_string()
);
assert_eq!(
error_target.endpoint.as_ref().unwrap().path,
Some("/error".to_string())
);
let tracing = config.tracing.as_ref().unwrap();
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
}

View file

@ -0,0 +1,22 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARC_FC_CLUSTER: &str = "arch_fc";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";

View file

@ -0,0 +1,59 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
#[serde(rename = "input")]
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
#[serde(rename = "model")]
pub model: String,
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
pub dimensions: Option<i32>,
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl CreateEmbeddingRequest {
pub fn new(
input: embeddings::CreateEmbeddingRequestInput,
model: String,
) -> CreateEmbeddingRequest {
CreateEmbeddingRequest {
input: Box::new(input),
model,
encoding_format: None,
dimensions: None,
user: None,
}
}
}
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum EncodingFormat {
#[serde(rename = "float")]
Float,
#[serde(rename = "base64")]
Base64,
}
impl Default for EncodingFormat {
fn default() -> EncodingFormat {
Self::Float
}
}

View file

@ -0,0 +1,29 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CreateEmbeddingRequestInput {
/// The string that will be turned into an embedding.
String(String),
/// The array of integers that will be turned into an embedding.
Array(Vec<i32>),
}
impl Default for CreateEmbeddingRequestInput {
fn default() -> Self {
Self::String(Default::default())
}
}

View file

@ -0,0 +1,55 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
/// The list of embeddings generated by the model.
#[serde(rename = "data")]
pub data: Vec<embeddings::Embedding>,
/// The name of the model used to generate the embedding.
#[serde(rename = "model")]
pub model: String,
/// The object type, which is always \"list\".
#[serde(rename = "object")]
pub object: Object,
#[serde(rename = "usage")]
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
}
impl CreateEmbeddingResponse {
pub fn new(
data: Vec<embeddings::Embedding>,
model: String,
object: Object,
usage: embeddings::CreateEmbeddingResponseUsage,
) -> CreateEmbeddingResponse {
CreateEmbeddingResponse {
data,
model,
object,
usage: Box::new(usage),
}
}
}
/// The object type, which is always \"list\".
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "list")]
List,
}
impl Default for Object {
fn default() -> Object {
Self::List
}
}

View file

@ -0,0 +1,33 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
/// CreateEmbeddingResponseUsage : The usage information for the request.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponseUsage {
/// The number of tokens used by the prompt.
#[serde(rename = "prompt_tokens")]
pub prompt_tokens: i32,
/// The total number of tokens used by the request.
#[serde(rename = "total_tokens")]
pub total_tokens: i32,
}
impl CreateEmbeddingResponseUsage {
/// The usage information for the request.
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
CreateEmbeddingResponseUsage {
prompt_tokens,
total_tokens,
}
}
}

View file

@ -0,0 +1,49 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
/// Embedding : Represents an embedding vector returned by embedding endpoint.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct Embedding {
/// The index of the embedding in the list of embeddings.
#[serde(rename = "index")]
pub index: i32,
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
#[serde(rename = "embedding")]
pub embedding: Vec<f64>,
/// The object type, which is always \"embedding\"
#[serde(rename = "object")]
pub object: Object,
}
impl Embedding {
/// Represents an embedding vector returned by embedding endpoint.
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
Embedding {
index,
embedding,
object,
}
}
}
/// The object type, which is always \"embedding\"
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "embedding")]
Embedding,
}
impl Default for Object {
fn default() -> Object {
Self::Embedding
}
}

View file

@ -0,0 +1,10 @@
pub mod create_embedding_request;
pub use self::create_embedding_request::CreateEmbeddingRequest;
pub mod create_embedding_request_input;
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
pub mod create_embedding_response;
pub use self::create_embedding_response::CreateEmbeddingResponse;
pub mod create_embedding_response_usage;
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
pub mod embedding;
pub use self::embedding::Embedding;

93
crates/common/src/http.rs Normal file
View file

@ -0,0 +1,93 @@
use crate::stats::{Gauge, IncrementingMetric};
use derivative::Derivative;
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
use serde::Serialize;
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
#[derive(Derivative, Serialize)]
#[derivative(Debug)]
pub struct CallArgs<'a> {
upstream: &'a str,
path: &'a str,
headers: Vec<(&'a str, &'a str)>,
#[derivative(Debug = "ignore")]
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
}
impl<'a> CallArgs<'a> {
pub fn new(
upstream: &'a str,
path: &'a str,
headers: Vec<(&'a str, &'a str)>,
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
) -> Self {
CallArgs {
upstream,
path,
headers,
body,
trailers,
timeout,
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
pub trait Client: Context {
type CallContext: Debug;
fn http_call(
&self,
call_args: CallArgs,
call_context: Self::CallContext,
) -> Result<u32, ClientError> {
debug!(
"dispatching http call with args={:?} context={:?}",
call_args, call_context
);
match self.dispatch_http_call(
call_args.upstream,
call_args.headers,
call_args.body,
call_args.trailers,
call_args.timeout,
) {
Ok(id) => {
self.add_call_context(id, call_context);
Ok(id)
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
path: String::from(call_args.path),
internal_status: status,
}),
}
}
fn add_call_context(&self, id: u32, call_context: Self::CallContext) {
let callouts = self.callouts();
if callouts.borrow_mut().insert(id, call_context).is_some() {
panic!("Duplicate http call with id={}", id);
}
self.active_http_calls().increment(1);
}
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>>;
fn active_http_calls(&self) -> &Gauge;
}

12
crates/common/src/lib.rs Normal file
View file

@ -0,0 +1,12 @@
#![allow(unused_imports)]
pub mod common_types;
pub mod configuration;
pub mod consts;
pub mod embeddings;
pub mod http;
pub mod llm_providers;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;

View file

@ -0,0 +1,69 @@
use crate::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
}
impl LlmProviders {
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
self.providers.iter()
}
pub fn default(&self) -> Option<Rc<LlmProvider>> {
self.default.as_ref().map(|rc| rc.clone())
}
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).cloned()
}
}
#[derive(thiserror::Error, Debug)]
pub enum LlmProvidersNewError {
#[error("There must be at least one LLM Provider")]
EmptySource,
#[error("There must be at most one default LLM Provider")]
MoreThanOneDefault,
#[error("\'{0}\' is not a unique name")]
DuplicateName(String),
}
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
type Error = LlmProvidersNewError;
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
if llm_providers_config.is_empty() {
return Err(LlmProvidersNewError::EmptySource);
}
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
Ok(llm_providers)
}
}

View file

@ -0,0 +1,451 @@
use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
use std::{collections::HashMap, sync::OnceLock};
pub type RatelimitData = RwLock<RatelimitMap>;
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
RATELIMIT_DATA.get_or_init(|| {
RwLock::new(RatelimitMap::new(
ratelimits_config.expect("The initialization call has to have passed a config"),
))
})
}
// The Data Structure is laid out in the following way:
// Provider -> Hash { Header -> Limit }.
// If the Header used to configure the given Limit:
// a) Has None value, then there will be N Limit keyed by the Header value.
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
pub struct RatelimitMap {
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
}
// This version of Header demands that the user passes a header value to match on.
#[derive(Debug, Clone)]
pub struct Header {
pub key: String,
pub value: String,
}
impl Display for Header {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<Header> for configuration::Header {
fn from(header: Header) -> Self {
Self {
key: header.key,
value: Some(header.value),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")]
ExceededLimit {
provider: String,
selector: Header,
tokens_used: NonZeroU32,
},
}
impl RatelimitMap {
// n.b new is private so that the only access to the Ratelimits can be done via the static
// reference inside a RwLock via ratelimit::ratelimits().
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
let mut new_ratelimit_map = RatelimitMap {
datastore: HashMap::new(),
};
for ratelimit_config in ratelimits_config {
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
}
None => {
limits.insert(ratelimit_config.selector, limit);
}
},
None => {
// The provider has not been seen before.
// Insert the provider and a new HashMap with the specified limit
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
new_ratelimit_map
.datastore
.insert(ratelimit_config.model, new_hash_map);
}
}
}
new_ratelimit_map
}
#[allow(unused)]
pub fn check_limit(
&self,
provider: String,
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), Error> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);
let provider_limits = match self.datastore.get(&provider) {
None => {
// No limit configured for this provider, hence ok.
return Ok(());
}
Some(limit) => limit,
};
let mut config_selector = configuration::Header::from(selector.clone());
let (limit, limit_key) = match provider_limits.get(&config_selector) {
// This is a specific limit, i.e one that was configured with both key, and value.
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
Some(limit) => (limit, String::from("")),
None => {
// Unwrap is ok here because we _know_ the value exists.
let header_key = config_selector.value.take().unwrap();
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
// value has its own key in the internal limit.
match provider_limits.get(&config_selector) {
Some(limit) => (limit, header_key),
// No limit for that header key, value pair exists within that provider limits.
None => {
return Ok(());
}
}
}
};
match limit.check_key_n(&limit_key, tokens_used) {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit {
provider,
selector,
tokens_used,
}),
}
}
}
fn get_quota(limit: Limit) -> Quota {
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
match limit.unit {
TimeUnit::Second => Quota::per_second(tokens),
TimeUnit::Minute => Quota::per_minute(tokens),
TimeUnit::Hour => Quota::per_hour(tokens),
}
}
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
// different configuration values per test.
#[test]
fn non_existent_provider_is_ok() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("non-existent-provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn non_existent_key_is_ok() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_does_not_catch_non_specific_value() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Second,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("not-the-correct-value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_is_hit() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
}
#[test]
fn non_specific_key_has_different_limits_for_different_values() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
// Value1 takes 50.
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(50).unwrap(),
)
.is_ok());
// value2 takes 60 because it has its own 100 limit
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value2"),
},
NonZero::new(60).unwrap(),
)
.is_ok());
// However value1 cannot take more than 100 per hour which 50+70 = 120
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(70).unwrap(),
)
.is_err())
}
#[test]
fn different_provider_can_have_different_limits_with_the_same_keys() {
let ratelimits_config = vec![
Ratelimit {
model: String::from("first_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
},
Ratelimit {
model: String::from("second_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
},
];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(100).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(200).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
}
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
// If more tests are written here, move the initial call out of the test.
#[cfg(test)]
mod test {
use crate::configuration;
use super::ratelimits;
use configuration::{Limit, Ratelimit, TimeUnit};
use std::num::NonZero;
use std::thread;
#[test]
fn make_ratelimits_optional() {
let ratelimits_config = Vec::new();
// Initialize in the main thread.
ratelimits(Some(ratelimits_config));
}
#[test]
fn different_threads_have_same_ratelimit_data_structure() {
let ratelimits_config = Some(vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}]);
// Initialize in the main thread.
ratelimits(ratelimits_config);
// Use the singleton in a different thread.
thread::spawn(|| {
let ratelimits = ratelimits(None);
assert!(ratelimits
.read()
.unwrap()
.check_limit(
String::from("provider"),
super::Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
});
}
}

View file

@ -0,0 +1,50 @@
use std::rc::Rc;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
use log::debug;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
pub enum ProviderHint {
Default,
Name(String),
}
impl From<String> for ProviderHint {
fn from(value: String) -> Self {
match value.as_str() {
"default" => ProviderHint::Default,
_ => ProviderHint::Name(value),
}
}
}
pub fn get_llm_provider(
llm_providers: &LlmProviders,
provider_hint: Option<ProviderHint>,
) -> Rc<LlmProvider> {
let maybe_provider = provider_hint.and_then(|hint| match hint {
ProviderHint::Default => llm_providers.default(),
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
ProviderHint::Name(name) => llm_providers.get(&name),
});
if let Some(provider) = maybe_provider {
return provider;
}
if llm_providers.default().is_some() {
debug!("no llm provider found for hint, using default llm provider");
return llm_providers.default().unwrap();
}
debug!("no default llm found, using random llm provider");
let mut rng = thread_rng();
llm_providers
.iter()
.choose(&mut rng)
.expect("There should always be at least one llm provider")
.1
.clone()
}

103
crates/common/src/stats.rs Normal file
View file

@ -0,0 +1,103 @@
use log::error;
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> Result<u64, String> {
match hostcalls::get_metric(self.id()) {
Ok(value) => Ok(value),
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
Err(err) => Err(format!("unexpected status: {:?}", err)),
}
}
}
#[allow(unused)]
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(_) => (),
Err(err) => error!("error incrementing metric: {:?}", err),
}
}
}
#[allow(unused)]
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(_) => (),
Err(err) => error!("error recording metric: {:?}", err),
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct Counter {
id: u32,
}
#[allow(unused)]
impl Counter {
pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
.expect("failed to define counter '{}', name");
Counter { id: returned_id }
}
}
impl Metric for Counter {
fn id(&self) -> u32 {
self.id
}
}
impl IncrementingMetric for Counter {}
#[derive(Copy, Clone, Debug)]
pub struct Gauge {
id: u32,
}
impl Gauge {
pub fn new(name: String) -> Gauge {
let returned_id = hostcalls::define_metric(MetricType::Gauge, &name)
.expect("failed to define gauge '{}', name");
Gauge { id: returned_id }
}
}
impl Metric for Gauge {
fn id(&self) -> u32 {
self.id
}
}
/// For state of the world updates
impl RecordingMetric for Gauge {}
/// For offset deltas
impl IncrementingMetric for Gauge {}
#[derive(Copy, Clone)]
pub struct Histogram {
id: u32,
}
#[allow(unused)]
impl Histogram {
pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
.expect("failed to define histogram '{}', name");
Histogram { id: returned_id }
}
}
impl Metric for Histogram {
fn id(&self) -> u32 {
self.id
}
}
impl RecordingMetric for Histogram {}

View file

@ -0,0 +1,39 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
Ok(bpe.encode_ordinary(text).len())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn encode_ordinary() {
let model_name = "gpt-3.5-turbo";
let text = "How many tokens does this sentence have?";
assert_eq!(
8,
token_count(model_name, text).expect("correct tokenization")
);
}
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
token_count("unknown", "").expect_err("unknown model")
)
}
}