mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
llm listener split (#155)
This commit is contained in:
parent
8b5db45507
commit
e81ca8d5cf
16 changed files with 305 additions and 54 deletions
|
|
@ -150,6 +150,11 @@ properties:
|
|||
random_sampling:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
mode:
|
||||
type: string
|
||||
enum:
|
||||
- llm
|
||||
- prompt
|
||||
additionalProperties: false
|
||||
required:
|
||||
- version
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ services:
|
|||
ports:
|
||||
- "10000:10000"
|
||||
- "11000:11000"
|
||||
- "12000:12000"
|
||||
- "19901:9901"
|
||||
volumes:
|
||||
- ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ services:
|
|||
image: archgw:latest
|
||||
ports:
|
||||
- "10000:10000"
|
||||
- "11000:11000"
|
||||
- "12000:12000"
|
||||
- "19901:9901"
|
||||
volumes:
|
||||
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml
|
||||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
- ~/archgw_logs/arch_logs:/var/log/
|
||||
- ~/archgw_logs:/var/log/
|
||||
env_file:
|
||||
- stage.env
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ static_resources:
|
|||
- name: envoy.access_loggers.file
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/arch_access.log"
|
||||
path: "/var/log/access_ingress.log"
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
|
|
@ -57,12 +57,22 @@ static_resources:
|
|||
cluster: {{ provider.provider }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-upstream"
|
||||
string_match:
|
||||
exact: arch_llm_listener
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: arch_llm_listener
|
||||
timeout: 60s
|
||||
- match:
|
||||
prefix: "/"
|
||||
direct_response:
|
||||
status: 400
|
||||
body:
|
||||
inline_string: "x-arch-llm-provider header not set, cannot perform routing\n"
|
||||
inline_string: "x-arch-llm-provider or x-arch-upstream header not set, cannot perform routing\n"
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
|
|
@ -71,6 +81,7 @@ static_resources:
|
|||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
root_id: prompt_gateway
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
|
|
@ -118,7 +129,7 @@ static_resources:
|
|||
- name: envoy.access_loggers.file
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/arch_access_internal.log"
|
||||
path: "/var/log/access_internal.log"
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
|
|
@ -162,6 +173,88 @@ static_resources:
|
|||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
|
||||
- name: arch_listener_llm
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
port_value: 12000
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
{% if arch_tracing.random_sampling > 0 %}
|
||||
generate_request_id: true
|
||||
tracing:
|
||||
provider:
|
||||
name: envoy.tracers.opentelemetry
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
|
||||
grpc_service:
|
||||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: arch
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
stat_prefix: arch_listener_http
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
access_log:
|
||||
- name: envoy.access_loggers.file
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/access_llm.log"
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
- name: local_service
|
||||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
{% for provider in arch_llm_providers %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: {{ provider.name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ provider.provider }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
direct_response:
|
||||
status: 400
|
||||
body:
|
||||
inline_string: "x-arch-llm-provider header not set, cannot perform routing\n"
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
root_id: llm_gateway
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
{{ arch_llm_config | indent(32) }}
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
local:
|
||||
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
|
||||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
|
||||
clusters:
|
||||
- name: openai
|
||||
connect_timeout: 5s
|
||||
|
|
@ -289,6 +382,22 @@ static_resources:
|
|||
port_value: 11000
|
||||
hostname: arch_internal
|
||||
|
||||
- name: arch_llm_listener
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: arch_llm_listener
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
port_value: 12000
|
||||
hostname: arch_llm_listener
|
||||
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
- name: opentelemetry_collector
|
||||
type: STRICT_DNS
|
||||
|
|
|
|||
|
|
@ -18,3 +18,4 @@ pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
|||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
|
||||
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ use log::debug;
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::configuration::{
|
||||
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
|
||||
};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
|
|
@ -53,6 +55,7 @@ pub struct FilterContext {
|
|||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
mode: GatewayMode,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
|
|
@ -68,8 +71,9 @@ impl FilterContext {
|
|||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
mode: GatewayMode::Prompt,
|
||||
llm_providers: None,
|
||||
embeddings_store: None,
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -253,6 +257,7 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.mode = config.mode.unwrap_or_default();
|
||||
|
||||
ratelimit::ratelimits(config.ratelimits);
|
||||
|
||||
|
|
@ -275,8 +280,15 @@ impl RootContext for FilterContext {
|
|||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
self.embeddings_store.as_ref()?;
|
||||
|
||||
let embedding_store;
|
||||
match self.mode {
|
||||
GatewayMode::Llm => {
|
||||
embedding_store = None;
|
||||
}
|
||||
GatewayMode::Prompt => {
|
||||
embedding_store = Some(Rc::clone(self.embeddings_store.as_ref().unwrap()))
|
||||
}
|
||||
}
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
|
|
@ -289,11 +301,8 @@ impl RootContext for FilterContext {
|
|||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
Rc::clone(
|
||||
self.embeddings_store
|
||||
.as_ref()
|
||||
.expect("Embeddings Store must exist when StreamContext is being constructed"),
|
||||
),
|
||||
embedding_store,
|
||||
self.mode.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
|
|
@ -307,7 +316,11 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
self.process_prompt_targets();
|
||||
debug!("starting up arch filter in mode: {:?}", self.mode);
|
||||
if self.mode == GatewayMode::Prompt {
|
||||
self.process_prompt_targets();
|
||||
}
|
||||
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use log::debug;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
|
|
@ -32,6 +34,12 @@ pub fn get_llm_provider(
|
|||
return provider;
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
debug!("no llm provider found for hint, using default llm provider");
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
debug!("no default llm found, using random llm provider");
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
|
||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
|
||||
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
|
|
@ -26,7 +27,7 @@ use public_types::common_types::{
|
|||
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{GatewayMode, LlmProvider};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -93,7 +94,7 @@ pub struct StreamContext {
|
|||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -110,6 +111,7 @@ pub struct StreamContext {
|
|||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
mode: GatewayMode,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -122,7 +124,8 @@ impl StreamContext {
|
|||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
mode: GatewayMode,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
|
|
@ -146,6 +149,7 @@ impl StreamContext {
|
|||
prompt_guards,
|
||||
overrides,
|
||||
request_id: None,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -154,19 +158,35 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn embeddings_store(&self) -> &EmbeddingsStore {
|
||||
self.embeddings_store
|
||||
.as_ref()
|
||||
.expect("embeddings store is not set")
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
|
||||
debug!("llm provider hint: {:?}", provider_hint);
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
match self.mode {
|
||||
GatewayMode::Prompt => {
|
||||
// in prompt gateway mode, we need to route to llm upstream listener
|
||||
self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER);
|
||||
}
|
||||
_ => {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
|
|
@ -247,7 +267,7 @@ impl StreamContext {
|
|||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let pte = match self.embeddings_store.get(prompt_name) {
|
||||
let pte = match self.embeddings_store().get(prompt_name) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
|
|
@ -901,32 +921,37 @@ impl StreamContext {
|
|||
debug!("arch => openai request body: {}", json_string);
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
if let Ok(token_count) =
|
||||
tokenizer::token_count(&chat_completions_request.model, &json_string)
|
||||
{
|
||||
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
chat_completions_request.model,
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(err),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn enforce_ratelimits(
|
||||
&mut self,
|
||||
model: &str,
|
||||
json_string: &str,
|
||||
) -> Result<(), ratelimit::Error> {
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
// Tokenize and Ratelimit.
|
||||
if let Ok(token_count) = tokenizer::token_count(model, &json_string) {
|
||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
model.to_owned(),
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
debug!("response received for arch guard");
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
|
|
@ -1140,6 +1165,41 @@ impl HttpContext for StreamContext {
|
|||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
self.is_chat_completions_request = true;
|
||||
|
||||
if self.mode == GatewayMode::Llm {
|
||||
debug!("llm gateway mode, skipping over all prompt targets");
|
||||
|
||||
// remove metadata from the request body
|
||||
deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
}
|
||||
deserialized_body
|
||||
.model
|
||||
.clone_from(&self.llm_provider.as_ref().unwrap().model);
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
// enforce ratelimits
|
||||
if let Err(e) =
|
||||
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
|
||||
{
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model, chat_completion_request_str
|
||||
);
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
|
|
@ -1154,7 +1214,6 @@ impl HttpContext for StreamContext {
|
|||
None => None,
|
||||
};
|
||||
|
||||
self.is_chat_completions_request = true;
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,12 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(Some("default"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-upstream"),
|
||||
Some("arch_llm_listener"),
|
||||
)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
|
|
@ -267,6 +273,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
module
|
||||
.call_proxy_on_tick(filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
|
|
|
|||
|
|
@ -76,9 +76,12 @@ def validate_and_render_schema():
|
|||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
config_yaml["mode"] = "llm"
|
||||
arch_llm_config_string = yaml.dump(config_yaml)
|
||||
|
||||
data = {
|
||||
"arch_config": arch_config_string,
|
||||
"arch_llm_config": arch_llm_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_llm_providers": arch_llm_providers,
|
||||
"arch_tracing": arch_tracing,
|
||||
|
|
|
|||
20
chatbot_ui/.vscode/launch.json
vendored
20
chatbot_ui/.vscode/launch.json
vendored
|
|
@ -11,6 +11,22 @@
|
|||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"env": {
|
||||
"LLM": "1",
|
||||
"CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "chatbot-ui llm",
|
||||
"cwd": "${workspaceFolder}/app",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"env": {
|
||||
"LLM": "1",
|
||||
"CHAT_COMPLETION_ENDPOINT": "http://localhost:12000/v1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "chatbot-ui streaming",
|
||||
|
|
@ -19,6 +35,10 @@
|
|||
"request": "launch",
|
||||
"program": "run_stream.py",
|
||||
"console": "integratedTerminal",
|
||||
"env": {
|
||||
"LLM": "1",
|
||||
"CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,16 +7,12 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}),
|
||||
)
|
||||
|
|
@ -31,8 +27,6 @@ def predict(message, state):
|
|||
|
||||
# Custom headers
|
||||
custom_headers = {
|
||||
"x-arch-openai-api-key": f"{OPENAI_API_KEY}",
|
||||
"x-arch-mistral-api-key": f"{MISTRAL_API_KEY}",
|
||||
"x-arch-deterministic-provider": "openai",
|
||||
}
|
||||
|
||||
|
|
@ -42,7 +36,7 @@ def predict(message, state):
|
|||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(
|
||||
model=MODEL_NAME,
|
||||
model="--",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
|
|
|
|||
|
|
@ -16,11 +16,15 @@ overrides:
|
|||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
- name: gpt-4o
|
||||
access_key: OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
default: true
|
||||
- name: mistral-large-latest
|
||||
access_key: MISTRAL_API_KEY
|
||||
provider: mistral
|
||||
model: mistral-large-latest
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ services:
|
|||
context: ../../chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18090:8080"
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
"name": "arch",
|
||||
"path": "arch"
|
||||
},
|
||||
{
|
||||
"name": "arch/tools",
|
||||
"path": "arch/tools"
|
||||
},
|
||||
{
|
||||
"name": "model_server",
|
||||
"path": "model_server"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,20 @@ pub struct Tracing {
|
|||
pub sampling_rate: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GatewayMode {
|
||||
#[serde(rename = "llm")]
|
||||
Llm,
|
||||
#[serde(rename = "prompt")]
|
||||
Prompt,
|
||||
}
|
||||
|
||||
impl Default for GatewayMode {
|
||||
fn default() -> Self {
|
||||
GatewayMode::Prompt
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
|
|
@ -26,6 +40,7 @@ pub struct Configuration {
|
|||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -283,5 +298,11 @@ mod test {
|
|||
|
||||
let tracing = config.tracing.as_ref().unwrap();
|
||||
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
|
||||
|
||||
let mode = config
|
||||
.mode
|
||||
.as_ref()
|
||||
.unwrap_or(&super::GatewayMode::Prompt);
|
||||
assert_eq!(*mode, super::GatewayMode::Prompt);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue