rename envoyfilter => arch (#91)

* rename envoyfilter => arch

* fix more files

* more fixes

* more renames
This commit is contained in:
Adil Hafeez 2024-09-27 16:41:39 -07:00 committed by GitHub
parent 7168b14ed3
commit ea86f73605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 91 additions and 99 deletions

2135
arch/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

26
arch/Cargo.toml Normal file
View file

@ -0,0 +1,26 @@
[package]
name = "intelligent-prompt-gateway"
version = "0.1.0"
authors = ["Katanemo Inc <info@katanemo.com>"]
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
proxy-wasm = "0.2.1"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
public_types = { path = "../public_types" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
serial_test = "3.1.1"

17
arch/Dockerfile Normal file
View file

@ -0,0 +1,17 @@
# build filter using rust toolchain
FROM rust:1.80.0 as builder
RUN rustup -v target add wasm32-wasi
WORKDIR /arch
COPY arch/src /arch/src
COPY arch/Cargo.toml /arch/
COPY arch/Cargo.lock /arch/
COPY public_types /public_types
RUN cargo build --release --target wasm32-wasi
# copy built filter into envoy image
FROM envoyproxy/envoy:v1.30-latest
COPY --from=builder /arch/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--log-level", "debug"]
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--component-log-level", "wasm:debug"]

71
arch/README.md Normal file
View file

@ -0,0 +1,71 @@
# Envoy filter code for gateway
## Add toolchain
```sh
$ rustup target add wasm32-wasi
```
## Building
```sh
$ cargo build --target wasm32-wasi --release
```
## Testing
```sh
$ cargo test
```
## Using in Envoy
This example can be run with [`docker compose`](https://docs.docker.com/compose/install/)
and has a matching Envoy configuration.
```sh
$ docker compose up
```
## Examples
### Direct response.
Send HTTP request to `localhost:10000/hello`:
```sh
$ curl localhost:10000/hello
```
Expected response:
```console
HTTP/1.1 200 OK
content-length: 40
content-type: text/plain
custom-header: katanemo filter
date: Wed, 10 Jul 2024 16:59:43 GMT
server: envoy
```
### Inline call.
Send HTTP request to `localhost:10000/inline`:
```sh
$ curl localhost:10000/hello
{
"headers": {
"Accept": "*/*",
"Host": "localhost",
"User-Agent": "curl/7.81.0",
"X-Amzn-Trace-Id": "Root=1-637c4767-6e31776a0b407a0219b5b570",
"X-Envoy-Expected-Rq-Timeout-Ms": "15000"
}
}
```
Expected Envoy logs:
```console
[...] wasm log http_auth_random: Access granted.
```

3
arch/build_filter.sh Normal file
View file

@ -0,0 +1,3 @@
RUST_VERSION=1.80.0
docker run --rm -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION rustup -v target add wasm32-wasi
docker run --rm -v $PWD/../open-message-format:/code/open-message-format -v ~/.cargo:/root/.cargo -v $(pwd):/code/arch -w /code/arch -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION cargo build --release --target wasm32-wasi

43
arch/docker-compose.yaml Normal file
View file

@ -0,0 +1,43 @@
services:
envoy:
image: envoyproxy/envoy:v1.30-latest
hostname: envoy
ports:
- "10000:10000"
- "19901:9901"
volumes:
- ./envoy.yaml:/etc/envoy/envoy.yaml
- ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
depends_on:
qdrant:
condition: service_started
embeddingserver:
condition: service_healthy
embeddingserver:
build:
context: ../embedding-server
dockerfile: Dockerfile
ports:
- "18080:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
qdrant:
image: qdrant/qdrant
hostname: vector-db
ports:
- 16333:6333
- 16334:6334
chatbot-ui:
build:
context: ../chatbot-ui
dockerfile: Dockerfile
ports:
- "18080:8080"
environment:
- CHAT_COMPLETION_ENDPOINT=http://envoy:10000/v1

View file

@ -0,0 +1 @@
huggingface-cli download TheBloke/Mistral-7B-Instruct-v0.2-GGUF mistral-7b-instruct-v0.2.Q4_K_M.gguf --local-dir . --local-dir-use-symlinks False

192
arch/envoy.template.yaml Normal file
View file

@ -0,0 +1,192 @@
admin:
address:
socket_address: { address: 0.0.0.0, port_value: 9901 }
static_resources:
listeners:
address:
socket_address:
address: 0.0.0.0
port_value: 10000
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: ingress_http
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/arch_access.log"
route_config:
name: local_routes
virtual_hosts:
- name: local_service
domains:
- "*"
routes:
- match:
prefix: "/mistral/v1/chat/completions"
route:
auto_host_rewrite: true
cluster: mistral_7b_instruct
timeout: 60s
- match:
prefix: "/v1/chat/completions"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: openai
route:
auto_host_rewrite: true
cluster: openai
timeout: 60s
- match:
prefix: "/v1/chat/completions"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: mistral
route:
auto_host_rewrite: true
cluster: mistral
timeout: 60s
http_filters:
- name: envoy.filters.http.wasm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ katanemo_config | indent(30) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
# LLM Host
# Embedding Providers
# External LLM Providers
- name: openai
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.openai.com
port_value: 443
hostname: "api.openai.com"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.openai.com
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: mistral
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: mistral
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.mistral.ai
port_value: 443
hostname: "api.mistral.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.mistral.ai
- name: model_server
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: model_server
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: model_server
port_value: 80
hostname: "model_server"
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: mistral_7b_instruct
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: mistral_7b_instruct
port_value: 10001
hostname: "mistral_7b_instruct"
- name: arch_fc
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch_fc
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: function_resolver
port_value: 80
hostname: "arch_fc"
{% for _, cluster in arch_clusters.items() %}
- name: {{ cluster.name }}
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ cluster.name }}
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: {{ cluster.address }}
port_value: {{ cluster.port }}
hostname: {{ cluster.address }}
{% endfor %}

233
arch/envoy.yaml Normal file
View file

@ -0,0 +1,233 @@
admin:
address:
socket_address: { address: 0.0.0.0, port_value: 9901 }
static_resources:
listeners:
address:
socket_address:
address: 0.0.0.0
port_value: 10000
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: ingress_http
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
route_config:
- name: arch
domains:
- "*"
routes:
- match:
headers:
- name: "x-arch-llm-provider"
string_match:
exact: openai
route:
auto_host_rewrite: true
cluster: openai
timeout: 60s
- match:
headers:
- name: "x-arch-llm-provider"
string_match:
exact: mistral
route:
auto_host_rewrite: true
cluster: mistral
timeout: 60s
- match:
prefix: "/embeddings"
route:
cluster: embeddingserver
http_filters:
- name: envoy.filters.http.wasm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000
embedding_provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm_providers:
- name: open-ai-gpt-4
api_key: "$OPEN_AI_API_KEY"
model: gpt-4
- name: mistral_7b_instruct
model: mistral-7b-instruct
endpoint: http://mistral_7b_instruct:10001/v1/chat/completions
default: true
prompt_targets:
- type: context_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
- how is the weather in San Francisco?
- what is the forecast in Seattle?
entities:
- name: city
required: true
- name: days
endpoint:
cluster: weatherhost
path: /weather
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
# LLM Host
# Embedding Providers
# External LLM Providers
- name: openai
connect_timeout: 5s
type: LOGICAL_DNS
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.openai.com
port_value: 443
hostname: "api.openai.com"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.openai.com
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: mistral
connect_timeout: 5s
type: LOGICAL_DNS
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: mistral
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.mistral.ai
port_value: 443
hostname: "api.mistral.ai"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.mistral.ai
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: embeddingserver
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: embeddingserver
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: host.docker.internal
port_value: 8000
hostname: "embeddingserver"
- name: weatherhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: weatherhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: host.docker.internal
port_value: 8000
hostname: "embeddingserver"
- name: nerhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: nerhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: host.docker.internal
port_value: 8000
hostname: "embeddingserver"
- name: qdrant
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: qdrant
port_value: 6333
hostname: "qdrant"
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: mistral_7b_instruct
port_value: 10001
hostname: "mistral_7b_instruct"

View file

@ -0,0 +1,9 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
url: http://prometheus:9090
isDefault: true
access: proxy
editable: true

16
arch/init_vector_store.sh Normal file
View file

@ -0,0 +1,16 @@
#!/bin/sh
echo 'Deleting prompt_vector_store collection'
curl -X DELETE http://localhost:16333/collections/prompt_vector_store
echo
echo 'Creating prompt_vector_store collection'
curl -X PUT 'http://localhost:16333/collections/prompt_vector_store' \
-H 'Content-Type: application/json' \
--data-raw '{
"vectors": {
"size": 1024,
"distance": "Cosine"
}
}'
echo
echo 'Created prompt_vector_store collection'

37
arch/katanemo-config.yaml Normal file
View file

@ -0,0 +1,37 @@
default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000
llm_providers:
- name: "open-ai-gpt-4"
api_key: "$OPEN_AI_API_KEY"
model: gpt-4
prompt_targets:
- type: context_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
- how is the weather in San Francisco?
- what is the forecast in Chicago?
entities:
- name: city
required: true
- name: days
endpoint:
cluster: weatherhost
path: /weather
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
#TODO: add support for adding custom clusters e.g.
# clusters:
# qdrant:
# options:
# - address: "qdrant"
# - address: "weatherhost"
# - port: 6333

View file

@ -0,0 +1,23 @@
global:
scrape_interval: 15s
scrape_timeout: 10s
evaluation_interval: 15s
alerting:
alertmanagers:
- static_configs:
- targets: []
scheme: http
timeout: 10s
api_version: v1
scrape_configs:
- job_name: envoy
honor_timestamps: true
scrape_interval: 15s
scrape_timeout: 10s
metrics_path: /stats
scheme: http
static_configs:
- targets:
- envoy:9901
params:
format: ['prometheus']

11
arch/src/consts.rs Normal file
View file

@ -0,0 +1,11 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARC_FC_CLUSTER: &str = "arch_fc";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";

285
arch/src/filter_context.rs Normal file
View file

@ -0,0 +1,285 @@
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{OnceLock, RwLock};
use std::time::Duration;
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
}
}
}
#[derive(Debug)]
struct CallContext {
prompt_target: String,
embedding_type: EmbeddingType,
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
EMBEDDINGS.get_or_init(|| {
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
RwLock::new(embeddings)
})
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
}
}
fn process_prompt_targets(&mut self) {
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
panic!("Error reading prompt targets: {:?}", e);
}
};
for values in prompt_targets.iter() {
let prompt_target = &values.1;
// schedule embeddings call for prompt target name
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Name,
}
})
.is_some()
{
panic!("duplicate token_id")
}
// schedule embeddings call for prompt target description
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Description,
}
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
fn schedule_embeddings_call(&self, input: String) -> u32 {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(input)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
MODEL_SERVER_NAME,
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", MODEL_SERVER_NAME),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
token_id
}
fn embedding_response_handler(
&mut self,
body_size: usize,
embedding_type: EmbeddingType,
prompt_target_name: String,
) {
let prompt_targets = self.prompt_targets.read().unwrap();
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let embeddings = embedding_response.data.remove(0).embedding;
log::info!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
embeddings_store().write().unwrap().insert(
prompt_target.name.clone(),
HashMap::from([(embedding_type, embeddings)]),
);
}
} else {
panic!("No body in response");
}
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
debug!(
"filter_context: on_http_call_response called with token_id: {:?}",
token_id
);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target,
)
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
debug!("set configuration object");
if let Some(ratelimits_config) = self
.config
.as_mut()
.and_then(|config| config.ratelimits.as_mut())
{
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
}
if let Some(prompt_guards) = self
.config
.as_mut()
.and_then(|config| config.prompt_guards.as_mut())
{
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
}
}
true
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
debug!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
)))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}

19
arch/src/lib.rs Normal file
View file

@ -0,0 +1,19 @@
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod consts;
mod filter_context;
mod llm_providers;
mod ratelimit;
mod routing;
mod stats;
mod stream_context;
mod tokenizer;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(FilterContext::new())
});
}}

47
arch/src/llm_providers.rs Normal file
View file

@ -0,0 +1,47 @@
#[non_exhaustive]
pub struct LlmProviders;
impl LlmProviders {
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "openai",
api_key_header: "x-arch-openai-api-key",
model: "gpt-3.5-turbo",
};
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "mistral",
api_key_header: "x-arch-mistral-api-key",
model: "mistral-large-latest",
};
pub const VARIANTS: &'static [LlmProvider<'static>] =
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
}
pub struct LlmProvider<'prov> {
name: &'prov str,
api_key_header: &'prov str,
model: &'prov str,
}
impl AsRef<str> for LlmProvider<'_> {
fn as_ref(&self) -> &str {
self.name
}
}
impl std::fmt::Display for LlmProvider<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl LlmProvider<'_> {
pub fn api_key_header(&self) -> &str {
self.api_key_header
}
pub fn choose_model(&self) -> &str {
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
self.model
}
}

426
arch/src/ratelimit.rs Normal file
View file

@ -0,0 +1,426 @@
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use public_types::configuration;
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
use std::{collections::HashMap, sync::OnceLock};
pub type RatelimitData = RwLock<RatelimitMap>;
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
RATELIMIT_DATA.get_or_init(|| {
RwLock::new(RatelimitMap::new(
ratelimits_config.expect("The initialization call has to have passed a config"),
))
})
}
// The Data Structure is laid out in the following way:
// Provider -> Hash { Header -> Limit }.
// If the Header used to configure the given Limit:
// a) Has None value, then there will be N Limit keyed by the Header value.
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
pub struct RatelimitMap {
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
}
// This version of Header demands that the user passes a header value to match on.
#[allow(unused)]
#[derive(Debug)]
pub struct Header {
pub key: String,
pub value: String,
}
impl From<Header> for configuration::Header {
fn from(header: Header) -> Self {
Self {
key: header.key,
value: Some(header.value),
}
}
}
impl RatelimitMap {
// n.b new is private so that the only access to the Ratelimits can be done via the static
// reference inside a RwLock via ratelimit::ratelimits().
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
let mut new_ratelimit_map = RatelimitMap {
datastore: HashMap::new(),
};
for ratelimit_config in ratelimits_config {
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
match new_ratelimit_map
.datastore
.get_mut(&ratelimit_config.provider)
{
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
}
None => {
limits.insert(ratelimit_config.selector, limit);
}
},
None => {
// The provider has not been seen before.
// Insert the provider and a new HashMap with the specified limit
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
new_ratelimit_map
.datastore
.insert(ratelimit_config.provider, new_hash_map);
}
}
}
new_ratelimit_map
}
#[allow(unused)]
pub fn check_limit(
&self,
provider: String,
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), String> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);
let provider_limits = match self.datastore.get(&provider) {
None => {
// No limit configured for this provider, hence ok.
return Ok(());
}
Some(limit) => limit,
};
let mut config_selector = configuration::Header::from(selector);
let (limit, limit_key) = match provider_limits.get(&config_selector) {
// This is a specific limit, i.e one that was configured with both key, and value.
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
Some(limit) => (limit, String::from("")),
None => {
// Unwrap is ok here because we _know_ the value exists.
let header_key = config_selector.value.take().unwrap();
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
// value has its own key in the internal limit.
match provider_limits.get(&config_selector) {
Some(limit) => (limit, header_key),
// No limit for that header key, value pair exists within that provider limits.
None => {
return Ok(());
}
}
}
};
match limit.check_key_n(&limit_key, tokens_used) {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) => Err(String::from("Not allowed")),
Err(InsufficientCapacity(_)) => Err(String::from("Not allowed")),
}
}
}
fn get_quota(limit: Limit) -> Quota {
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
match limit.unit {
TimeUnit::Second => Quota::per_second(tokens),
TimeUnit::Minute => Quota::per_minute(tokens),
TimeUnit::Hour => Quota::per_hour(tokens),
}
}
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
// different configuration values per test.
#[test]
fn non_existent_provider_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("non-existent-provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn non_existent_key_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_does_not_catch_non_specific_value() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Second,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("not-the-correct-value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_is_hit() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
}
#[test]
fn non_specific_key_has_different_limits_for_different_values() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
// Value1 takes 50.
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(50).unwrap(),
)
.is_ok());
// value2 takes 60 because it has its own 100 limit
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value2"),
},
NonZero::new(60).unwrap(),
)
.is_ok());
// However value1 cannot take more than 100 per hour which 50+70 = 120
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(70).unwrap(),
)
.is_err())
}
#[test]
fn different_provider_can_have_different_limits_with_the_same_keys() {
let ratelimits_config = vec![
Ratelimit {
provider: String::from("first_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
},
Ratelimit {
provider: String::from("second_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
},
];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(100).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(200).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
}
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
// If more tests are written here, move the initial call out of the test.
#[cfg(test)]
mod test {
use super::ratelimits;
use configuration::{Limit, Ratelimit, TimeUnit};
use public_types::configuration;
use std::num::NonZero;
use std::thread;
#[test]
fn different_threads_have_same_ratelimit_data_structure() {
let ratelimits_config = Some(vec![Ratelimit {
provider: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}]);
// Initialize in the main thread.
ratelimits(ratelimits_config);
// Use the singleton in a different thread.
thread::spawn(|| {
let ratelimits = ratelimits(None);
assert!(ratelimits
.read()
.unwrap()
.check_limit(
String::from("provider"),
super::Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
});
}
}

13
arch/src/routing.rs Normal file
View file

@ -0,0 +1,13 @@
use crate::llm_providers::{LlmProvider, LlmProviders};
use rand::{seq::SliceRandom, thread_rng};
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
if deterministic {
&LlmProviders::OPENAI_PROVIDER
} else {
let mut rng = thread_rng();
LlmProviders::VARIANTS
.choose(&mut rng)
.expect("There should always be at least one llm provider")
}
}

102
arch/src/stats.rs Normal file
View file

@ -0,0 +1,102 @@
use log::error;
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> Result<u64, String> {
match hostcalls::get_metric(self.id()) {
Ok(value) => Ok(value),
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
Err(err) => Err(format!("unexpected status: {:?}", err)),
}
}
}
#[allow(unused)]
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(_) => (),
Err(err) => error!("error incrementing metric: {:?}", err),
}
}
}
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(_) => (),
Err(err) => error!("error recording metric: {:?}", err),
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct Counter {
id: u32,
}
#[allow(unused)]
impl Counter {
pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
.expect("failed to define counter '{}', name");
Counter { id: returned_id }
}
}
impl Metric for Counter {
fn id(&self) -> u32 {
self.id
}
}
impl IncrementingMetric for Counter {}
#[derive(Copy, Clone, Debug)]
pub struct Gauge {
id: u32,
}
impl Gauge {
pub fn new(name: String) -> Gauge {
let returned_id = hostcalls::define_metric(MetricType::Gauge, &name)
.expect("failed to define gauge '{}', name");
Gauge { id: returned_id }
}
}
impl Metric for Gauge {
fn id(&self) -> u32 {
self.id
}
}
/// For state of the world updates
impl RecordingMetric for Gauge {}
/// For offset deltas
impl IncrementingMetric for Gauge {}
#[derive(Copy, Clone)]
pub struct Histogram {
id: u32,
}
#[allow(unused)]
impl Histogram {
pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
.expect("failed to define histogram '{}', name");
Histogram { id: returned_id }
}
}
impl Metric for Histogram {
fn id(&self) -> u32 {
self.id
}
}
impl RecordingMetric for Histogram {}

1098
arch/src/stream_context.rs Normal file

File diff suppressed because it is too large Load diff

39
arch/src/tokenizer.rs Normal file
View file

@ -0,0 +1,39 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
Ok(bpe.encode_ordinary(text).len())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn encode_ordinary() {
let model_name = "gpt-3.5-turbo";
let text = "How many tokens does this sentence have?";
assert_eq!(
8,
token_count(model_name, text).expect("correct tokenization")
);
}
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
token_count("unknown", "").expect_err("unknown model")
)
}
}

582
arch/tests/integration.rs Normal file
View file

@ -0,0 +1,582 @@
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("target/wasm32-wasi/release/intelligent_prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasi` first"
);
wasm_file.to_str().unwrap().to_string()
}
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-deterministic-provider"),
)
.returning(Some("true"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
Some("openai"),
)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-openai-api-key"),
)
.returning(Some("api-key"))
.expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("Authorization"),
Some("Bearer api-key"),
)
.expect_remove_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-openai-api-key"),
)
.expect_remove_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-mistral-api-key"),
)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-ratelimit-selector"),
)
.returning(Some("selector-key"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
.returning(Some("selector-value"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Info), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
index: 0,
embedding: vec![],
object: Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
};
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
1,
0,
embeddings_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("arch_fc"), None, None, None, None)
.returning(Some(3))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn default_config() -> Configuration {
let config: &str = r#"
default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000
llm_providers:
- name: "open-ai-gpt-4"
api_key: "$OPEN_AI_API_KEY"
model: gpt-4
system_prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt_targets:
- type: function_resolver
name: weather_forecast
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
- type: function_resolver
name: weather_forecast_2
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: city
ratelimits:
- provider: gpt-3.5-turbo
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#;
serde_yaml::from_str(config).unwrap()
}
#[test]
#[serial]
fn successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let root_context = 1;
module
.call_proxy_on_context_create(root_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, root_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let root_context = 1;
module
.call_proxy_on_context_create(root_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, root_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn request_ratelimited() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
},
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
model: String::from("test"),
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None,
None,
None,
)
.expect_metric_increment("ratelimited_rq", 1)
.expect_log(
Some(LogLevel::Debug),
Some("server error occurred: Exceeded Ratelimit: Not allowed"),
)
.execute_and_expect(ReturnType::None)
.unwrap();
}
#[test]
#[serial]
fn request_not_ratelimited() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = 1;
let mut config = default_config();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config_str.len() as i32)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config_str))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
},
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
model: String::from("test"),
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}