rename public_types => common and move common code there

This commit is contained in:
Adil Hafeez 2024-10-16 10:42:54 -07:00
parent db202578d3
commit 93ea6e1a3d
35 changed files with 458 additions and 791 deletions

View file

@ -217,6 +217,22 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]]
name = "common"
version = "0.1.0"
dependencies = [
"derivative",
"duration-string",
"governor",
"log",
"proxy-wasm",
"rand",
"serde",
"serde_yaml",
"thiserror",
"tiktoken-rs",
]
[[package]]
name = "cpp_demangle"
version = "0.4.4"
@ -842,6 +858,7 @@ name = "llm_gateway"
version = "0.1.0"
dependencies = [
"acap",
"common",
"derivative",
"governor",
"http",
@ -849,7 +866,6 @@ dependencies = [
"md5",
"proxy-wasm",
"proxy-wasm-test-framework",
"public_types",
"rand",
"serde",
"serde_json",
@ -857,7 +873,6 @@ dependencies = [
"serial_test",
"sha2",
"thiserror",
"tiktoken-rs",
]
[[package]]
@ -1094,20 +1109,6 @@ dependencies = [
"cc",
]
[[package]]
name = "public_types"
version = "0.1.0"
dependencies = [
"derivative",
"duration-string",
"governor",
"log",
"proxy-wasm",
"serde",
"serde_yaml",
"thiserror",
]
[[package]]
name = "quote"
version = "1.0.37"
@ -1202,9 +1203,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.6"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
@ -1214,9 +1215,9 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
@ -1225,9 +1226,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rustc-demangle"

View file

@ -14,10 +14,9 @@ serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
public_types = { path = "../public_types" }
common = { path = "../common" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"

View file

@ -1,25 +1,23 @@
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::MODEL_SERVER_NAME;
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
};
use public_types::consts::ARCH_INTERNAL_CLUSTER_NAME;
use public_types::consts::ARCH_UPSTREAM_HOST_HEADER;
use public_types::consts::DEFAULT_EMBEDDING_MODEL;
use public_types::consts::MODEL_SERVER_NAME;
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use public_types::http::CallArgs;
use public_types::http::Client;
use public_types::stats::Counter;
use public_types::stats::Gauge;
use public_types::stats::IncrementingMetric;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;

View file

@ -3,11 +3,7 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod llm_providers;
mod ratelimit;
mod routing;
mod stream_context;
mod tokenizer;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);

View file

@ -1,69 +0,0 @@
use public_types::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
}
impl LlmProviders {
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
self.providers.iter()
}
pub fn default(&self) -> Option<Rc<LlmProvider>> {
self.default.as_ref().map(|rc| rc.clone())
}
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).cloned()
}
}
#[derive(thiserror::Error, Debug)]
pub enum LlmProvidersNewError {
#[error("There must be at least one LLM Provider")]
EmptySource,
#[error("There must be at most one default LLM Provider")]
MoreThanOneDefault,
#[error("\'{0}\' is not a unique name")]
DuplicateName(String),
}
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
type Error = LlmProvidersNewError;
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
if llm_providers_config.is_empty() {
return Err(LlmProvidersNewError::EmptySource);
}
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
Ok(llm_providers)
}
}

View file

@ -1,450 +0,0 @@
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use public_types::configuration;
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
use std::{collections::HashMap, sync::OnceLock};
pub type RatelimitData = RwLock<RatelimitMap>;
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
RATELIMIT_DATA.get_or_init(|| {
RwLock::new(RatelimitMap::new(
ratelimits_config.expect("The initialization call has to have passed a config"),
))
})
}
// The Data Structure is laid out in the following way:
// Provider -> Hash { Header -> Limit }.
// If the Header used to configure the given Limit:
// a) Has None value, then there will be N Limit keyed by the Header value.
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
pub struct RatelimitMap {
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
}
// This version of Header demands that the user passes a header value to match on.
#[derive(Debug, Clone)]
pub struct Header {
pub key: String,
pub value: String,
}
impl Display for Header {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<Header> for configuration::Header {
fn from(header: Header) -> Self {
Self {
key: header.key,
value: Some(header.value),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")]
ExceededLimit {
provider: String,
selector: Header,
tokens_used: NonZeroU32,
},
}
impl RatelimitMap {
// n.b new is private so that the only access to the Ratelimits can be done via the static
// reference inside a RwLock via ratelimit::ratelimits().
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
let mut new_ratelimit_map = RatelimitMap {
datastore: HashMap::new(),
};
for ratelimit_config in ratelimits_config {
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
}
None => {
limits.insert(ratelimit_config.selector, limit);
}
},
None => {
// The provider has not been seen before.
// Insert the provider and a new HashMap with the specified limit
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
new_ratelimit_map
.datastore
.insert(ratelimit_config.model, new_hash_map);
}
}
}
new_ratelimit_map
}
#[allow(unused)]
pub fn check_limit(
&self,
provider: String,
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), Error> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);
let provider_limits = match self.datastore.get(&provider) {
None => {
// No limit configured for this provider, hence ok.
return Ok(());
}
Some(limit) => limit,
};
let mut config_selector = configuration::Header::from(selector.clone());
let (limit, limit_key) = match provider_limits.get(&config_selector) {
// This is a specific limit, i.e one that was configured with both key, and value.
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
Some(limit) => (limit, String::from("")),
None => {
// Unwrap is ok here because we _know_ the value exists.
let header_key = config_selector.value.take().unwrap();
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
// value has its own key in the internal limit.
match provider_limits.get(&config_selector) {
Some(limit) => (limit, header_key),
// No limit for that header key, value pair exists within that provider limits.
None => {
return Ok(());
}
}
}
};
match limit.check_key_n(&limit_key, tokens_used) {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit {
provider,
selector,
tokens_used,
}),
}
}
}
fn get_quota(limit: Limit) -> Quota {
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
match limit.unit {
TimeUnit::Second => Quota::per_second(tokens),
TimeUnit::Minute => Quota::per_minute(tokens),
TimeUnit::Hour => Quota::per_hour(tokens),
}
}
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
// different configuration values per test.
#[test]
fn non_existent_provider_is_ok() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("non-existent-provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn non_existent_key_is_ok() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Minute,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_does_not_catch_non_specific_value() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Second,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("not-the-correct-value"),
},
NonZero::new(5000).unwrap(),
)
.is_ok())
}
#[test]
fn specific_limit_is_hit() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
}
#[test]
fn non_specific_key_has_different_limits_for_different_values() {
let ratelimits_config = vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
}];
let ratelimits = RatelimitMap::new(ratelimits_config);
// Value1 takes 50.
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(50).unwrap(),
)
.is_ok());
// value2 takes 60 because it has its own 100 limit
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value2"),
},
NonZero::new(60).unwrap(),
)
.is_ok());
// However value1 cannot take more than 100 per hour which 50+70 = 120
assert!(ratelimits
.check_limit(
String::from("provider"),
Header {
key: String::from("only-key"),
value: String::from("value1"),
},
NonZero::new(70).unwrap(),
)
.is_err())
}
#[test]
fn different_provider_can_have_different_limits_with_the_same_keys() {
let ratelimits_config = vec![
Ratelimit {
model: String::from("first_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 100,
unit: TimeUnit::Hour,
},
},
Ratelimit {
model: String::from("second_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
},
];
let ratelimits = RatelimitMap::new(ratelimits_config);
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(100).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(200).unwrap(),
)
.is_ok());
assert!(ratelimits
.check_limit(
String::from("first_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
assert!(ratelimits
.check_limit(
String::from("second_provider"),
Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(1).unwrap(),
)
.is_err());
}
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
// If more tests are written here, move the initial call out of the test.
#[cfg(test)]
mod test {
use super::ratelimits;
use configuration::{Limit, Ratelimit, TimeUnit};
use public_types::configuration;
use std::num::NonZero;
use std::thread;
#[test]
fn make_ratelimits_optional() {
let ratelimits_config = Vec::new();
// Initialize in the main thread.
ratelimits(Some(ratelimits_config));
}
#[test]
fn different_threads_have_same_ratelimit_data_structure() {
let ratelimits_config = Some(vec![Ratelimit {
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
},
limit: Limit {
tokens: 200,
unit: TimeUnit::Hour,
},
}]);
// Initialize in the main thread.
ratelimits(ratelimits_config);
// Use the singleton in a different thread.
thread::spawn(|| {
let ratelimits = ratelimits(None);
assert!(ratelimits
.read()
.unwrap()
.check_limit(
String::from("provider"),
super::Header {
key: String::from("key"),
value: String::from("value"),
},
NonZero::new(5000).unwrap(),
)
.is_err())
});
}
}

View file

@ -1,50 +0,0 @@
use std::rc::Rc;
use crate::llm_providers::LlmProviders;
use log::debug;
use public_types::configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
pub enum ProviderHint {
Default,
Name(String),
}
impl From<String> for ProviderHint {
fn from(value: String) -> Self {
match value.as_str() {
"default" => ProviderHint::Default,
_ => ProviderHint::Name(value),
}
}
}
pub fn get_llm_provider(
llm_providers: &LlmProviders,
provider_hint: Option<ProviderHint>,
) -> Rc<LlmProvider> {
let maybe_provider = provider_hint.and_then(|hint| match hint {
ProviderHint::Default => llm_providers.default(),
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
ProviderHint::Name(name) => llm_providers.get(&name),
});
if let Some(provider) = maybe_provider {
return provider;
}
if llm_providers.default().is_some() {
debug!("no llm provider found for hint, using default llm provider");
return llm_providers.default().unwrap();
}
debug!("no default llm found, using random llm provider");
let mut rng = thread_rng();
llm_providers
.iter()
.choose(&mut rng)
.expect("There should always be at least one llm provider")
.1
.clone()
}

View file

@ -1,25 +1,18 @@
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::{ratelimit, routing, tokenizer};
use acap::cos;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use public_types::common_types::{
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::{GatewayMode, LlmProvider};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::consts::{
use common::configuration::{GatewayMode, LlmProvider};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER,
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
@ -27,11 +20,18 @@ use public_types::consts::{
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use public_types::embeddings::{
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use public_types::http::{CallArgs, Client, ClientError};
use public_types::stats::Gauge;
use common::http::{CallArgs, Client, ClientError};
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::Gauge;
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
@ -40,7 +40,7 @@ use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;
use public_types::stats::IncrementingMetric;
use common::stats::IncrementingMetric;
#[derive(Debug, Clone)]
enum ResponseHandlerType {
@ -1280,7 +1280,7 @@ impl HttpContext for StreamContext {
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
.contains_key(&common::configuration::GuardType::Jailbreak);
self.chat_completions_request = Some(deserialized_body);

View file

@ -1,39 +0,0 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
Ok(bpe.encode_ordinary(text).len())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn encode_ordinary() {
let model_name = "gpt-3.5-turbo";
let text = "How many tokens does this sentence have?";
assert_eq!(
8,
token_count(model_name, text).expect("correct tokenization")
);
}
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
token_count("unknown", "").expect_err("unknown model")
)
}
}

View file

@ -1,23 +1,23 @@
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use public_types::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("target/wasm32-wasi/release/llm_gateway.wasm");
let wasm_file = Path::new("target/wasm32-wasi/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasi` first"