split wasm filter (#186)

* split wasm filter

* fix int and unit tests

* rename public_types => common and move common code there

* rename

* fix int test
This commit is contained in:
Adil Hafeez 2024-10-16 14:20:26 -07:00 committed by GitHub
parent b1746b38b4
commit 3bd2ffe9fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 5755 additions and 351 deletions

1
.dockerignore Normal file
View file

@ -0,0 +1 @@
crates/*/target*

View file

@ -12,13 +12,24 @@ jobs:
steps:
- name: Setup | Checkout
uses: actions/checkout@v4
- name: Setup | Rust
run: rustup toolchain install stable --profile minimal
- name: Setup | Install wasm toolchain
run: rustup target add wasm32-wasi
- name: Build wasm module
run: cd arch && cargo build --release --target=wasm32-wasi
- name: Run Tests on arch
run: cd arch && cargo test
- name: Run Tests on public_types
run: cd public_types && cargo test
- name: Build wasm module for prompt_gateway
run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi
- name: Run Tests on common crate
run: cd crates/common && cargo test
- name: Run Tests on prompt_gateway crate
run: cd crates/prompt_gateway && cargo test
- name: Build wasm module for llm_gateway
run: cd crates/llm_gateway && cargo build --release --target=wasm32-wasi
- name: Run Tests on llm_gateway crate
run: cd crates/llm_gateway && cargo test

3
.gitignore vendored
View file

@ -1,6 +1,4 @@
arch/target
arch/qdrant_data/
public_types/target
/venv/
__pycache__
grafana-data
@ -31,3 +29,4 @@ model_server/build
model_server/dist
arch_logs/
dist/
crates/*/target/

View file

@ -8,23 +8,27 @@ repos:
- id: trailing-whitespace
- repo: local
hooks:
- id: cargo-fmt
name: cargo-fmt
language: system
types: [file, rust]
entry: bash -c "cd arch && cargo fmt -p intelligent-prompt-gateway -- --check"
entry: bash -c "cd crates/llm_gateway && cargo fmt -- --check"
- id: cargo-clippy
name: cargo-clippy
language: system
types: [file, rust]
entry: bash -c "cd arch && cargo clippy -p intelligent-prompt-gateway --all"
entry: bash -c "cd crates/llm_gateway && cargo clippy --all"
- id: cargo-test
name: cargo-test
language: system
types: [file, rust]
# --lib is to only test the library, since when integration tests are made,
# they will be in a seperate tests directory
entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib"
entry: bash -c "cd crates/llm_gateway && cargo test --lib"
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:

View file

@ -2,19 +2,18 @@
FROM rust:1.80.0 as builder
RUN rustup -v target add wasm32-wasi
WORKDIR /arch
COPY arch/src /arch/src
COPY arch/Cargo.toml /arch/
COPY arch/Cargo.lock /arch/
COPY public_types /public_types
COPY crates .
RUN cargo build --release --target wasm32-wasi
RUN cd prompt_gateway && cargo build --release --target wasm32-wasi
RUN cd llm_gateway && cargo build --release --target wasm32-wasi
# copy built filter into envoy image
FROM envoyproxy/envoy:v1.31-latest as envoy
#Build config generator, so that we have a single build image for both Rust and Python
FROM python:3-slim as arch
COPY --from=builder /arch/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
COPY --from=builder /arch/prompt_gateway/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
COPY --from=builder /arch/llm_gateway/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy
WORKDIR /config
COPY arch/requirements.txt .

View file

@ -90,7 +90,7 @@ static_resources:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
filename: "/etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
@ -250,7 +250,7 @@ static_resources:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
filename: "/etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router

View file

@ -1,9 +0,0 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
url: http://prometheus:9090
isDefault: true
access: proxy
editable: true

View file

@ -1,23 +0,0 @@
global:
scrape_interval: 15s
scrape_timeout: 10s
evaluation_interval: 15s
alerting:
alertmanagers:
- static_configs:
- targets: []
scheme: http
timeout: 10s
api_version: v1
scrape_configs:
- job_name: envoy
honor_timestamps: true
scrape_interval: 15s
scrape_timeout: 10s
metrics_path: /stats
scheme: http
static_configs:
- targets:
- envoy:9901
params:
format: ['prometheus']

668
crates/common/Cargo.lock generated Normal file
View file

@ -0,0 +1,668 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "ahash"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
[[package]]
name = "ahash"
version = "0.8.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
[[package]]
name = "anyhow"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6"
[[package]]
name = "autocfg"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bstr"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c"
dependencies = [
"memchr",
"regex-automata",
"serde",
]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "common"
version = "0.1.0"
dependencies = [
"derivative",
"duration-string",
"governor",
"log",
"pretty_assertions",
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",
]
[[package]]
name = "derivative"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
[[package]]
name = "duration-string"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
dependencies = [
"serde",
]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "fancy-regex"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05"
dependencies = [
"bit-set",
"regex",
]
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"no-std-compat",
"nonzero_ext",
"portable-atomic",
"smallvec",
"spinning_top",
]
[[package]]
name = "hashbrown"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
dependencies = [
"ahash 0.3.8",
"autocfg",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash 0.8.11",
"allocator-api2",
]
[[package]]
name = "indexmap"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5"
dependencies = [
"equivalent",
"hashbrown 0.14.5",
]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
dependencies = [
"hashbrown 0.8.2",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "once_cell"
version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "parking_lot"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets",
]
[[package]]
name = "portable-atomic"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [
"zerocopy",
]
[[package]]
name = "pretty_assertions"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
dependencies = [
"diff",
"yansi",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
[[package]]
name = "proxy-wasm"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14a5a4df5a1ab77235e36a0a0f638687ee1586d21ee9774037693001e94d4e11"
dependencies = [
"hashbrown 0.14.5",
"log",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "redox_syscall"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "serde_json"
version = "1.0.128"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "smallvec"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "tiktoken-rs"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234"
dependencies = [
"anyhow",
"base64",
"bstr",
"fancy-regex",
"lazy_static",
"parking_lot",
"rustc-hash",
]
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "zerocopy"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
"byteorder",
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]

View file

@ -1,5 +1,5 @@
[package]
name = "public_types"
name = "common"
version = "0.1.0"
edition = "2021"
@ -7,6 +7,13 @@ edition = "2021"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
duration-string = { version = "0.3.0", features = ["serde"] }
proxy-wasm = "0.2.1"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
log = "0.4"
derivative = "2.2.0"
thiserror = "1.0.64"
tiktoken-rs = "0.5.9"
rand = "0.8.5"
[dev-dependencies]
pretty_assertions = "1.4.1"

View file

@ -1,5 +1,6 @@
use duration_string::DurationString;
use serde::{Deserialize, Deserializer, Serialize};
use std::default;
use std::fmt::Display;
use std::{collections::HashMap, time::Duration};
@ -13,20 +14,15 @@ pub struct Tracing {
pub sampling_rate: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum GatewayMode {
#[serde(rename = "llm")]
Llm,
#[default]
#[serde(rename = "prompt")]
Prompt,
}
impl Default for GatewayMode {
fn default() -> Self {
GatewayMode::Prompt
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -225,9 +221,10 @@ mod test {
#[test]
fn test_deserialize_configuration() {
let ref_config =
fs::read_to_string("../docs/source/resources/includes/arch_config_full_reference.yaml")
.expect("reference config file not found");
let ref_config = fs::read_to_string(
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
)
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
assert_eq!(config.version, "v0.1");
@ -299,10 +296,7 @@ mod test {
let tracing = config.tracing.as_ref().unwrap();
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
let mode = config
.mode
.as_ref()
.unwrap_or(&super::GatewayMode::Prompt);
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
}

12
crates/common/src/lib.rs Normal file
View file

@ -0,0 +1,12 @@
#![allow(unused_imports)]
pub mod common_types;
pub mod configuration;
pub mod consts;
pub mod embeddings;
pub mod http;
pub mod llm_providers;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;

View file

@ -1,4 +1,4 @@
use public_types::configuration::LlmProvider;
use crate::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;

View file

@ -1,7 +1,7 @@
use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use public_types::configuration;
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
@ -398,9 +398,10 @@ fn different_provider_can_have_different_limits_with_the_same_keys() {
// If more tests are written here, move the initial call out of the test.
#[cfg(test)]
mod test {
use crate::configuration;
use super::ratelimits;
use configuration::{Limit, Ratelimit, TimeUnit};
use public_types::configuration;
use std::num::NonZero;
use std::thread;

View file

@ -1,8 +1,8 @@
use std::rc::Rc;
use crate::llm_providers::LlmProviders;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
use log::debug;
use public_types::configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]

View file

@ -217,6 +217,22 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]]
name = "common"
version = "0.1.0"
dependencies = [
"derivative",
"duration-string",
"governor",
"log",
"proxy-wasm",
"rand",
"serde",
"serde_yaml",
"thiserror",
"tiktoken-rs",
]
[[package]]
name = "cpp_demangle"
version = "0.4.4"
@ -753,29 +769,6 @@ dependencies = [
"serde",
]
[[package]]
name = "intelligent-prompt-gateway"
version = "0.1.0"
dependencies = [
"acap",
"derivative",
"governor",
"http",
"log",
"md5",
"proxy-wasm",
"proxy-wasm-test-framework",
"public_types",
"rand",
"serde",
"serde_json",
"serde_yaml",
"serial_test",
"sha2",
"thiserror",
"tiktoken-rs",
]
[[package]]
name = "itertools"
version = "0.12.1"
@ -860,6 +853,28 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "llm_gateway"
version = "0.1.0"
dependencies = [
"acap",
"common",
"derivative",
"governor",
"http",
"log",
"md5",
"proxy-wasm",
"proxy-wasm-test-framework",
"rand",
"serde",
"serde_json",
"serde_yaml",
"serial_test",
"sha2",
"thiserror",
]
[[package]]
name = "lock_api"
version = "0.4.12"
@ -1094,15 +1109,6 @@ dependencies = [
"cc",
]
[[package]]
name = "public_types"
version = "0.1.0"
dependencies = [
"duration-string",
"serde",
"serde_yaml",
]
[[package]]
name = "quote"
version = "1.0.37"
@ -1197,9 +1203,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.6"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
@ -1209,9 +1215,9 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
@ -1220,9 +1226,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rustc-demangle"

View file

@ -1,5 +1,5 @@
[package]
name = "intelligent-prompt-gateway"
name = "llm_gateway"
version = "0.1.0"
authors = ["Katanemo Inc <info@katanemo.com>"]
edition = "2021"
@ -14,10 +14,9 @@ serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
public_types = { path = "../public_types" }
common = { path = "../common" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"

View file

@ -1,22 +1,23 @@
use crate::consts::{
ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL,
MODEL_SERVER_NAME,
};
use crate::http::{CallArgs, Client};
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, IncrementingMetric};
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::MODEL_SERVER_NAME;
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;

View file

@ -2,15 +2,8 @@ use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod consts;
mod filter_context;
mod http;
mod llm_providers;
mod ratelimit;
mod routing;
mod stats;
mod stream_context;
mod tokenizer;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);

View file

@ -1,4 +1,18 @@
use crate::consts::{
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use common::configuration::{GatewayMode, LlmProvider};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER,
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
@ -6,32 +20,18 @@ use crate::consts::{
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::http::{CallArgs, Client, ClientError};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::{ratelimit, routing, tokenizer};
use acap::cos;
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::{CallArgs, Client, ClientError};
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::Gauge;
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use public_types::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::{GatewayMode, LlmProvider};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
@ -40,6 +40,8 @@ use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;
use common::stats::IncrementingMetric;
#[derive(Debug, Clone)]
enum ResponseHandlerType {
GetEmbeddings,
@ -753,10 +755,8 @@ impl StreamContext {
}
}
}
} else {
if let Some(user_message) = callout_context.user_message.as_ref() {
user_messages.push(user_message.clone());
}
} else if let Some(user_message) = callout_context.user_message.as_ref() {
user_messages.push(user_message.clone());
}
let user_messages_str = user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
@ -1280,7 +1280,7 @@ impl HttpContext for StreamContext {
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
.contains_key(&common::configuration::GuardType::Jailbreak);
self.chat_completions_request = Some(deserialized_body);
@ -1570,7 +1570,7 @@ impl Client for StreamContext {
&self.callouts
}
fn active_http_calls(&self) -> &crate::stats::Gauge {
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}

View file

@ -1,23 +1,23 @@
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use public_types::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("target/wasm32-wasi/release/intelligent_prompt_gateway.wasm");
let wasm_file = Path::new("target/wasm32-wasi/release/llm_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasi` first"

2165
crates/prompt_gateway/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,28 @@
[package]
name = "prompt_gateway"
version = "0.1.0"
authors = ["Katanemo Inc <info@katanemo.com>"]
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
proxy-wasm = "0.2.1"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
common = { path = "../common" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
derivative = "2.2.0"
sha2 = "0.10.8"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"

View file

@ -0,0 +1,322 @@
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::MODEL_SERVER_NAME;
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::rc::Rc;
use std::time::Duration;
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
}
}
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
#[derive(Debug)]
pub struct FilterCallContext {
pub prompt_target_name: String,
pub embedding_type: EmbeddingType,
}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
mode: GatewayMode,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
mode: GatewayMode::Prompt,
llm_providers: None,
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
}
}
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,
EmbeddingType::Description,
);
}
}
fn schedule_embeddings_call(
&self,
prompt_target_name: &str,
input: &str,
embedding_type: EmbeddingType,
) {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = serde_json::to_string(&embeddings_input).unwrap();
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
vec![
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", MODEL_SERVER_NAME),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
);
let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}")
}
}
fn embedding_response_handler(
&mut self,
body_size: usize,
embedding_type: EmbeddingType,
prompt_target_name: String,
) {
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap_or_else(|| {
panic!(
"Received embeddings response for unknown prompt target name={}",
prompt_target_name
)
});
let body = self
.get_http_call_response_body(0, body_size)
.expect("No body in response");
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let embeddings = embedding_response.data.remove(0).embedding;
debug!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
let entry = self.temp_embeddings_store.entry(prompt_target_name);
match entry {
Entry::Occupied(_) => {
entry.and_modify(|e| {
if let Entry::Vacant(e) = e.entry(embedding_type) {
e.insert(embeddings);
} else {
panic!(
"Duplicate {:?} for prompt target with name=\"{}\"",
&embedding_type, prompt_target.name
)
}
});
}
Entry::Vacant(_) => {
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
}
}
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
self.embeddings_store =
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
}
}
}
}
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
debug!(
"filter_context: on_http_call_response called with token_id: {:?}",
token_id
);
let callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target_name,
)
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};
self.overrides = Rc::new(config.overrides);
let mut prompt_targets = HashMap::new();
for pt in config.prompt_targets {
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
debug!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
let embedding_store = match self.mode {
GatewayMode::Llm => None,
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
};
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Rc::clone(
self.llm_providers
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
embedding_store,
self.mode.clone(),
)))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
debug!("starting up arch filter in mode: {:?}", self.mode);
if self.mode == GatewayMode::Prompt {
self.process_prompt_targets();
}
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -0,0 +1,13 @@
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(FilterContext::new())
});
}}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,805 @@
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("target/wasm32-wasi/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasi` first"
);
wasm_file.to_str().unwrap().to_string()
}
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_log(Some(LogLevel::Debug), None)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-upstream"),
Some("arch_llm_listener"),
)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
Some("open-ai-gpt-4"),
)
.expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("Authorization"),
Some("Bearer secret_key"),
)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-ratelimit-selector"),
)
.returning(Some("selector-key"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
.returning(Some("selector-value"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/guard"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
let prompt_guard_response = PromptGuardResponse {
toxic_prob: None,
toxic_verdict: None,
jailbreak_prob: None,
jailbreak_verdict: None,
};
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
1,
0,
prompt_guard_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
index: 0,
embedding: vec![],
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
};
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
embeddings_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/zeroshot"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
3,
0,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("x-arch-upstream", "arch_fc"),
(":path", "/v1/chat/completions"),
(":authority", "arch_fc"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "120000"),
]),
None,
None,
None,
)
.returning(Some(4))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
let filter_context = 1;
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
embedding: vec![],
index: 0,
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage {
prompt_tokens: 0,
total_tokens: 0,
}),
};
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
101,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
101
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
fn default_config() -> &'static str {
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn request_ratelimited() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "api_server"),
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(6))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None,
None,
None,
)
.expect_metric_increment("ratelimited_rq", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
#[test]
#[serial]
fn request_not_ratelimited() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
// hallucination should return that parameters were not halliucinated
// prompt: str
// parameters: dict
// model: str
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "api_server"),
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(6))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}

View file

@ -5,8 +5,16 @@
"path": "."
},
{
"name": "arch",
"path": "arch"
"name": "common",
"path": "crates/common"
},
{
"name": "prompt_gateway",
"path": "crates/prompt_gateway"
},
{
"name": "llm_gateway",
"path": "crates/prompt_gateway"
},
{
"name": "arch/tools",

171
public_types/Cargo.lock generated
View file

@ -1,171 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
[[package]]
name = "duration-string"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
dependencies = [
"serde",
]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "indexmap"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5"
dependencies = [
"equivalent",
"hashbrown",
]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "pretty_assertions"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
dependencies = [
"diff",
"yansi",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
[[package]]
name = "public_types"
version = "0.1.0"
dependencies = [
"duration-string",
"pretty_assertions",
"serde",
"serde_json",
"serde_yaml",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "serde"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.128"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "syn"
version = "2.0.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"

View file

@ -1,5 +0,0 @@
#![allow(unused_imports)]
pub mod common_types;
pub mod configuration;
pub mod embeddings;