Add prem support for a2a agents

This commit is contained in:
Adil Hafeez 2025-04-25 00:57:13 -07:00
parent 2e346143dd
commit 299f183e66
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
23 changed files with 2544 additions and 16 deletions

23
crates/Cargo.lock generated
View file

@ -20,6 +20,29 @@ dependencies = [
"gimli",
]
[[package]]
name = "agent_gateway"
version = "0.1.0"
dependencies = [
"acap",
"common",
"derivative",
"governor",
"http",
"log",
"md5",
"pretty_assertions",
"proxy-wasm",
"proxy-wasm-test-framework",
"rand",
"serde",
"serde_json",
"serde_yaml",
"serial_test",
"sha2",
"thiserror",
]
[[package]]
name = "ahash"
version = "0.3.8"

View file

@ -1,3 +1,3 @@
[workspace]
resolver = "2"
members = ["llm_gateway", "prompt_gateway", "common"]
members = ["llm_gateway", "prompt_gateway", "agent_gateway", "common"]

View file

@ -0,0 +1,29 @@
[package]
name = "agent_gateway"
version = "0.1.0"
authors = ["Katanemo Inc <info@katanemo.com>"]
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
proxy-wasm = "0.2.1"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
common = { path = "../common" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
derivative = "2.2.0"
sha2 = "0.10.8"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"
pretty_assertions = "1.4.1"

View file

@ -0,0 +1,66 @@
use std::str::FromStr;
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use http::StatusCode;
use log::warn;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or_default();
if let Some(http_status) = self.get_http_call_response_header(":status") {
match StatusCode::from_str(http_status.as_str()) {
Ok(status_code) => {
if !status_code.is_success() {
let server_error = ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
};
warn!("received non 2xx code: {:?}", server_error);
return self.send_server_error(
server_error,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
Err(_) => {
// invalid status code (status code non numeric)
return self.send_server_error(
ServerError::LogicError(format!("invalid status code: {}", http_status)),
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
} else {
// :status header not found
warn!("missing :status header");
}
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
}
}
}

View file

@ -0,0 +1,121 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::{
Agent, Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tool, Tracing,
};
use common::http::Client;
use common::stats::Gauge;
use log::trace;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct FilterCallContext {}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<Metrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
agents: Rc<HashMap<String, Agent>>,
tools: Rc<HashMap<String, Tool>>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
prompt_guards: Rc<PromptGuards>,
tracing: Rc<Option<Tracing>>,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(Metrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
endpoints: Rc::new(None),
tracing: Rc::new(None),
agents: Rc::new(HashMap::new()),
tools: Rc::new(HashMap::new()),
}
}
}
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};
self.overrides = Rc::new(config.overrides);
let mut prompt_targets = HashMap::new();
for pt in config.prompt_targets.unwrap_or_default() {
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.endpoints = Rc::new(config.endpoints);
self.agents = Rc::new(config.agents);
self.tools = Rc::new(config.tools);
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
self.tracing = Rc::new(config.tracing);
true
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.endpoints),
Rc::clone(&self.overrides),
Rc::clone(&self.tracing),
Rc::clone(&self.agents),
Rc::clone(&self.tools),
)))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
true
}
}

View file

@ -0,0 +1,440 @@
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
use common::{
api::open_ai::{self, ArchState, ChatCompletionTool, ChatCompletionsRequest, Message},
consts::{
ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER,
CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER,
SYSTEM_ROLE, TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER,
},
errors::ServerError,
http::{CallArgs, Client},
pii::obfuscate_auth_header,
};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value;
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
};
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
// get endpoint that has agent_orchestrator set to true
if let Some(endpoints) = self.endpoints.as_ref() {
if endpoints.len() == 1 {
let (name, _) = endpoints.iter().next().unwrap();
info!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name);
self.set_http_request_header(ARCH_ROUTING_HEADER, Some(name));
} else {
warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error(
ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set"
.to_string(),
),
None,
);
}
}
}
}
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
self.send_http_response(200, vec![], None);
return Action::Continue;
}
self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str());
// check if agent name is in the request header
// if not, check if there is only one agent in the config
// if so, use that agent
// if there are multiple agents in the config, return an error
if let Some(agent_header_value) = self.get_http_request_header("x-agent-name") {
if let Some(agent) = self.agents.as_ref().get(&agent_header_value) {
self.agent = Some(agent.clone());
} else {
warn!("Agent not found in config");
self.send_server_error(
ServerError::LogicError(format!(
"Agent {} not found in config",
agent_header_value
)),
None,
);
return Action::Pause;
}
} else if self.agents.as_ref().len() == 1 {
let (name, agent) = self.agents.iter().next().unwrap();
info!("Setting agent to {}", name);
self.agent = Some(agent.clone());
} else {
warn!("Multiple agents found in config and no agent name in request header");
self.send_http_response(
400,
vec![],
Some(
"Multiple agents found in config and no agent name in request header"
.as_bytes(),
),
);
return Action::Pause;
}
debug!(
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
obfuscate_auth_header(&mut self.get_http_request_headers())
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
self.request_body_size = body_size;
debug!(
"on_http_request_body S[{}] body_size={}",
self.context_id, body_size
);
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
debug!("request body: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: ChatCompletionsRequest = match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(X_ARCH_STATE_HEADER) {
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
None
}
}
None => None,
};
self.streaming_response = deserialized_body.stream;
let last_user_prompt: &open_ai::Message = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
{
Some(content) => content,
None => {
warn!("No messages in the request body");
return Action::Continue;
}
};
self.user_prompt = Some(last_user_prompt.clone());
let mut tool_calls = Vec::new();
if let Some(agent) = self.agent.as_ref() {
if let Some(tools) = agent.tools.as_ref() {
for tool in tools {
if let Some(tool) = self.tools.as_ref().get(tool) {
info!("tool: {:?}", tool);
let tool_chat_completion_tool: ChatCompletionTool = tool.into();
info!("tool_chat_completion_tool: {:?}", tool_chat_completion_tool);
tool_calls.push(tool_chat_completion_tool);
}
}
}
}
let mut metadata = deserialized_body.metadata.clone();
if let Some(overrides) = self.overrides.as_ref() {
if overrides.optimize_context_window.unwrap_or_default() {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata
.as_mut()
.unwrap()
.insert("optimize_context_window".to_string(), "true".to_string());
}
}
let messages: Vec<Message> = match self.agent.as_ref().unwrap().agent_orchestrator_prompt {
Some(ref agent_orchestrator_prompt) => {
let mut messages = Vec::new();
messages.push(Message {
role: SYSTEM_ROLE.to_string(),
content: Some(agent_orchestrator_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
});
messages.extend(deserialized_body.messages.clone());
messages
}
None => deserialized_body.messages.clone(),
};
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages,
metadata,
//HACK: adilhafeez: enable streaming for agent orchestrator
stream: false,
model: deserialized_body.model.clone(),
stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls),
};
self.chat_completions_request = Some(deserialized_body);
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
Ok(json_data) => json_data,
Err(error) => {
self.send_server_error(ServerError::Serialization(error), None);
return Action::Pause;
}
};
info!("on_http_request_body: sending request to model server");
debug!("request body: {}", json_data);
let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, "openai"),
(":method", "POST"),
(":path", "/v1/chat/completions"),
("content-type", "application/json"),
(":authority", "openai"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
"arch_listener_llm",
"/v1/chat/completions",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchFC,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
upstream_cluster_path: Some("/function_calling".to_string()),
agent: self.agent.clone(),
};
if let Err(e) = self.http_call(call_args, call_context) {
warn!("http_call failed: {:?}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
Action::Pause
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!(
"on_http_response_headers recv [S={}] headers={:?}",
self.context_id,
self.get_http_response_headers()
);
// delete content-lenght header let envoy calculate it, because we modify the response body
// that would result in a different content-length
self.set_http_response_header("content-length", None);
Action::Continue
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"on_http_response_body: recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.is_chat_completions_request {
info!("non-gpt request");
return Action::Continue;
}
if self.time_to_first_token.is_none() {
self.time_to_first_token = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
);
}
if end_of_stream && body_size == 0 {
return Action::Continue;
}
let body = if self.streaming_response {
let streaming_chunk = match self.get_http_response_body(0, body_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empty, chunk_start: {}, chunk_size: {}",
0, body_size
);
return Action::Continue;
}
};
if streaming_chunk.len() != body_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
body_size
);
}
streaming_chunk
} else {
info!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
warn!("non streaming response body empty");
return Action::Continue;
}
}
};
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
info!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
if self.streaming_response {
debug!("streaming response");
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
let chunks = vec![
// ChatCompletionStreamResponse::new(
// self.arch_fc_response.clone(),
// Some(ASSISTANT_ROLE.to_string()),
// Some(ARCH_FC_MODEL_NAME.to_string()),
// None,
// ),
// ChatCompletionStreamResponse::new(
// self.tool_call_response.clone(),
// Some(TOOL_ROLE.to_string()),
// Some(ARCH_FC_MODEL_NAME.to_string()),
// None,
// ),
];
let mut response_str = open_ai::to_server_events(chunks);
// append the original response from the model to the stream
response_str.push_str(&body_utf8);
self.set_http_response_body(0, body_size, response_str.as_bytes());
self.tool_calls = None;
}
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return Action::Continue;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
let data_serialized = serde_json::to_string(&data).unwrap();
info!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream);
Action::Continue
}
}

View file

@ -0,0 +1,17 @@
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod context;
mod filter_context;
mod http_context;
mod metrics;
mod stream_context;
mod tools;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(FilterContext::new())
});
}}

View file

@ -0,0 +1,14 @@
use common::stats::Gauge;
#[derive(Copy, Clone, Debug)]
pub struct Metrics {
pub active_http_calls: Gauge,
}
impl Metrics {
pub fn new() -> Metrics {
Metrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}

View file

@ -0,0 +1,528 @@
use crate::metrics::Metrics;
use crate::tools::compute_request_path_body;
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Agent, Endpoint, Overrides, Tool, Tracing};
use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub enum ResponseHandlerType {
ArchFC,
FunctionCall,
}
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct StreamCallContext {
pub response_handler_type: ResponseHandlerType,
pub user_message: Option<String>,
pub prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
pub request_body: ChatCompletionsRequest,
pub similarity_scores: Option<Vec<(String, f64)>>,
pub upstream_cluster: Option<String>,
pub upstream_cluster_path: Option<String>,
pub agent: Option<Agent>,
}
pub struct StreamContext {
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
pub overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub context_id: u32,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_response: Option<String>,
pub arch_state: Option<Vec<ArchState>>,
pub request_body_size: usize,
pub user_prompt: Option<Message>,
pub streaming_response: bool,
pub is_chat_completions_request: bool,
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub request_id: Option<String>,
pub start_upstream_llm_request_time: u128,
pub time_to_first_token: Option<u128>,
pub traceparent: Option<String>,
pub agents: Rc<HashMap<String, Agent>>,
pub agent: Option<Agent>,
pub tools: Rc<HashMap<String, Tool>>,
pub _tracing: Rc<Option<Tracing>>,
pub arch_fc_response: Option<String>,
}
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<Metrics>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
overrides: Rc<Option<Overrides>>,
tracing: Rc<Option<Tracing>>,
agents: Rc<HashMap<String, Agent>>,
tools: Rc<HashMap<String, Tool>>,
) -> Self {
StreamContext {
context_id,
metrics,
endpoints,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
overrides,
request_id: None,
traceparent: None,
_tracing: tracing,
start_upstream_llm_request_time: 0,
time_to_first_token: None,
arch_fc_response: None,
agents,
tools,
agent: None,
}
}
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.as_u16()
.into(),
vec![],
Some(format!("{error}").as_bytes()),
);
}
fn _trace_arch_internal(&self) -> bool {
match self._tracing.as_ref() {
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
Some(trace_arch_internal) => *trace_arch_internal,
None => false,
},
None => false,
}
}
pub fn arch_fc_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
info!("on_http_call_response: model server response received");
debug!("response body: {}", body_str);
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
"error deserializing llm response: {}, body: {}",
e, body_str
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
//TODO: try to avoid clone
let message = model_server_response
.choices
.first()
.map(|choice| choice.message.clone())
.unwrap();
self.tool_calls = message.tool_calls;
if self.tool_calls.as_ref().is_some() && self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
self.tool_calls.as_ref().unwrap().len()
);
}
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
// this means llm model didn't need additional data from tool calls and is ready to respond back to user
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
self.arch_fc_response.clone(),
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_string()),
None,
),
ChatCompletionStreamResponse::new(
Some(
model_server_response.choices[0]
.message
.content
.as_ref()
.unwrap()
.clone(),
),
None,
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
None,
),
];
to_server_events(chunks)
} else {
body_str
};
self.tool_calls = None;
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(direct_response_str.as_bytes()),
);
}
// update prompt target name from the tool call response
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
self.schedule_api_call_request(callout_context);
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
// Construct messages early to avoid mutable borrow conflicts
let tool_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let tool = self.tools.get(&tool_name).unwrap().clone();
let tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
let endpoint_details = tool.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
.path
.as_ref()
.unwrap_or(&String::from("/"))
.to_string();
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = tool.parameters.clone().unwrap_or_default();
let mut tool_params_json: Option<HashMap<String, Value>> = None;
if let Some(params) = tool_params.as_ref() {
match serde_json::from_str::<HashMap<String, Value>>(params.as_str()) {
Ok(params_json) => tool_params_json = Some(params_json),
Err(e) => {
log::warn!(
"error deserializing tool params: {}, body str: {}",
e,
String::from_utf8(params.as_bytes().to_vec()).unwrap()
);
return self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
}
};
}
//TODO: fixme hack adilhafeez
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
&tool_params_json,
&prompt_target_params,
&http_method,
) {
Ok((path, body)) => (path, body),
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error computing api request path or body: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
debug!("on_http_call_response: api call body {:?}", api_call_body);
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
let http_method_str = http_method.to_string();
let mut headers: HashMap<_, _> = [
(ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()),
(":method", &http_method_str),
(":path", &path),
(":authority", endpoint_details.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
]
.into_iter()
.collect();
if self.request_id.is_some() {
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
}
if self.traceparent.is_some() {
headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap());
}
// override http headers that are set in the prompt target
let http_headers = endpoint_details.http_headers.clone().unwrap_or_default();
for (key, value) in http_headers.iter() {
headers.insert(key.as_str(), value.as_str());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
headers.into_iter().collect(),
api_call_body.as_deref().map(|s| s.as_bytes()),
vec![],
Duration::from_secs(5),
);
info!(
"on_http_call_response: dispatching api call to developer endpoint: {}, path: {}, method: {}",
endpoint_details.name, path, http_method_str
);
callout_context.upstream_cluster = Some(endpoint_details.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
callout_context.agent = self.agent.clone();
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
info!(
"on_http_call_response: developer api call response received: status code: {}",
http_status
);
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
http_status
);
return self.send_server_error(
ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
},
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
self.tool_call_response = Some(String::from_utf8(body).unwrap());
debug!(
"response body: {}",
self.tool_call_response.as_ref().unwrap()
);
let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
None => {
return self.send_server_error(
ServerError::NoMessagesFound {
why: "no user messages found".to_string(),
},
None,
);
}
};
let final_prompt = format!(
"{}\ncontext: {}",
user_message.content.unwrap(),
self.tool_call_response.as_ref().unwrap()
);
// add original user prompt
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(final_prompt),
model: None,
tool_calls: None,
tool_call_id: None,
}
});
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
info!("on_http_call_response: sending request to upstream llm");
debug!("request body: {}", llm_request_str);
self.start_upstream_llm_request_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
self.resume_http_request();
}
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
messages
.iter()
.filter(|m| {
!(m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
if let Some(agent) = callout_context.agent.as_ref() {
if let Some(system_prompt) = agent.system_prompt.as_ref() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
}
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
messages
}
}
impl Client for StreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
#[cfg(test)]
mod test {
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
use crate::stream_context::check_intent_matched;
#[test]
fn test_intent_matched() {
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(!check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("hello".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![ToolCall {
id: "1".to_string(),
function: common::api::open_ai::FunctionCallDetail {
name: "test".to_string(),
arguments: None,
},
tool_type: common::api::open_ai::ToolType::Function,
}]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
}
}

View file

@ -0,0 +1,162 @@
use common::configuration::{HttpMethod, Parameter};
use std::collections::HashMap;
use serde_yaml::Value;
// only add params that are of string, number and bool type
pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashMap<String, String> {
if tool_params.is_none() {
return HashMap::new();
}
tool_params
.as_ref()
.unwrap()
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
Value::Number(n) => (key.clone(), n.to_string()),
Value::String(s) => (key.clone(), s.clone()),
Value::Bool(b) => (key.clone(), b.to_string()),
Value::Null => todo!(),
Value::Sequence(_) => todo!(),
Value::Mapping(_) => todo!(),
Value::Tagged(_) => todo!(),
})
.collect::<HashMap<String, String>>()
}
pub fn compute_request_path_body(
endpoint_path: &str,
tool_params: &Option<HashMap<String, Value>>,
prompt_target_params: &[Parameter],
http_method: &HttpMethod,
) -> Result<(String, Option<String>), String> {
let tool_url_params = filter_tool_params(tool_params);
let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path(
endpoint_path,
&tool_url_params,
prompt_target_params,
)?;
let (path, body) = match http_method {
HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None),
HttpMethod::Post => {
let mut additional_params = additional_params;
if !query_string.is_empty() {
query_string.split("&").for_each(|param| {
let mut parts = param.split("=");
let key = parts.next().unwrap();
let value = parts.next().unwrap();
additional_params.insert(key.to_string(), value.to_string());
});
}
let body = serde_json::to_string(&additional_params).unwrap();
(path_with_params, Some(body))
}
};
Ok((path, body))
}
#[cfg(test)]
mod test {
use common::configuration::{HttpMethod, Parameter};
#[test]
fn test_compute_request_path_body() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let tool_params = serde_yaml::from_str(
r#"
cluster_name: test1
hello: hello world
"#,
)
.unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US"
);
assert_eq!(body, None);
}
#[test]
fn test_compute_request_path_body_empty_params() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
let tool_params = serde_yaml::from_str(r#"{}"#).unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/?country=US"
);
assert_eq!(body, None);
}
#[test]
fn test_compute_request_path_body_override_default_val() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
let tool_params = serde_yaml::from_str(
r#"
country: UK
"#,
)
.unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/?country=UK"
);
assert_eq!(body, None);
}
}

View file

@ -0,0 +1,690 @@
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::configuration::Configuration;
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasip1` first"
);
wasm_file.to_str().unwrap().to_string()
}
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_log(Some(LogLevel::Debug), None)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", "model_server"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
None,
None,
Some(5000),
)
.returning(Some(1))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
let filter_context = 1;
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
filter_context
}
fn default_config() -> &'static str {
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_request_to_llm_gateway() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: Some(HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)])),
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: {
let mut map: HashMap<String, String> = HashMap::new();
map.insert("function_latency".to_string(), "0.0".to_string());
Some(map)
},
};
let expected_body = "{\"city\":\"seattle\"}";
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-envoy-max-retries", "3"),
("x-arch-upstream", "api_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
(":path", "/weather"),
(":method", "POST"),
(":authority", "api_server"),
]),
Some(expected_body),
None,
Some(5000),
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let chat_completion_response = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
module
.call_proxy_on_response_body(
http_context,
chat_completion_response_str.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_request_no_intent_match() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("no default prompt target found, forwarding request to upstream llm"),
)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn arch_config_default_target() -> &'static str {
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
endpoint:
name: weather_forecast_service
path: /default_target
http_method: POST
system_prompt: |
You are a helpful assistant! Summarize the user's request and provide a helpful response.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: false
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn prompt_gateway_request_no_intent_match_default_target() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(arch_config_default_target()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("default prompt target found, forwarding request to default prompt target"),
)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("x-arch-upstream", "weather_forecast_service"),
(":path", "/default_target"),
(":authority", "weather_forecast_service"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
None,
None,
Some(5000),
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}

View file

@ -1,6 +1,5 @@
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
@ -43,6 +42,8 @@ pub struct FunctionDefinition {
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub properties_type: String,
pub properties: HashMap<String, FunctionParameter>,
}
@ -51,7 +52,7 @@ impl Serialize for FunctionParameters {
where
S: serde::Serializer,
{
// select all requried parameters
// select all required parameters
let required: Vec<&String> = self
.properties
.iter()
@ -60,6 +61,7 @@ impl Serialize for FunctionParameters {
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
map.serialize_entry("type", &self.properties_type)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
@ -113,7 +115,7 @@ pub enum ParameterType {
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
#[serde(rename = "string")]
String,
#[serde(rename = "list")]
List,
@ -189,7 +191,7 @@ pub struct ToolCall {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: Option<HashMap<String, Value>>,
pub arguments: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]

View file

@ -19,6 +19,37 @@ pub struct Configuration {
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub agents: HashMap<String, Agent>,
pub tools: HashMap<String, Tool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub name: String,
pub description: String,
pub default_input_modes: Option<Vec<String>>,
pub default_output_modes: Option<Vec<String>>,
pub skills: Option<Vec<Skill>>,
pub model: String,
pub agent_orchestrator_prompt: Option<String>,
pub system_prompt: Option<String>,
pub tools: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Skill {
pub id: String,
pub name: String,
pub description: String,
pub examples: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub endpoint: Option<EndpointDetails>,
pub parameters: Option<Vec<Parameter>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -260,7 +291,48 @@ impl From<&PromptTarget> for ChatCompletionTool {
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters { properties },
parameters: FunctionParameters {
properties,
properties_type: "object".to_string(),
},
},
}
}
}
// convert Tool to ChatCompletionTool
impl From<&Tool> for ChatCompletionTool {
fn from(val: &Tool) -> Self {
let properties: HashMap<String, FunctionParameter> = match val.parameters {
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
format: entity.format.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
ChatCompletionTool {
tool_type: crate::api::open_ai::ToolType::Function,
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters {
properties,
properties_type: "object".to_string(),
},
},
}
}

View file

@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
pub const CHAT_COMPLETIONS_PATH: [&str; 2] =
["/v1/chat/completions", "/openai/v1/chat/completions"];
pub const HEALTHZ_PATH: &str = "/healthz";
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";

View file

@ -371,7 +371,6 @@ impl StreamContext {
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
.path
@ -382,9 +381,10 @@ impl StreamContext {
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
//TODO: fixme: adilhafeez hack
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
tool_params,
&None,
&prompt_target_params,
&http_method,
) {
@ -777,18 +777,19 @@ impl StreamContext {
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
let content = model_server_response
.choices.first()
.choices
.first()
.and_then(|choice| choice.message.content.as_ref());
let content_has_value = content.is_some() && !content.unwrap().is_empty();
let tool_calls = model_server_response
.choices.first()
.choices
.first()
.and_then(|choice| choice.message.tool_calls.as_ref());
// intent was matched if content has some value or tool_calls is empty
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
}