diff --git a/api_llm_gateway.rest b/api_llm_gateway.rest
new file mode 100644
index 00000000..b40c229b
--- /dev/null
+++ b/api_llm_gateway.rest
@@ -0,0 +1,76 @@
+@llm_endpoint = http://localhost:12000
+@openai_endpoint = https://api.openai.com
+@access_key = {{$dotenv OPENAI_API_KEY}}
+
+### openai request
+POST {{openai_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+Authorization: Bearer {{access_key}}
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello"
+ }
+ ],
+ "model": "gpt-4o-mini"
+}
+
+### openai request (streaming)
+POST {{openai_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+Authorization: Bearer {{access_key}}
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello"
+ }
+ ],
+ "model": "gpt-4o-mini",
+ "stream": true
+}
+
+
+### llm gateway request
+POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello"
+ }
+ ]
+}
+
+### llm gateway request (streaming)
+POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello"
+ }
+ ],
+ "stream": true
+}
+
+### llm gateway request (provider hint)
+POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+x-arch-llm-provider-hint: gpt-3.5-turbo-0125
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello"
+ }
+ ]
+}
diff --git a/api_model_server.rest b/api_model_server.rest
new file mode 100644
index 00000000..9102786a
--- /dev/null
+++ b/api_model_server.rest
@@ -0,0 +1,44 @@
+@model_server_endpoint = http://localhost:51000
+@archfc_endpoint = https://api.fc.archgw.com
+
+### talk to model_server for completion
+POST {{model_server_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle for next 10 days"
+ }
+ ],
+ "tools": [
+ {
+ "id": "weather-112",
+ "tool_type": "function",
+ "function": {
+ "name": "weather_forecast",
+ "arguments": {"city": "str", "days": "int"}
+ }
+ }
+ ]
+}
+
+
+### talk to arch_fc directly for completion
+POST {{archfc_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "model": "Arch-Function",
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"id\": \"weather-112\", \"tool_type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"arguments\": {\"city\": \"str\", \"days\": \"int\"}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n"
+ },
+ { "role": "user", "content": "how is the weather in seattle?" },
+ { "role": "assistant", "content": "Of course! " }
+ ],
+ "continue_final_message": true,
+ "add_generation_prompt": false
+}
diff --git a/api_prompt_gateway.rest b/api_prompt_gateway.rest
new file mode 100644
index 00000000..b79b4230
--- /dev/null
+++ b/api_prompt_gateway.rest
@@ -0,0 +1,87 @@
+@prompt_endpoint = http://localhost:10000
+
+### prompt gateway request
+POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle for next 10 days"
+ }
+ ]
+}
+
+### prompt gateway request (streaming)
+POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle for next 10 days"
+ }
+ ],
+ "stream": true
+}
+
+
+### prompt gateway request param gathering
+POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle"
+ }
+ ]
+}
+
+### prompt gateway request param gathering and function calling
+POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle"
+ },
+ {
+ "role": "assistant",
+ "content": "It seems I'm missing some information. Could you provide the following details days ?",
+ "model": "Arch-Function-1.5b"
+ },
+ {
+ "role": "user",
+ "content": "for next 10 days"
+ }
+ ]
+}
+
+### prompt gateway request param gathering and function calling (streaming)
+POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
+Content-Type: application/json
+
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": "how is the weather in seattle"
+ },
+ {
+ "role": "assistant",
+ "content": "It seems I'm missing some information. Could you provide the following details days ?",
+ "model": "Arch-Function-1.5b"
+ },
+ {
+ "role": "user",
+ "content": "for next 10 days"
+ }
+ ],
+ "stream": true
+}
diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs
index de120369..0b44fac9 100644
--- a/crates/prompt_gateway/src/filter_context.rs
+++ b/crates/prompt_gateway/src/filter_context.rs
@@ -11,7 +11,8 @@ use common::http::CallArgs;
use common::http::Client;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
-use log::debug;
+use http::StatusCode;
+use log::{debug, info, trace, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
@@ -53,6 +54,7 @@ pub struct FilterContext {
prompt_guards: Rc,
embeddings_store: Option>,
temp_embeddings_store: EmbeddingsStore,
+ active_embedding_calls_count: u32,
}
impl FilterContext {
@@ -66,22 +68,26 @@ impl FilterContext {
prompt_guards: Rc::new(PromptGuards::default()),
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
+ active_embedding_calls_count: 0,
}
}
- fn process_prompt_targets(&self) {
- for values in self.prompt_targets.iter() {
- let prompt_target = values.1;
- self.schedule_embeddings_call(
- &prompt_target.name,
- &prompt_target.description,
- EmbeddingType::Description,
- );
- }
+ fn process_prompt_targets(&mut self) {
+ let prompt_target_description: Vec<(String, String)> = self
+ .prompt_targets
+ .iter()
+ .map(|(k, v)| (k.clone(), v.description.clone()))
+ .collect();
+
+ prompt_target_description
+ .iter()
+ .for_each(|(name, description)| {
+ self.schedule_embeddings_call(name, description, EmbeddingType::Description);
+ });
}
fn schedule_embeddings_call(
- &self,
+ &mut self,
prompt_target_name: &str,
input: &str,
embedding_type: EmbeddingType,
@@ -116,6 +122,7 @@ impl FilterContext {
embedding_type,
};
+ self.active_embedding_calls_count += 1;
if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}")
}
@@ -123,9 +130,9 @@ impl FilterContext {
fn embedding_response_handler(
&mut self,
- body_size: usize,
embedding_type: EmbeddingType,
prompt_target_name: String,
+ body: Vec,
) {
let prompt_target = self
.prompt_targets
@@ -137,9 +144,6 @@ impl FilterContext {
)
});
- let body = self
- .get_http_call_response_body(0, body_size)
- .expect("No body in response");
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
@@ -208,7 +212,7 @@ impl Context for FilterContext {
body_size: usize,
_num_trailers: usize,
) {
- debug!(
+ trace!(
"filter_context: on_http_call_response called with token_id: {:?}",
token_id
);
@@ -218,13 +222,26 @@ impl Context for FilterContext {
.remove(&token_id)
.expect("invalid token_id");
+ self.active_embedding_calls_count -= 1;
self.metrics.active_http_calls.increment(-1);
+ let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
- self.embedding_response_handler(
- body_size,
- callout_data.embedding_type,
- callout_data.prompt_target_name,
- )
+ if let Some(status_code) = self.get_http_call_response_header(":status") {
+ if status_code == StatusCode::OK.as_str() {
+ self.embedding_response_handler(
+ callout_data.embedding_type,
+ callout_data.prompt_target_name,
+ body_bytes,
+ );
+ } else {
+ warn!(
+ "Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
+ status_code,
+ token_id,
+ String::from_utf8(body_bytes).unwrap()
+ );
+ }
+ }
}
}
@@ -262,10 +279,7 @@ impl RootContext for FilterContext {
context_id
);
- let embedding_store = match self.embeddings_store.as_ref() {
- None => return None,
- Some(store) => Some(Rc::clone(store)),
- };
+ let embedding_store = self.embeddings_store.as_ref().map(Rc::clone);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
@@ -287,8 +301,20 @@ impl RootContext for FilterContext {
}
fn on_tick(&mut self) {
- debug!("starting up arch filter in mode: prompt gateway mode");
- self.process_prompt_targets();
- self.set_tick_period(Duration::from_secs(0));
+ if self.embeddings_store.is_some()
+ && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
+ {
+ info!("embeddings store initialized");
+ self.set_tick_period(Duration::from_secs(0));
+ } else {
+ if self.active_embedding_calls_count == 0 {
+ info!("retrieving embeddings from embedding server");
+ self.process_prompt_targets();
+ } else {
+ info!("waiting for embeddings store to be initialized");
+ }
+
+ self.set_tick_period(Duration::from_secs(5));
+ }
}
}
diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs
index 0ea27cfb..ed5a23c8 100644
--- a/crates/prompt_gateway/src/http_context.rs
+++ b/crates/prompt_gateway/src/http_context.rs
@@ -37,10 +37,10 @@ impl HttpContext for StreamContext {
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
- if self.embeddings_store.is_none() {
- self.send_http_response(503, vec![], None);
- } else {
+ if self.is_embedding_store_initialized() {
self.send_http_response(200, vec![], None);
+ } else {
+ self.send_http_response(503, vec![], None);
}
return Action::Continue;
}
diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs
index c068072b..c7865a75 100644
--- a/crates/prompt_gateway/src/stream_context.rs
+++ b/crates/prompt_gateway/src/stream_context.rs
@@ -61,7 +61,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc