diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index b0785e1e..9b63840e 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -150,6 +150,11 @@ properties: random_sampling: type: integer additionalProperties: false + mode: + type: string + enum: + - llm + - prompt additionalProperties: false required: - version diff --git a/arch/docker-compose.dev.yaml b/arch/docker-compose.dev.yaml index 33b692bb..36c364bb 100644 --- a/arch/docker-compose.dev.yaml +++ b/arch/docker-compose.dev.yaml @@ -4,6 +4,7 @@ services: ports: - "10000:10000" - "11000:11000" + - "12000:12000" - "19901:9901" volumes: - ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml diff --git a/arch/docker-compose.yaml b/arch/docker-compose.yaml index 582e5a2f..3860fac0 100644 --- a/arch/docker-compose.yaml +++ b/arch/docker-compose.yaml @@ -3,10 +3,12 @@ services: image: archgw:latest ports: - "10000:10000" + - "11000:11000" + - "12000:12000" - "19901:9901" volumes: - ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml - /etc/ssl/cert.pem:/etc/ssl/cert.pem - - ~/archgw_logs/arch_logs:/var/log/ + - ~/archgw_logs:/var/log/ env_file: - stage.env diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 4dba952c..c6bcedba 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -37,7 +37,7 @@ static_resources: - name: envoy.access_loggers.file typed_config: "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog - path: "/var/log/arch_access.log" + path: "/var/log/access_ingress.log" route_config: name: local_routes virtual_hosts: @@ -57,12 +57,22 @@ static_resources: cluster: {{ provider.provider }} timeout: 60s {% endfor %} + - match: + prefix: "/" + headers: + - name: "x-arch-upstream" + string_match: + exact: arch_llm_listener + route: + auto_host_rewrite: true + cluster: arch_llm_listener + timeout: 60s - match: prefix: "/" direct_response: status: 400 body: - inline_string: "x-arch-llm-provider header not set, cannot perform routing\n" + inline_string: "x-arch-llm-provider or x-arch-upstream header not set, cannot perform routing\n" http_filters: - name: envoy.filters.http.wasm typed_config: @@ -71,6 +81,7 @@ static_resources: value: config: name: "http_config" + root_id: prompt_gateway configuration: "@type": "type.googleapis.com/google.protobuf.StringValue" value: | @@ -118,7 +129,7 @@ static_resources: - name: envoy.access_loggers.file typed_config: "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog - path: "/var/log/arch_access_internal.log" + path: "/var/log/access_internal.log" route_config: name: local_routes virtual_hosts: @@ -162,6 +173,88 @@ static_resources: - name: envoy.filters.http.router typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + + - name: arch_listener_llm + address: + socket_address: + address: 0.0.0.0 + port_value: 12000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + {% if arch_tracing.random_sampling > 0 %} + generate_request_id: true + tracing: + provider: + name: envoy.tracers.opentelemetry + typed_config: + "@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig + grpc_service: + envoy_grpc: + cluster_name: opentelemetry_collector + timeout: 0.250s + service_name: arch + random_sampling: + value: {{ arch_tracing.random_sampling }} + {% endif %} + stat_prefix: arch_listener_http + codec_type: AUTO + scheme_header_transformation: + scheme_to_overwrite: https + access_log: + - name: envoy.access_loggers.file + typed_config: + "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog + path: "/var/log/access_llm.log" + route_config: + name: local_routes + virtual_hosts: + - name: local_service + domains: + - "*" + routes: + {% for provider in arch_llm_providers %} + - match: + prefix: "/" + headers: + - name: "x-arch-llm-provider" + string_match: + exact: {{ provider.name }} + route: + auto_host_rewrite: true + cluster: {{ provider.provider }} + timeout: 60s + {% endfor %} + - match: + prefix: "/" + direct_response: + status: 400 + body: + inline_string: "x-arch-llm-provider header not set, cannot perform routing\n" + http_filters: + - name: envoy.filters.http.wasm + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_config" + root_id: llm_gateway + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + {{ arch_llm_config | indent(32) }} + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm" + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: - name: openai connect_timeout: 5s @@ -289,6 +382,22 @@ static_resources: port_value: 11000 hostname: arch_internal + - name: arch_llm_listener + connect_timeout: 5s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: arch_llm_listener + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 0.0.0.0 + port_value: 12000 + hostname: arch_llm_listener + {% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %} - name: opentelemetry_collector type: STRICT_DNS diff --git a/arch/src/consts.rs b/arch/src/consts.rs index 07d38cf8..a3e8e428 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -18,3 +18,4 @@ pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; +pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 491484bb..09314ff5 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -11,7 +11,9 @@ use log::debug; use proxy_wasm::traits::*; use proxy_wasm::types::*; use public_types::common_types::EmbeddingType; -use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; +use public_types::configuration::{ + Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget, +}; use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; @@ -53,6 +55,7 @@ pub struct FilterContext { overrides: Rc>, system_prompt: Rc>, prompt_targets: Rc>, + mode: GatewayMode, prompt_guards: Rc, llm_providers: Option>, embeddings_store: Option>, @@ -68,8 +71,9 @@ impl FilterContext { prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), + mode: GatewayMode::Prompt, llm_providers: None, - embeddings_store: None, + embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), } } @@ -253,6 +257,7 @@ impl RootContext for FilterContext { } self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); + self.mode = config.mode.unwrap_or_default(); ratelimit::ratelimits(config.ratelimits); @@ -275,8 +280,15 @@ impl RootContext for FilterContext { ); // No StreamContext can be created until the Embedding Store is fully initialized. - self.embeddings_store.as_ref()?; - + let embedding_store; + match self.mode { + GatewayMode::Llm => { + embedding_store = None; + } + GatewayMode::Prompt => { + embedding_store = Some(Rc::clone(self.embeddings_store.as_ref().unwrap())) + } + } Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), @@ -289,11 +301,8 @@ impl RootContext for FilterContext { .as_ref() .expect("LLM Providers must exist when Streams are being created"), ), - Rc::clone( - self.embeddings_store - .as_ref() - .expect("Embeddings Store must exist when StreamContext is being constructed"), - ), + embedding_store, + self.mode.clone(), ))) } @@ -307,7 +316,11 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - self.process_prompt_targets(); + debug!("starting up arch filter in mode: {:?}", self.mode); + if self.mode == GatewayMode::Prompt { + self.process_prompt_targets(); + } + self.set_tick_period(Duration::from_secs(0)); } } diff --git a/arch/src/routing.rs b/arch/src/routing.rs index 1f23f383..a372537e 100644 --- a/arch/src/routing.rs +++ b/arch/src/routing.rs @@ -1,9 +1,11 @@ use std::rc::Rc; use crate::llm_providers::LlmProviders; +use log::debug; use public_types::configuration::LlmProvider; use rand::{seq::IteratorRandom, thread_rng}; +#[derive(Debug)] pub enum ProviderHint { Default, Name(String), @@ -32,6 +34,12 @@ pub fn get_llm_provider( return provider; } + if llm_providers.default().is_some() { + debug!("no llm provider found for hint, using default llm provider"); + return llm_providers.default().unwrap(); + } + + debug!("no default llm found, using random llm provider"); let mut rng = thread_rng(); llm_providers .iter() diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 1979a183..0f9f9979 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -1,8 +1,9 @@ use crate::consts::{ - ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, - ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, - ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, + ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, + ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, + DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, + DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; @@ -26,7 +27,7 @@ use public_types::common_types::{ PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use public_types::configuration::LlmProvider; +use public_types::configuration::{GatewayMode, LlmProvider}; use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -93,7 +94,7 @@ pub struct StreamContext { metrics: Rc, system_prompt: Rc>, prompt_targets: Rc>, - embeddings_store: Rc, + embeddings_store: Option>, overrides: Rc>, callouts: RefCell>, tool_calls: Option>, @@ -110,6 +111,7 @@ pub struct StreamContext { llm_providers: Rc, llm_provider: Option>, request_id: Option, + mode: GatewayMode, } impl StreamContext { @@ -122,7 +124,8 @@ impl StreamContext { prompt_guards: Rc, overrides: Rc>, llm_providers: Rc, - embeddings_store: Rc, + embeddings_store: Option>, + mode: GatewayMode, ) -> Self { StreamContext { context_id, @@ -146,6 +149,7 @@ impl StreamContext { prompt_guards, overrides, request_id: None, + mode, } } fn llm_provider(&self) -> &LlmProvider { @@ -154,19 +158,35 @@ impl StreamContext { .expect("the provider should be set when asked for it") } + fn embeddings_store(&self) -> &EmbeddingsStore { + self.embeddings_store + .as_ref() + .expect("embeddings store is not set") + } + fn select_llm_provider(&mut self) { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) .map(|provider_name| provider_name.into()); + debug!("llm provider hint: {:?}", provider_hint); self.llm_provider = Some(routing::get_llm_provider( &self.llm_providers, provider_hint, )); + debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name); } fn add_routing_header(&mut self) { - self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); + match self.mode { + GatewayMode::Prompt => { + // in prompt gateway mode, we need to route to llm upstream listener + self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER); + } + _ => { + self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); + } + } } fn modify_auth_headers(&mut self) -> Result<(), ServerError> { @@ -247,7 +267,7 @@ impl StreamContext { // exclude default prompt target .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) .map(|(prompt_name, _)| { - let pte = match self.embeddings_store.get(prompt_name) { + let pte = match self.embeddings_store().get(prompt_name) { Some(embeddings) => embeddings, None => { warn!( @@ -901,32 +921,37 @@ impl StreamContext { debug!("arch => openai request body: {}", json_string); // Tokenize and Ratelimit. - if let Some(selector) = self.ratelimit_selector.take() { - if let Ok(token_count) = - tokenizer::token_count(&chat_completions_request.model, &json_string) - { - match ratelimit::ratelimits(None).read().unwrap().check_limit( - chat_completions_request.model, - selector, - NonZero::new(token_count as u32).unwrap(), - ) { - Ok(_) => (), - Err(err) => { - self.send_server_error( - ServerError::ExceededRatelimit(err), - Some(StatusCode::TOO_MANY_REQUESTS), - ); - self.metrics.ratelimited_rq.increment(1); - return; - } - } - } + if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) { + self.send_server_error( + ServerError::ExceededRatelimit(e), + Some(StatusCode::TOO_MANY_REQUESTS), + ); + self.metrics.ratelimited_rq.increment(1); + return; } self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); self.resume_http_request(); } + fn enforce_ratelimits( + &mut self, + model: &str, + json_string: &str, + ) -> Result<(), ratelimit::Error> { + if let Some(selector) = self.ratelimit_selector.take() { + // Tokenize and Ratelimit. + if let Ok(token_count) = tokenizer::token_count(model, &json_string) { + ratelimit::ratelimits(None).read().unwrap().check_limit( + model.to_owned(), + selector, + NonZero::new(token_count as u32).unwrap(), + )?; + } + } + Ok(()) + } + fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); @@ -1140,6 +1165,41 @@ impl HttpContext for StreamContext { return Action::Pause; } }; + self.is_chat_completions_request = true; + + if self.mode == GatewayMode::Llm { + debug!("llm gateway mode, skipping over all prompt targets"); + + // remove metadata from the request body + deserialized_body.metadata = None; + // delete model key from message array + for message in deserialized_body.messages.iter_mut() { + message.model = None; + } + deserialized_body + .model + .clone_from(&self.llm_provider.as_ref().unwrap().model); + let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); + + // enforce ratelimits + if let Err(e) = + self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str) + { + self.send_server_error( + ServerError::ExceededRatelimit(e), + Some(StatusCode::TOO_MANY_REQUESTS), + ); + self.metrics.ratelimited_rq.increment(1); + return Action::Continue; + } + + debug!( + "arch => {:?}, body: {}", + deserialized_body.model, chat_completion_request_str + ); + self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); + return Action::Continue; + } self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { @@ -1154,7 +1214,6 @@ impl HttpContext for StreamContext { None => None, }; - self.is_chat_completions_request = true; // Set the model based on the chosen LLM Provider deserialized_body.model = String::from(&self.llm_provider().model); diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 3fcf3fee..06918bdf 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -33,6 +33,12 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { Some("x-arch-llm-provider-hint"), ) .returning(Some("default")) + .expect_log(Some(LogLevel::Debug), None) + .expect_add_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-arch-upstream"), + Some("arch_llm_listener"), + ) .expect_add_header_map_value( Some(MapType::HttpRequestHeaders), Some("x-arch-llm-provider"), @@ -267,6 +273,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("arch_internal"), Some(vec![ diff --git a/arch/tools/config_generator.py b/arch/tools/config_generator.py index fde60526..33741ee9 100644 --- a/arch/tools/config_generator.py +++ b/arch/tools/config_generator.py @@ -76,9 +76,12 @@ def validate_and_render_schema(): arch_llm_providers = config_yaml["llm_providers"] arch_tracing = config_yaml.get("tracing", {}) arch_config_string = yaml.dump(config_yaml) + config_yaml["mode"] = "llm" + arch_llm_config_string = yaml.dump(config_yaml) data = { "arch_config": arch_config_string, + "arch_llm_config": arch_llm_config_string, "arch_clusters": inferred_clusters, "arch_llm_providers": arch_llm_providers, "arch_tracing": arch_tracing, diff --git a/chatbot_ui/.vscode/launch.json b/chatbot_ui/.vscode/launch.json index d08bb1e4..47ee5a58 100644 --- a/chatbot_ui/.vscode/launch.json +++ b/chatbot_ui/.vscode/launch.json @@ -11,6 +11,22 @@ "request": "launch", "program": "run.py", "console": "integratedTerminal", + "env": { + "LLM": "1", + "CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1" + } + }, + { + "name": "chatbot-ui llm", + "cwd": "${workspaceFolder}/app", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "env": { + "LLM": "1", + "CHAT_COMPLETION_ENDPOINT": "http://localhost:12000/v1" + } }, { "name": "chatbot-ui streaming", @@ -19,6 +35,10 @@ "request": "launch", "program": "run_stream.py", "console": "integratedTerminal", + "env": { + "LLM": "1", + "CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1" + } } ] } diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 1fe10e12..f2e85231 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -7,16 +7,12 @@ from dotenv import load_dotenv load_dotenv() -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") -MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") ARCH_STATE_HEADER = "x-arch-state" - log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) client = OpenAI( - api_key=OPENAI_API_KEY, + api_key="--", base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}), ) @@ -31,8 +27,6 @@ def predict(message, state): # Custom headers custom_headers = { - "x-arch-openai-api-key": f"{OPENAI_API_KEY}", - "x-arch-mistral-api-key": f"{MISTRAL_API_KEY}", "x-arch-deterministic-provider": "openai", } @@ -42,7 +36,7 @@ def predict(message, state): try: raw_response = client.chat.completions.with_raw_response.create( - model=MODEL_NAME, + model="--", messages=history, temperature=1.0, metadata=metadata, diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 3371b12b..5bde5dda 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -16,11 +16,15 @@ overrides: prompt_target_intent_matching_threshold: 0.6 llm_providers: - - name: open-ai-gpt-4 + - name: gpt-4o access_key: OPENAI_API_KEY provider: openai model: gpt-4o default: true + - name: mistral-large-latest + access_key: MISTRAL_API_KEY + provider: mistral + model: mistral-large-latest system_prompt: | You are a helpful assistant. diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 34251b13..2611e743 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -15,7 +15,7 @@ services: context: ../../chatbot_ui dockerfile: Dockerfile ports: - - "18090:8080" + - "18080:8080" environment: - OPENAI_API_KEY=${OPENAI_API_KEY:?error} - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} diff --git a/gateway.code-workspace b/gateway.code-workspace index e864caad..617e49ec 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -8,6 +8,10 @@ "name": "arch", "path": "arch" }, + { + "name": "arch/tools", + "path": "arch/tools" + }, { "name": "model_server", "path": "model_server" diff --git a/public_types/src/configuration.rs b/public_types/src/configuration.rs index d7f2e543..74f58ab0 100644 --- a/public_types/src/configuration.rs +++ b/public_types/src/configuration.rs @@ -13,6 +13,20 @@ pub struct Tracing { pub sampling_rate: Option, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum GatewayMode { + #[serde(rename = "llm")] + Llm, + #[serde(rename = "prompt")] + Prompt, +} + +impl Default for GatewayMode { + fn default() -> Self { + GatewayMode::Prompt + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -26,6 +40,7 @@ pub struct Configuration { pub error_target: Option, pub ratelimits: Option>, pub tracing: Option, + pub mode: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -283,5 +298,11 @@ mod test { let tracing = config.tracing.as_ref().unwrap(); assert_eq!(tracing.sampling_rate.unwrap(), 0.1); + + let mode = config + .mode + .as_ref() + .unwrap_or(&super::GatewayMode::Prompt); + assert_eq!(*mode, super::GatewayMode::Prompt); } }