llm listener split (#155)

This commit is contained in:
Adil Hafeez 2024-10-09 15:47:32 -07:00 committed by GitHub
parent 8b5db45507
commit e81ca8d5cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 305 additions and 54 deletions

View file

@ -150,6 +150,11 @@ properties:
random_sampling:
type: integer
additionalProperties: false
mode:
type: string
enum:
- llm
- prompt
additionalProperties: false
required:
- version

View file

@ -4,6 +4,7 @@ services:
ports:
- "10000:10000"
- "11000:11000"
- "12000:12000"
- "19901:9901"
volumes:
- ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml

View file

@ -3,10 +3,12 @@ services:
image: archgw:latest
ports:
- "10000:10000"
- "11000:11000"
- "12000:12000"
- "19901:9901"
volumes:
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
- ~/archgw_logs/arch_logs:/var/log/
- ~/archgw_logs:/var/log/
env_file:
- stage.env

View file

@ -37,7 +37,7 @@ static_resources:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/arch_access.log"
path: "/var/log/access_ingress.log"
route_config:
name: local_routes
virtual_hosts:
@ -57,12 +57,22 @@ static_resources:
cluster: {{ provider.provider }}
timeout: 60s
{% endfor %}
- match:
prefix: "/"
headers:
- name: "x-arch-upstream"
string_match:
exact: arch_llm_listener
route:
auto_host_rewrite: true
cluster: arch_llm_listener
timeout: 60s
- match:
prefix: "/"
direct_response:
status: 400
body:
inline_string: "x-arch-llm-provider header not set, cannot perform routing\n"
inline_string: "x-arch-llm-provider or x-arch-upstream header not set, cannot perform routing\n"
http_filters:
- name: envoy.filters.http.wasm
typed_config:
@ -71,6 +81,7 @@ static_resources:
value:
config:
name: "http_config"
root_id: prompt_gateway
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
@ -118,7 +129,7 @@ static_resources:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/arch_access_internal.log"
path: "/var/log/access_internal.log"
route_config:
name: local_routes
virtual_hosts:
@ -162,6 +173,88 @@ static_resources:
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
- name: arch_listener_llm
address:
socket_address:
address: 0.0.0.0
port_value: 12000
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
{% if arch_tracing.random_sampling > 0 %}
generate_request_id: true
tracing:
provider:
name: envoy.tracers.opentelemetry
typed_config:
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
grpc_service:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: arch
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
stat_prefix: arch_listener_http
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/access_llm.log"
route_config:
name: local_routes
virtual_hosts:
- name: local_service
domains:
- "*"
routes:
{% for provider in arch_llm_providers %}
- match:
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: {{ provider.name }}
route:
auto_host_rewrite: true
cluster: {{ provider.provider }}
timeout: 60s
{% endfor %}
- match:
prefix: "/"
direct_response:
status: 400
body:
inline_string: "x-arch-llm-provider header not set, cannot perform routing\n"
http_filters:
- name: envoy.filters.http.wasm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
root_id: llm_gateway
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ arch_llm_config | indent(32) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
- name: openai
connect_timeout: 5s
@ -289,6 +382,22 @@ static_resources:
port_value: 11000
hostname: arch_internal
- name: arch_llm_listener
connect_timeout: 5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch_llm_listener
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: 0.0.0.0
port_value: 12000
hostname: arch_llm_listener
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
- name: opentelemetry_collector
type: STRICT_DNS

View file

@ -18,3 +18,4 @@ pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";

View file

@ -11,7 +11,9 @@ use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use public_types::configuration::{
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
@ -53,6 +55,7 @@ pub struct FilterContext {
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
mode: GatewayMode,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
@ -68,8 +71,9 @@ impl FilterContext {
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
mode: GatewayMode::Prompt,
llm_providers: None,
embeddings_store: None,
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
}
}
@ -253,6 +257,7 @@ impl RootContext for FilterContext {
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
ratelimit::ratelimits(config.ratelimits);
@ -275,8 +280,15 @@ impl RootContext for FilterContext {
);
// No StreamContext can be created until the Embedding Store is fully initialized.
self.embeddings_store.as_ref()?;
let embedding_store;
match self.mode {
GatewayMode::Llm => {
embedding_store = None;
}
GatewayMode::Prompt => {
embedding_store = Some(Rc::clone(self.embeddings_store.as_ref().unwrap()))
}
}
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
@ -289,11 +301,8 @@ impl RootContext for FilterContext {
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
Rc::clone(
self.embeddings_store
.as_ref()
.expect("Embeddings Store must exist when StreamContext is being constructed"),
),
embedding_store,
self.mode.clone(),
)))
}
@ -307,7 +316,11 @@ impl RootContext for FilterContext {
}
fn on_tick(&mut self) {
self.process_prompt_targets();
debug!("starting up arch filter in mode: {:?}", self.mode);
if self.mode == GatewayMode::Prompt {
self.process_prompt_targets();
}
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -1,9 +1,11 @@
use std::rc::Rc;
use crate::llm_providers::LlmProviders;
use log::debug;
use public_types::configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
pub enum ProviderHint {
Default,
Name(String),
@ -32,6 +34,12 @@ pub fn get_llm_provider(
return provider;
}
if llm_providers.default().is_some() {
debug!("no llm provider found for hint, using default llm provider");
return llm_providers.default().unwrap();
}
debug!("no default llm found, using random llm provider");
let mut rng = thread_rng();
llm_providers
.iter()

View file

@ -1,8 +1,9 @@
use crate::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
@ -26,7 +27,7 @@ use public_types::common_types::{
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{GatewayMode, LlmProvider};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -93,7 +94,7 @@ pub struct StreamContext {
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
callouts: RefCell<HashMap<u32, StreamCallContext>>,
tool_calls: Option<Vec<ToolCall>>,
@ -110,6 +111,7 @@ pub struct StreamContext {
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
mode: GatewayMode,
}
impl StreamContext {
@ -122,7 +124,8 @@ impl StreamContext {
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
embeddings_store: Rc<EmbeddingsStore>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
mode: GatewayMode,
) -> Self {
StreamContext {
context_id,
@ -146,6 +149,7 @@ impl StreamContext {
prompt_guards,
overrides,
request_id: None,
mode,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -154,19 +158,35 @@ impl StreamContext {
.expect("the provider should be set when asked for it")
}
fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store
.as_ref()
.expect("embeddings store is not set")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
}
fn add_routing_header(&mut self) {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
match self.mode {
GatewayMode::Prompt => {
// in prompt gateway mode, we need to route to llm upstream listener
self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER);
}
_ => {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
}
}
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
@ -247,7 +267,7 @@ impl StreamContext {
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let pte = match self.embeddings_store.get(prompt_name) {
let pte = match self.embeddings_store().get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
@ -901,32 +921,37 @@ impl StreamContext {
debug!("arch => openai request body: {}", json_string);
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
if let Ok(token_count) =
tokenizer::token_count(&chat_completions_request.model, &json_string)
{
match ratelimit::ratelimits(None).read().unwrap().check_limit(
chat_completions_request.model,
selector,
NonZero::new(token_count as u32).unwrap(),
) {
Ok(_) => (),
Err(err) => {
self.send_server_error(
ServerError::ExceededRatelimit(err),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
}
}
if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) {
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.resume_http_request();
}
fn enforce_ratelimits(
&mut self,
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
if let Some(selector) = self.ratelimit_selector.take() {
// Tokenize and Ratelimit.
if let Ok(token_count) = tokenizer::token_count(model, &json_string) {
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
NonZero::new(token_count as u32).unwrap(),
)?;
}
}
Ok(())
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
@ -1140,6 +1165,41 @@ impl HttpContext for StreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
if self.mode == GatewayMode::Llm {
debug!("llm gateway mode, skipping over all prompt targets");
// remove metadata from the request body
deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
}
deserialized_body
.model
.clone_from(&self.llm_provider.as_ref().unwrap().model);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
// enforce ratelimits
if let Err(e) =
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
{
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return Action::Continue;
}
debug!(
"arch => {:?}, body: {}",
deserialized_body.model, chat_completion_request_str
);
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
return Action::Continue;
}
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
@ -1154,7 +1214,6 @@ impl HttpContext for StreamContext {
None => None,
};
self.is_chat_completions_request = true;
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(&self.llm_provider().model);

View file

@ -33,6 +33,12 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_log(Some(LogLevel::Debug), None)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-upstream"),
Some("arch_llm_listener"),
)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
@ -267,6 +273,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![

View file

@ -76,9 +76,12 @@ def validate_and_render_schema():
arch_llm_providers = config_yaml["llm_providers"]
arch_tracing = config_yaml.get("tracing", {})
arch_config_string = yaml.dump(config_yaml)
config_yaml["mode"] = "llm"
arch_llm_config_string = yaml.dump(config_yaml)
data = {
"arch_config": arch_config_string,
"arch_llm_config": arch_llm_config_string,
"arch_clusters": inferred_clusters,
"arch_llm_providers": arch_llm_providers,
"arch_tracing": arch_tracing,

View file

@ -11,6 +11,22 @@
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"env": {
"LLM": "1",
"CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1"
}
},
{
"name": "chatbot-ui llm",
"cwd": "${workspaceFolder}/app",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"env": {
"LLM": "1",
"CHAT_COMPLETION_ENDPOINT": "http://localhost:12000/v1"
}
},
{
"name": "chatbot-ui streaming",
@ -19,6 +35,10 @@
"request": "launch",
"program": "run_stream.py",
"console": "integratedTerminal",
"env": {
"LLM": "1",
"CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1"
}
}
]
}

View file

@ -7,16 +7,12 @@ from dotenv import load_dotenv
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
ARCH_STATE_HEADER = "x-arch-state"
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
client = OpenAI(
api_key=OPENAI_API_KEY,
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}),
)
@ -31,8 +27,6 @@ def predict(message, state):
# Custom headers
custom_headers = {
"x-arch-openai-api-key": f"{OPENAI_API_KEY}",
"x-arch-mistral-api-key": f"{MISTRAL_API_KEY}",
"x-arch-deterministic-provider": "openai",
}
@ -42,7 +36,7 @@ def predict(message, state):
try:
raw_response = client.chat.completions.with_raw_response.create(
model=MODEL_NAME,
model="--",
messages=history,
temperature=1.0,
metadata=metadata,

View file

@ -16,11 +16,15 @@ overrides:
prompt_target_intent_matching_threshold: 0.6
llm_providers:
- name: open-ai-gpt-4
- name: gpt-4o
access_key: OPENAI_API_KEY
provider: openai
model: gpt-4o
default: true
- name: mistral-large-latest
access_key: MISTRAL_API_KEY
provider: mistral
model: mistral-large-latest
system_prompt: |
You are a helpful assistant.

View file

@ -15,7 +15,7 @@ services:
context: ../../chatbot_ui
dockerfile: Dockerfile
ports:
- "18090:8080"
- "18080:8080"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}

View file

@ -8,6 +8,10 @@
"name": "arch",
"path": "arch"
},
{
"name": "arch/tools",
"path": "arch/tools"
},
{
"name": "model_server",
"path": "model_server"

View file

@ -13,6 +13,20 @@ pub struct Tracing {
pub sampling_rate: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum GatewayMode {
#[serde(rename = "llm")]
Llm,
#[serde(rename = "prompt")]
Prompt,
}
impl Default for GatewayMode {
fn default() -> Self {
GatewayMode::Prompt
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -26,6 +40,7 @@ pub struct Configuration {
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -283,5 +298,11 @@ mod test {
let tracing = config.tracing.as_ref().unwrap();
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
let mode = config
.mode
.as_ref()
.unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
}