diff --git a/README.md b/README.md
index 921d309f..c5762f52 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,21 @@
-Focus on what matters most. Arch is an **intelligent proxy server designed for prompts** - to help you protect, observe, and build agentic apps by simply connecting (existing) APIs.
+
+
+
+
+
+
+
+Arch is an **intelligent (edge and LLM) proxy designed for agentic applications** - to help you protect, observe, and build agentic tasks by simply connecting (existing) APIs.
Built by the contributors of [Envoy Proxy](https://www.envoyproxy.io/) with the belief that:
>Prompts are nuanced and opaque user requests, which require the same capabilities as traditional HTTP requests including secure handling, intelligent routing, robust observability, and integration with backend (API) systems for personalization – outside core business logic.*
-
-
-
-
[](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml)
[](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml)
[](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml)
[](https://github.com/katanemo/arch/actions/workflows/static.yml)
-
Arch is engineered with purpose-built LLMs to handle critical but undifferentiated tasks related to the handling and processing of prompts. This includes detecting and rejecting [jailbreak](https://github.com/verazuo/jailbreak_llms) attempts, intelligent task routing for improved accuracy, mapping user request into "backend" functions, and managing the observability of prompts and LLM API calls in a centralized way.
@@ -24,7 +26,7 @@ Arch is engineered with purpose-built LLMs to handle critical but undifferentiat
- Routing & Traffic Management: Arch centralizes calls to LLMs used by your applications, offering smart retries, automatic cutover, and resilient upstream connections for continuous availability.
- Observability: Arch uses the W3C Trace Context standard to enable complete request tracing across applications, ensuring compatibility with observability tools, and provides metrics to monitor latency, token usage, and error rates, helping optimize AI application performance.
-**High-Level Network Flow**:
+**High-Level Sequence Diagram**:

**Jump to our [docs](https://docs.archgw.com)** to learn how you can use Arch to improve the speed, security and personalization of your GenAI apps.
diff --git a/crates/common/src/routing.rs b/crates/common/src/routing.rs
index 1a440ee9..f4baf896 100644
--- a/crates/common/src/routing.rs
+++ b/crates/common/src/routing.rs
@@ -2,7 +2,6 @@ use std::rc::Rc;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
-use log::debug;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
@@ -35,11 +34,9 @@ pub fn get_llm_provider(
}
if llm_providers.default().is_some() {
- debug!("no llm provider found for hint, using default llm provider");
return llm_providers.default().unwrap();
}
- debug!("no default llm found, using random llm provider");
let mut rng = thread_rng();
llm_providers
.iter()
diff --git a/crates/common/src/tokenizer.rs b/crates/common/src/tokenizer.rs
index aa0870f2..c424e344 100644
--- a/crates/common/src/tokenizer.rs
+++ b/crates/common/src/tokenizer.rs
@@ -1,4 +1,4 @@
-use log::debug;
+use log::trace;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
@@ -9,7 +9,7 @@ pub enum Error {
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result {
- debug!("getting token count model={}", model_name);
+ trace!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs
index 16580063..0edba456 100644
--- a/crates/llm_gateway/src/filter_context.rs
+++ b/crates/llm_gateway/src/filter_context.rs
@@ -9,7 +9,7 @@ use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Gauge;
use common::tracing::TraceData;
-use log::debug;
+use log::trace;
use log::warn;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
@@ -79,7 +79,7 @@ impl RootContext for FilterContext {
}
fn create_http_context(&self, context_id: u32) -> Option> {
- debug!(
+ trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
@@ -108,10 +108,8 @@ impl RootContext for FilterContext {
fn on_tick(&mut self) {
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
while let Some(trace) = traces_queue.pop_front() {
- debug!("trace received: {:?}", trace);
-
let trace_str = serde_json::to_string(&trace).unwrap();
- debug!("trace: {}", trace_str);
+ trace!("trace details: {}", trace_str);
let call_args = CallArgs::new(
OTEL_COLLECTOR_HTTP,
OTEL_POST_PATH,
@@ -144,7 +142,7 @@ impl Context for FilterContext {
_body_size: usize,
_num_trailers: usize,
) {
- debug!(
+ trace!(
"||| on_http_call_response called with token_id: {:?} |||",
token_id
);
@@ -156,7 +154,7 @@ impl Context for FilterContext {
.expect("invalid token_id");
if let Some(status) = self.get_http_call_response_header(":status") {
- debug!("trace response status: {:?}", status);
+ trace!("trace response status: {:?}", status);
};
}
}
diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs
index 6939f1d8..528358a3 100644
--- a/crates/llm_gateway/src/stream_context.rs
+++ b/crates/llm_gateway/src/stream_context.rs
@@ -10,7 +10,6 @@ use common::consts::{
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
-use common::pii::obfuscate_auth_header;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
@@ -82,12 +81,16 @@ impl StreamContext {
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|llm_name| llm_name.into());
- debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
- debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
+
+ debug!(
+ "request received: llm provider hint: {:?}, selected llm: {}",
+ self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER),
+ self.llm_provider.as_ref().unwrap().name
+ );
}
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
@@ -150,7 +153,7 @@ impl StreamContext {
self.metrics
.input_sequence_length
.record(token_count as u64);
- log::debug!("Recorded input token count: {}", token_count);
+ trace!("Recorded input token count: {}", token_count);
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
@@ -161,7 +164,7 @@ impl StreamContext {
NonZero::new(token_count as u32).unwrap(),
)?;
} else {
- log::debug!("No rate limit applied for model: {}", model);
+ trace!("No rate limit applied for model: {}", model);
}
Ok(())
@@ -197,12 +200,6 @@ impl HttpContext for StreamContext {
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
- debug!(
- "on_http_request_headers S[{}] req_headers={:?}",
- self.context_id,
- obfuscate_auth_header(&mut self.get_http_request_headers())
- );
-
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
@@ -310,9 +307,10 @@ impl HttpContext for StreamContext {
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
- debug!(
+ trace!(
"on_http_response_headers [S={}] end_stream={}",
- self.context_id, _end_of_stream
+ self.context_id,
+ _end_of_stream
);
self.set_property(
@@ -324,9 +322,11 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
- debug!(
+ trace!(
"on_http_response_body [S={}] bytes={} end_stream={}",
- self.context_id, body_size, end_of_stream
+ self.context_id,
+ body_size,
+ end_of_stream
);
if !self.is_chat_completions_request {
@@ -342,7 +342,7 @@ impl HttpContext for StreamContext {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
- debug!("Total latency: {} milliseconds", duration_ms);
+ debug!("request latency: {}ms", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
@@ -350,11 +350,14 @@ impl HttpContext for StreamContext {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
- debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
- debug!("Tokens per second: {}", 1000 / tpot);
+ trace!(
+ "time per token: {}ms, tokens per second: {}",
+ tpot,
+ 1000 / tpot
+ );
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
@@ -414,9 +417,10 @@ impl HttpContext for StreamContext {
let body = if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
- debug!(
+ trace!(
"streaming response reading, {}..{}",
- chunk_start, chunk_size
+ chunk_start,
+ chunk_size
);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
@@ -438,7 +442,7 @@ impl HttpContext for StreamContext {
}
streaming_chunk
} else {
- debug!("non streaming response bytes read: 0:{}", body_size);
+ trace!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
@@ -510,7 +514,7 @@ impl HttpContext for StreamContext {
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
- debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
+ debug!("time to first token: {}ms", duration_ms);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
@@ -520,12 +524,15 @@ impl HttpContext for StreamContext {
}
}
} else {
- debug!("non streaming response");
+ trace!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(body_utf8.as_str()) {
Ok(de) => de,
- Err(_e) => {
- debug!("invalid response: {}", body_utf8);
+ Err(err) => {
+ debug!(
+ "non chat-completion compliant response received err: {}, body: {}",
+ err, body_utf8
+ );
return Action::Continue;
}
};
@@ -539,9 +546,11 @@ impl HttpContext for StreamContext {
}
}
- debug!(
+ trace!(
"recv [S={}] total_tokens={} end_stream={}",
- self.context_id, self.response_tokens, end_of_stream
+ self.context_id,
+ self.response_tokens,
+ end_of_stream
);
Action::Continue
diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs
index 7a74dfa8..0b28a175 100644
--- a/crates/llm_gateway/tests/integration.rs
+++ b/crates/llm_gateway/tests/integration.rs
@@ -22,12 +22,8 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider-hint"),
)
- .returning(Some("default"))
- .expect_log(
- Some(LogLevel::Debug),
- Some("llm provider hint: Some(Default)"),
- )
- .expect_log(Some(LogLevel::Debug), Some("selected llm: open-ai-gpt-4"))
+ .returning(None)
+ .expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: Some(\"default\"), selected llm: open-ai-gpt-4"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
@@ -38,7 +34,11 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("Authorization"),
Some("Bearer secret_key"),
)
- .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
+ .expect_get_header_map_value(
+ Some(MapType::HttpRequestHeaders),
+ Some("x-arch-llm-provider-hint"),
+ )
+ .returning(Some("default"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-ratelimit-selector"),
@@ -50,7 +50,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
- .expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
@@ -62,7 +61,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -187,7 +186,10 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(
+ Some(LogLevel::Trace),
+ Some("||| create_http_context called with context_id: 2 |||"),
+ )
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -218,9 +220,9 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 21)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
@@ -251,7 +253,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -339,9 +341,9 @@ fn llm_gateway_request_ratelimited() {
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 107)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
@@ -405,9 +407,9 @@ fn llm_gateway_request_not_ratelimited() {
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs
index 2b0f8d3f..89725e0d 100644
--- a/crates/prompt_gateway/src/context.rs
+++ b/crates/prompt_gateway/src/context.rs
@@ -3,7 +3,7 @@ use std::str::FromStr;
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use http::StatusCode;
-use log::{debug, warn};
+use log::warn;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
@@ -25,27 +25,38 @@ impl Context for StreamContext {
let body = self
.get_http_call_response_body(0, body_size)
- .unwrap_or(vec![]);
+ .unwrap_or_default();
- let http_status = self
- .get_http_call_response_header(":status")
- .unwrap_or(StatusCode::OK.as_str().to_string());
- debug!("http call response code: {}", http_status);
- if http_status != StatusCode::OK.as_str() {
- let server_error = ServerError::Upstream {
- host: callout_context.upstream_cluster.unwrap(),
- path: callout_context.upstream_cluster_path.unwrap(),
- status: http_status.clone(),
- body: String::from_utf8(body).unwrap(),
- };
- warn!("filter received non 2xx code: {:?}", server_error);
- return self.send_server_error(
- server_error,
- Some(StatusCode::from_str(http_status.as_str()).unwrap()),
- );
+ if let Some(http_status) = self.get_http_call_response_header(":status") {
+ match StatusCode::from_str(http_status.as_str()) {
+ Ok(status_code) => {
+ if !status_code.is_success() {
+ let server_error = ServerError::Upstream {
+ host: callout_context.upstream_cluster.unwrap(),
+ path: callout_context.upstream_cluster_path.unwrap(),
+ status: http_status.clone(),
+ body: String::from_utf8(body).unwrap(),
+ };
+ warn!("received non 2xx code: {:?}", server_error);
+ return self.send_server_error(
+ server_error,
+ Some(StatusCode::from_str(http_status.as_str()).unwrap()),
+ );
+ }
+ }
+ Err(_) => {
+ // invalid status code (status code non numeric)
+ return self.send_server_error(
+ ServerError::LogicError(format!("invalid status code: {}", http_status)),
+ Some(StatusCode::from_str(http_status.as_str()).unwrap()),
+ );
+ }
+ }
+ } else {
+ // :status header not found
+ warn!("missing :status header");
}
- debug!("http call response handler type: {:?}", callout_context.response_handler_type);
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs
index 9780ed7d..f782dea2 100644
--- a/crates/prompt_gateway/src/filter_context.rs
+++ b/crates/prompt_gateway/src/filter_context.rs
@@ -3,7 +3,7 @@ use crate::stream_context::StreamContext;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::http::Client;
use common::stats::Gauge;
-use log::debug;
+use log::trace;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
@@ -84,7 +84,7 @@ impl RootContext for FilterContext {
}
fn create_http_context(&self, context_id: u32) -> Option> {
- debug!(
+ trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs
index 76e137c1..e7d920f1 100644
--- a/crates/prompt_gateway/src/http_context.rs
+++ b/crates/prompt_gateway/src/http_context.rs
@@ -85,10 +85,7 @@ impl HttpContext for StreamContext {
}
};
- debug!(
- "developer => archgw: {}",
- String::from_utf8_lossy(&body_bytes)
- );
+ trace!("request body: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
@@ -159,7 +156,8 @@ impl HttpContext for StreamContext {
}
};
- debug!("archgw => archfc: {}", json_data);
+ debug!("sending request to model server");
+ trace!("request body: {}", json_data);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs
index 98d230a5..41751312 100644
--- a/crates/prompt_gateway/src/stream_context.rs
+++ b/crates/prompt_gateway/src/stream_context.rs
@@ -14,7 +14,7 @@ use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
-use log::{debug, warn};
+use log::{debug, trace, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
@@ -125,13 +125,14 @@ impl StreamContext {
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
- debug!("archgw <= archfc response: {}", body_str);
+ debug!("model server response received");
+ trace!("response body: {}", body_str);
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
- "error deserializing archfc response: {}, body: {}",
+ "error deserializing modelserver response: {}, body: {}",
e, body_str
);
return self.send_server_error(ServerError::Deserialization(e), None);
@@ -141,7 +142,7 @@ impl StreamContext {
let arch_fc_response = match model_server_response {
ModelServerResponse::ChatCompletionsResponse(response) => response,
ModelServerResponse::ModelServerErrorResponse(response) => {
- debug!("archgw <= archfc error response: {}", response.result);
+ debug!("archgw <= modelserver error response: {}", response.result);
if response.result == "No intent matched" {
if let Some(default_prompt_target) = self
.prompt_targets
@@ -263,7 +264,7 @@ impl StreamContext {
);
}
- // update prompt target name from the tool call
+ // update prompt target name from the tool call response
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
@@ -352,11 +353,10 @@ impl StreamContext {
);
debug!(
- "archgw => api call, endpoint: {}{}, body: {}",
- endpoint.name.as_str(),
- path,
- tool_params_json_str
+ "dispatching api call to developer endpoint: {}, path: {}",
+ endpoint.name, path
);
+ trace!("request body: {}", tool_params_json_str);
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
@@ -371,7 +371,10 @@ impl StreamContext {
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
- debug!("api_call_response_handler: http_status: {}", http_status);
+ debug!(
+ "developer api call response received: status code: {}",
+ http_status
+ );
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
@@ -388,12 +391,12 @@ impl StreamContext {
);
}
self.tool_call_response = Some(String::from_utf8(body).unwrap());
- debug!(
- "archgw <= api call response: {}",
+ trace!(
+ "response body: {}",
self.tool_call_response.as_ref().unwrap()
);
- let mut messages = self.filter_out_arch_messages(&callout_context);
+ let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@@ -439,7 +442,8 @@ impl StreamContext {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
- debug!("archgw => llm request: {}", llm_request_str);
+ debug!("sending request to upstream llm");
+ trace!("request body: {}", llm_request_str);
self.start_upstream_llm_request_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
@@ -450,25 +454,39 @@ impl StreamContext {
self.resume_http_request();
}
- fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec {
- let mut messages: Vec = Vec::new();
- // add system prompt
+ fn get_system_prompt(&self, prompt_target: Option) -> Option {
+ match prompt_target {
+ None => self.system_prompt.as_ref().clone(),
+ Some(prompt_target) => match prompt_target.system_prompt {
+ None => self.system_prompt.as_ref().clone(),
+ Some(system_prompt) => Some(system_prompt),
+ },
+ }
+ }
+ fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec {
+ messages
+ .iter()
+ .filter(|m| {
+ !(m.role == TOOL_ROLE
+ || m.content.is_none()
+ || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
+ })
+ .cloned()
+ .collect()
+ }
+
+ fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec {
+ let mut messages: Vec = Vec::new();
+
+ // add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
- let prompt_system_prompt = self
- .prompt_targets
- .get(prompt_target_name)
- .unwrap()
- .clone()
- .system_prompt;
- match prompt_system_prompt {
- None => self.system_prompt.as_ref().clone(),
- Some(system_prompt) => Some(system_prompt),
- }
+ self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
}
};
+
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
@@ -480,18 +498,9 @@ impl StreamContext {
messages.push(system_prompt_message);
}
- // don't send tools message and api response to chat gpt
- for m in callout_context.request_body.messages.iter() {
- // don't send api response and tool calls to upstream LLMs
- if m.role == TOOL_ROLE
- || m.content.is_none()
- || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
- {
- continue;
- }
- messages.push(m.clone());
- }
-
+ messages.append(
+ &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
+ );
messages
}
diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs
index 1a6ed0e6..9fcaf74d 100644
--- a/crates/prompt_gateway/tests/integration.rs
+++ b/crates/prompt_gateway/tests/integration.rs
@@ -41,7 +41,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -87,8 +87,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
None,
)
.returning(Some(1))
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
@@ -203,7 +204,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -234,8 +235,9 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
@@ -267,7 +269,7 @@ fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -302,7 +304,7 @@ fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
None,
None,
)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
@@ -369,10 +371,11 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
+ .expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
@@ -401,13 +404,12 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
+ .expect_log(Some(LogLevel::Debug), None)
+ .expect_log(Some(LogLevel::Trace), None)
+ .expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
- .expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
- .expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();