Merge branch 'main' into adil/add_endpoint_http_headers

This commit is contained in:
Adil Hafeez 2025-01-31 10:38:20 -08:00
commit 586c860bce
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
11 changed files with 169 additions and 141 deletions

View file

@ -1,19 +1,21 @@
Focus on what matters most. Arch is an **intelligent proxy server designed for prompts** - to help you protect, observe, and build agentic apps by simply connecting (existing) APIs.
<p align="center">
<img src="docs/source/_static/img/arch-logo.png" alt="Arch Logo" width="75%" heigh=auto>
</p>
<p align="center">
<a href="https://www.producthunt.com/posts/arch-3?embed=true&utm_source=badge-top-post-badge&utm_medium=badge&utm_souce=badge-arch&#0045;3" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/top-post-badge.svg?post_id=565761&theme=light&period=daily" alt="Arch - Build&#0032;fast&#0044;&#0032;hyper&#0045;personalized&#0032;agents&#0032;with&#0032;intelligent&#0032;infra | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</p>
Arch is an **intelligent (edge and LLM) proxy designed for agentic applications** - to help you protect, observe, and build agentic tasks by simply connecting (existing) APIs.
Built by the contributors of [Envoy Proxy](https://www.envoyproxy.io/) with the belief that:
>Prompts are nuanced and opaque user requests, which require the same capabilities as traditional HTTP requests including secure handling, intelligent routing, robust observability, and integration with backend (API) systems for personalization outside core business logic.*
![alt text](docs/source/_static/img/arch-logo.png)
<a href="https://www.producthunt.com/posts/arch-3?embed=true&utm_source=badge-top-post-badge&utm_medium=badge&utm_souce=badge-arch&#0045;3" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/top-post-badge.svg?post_id=565761&theme=light&period=daily" alt="Arch - Build&#0032;fast&#0044;&#0032;hyper&#0045;personalized&#0032;agents&#0032;with&#0032;intelligent&#0032;infra | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
[![pre-commit](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml)
[![rust tests (prompt and llm gateway)](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml)
[![e2e tests](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml)
[![Build and Deploy Documentation](https://github.com/katanemo/arch/actions/workflows/static.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/static.yml)
Arch is engineered with purpose-built LLMs to handle critical but undifferentiated tasks related to the handling and processing of prompts. This includes detecting and rejecting [jailbreak](https://github.com/verazuo/jailbreak_llms) attempts, intelligent task routing for improved accuracy, mapping user request into "backend" functions, and managing the observability of prompts and LLM API calls in a centralized way.
@ -24,7 +26,7 @@ Arch is engineered with purpose-built LLMs to handle critical but undifferentiat
- Routing & Traffic Management: Arch centralizes calls to LLMs used by your applications, offering smart retries, automatic cutover, and resilient upstream connections for continuous availability.
- Observability: Arch uses the W3C Trace Context standard to enable complete request tracing across applications, ensuring compatibility with observability tools, and provides metrics to monitor latency, token usage, and error rates, helping optimize AI application performance.
**High-Level Network Flow**:
**High-Level Sequence Diagram**:
![alt text](docs/source/_static/img/arch_network_diagram_high_level.png)
**Jump to our [docs](https://docs.archgw.com)** to learn how you can use Arch to improve the speed, security and personalization of your GenAI apps.

View file

@ -2,7 +2,6 @@ use std::rc::Rc;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
use log::debug;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
@ -35,11 +34,9 @@ pub fn get_llm_provider(
}
if llm_providers.default().is_some() {
debug!("no llm provider found for hint, using default llm provider");
return llm_providers.default().unwrap();
}
debug!("no default llm found, using random llm provider");
let mut rng = thread_rng();
llm_providers
.iter()

View file

@ -1,4 +1,4 @@
use log::debug;
use log::trace;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
@ -9,7 +9,7 @@ pub enum Error {
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
trace!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),

View file

@ -9,7 +9,7 @@ use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Gauge;
use common::tracing::TraceData;
use log::debug;
use log::trace;
use log::warn;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
@ -79,7 +79,7 @@ impl RootContext for FilterContext {
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
debug!(
trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
@ -108,10 +108,8 @@ impl RootContext for FilterContext {
fn on_tick(&mut self) {
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
while let Some(trace) = traces_queue.pop_front() {
debug!("trace received: {:?}", trace);
let trace_str = serde_json::to_string(&trace).unwrap();
debug!("trace: {}", trace_str);
trace!("trace details: {}", trace_str);
let call_args = CallArgs::new(
OTEL_COLLECTOR_HTTP,
OTEL_POST_PATH,
@ -144,7 +142,7 @@ impl Context for FilterContext {
_body_size: usize,
_num_trailers: usize,
) {
debug!(
trace!(
"||| on_http_call_response called with token_id: {:?} |||",
token_id
);
@ -156,7 +154,7 @@ impl Context for FilterContext {
.expect("invalid token_id");
if let Some(status) = self.get_http_call_response_header(":status") {
debug!("trace response status: {:?}", status);
trace!("trace response status: {:?}", status);
};
}
}

View file

@ -10,7 +10,6 @@ use common::consts::{
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::pii::obfuscate_auth_header;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
@ -82,12 +81,16 @@ impl StreamContext {
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|llm_name| llm_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
debug!(
"request received: llm provider hint: {:?}, selected llm: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER),
self.llm_provider.as_ref().unwrap().name
);
}
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
@ -150,7 +153,7 @@ impl StreamContext {
self.metrics
.input_sequence_length
.record(token_count as u64);
log::debug!("Recorded input token count: {}", token_count);
trace!("Recorded input token count: {}", token_count);
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
@ -161,7 +164,7 @@ impl StreamContext {
NonZero::new(token_count as u32).unwrap(),
)?;
} else {
log::debug!("No rate limit applied for model: {}", model);
trace!("No rate limit applied for model: {}", model);
}
Ok(())
@ -197,12 +200,6 @@ impl HttpContext for StreamContext {
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!(
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
obfuscate_auth_header(&mut self.get_http_request_headers())
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
@ -310,9 +307,10 @@ impl HttpContext for StreamContext {
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!(
trace!(
"on_http_response_headers [S={}] end_stream={}",
self.context_id, _end_of_stream
self.context_id,
_end_of_stream
);
self.set_property(
@ -324,9 +322,11 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
trace!(
"on_http_response_body [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
self.context_id,
body_size,
end_of_stream
);
if !self.is_chat_completions_request {
@ -342,7 +342,7 @@ impl HttpContext for StreamContext {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
debug!("Total latency: {} milliseconds", duration_ms);
debug!("request latency: {}ms", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
@ -350,11 +350,14 @@ impl HttpContext for StreamContext {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
debug!("Time per output token: {} milliseconds", tpot);
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!("Tokens per second: {}", 1000 / tpot);
trace!(
"time per token: {}ms, tokens per second: {}",
tpot,
1000 / tpot
);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
@ -414,9 +417,10 @@ impl HttpContext for StreamContext {
let body = if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
debug!(
trace!(
"streaming response reading, {}..{}",
chunk_start, chunk_size
chunk_start,
chunk_size
);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
@ -438,7 +442,7 @@ impl HttpContext for StreamContext {
}
streaming_chunk
} else {
debug!("non streaming response bytes read: 0:{}", body_size);
trace!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
@ -510,7 +514,7 @@ impl HttpContext for StreamContext {
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
debug!("Time to First Token (TTFT): {} milliseconds", duration_ms);
debug!("time to first token: {}ms", duration_ms);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
@ -520,12 +524,15 @@ impl HttpContext for StreamContext {
}
}
} else {
debug!("non streaming response");
trace!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(body_utf8.as_str()) {
Ok(de) => de,
Err(_e) => {
debug!("invalid response: {}", body_utf8);
Err(err) => {
debug!(
"non chat-completion compliant response received err: {}, body: {}",
err, body_utf8
);
return Action::Continue;
}
};
@ -539,9 +546,11 @@ impl HttpContext for StreamContext {
}
}
debug!(
trace!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
self.context_id,
self.response_tokens,
end_of_stream
);
Action::Continue

View file

@ -22,12 +22,8 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_log(
Some(LogLevel::Debug),
Some("llm provider hint: Some(Default)"),
)
.expect_log(Some(LogLevel::Debug), Some("selected llm: open-ai-gpt-4"))
.returning(None)
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: Some(\"default\"), selected llm: open-ai-gpt-4"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
@ -38,7 +34,11 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("Authorization"),
Some("Bearer secret_key"),
)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-ratelimit-selector"),
@ -50,7 +50,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
@ -62,7 +61,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -187,7 +186,10 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(
Some(LogLevel::Trace),
Some("||| create_http_context called with context_id: 2 |||"),
)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -218,9 +220,9 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 21)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
@ -251,7 +253,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -339,9 +341,9 @@ fn llm_gateway_request_ratelimited() {
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 107)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
@ -405,9 +407,9 @@ fn llm_gateway_request_not_ratelimited() {
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)

View file

@ -3,7 +3,7 @@ use std::str::FromStr;
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use http::StatusCode;
use log::{debug, warn};
use log::warn;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
@ -25,27 +25,38 @@ impl Context for StreamContext {
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or(vec![]);
.unwrap_or_default();
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
debug!("http call response code: {}", http_status);
if http_status != StatusCode::OK.as_str() {
let server_error = ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
};
warn!("filter received non 2xx code: {:?}", server_error);
return self.send_server_error(
server_error,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
if let Some(http_status) = self.get_http_call_response_header(":status") {
match StatusCode::from_str(http_status.as_str()) {
Ok(status_code) => {
if !status_code.is_success() {
let server_error = ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
};
warn!("received non 2xx code: {:?}", server_error);
return self.send_server_error(
server_error,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
Err(_) => {
// invalid status code (status code non numeric)
return self.send_server_error(
ServerError::LogicError(format!("invalid status code: {}", http_status)),
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
} else {
// :status header not found
warn!("missing :status header");
}
debug!("http call response handler type: {:?}", callout_context.response_handler_type);
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),

View file

@ -3,7 +3,7 @@ use crate::stream_context::StreamContext;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::http::Client;
use common::stats::Gauge;
use log::debug;
use log::trace;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
@ -84,7 +84,7 @@ impl RootContext for FilterContext {
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
debug!(
trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);

View file

@ -85,10 +85,7 @@ impl HttpContext for StreamContext {
}
};
debug!(
"developer => archgw: {}",
String::from_utf8_lossy(&body_bytes)
);
trace!("request body: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
@ -159,7 +156,8 @@ impl HttpContext for StreamContext {
}
};
debug!("archgw => archfc: {}", json_data);
debug!("sending request to model server");
trace!("request body: {}", json_data);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),

View file

@ -14,7 +14,7 @@ use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, warn};
use log::{debug, trace, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
@ -125,13 +125,14 @@ impl StreamContext {
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
debug!("archgw <= archfc response: {}", body_str);
debug!("model server response received");
trace!("response body: {}", body_str);
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
"error deserializing archfc response: {}, body: {}",
"error deserializing modelserver response: {}, body: {}",
e, body_str
);
return self.send_server_error(ServerError::Deserialization(e), None);
@ -141,7 +142,7 @@ impl StreamContext {
let arch_fc_response = match model_server_response {
ModelServerResponse::ChatCompletionsResponse(response) => response,
ModelServerResponse::ModelServerErrorResponse(response) => {
debug!("archgw <= archfc error response: {}", response.result);
debug!("archgw <= modelserver error response: {}", response.result);
if response.result == "No intent matched" {
if let Some(default_prompt_target) = self
.prompt_targets
@ -263,7 +264,7 @@ impl StreamContext {
);
}
// update prompt target name from the tool call
// update prompt target name from the tool call response
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
@ -352,11 +353,10 @@ impl StreamContext {
);
debug!(
"archgw => api call, endpoint: {}{}, body: {}",
endpoint.name.as_str(),
path,
tool_params_json_str
"dispatching api call to developer endpoint: {}, path: {}",
endpoint.name, path
);
trace!("request body: {}", tool_params_json_str);
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
@ -371,7 +371,10 @@ impl StreamContext {
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
debug!("api_call_response_handler: http_status: {}", http_status);
debug!(
"developer api call response received: status code: {}",
http_status
);
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
@ -388,12 +391,12 @@ impl StreamContext {
);
}
self.tool_call_response = Some(String::from_utf8(body).unwrap());
debug!(
"archgw <= api call response: {}",
trace!(
"response body: {}",
self.tool_call_response.as_ref().unwrap()
);
let mut messages = self.filter_out_arch_messages(&callout_context);
let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -439,7 +442,8 @@ impl StreamContext {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!("archgw => llm request: {}", llm_request_str);
debug!("sending request to upstream llm");
trace!("request body: {}", llm_request_str);
self.start_upstream_llm_request_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
@ -450,25 +454,39 @@ impl StreamContext {
self.resume_http_request();
}
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
match prompt_target {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target) => match prompt_target.system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
},
}
}
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
messages
.iter()
.filter(|m| {
!(m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
let prompt_system_prompt = self
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
}
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
@ -480,18 +498,9 @@ impl StreamContext {
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
messages
}

View file

@ -41,7 +41,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -87,8 +87,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
None,
)
.returning(Some(1))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
@ -203,7 +204,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -234,8 +235,9 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
@ -267,7 +269,7 @@ fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -302,7 +304,7 @@ fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
None,
None,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
@ -369,10 +371,11 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
@ -401,13 +404,12 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();