Add prem support for a2a agents

This commit is contained in:
Adil Hafeez 2025-04-25 00:57:13 -07:00
parent 2e346143dd
commit 299f183e66
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
23 changed files with 2544 additions and 16 deletions

View file

@ -4,8 +4,7 @@ RUN rustup -v target add wasm32-wasip1
WORKDIR /arch
COPY crates .
RUN cd prompt_gateway && cargo build --release --target wasm32-wasip1
RUN cd llm_gateway && cargo build --release --target wasm32-wasip1
RUN cargo build --release --target wasm32-wasip1
# copy built filter into envoy image
FROM docker.io/envoyproxy/envoy:v1.32-latest as envoy
@ -17,6 +16,7 @@ RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && r
COPY --from=builder /arch/target/wasm32-wasip1/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
COPY --from=builder /arch/target/wasm32-wasip1/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
COPY --from=builder /arch/target/wasm32-wasip1/release/agent_gateway.wasm /etc/envoy/proxy-wasm-plugins/agent_gateway.wasm
COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy
WORKDIR /app
COPY arch/requirements.txt .
@ -29,4 +29,4 @@ RUN pip install requests
RUN touch /var/log/envoy.log
# ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --log-level trace 2>&1 | tee /var/log/envoy.log"]
ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log/envoy.log"]
ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log"]

View file

@ -3,6 +3,10 @@ type: object
properties:
version:
type: string
tools:
type: object
agents:
type: object
listeners:
type: object
additionalProperties: false

View file

@ -452,6 +452,125 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
- name: ingress_traffic_agent
address:
socket_address:
address: 0.0.0.0
port_value: 14000
traffic_direction: INBOUND
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
generate_request_id: true
tracing:
provider:
name: envoy.tracers.opentelemetry
typed_config:
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
grpc_service:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: ingress_traffic_agent
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
stat_prefix: ingress_traffic_agent
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/access_ingress_agent.log"
route_config:
name: local_routes
virtual_hosts:
- name: local_service
domains:
- "*"
routes:
{% for provider in arch_llm_providers %}
# if endpoint is set then use custom cluster for upstream llm
{% if provider.endpoint %}
{% set llm_cluster_name = provider.name %}
{% else %}
{% set llm_cluster_name = provider.provider_interface %}
{% endif %}
- match:
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: {{ llm_cluster_name }}
route:
auto_host_rewrite: true
cluster: {{ llm_cluster_name }}
timeout: 60s
{% endfor %}
http_filters:
- name: envoy.filters.http.compressor
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.compressor.v3.Compressor
compressor_library:
name: compress
typed_config:
"@type": type.googleapis.com/envoy.extensions.compression.gzip.compressor.v3.Gzip
memory_level: 3
window_bits: 10
- name: envoy.filters.http.wasm_agent
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
root_id: agent_gateway
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ arch_config | indent(32) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/agent_gateway.wasm"
- name: envoy.filters.http.wasm_llm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
root_id: llm_gateway
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ arch_llm_config | indent(32) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm"
- name: envoy.filters.http.decompressor
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.decompressor.v3.Decompressor
decompressor_library:
name: decompress
typed_config:
"@type": "type.googleapis.com/envoy.extensions.compression.gzip.decompressor.v3.Gzip"
window_bits: 9
chunk_size: 8192
# If this ratio is set too low, then body data will not be decompressed completely.
max_inflate_ratio: 1000
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
- name: openai
connect_timeout: 0.5s

View file

@ -49,6 +49,7 @@ def docker_start_archgw_detached(
port_mappings = [
f"{prompt_gateway_port}:{prompt_gateway_port}",
f"{llm_gateway_port}:{llm_gateway_port}",
"14000:14000",
"9901:19901",
]
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
@ -56,7 +57,7 @@ def docker_start_archgw_detached(
volume_mappings = [
f"{logs_path_abs}:/var/log:rw",
f"{arch_config_file}:/app/arch_config.yaml:ro",
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
"/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
]
volume_mappings_args = [
item for volume in volume_mappings for item in ("-v", volume)

23
crates/Cargo.lock generated
View file

@ -20,6 +20,29 @@ dependencies = [
"gimli",
]
[[package]]
name = "agent_gateway"
version = "0.1.0"
dependencies = [
"acap",
"common",
"derivative",
"governor",
"http",
"log",
"md5",
"pretty_assertions",
"proxy-wasm",
"proxy-wasm-test-framework",
"rand",
"serde",
"serde_json",
"serde_yaml",
"serial_test",
"sha2",
"thiserror",
]
[[package]]
name = "ahash"
version = "0.3.8"

View file

@ -1,3 +1,3 @@
[workspace]
resolver = "2"
members = ["llm_gateway", "prompt_gateway", "common"]
members = ["llm_gateway", "prompt_gateway", "agent_gateway", "common"]

View file

@ -0,0 +1,29 @@
[package]
name = "agent_gateway"
version = "0.1.0"
authors = ["Katanemo Inc <info@katanemo.com>"]
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
proxy-wasm = "0.2.1"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
common = { path = "../common" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
derivative = "2.2.0"
sha2 = "0.10.8"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"
pretty_assertions = "1.4.1"

View file

@ -0,0 +1,66 @@
use std::str::FromStr;
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use http::StatusCode;
use log::warn;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or_default();
if let Some(http_status) = self.get_http_call_response_header(":status") {
match StatusCode::from_str(http_status.as_str()) {
Ok(status_code) => {
if !status_code.is_success() {
let server_error = ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
};
warn!("received non 2xx code: {:?}", server_error);
return self.send_server_error(
server_error,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
Err(_) => {
// invalid status code (status code non numeric)
return self.send_server_error(
ServerError::LogicError(format!("invalid status code: {}", http_status)),
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
}
} else {
// :status header not found
warn!("missing :status header");
}
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
}
}
}

View file

@ -0,0 +1,121 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::{
Agent, Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tool, Tracing,
};
use common::http::Client;
use common::stats::Gauge;
use log::trace;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct FilterCallContext {}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<Metrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
agents: Rc<HashMap<String, Agent>>,
tools: Rc<HashMap<String, Tool>>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
prompt_guards: Rc<PromptGuards>,
tracing: Rc<Option<Tracing>>,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(Metrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
endpoints: Rc::new(None),
tracing: Rc::new(None),
agents: Rc::new(HashMap::new()),
tools: Rc::new(HashMap::new()),
}
}
}
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};
self.overrides = Rc::new(config.overrides);
let mut prompt_targets = HashMap::new();
for pt in config.prompt_targets.unwrap_or_default() {
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.endpoints = Rc::new(config.endpoints);
self.agents = Rc::new(config.agents);
self.tools = Rc::new(config.tools);
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
self.tracing = Rc::new(config.tracing);
true
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
trace!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.endpoints),
Rc::clone(&self.overrides),
Rc::clone(&self.tracing),
Rc::clone(&self.agents),
Rc::clone(&self.tools),
)))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
true
}
}

View file

@ -0,0 +1,440 @@
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
use common::{
api::open_ai::{self, ArchState, ChatCompletionTool, ChatCompletionsRequest, Message},
consts::{
ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER,
CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER,
SYSTEM_ROLE, TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER,
},
errors::ServerError,
http::{CallArgs, Client},
pii::obfuscate_auth_header,
};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value;
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
};
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
// get endpoint that has agent_orchestrator set to true
if let Some(endpoints) = self.endpoints.as_ref() {
if endpoints.len() == 1 {
let (name, _) = endpoints.iter().next().unwrap();
info!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name);
self.set_http_request_header(ARCH_ROUTING_HEADER, Some(name));
} else {
warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error(
ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set"
.to_string(),
),
None,
);
}
}
}
}
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
self.send_http_response(200, vec![], None);
return Action::Continue;
}
self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str());
// check if agent name is in the request header
// if not, check if there is only one agent in the config
// if so, use that agent
// if there are multiple agents in the config, return an error
if let Some(agent_header_value) = self.get_http_request_header("x-agent-name") {
if let Some(agent) = self.agents.as_ref().get(&agent_header_value) {
self.agent = Some(agent.clone());
} else {
warn!("Agent not found in config");
self.send_server_error(
ServerError::LogicError(format!(
"Agent {} not found in config",
agent_header_value
)),
None,
);
return Action::Pause;
}
} else if self.agents.as_ref().len() == 1 {
let (name, agent) = self.agents.iter().next().unwrap();
info!("Setting agent to {}", name);
self.agent = Some(agent.clone());
} else {
warn!("Multiple agents found in config and no agent name in request header");
self.send_http_response(
400,
vec![],
Some(
"Multiple agents found in config and no agent name in request header"
.as_bytes(),
),
);
return Action::Pause;
}
debug!(
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
obfuscate_auth_header(&mut self.get_http_request_headers())
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
self.request_body_size = body_size;
debug!(
"on_http_request_body S[{}] body_size={}",
self.context_id, body_size
);
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
debug!("request body: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: ChatCompletionsRequest = match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(X_ARCH_STATE_HEADER) {
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
None
}
}
None => None,
};
self.streaming_response = deserialized_body.stream;
let last_user_prompt: &open_ai::Message = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
{
Some(content) => content,
None => {
warn!("No messages in the request body");
return Action::Continue;
}
};
self.user_prompt = Some(last_user_prompt.clone());
let mut tool_calls = Vec::new();
if let Some(agent) = self.agent.as_ref() {
if let Some(tools) = agent.tools.as_ref() {
for tool in tools {
if let Some(tool) = self.tools.as_ref().get(tool) {
info!("tool: {:?}", tool);
let tool_chat_completion_tool: ChatCompletionTool = tool.into();
info!("tool_chat_completion_tool: {:?}", tool_chat_completion_tool);
tool_calls.push(tool_chat_completion_tool);
}
}
}
}
let mut metadata = deserialized_body.metadata.clone();
if let Some(overrides) = self.overrides.as_ref() {
if overrides.optimize_context_window.unwrap_or_default() {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata
.as_mut()
.unwrap()
.insert("optimize_context_window".to_string(), "true".to_string());
}
}
let messages: Vec<Message> = match self.agent.as_ref().unwrap().agent_orchestrator_prompt {
Some(ref agent_orchestrator_prompt) => {
let mut messages = Vec::new();
messages.push(Message {
role: SYSTEM_ROLE.to_string(),
content: Some(agent_orchestrator_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
});
messages.extend(deserialized_body.messages.clone());
messages
}
None => deserialized_body.messages.clone(),
};
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages,
metadata,
//HACK: adilhafeez: enable streaming for agent orchestrator
stream: false,
model: deserialized_body.model.clone(),
stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls),
};
self.chat_completions_request = Some(deserialized_body);
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
Ok(json_data) => json_data,
Err(error) => {
self.send_server_error(ServerError::Serialization(error), None);
return Action::Pause;
}
};
info!("on_http_request_body: sending request to model server");
debug!("request body: {}", json_data);
let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, "openai"),
(":method", "POST"),
(":path", "/v1/chat/completions"),
("content-type", "application/json"),
(":authority", "openai"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
"arch_listener_llm",
"/v1/chat/completions",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchFC,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
upstream_cluster_path: Some("/function_calling".to_string()),
agent: self.agent.clone(),
};
if let Err(e) = self.http_call(call_args, call_context) {
warn!("http_call failed: {:?}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
Action::Pause
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!(
"on_http_response_headers recv [S={}] headers={:?}",
self.context_id,
self.get_http_response_headers()
);
// delete content-lenght header let envoy calculate it, because we modify the response body
// that would result in a different content-length
self.set_http_response_header("content-length", None);
Action::Continue
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"on_http_response_body: recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.is_chat_completions_request {
info!("non-gpt request");
return Action::Continue;
}
if self.time_to_first_token.is_none() {
self.time_to_first_token = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos(),
);
}
if end_of_stream && body_size == 0 {
return Action::Continue;
}
let body = if self.streaming_response {
let streaming_chunk = match self.get_http_response_body(0, body_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empty, chunk_start: {}, chunk_size: {}",
0, body_size
);
return Action::Continue;
}
};
if streaming_chunk.len() != body_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
body_size
);
}
streaming_chunk
} else {
info!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
warn!("non streaming response body empty");
return Action::Continue;
}
}
};
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
info!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
if self.streaming_response {
debug!("streaming response");
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
let chunks = vec![
// ChatCompletionStreamResponse::new(
// self.arch_fc_response.clone(),
// Some(ASSISTANT_ROLE.to_string()),
// Some(ARCH_FC_MODEL_NAME.to_string()),
// None,
// ),
// ChatCompletionStreamResponse::new(
// self.tool_call_response.clone(),
// Some(TOOL_ROLE.to_string()),
// Some(ARCH_FC_MODEL_NAME.to_string()),
// None,
// ),
];
let mut response_str = open_ai::to_server_events(chunks);
// append the original response from the model to the stream
response_str.push_str(&body_utf8);
self.set_http_response_body(0, body_size, response_str.as_bytes());
self.tool_calls = None;
}
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return Action::Continue;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
let data_serialized = serde_json::to_string(&data).unwrap();
info!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream);
Action::Continue
}
}

View file

@ -0,0 +1,17 @@
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod context;
mod filter_context;
mod http_context;
mod metrics;
mod stream_context;
mod tools;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(FilterContext::new())
});
}}

View file

@ -0,0 +1,14 @@
use common::stats::Gauge;
#[derive(Copy, Clone, Debug)]
pub struct Metrics {
pub active_http_calls: Gauge,
}
impl Metrics {
pub fn new() -> Metrics {
Metrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}

View file

@ -0,0 +1,528 @@
use crate::metrics::Metrics;
use crate::tools::compute_request_path_body;
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Agent, Endpoint, Overrides, Tool, Tracing};
use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub enum ResponseHandlerType {
ArchFC,
FunctionCall,
}
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct StreamCallContext {
pub response_handler_type: ResponseHandlerType,
pub user_message: Option<String>,
pub prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
pub request_body: ChatCompletionsRequest,
pub similarity_scores: Option<Vec<(String, f64)>>,
pub upstream_cluster: Option<String>,
pub upstream_cluster_path: Option<String>,
pub agent: Option<Agent>,
}
pub struct StreamContext {
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
pub overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub context_id: u32,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_response: Option<String>,
pub arch_state: Option<Vec<ArchState>>,
pub request_body_size: usize,
pub user_prompt: Option<Message>,
pub streaming_response: bool,
pub is_chat_completions_request: bool,
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub request_id: Option<String>,
pub start_upstream_llm_request_time: u128,
pub time_to_first_token: Option<u128>,
pub traceparent: Option<String>,
pub agents: Rc<HashMap<String, Agent>>,
pub agent: Option<Agent>,
pub tools: Rc<HashMap<String, Tool>>,
pub _tracing: Rc<Option<Tracing>>,
pub arch_fc_response: Option<String>,
}
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<Metrics>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
overrides: Rc<Option<Overrides>>,
tracing: Rc<Option<Tracing>>,
agents: Rc<HashMap<String, Agent>>,
tools: Rc<HashMap<String, Tool>>,
) -> Self {
StreamContext {
context_id,
metrics,
endpoints,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
overrides,
request_id: None,
traceparent: None,
_tracing: tracing,
start_upstream_llm_request_time: 0,
time_to_first_token: None,
arch_fc_response: None,
agents,
tools,
agent: None,
}
}
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.as_u16()
.into(),
vec![],
Some(format!("{error}").as_bytes()),
);
}
fn _trace_arch_internal(&self) -> bool {
match self._tracing.as_ref() {
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
Some(trace_arch_internal) => *trace_arch_internal,
None => false,
},
None => false,
}
}
pub fn arch_fc_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
info!("on_http_call_response: model server response received");
debug!("response body: {}", body_str);
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
"error deserializing llm response: {}, body: {}",
e, body_str
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
//TODO: try to avoid clone
let message = model_server_response
.choices
.first()
.map(|choice| choice.message.clone())
.unwrap();
self.tool_calls = message.tool_calls;
if self.tool_calls.as_ref().is_some() && self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
self.tool_calls.as_ref().unwrap().len()
);
}
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
// this means llm model didn't need additional data from tool calls and is ready to respond back to user
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
self.arch_fc_response.clone(),
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_string()),
None,
),
ChatCompletionStreamResponse::new(
Some(
model_server_response.choices[0]
.message
.content
.as_ref()
.unwrap()
.clone(),
),
None,
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
None,
),
];
to_server_events(chunks)
} else {
body_str
};
self.tool_calls = None;
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(direct_response_str.as_bytes()),
);
}
// update prompt target name from the tool call response
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
self.schedule_api_call_request(callout_context);
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
// Construct messages early to avoid mutable borrow conflicts
let tool_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let tool = self.tools.get(&tool_name).unwrap().clone();
let tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
let endpoint_details = tool.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
.path
.as_ref()
.unwrap_or(&String::from("/"))
.to_string();
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = tool.parameters.clone().unwrap_or_default();
let mut tool_params_json: Option<HashMap<String, Value>> = None;
if let Some(params) = tool_params.as_ref() {
match serde_json::from_str::<HashMap<String, Value>>(params.as_str()) {
Ok(params_json) => tool_params_json = Some(params_json),
Err(e) => {
log::warn!(
"error deserializing tool params: {}, body str: {}",
e,
String::from_utf8(params.as_bytes().to_vec()).unwrap()
);
return self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
}
};
}
//TODO: fixme hack adilhafeez
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
&tool_params_json,
&prompt_target_params,
&http_method,
) {
Ok((path, body)) => (path, body),
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error computing api request path or body: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
debug!("on_http_call_response: api call body {:?}", api_call_body);
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
let http_method_str = http_method.to_string();
let mut headers: HashMap<_, _> = [
(ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()),
(":method", &http_method_str),
(":path", &path),
(":authority", endpoint_details.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
]
.into_iter()
.collect();
if self.request_id.is_some() {
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
}
if self.traceparent.is_some() {
headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap());
}
// override http headers that are set in the prompt target
let http_headers = endpoint_details.http_headers.clone().unwrap_or_default();
for (key, value) in http_headers.iter() {
headers.insert(key.as_str(), value.as_str());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
headers.into_iter().collect(),
api_call_body.as_deref().map(|s| s.as_bytes()),
vec![],
Duration::from_secs(5),
);
info!(
"on_http_call_response: dispatching api call to developer endpoint: {}, path: {}, method: {}",
endpoint_details.name, path, http_method_str
);
callout_context.upstream_cluster = Some(endpoint_details.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
callout_context.agent = self.agent.clone();
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
info!(
"on_http_call_response: developer api call response received: status code: {}",
http_status
);
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
http_status
);
return self.send_server_error(
ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
},
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
self.tool_call_response = Some(String::from_utf8(body).unwrap());
debug!(
"response body: {}",
self.tool_call_response.as_ref().unwrap()
);
let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
None => {
return self.send_server_error(
ServerError::NoMessagesFound {
why: "no user messages found".to_string(),
},
None,
);
}
};
let final_prompt = format!(
"{}\ncontext: {}",
user_message.content.unwrap(),
self.tool_call_response.as_ref().unwrap()
);
// add original user prompt
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(final_prompt),
model: None,
tool_calls: None,
tool_call_id: None,
}
});
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
info!("on_http_call_response: sending request to upstream llm");
debug!("request body: {}", llm_request_str);
self.start_upstream_llm_request_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
self.resume_http_request();
}
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
messages
.iter()
.filter(|m| {
!(m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
if let Some(agent) = callout_context.agent.as_ref() {
if let Some(system_prompt) = agent.system_prompt.as_ref() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
}
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
messages
}
}
impl Client for StreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
#[cfg(test)]
mod test {
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
use crate::stream_context::check_intent_matched;
#[test]
fn test_intent_matched() {
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(!check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("hello".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![ToolCall {
id: "1".to_string(),
function: common::api::open_ai::FunctionCallDetail {
name: "test".to_string(),
arguments: None,
},
tool_type: common::api::open_ai::ToolType::Function,
}]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
}
}

View file

@ -0,0 +1,162 @@
use common::configuration::{HttpMethod, Parameter};
use std::collections::HashMap;
use serde_yaml::Value;
// only add params that are of string, number and bool type
pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashMap<String, String> {
if tool_params.is_none() {
return HashMap::new();
}
tool_params
.as_ref()
.unwrap()
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
Value::Number(n) => (key.clone(), n.to_string()),
Value::String(s) => (key.clone(), s.clone()),
Value::Bool(b) => (key.clone(), b.to_string()),
Value::Null => todo!(),
Value::Sequence(_) => todo!(),
Value::Mapping(_) => todo!(),
Value::Tagged(_) => todo!(),
})
.collect::<HashMap<String, String>>()
}
pub fn compute_request_path_body(
endpoint_path: &str,
tool_params: &Option<HashMap<String, Value>>,
prompt_target_params: &[Parameter],
http_method: &HttpMethod,
) -> Result<(String, Option<String>), String> {
let tool_url_params = filter_tool_params(tool_params);
let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path(
endpoint_path,
&tool_url_params,
prompt_target_params,
)?;
let (path, body) = match http_method {
HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None),
HttpMethod::Post => {
let mut additional_params = additional_params;
if !query_string.is_empty() {
query_string.split("&").for_each(|param| {
let mut parts = param.split("=");
let key = parts.next().unwrap();
let value = parts.next().unwrap();
additional_params.insert(key.to_string(), value.to_string());
});
}
let body = serde_json::to_string(&additional_params).unwrap();
(path_with_params, Some(body))
}
};
Ok((path, body))
}
#[cfg(test)]
mod test {
use common::configuration::{HttpMethod, Parameter};
#[test]
fn test_compute_request_path_body() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let tool_params = serde_yaml::from_str(
r#"
cluster_name: test1
hello: hello world
"#,
)
.unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US"
);
assert_eq!(body, None);
}
#[test]
fn test_compute_request_path_body_empty_params() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
let tool_params = serde_yaml::from_str(r#"{}"#).unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/?country=US"
);
assert_eq!(body, None);
}
#[test]
fn test_compute_request_path_body_override_default_val() {
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
let tool_params = serde_yaml::from_str(
r#"
country: UK
"#,
)
.unwrap();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
&tool_params,
&prompt_target_params,
&http_method,
)
.unwrap();
assert_eq!(
path,
"/cluster.open-cluster-management.io/v1/managedclusters/?country=UK"
);
assert_eq!(body, None);
}
}

View file

@ -0,0 +1,690 @@
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::configuration::Configuration;
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasip1` first"
);
wasm_file.to_str().unwrap().to_string()
}
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_log(Some(LogLevel::Debug), None)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", "model_server"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
None,
None,
Some(5000),
)
.returning(Some(1))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
let filter_context = 1;
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
filter_context
}
fn default_config() -> &'static str {
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_request_to_llm_gateway() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: Some(HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)])),
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: {
let mut map: HashMap<String, String> = HashMap::new();
map.insert("function_latency".to_string(), "0.0".to_string());
Some(map)
},
};
let expected_body = "{\"city\":\"seattle\"}";
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-envoy-max-retries", "3"),
("x-arch-upstream", "api_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
(":path", "/weather"),
(":method", "POST"),
(":authority", "api_server"),
]),
Some(expected_body),
None,
Some(5000),
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let chat_completion_response = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
module
.call_proxy_on_response_body(
http_context,
chat_completion_response_str.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_request_no_intent_match() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("no default prompt target found, forwarding request to upstream llm"),
)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn arch_config_default_target() -> &'static str {
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
endpoint:
name: weather_forecast_service
path: /default_target
http_method: POST
system_prompt: |
You are a helpful assistant! Summarize the user's request and provide a helpful response.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: false
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn prompt_gateway_request_no_intent_match_default_target() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(arch_config_default_target()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("default prompt target found, forwarding request to default prompt target"),
)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("x-arch-upstream", "weather_forecast_service"),
(":path", "/default_target"),
(":authority", "weather_forecast_service"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "30000"),
]),
None,
None,
Some(5000),
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}

View file

@ -1,6 +1,5 @@
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
@ -43,6 +42,8 @@ pub struct FunctionDefinition {
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub properties_type: String,
pub properties: HashMap<String, FunctionParameter>,
}
@ -51,7 +52,7 @@ impl Serialize for FunctionParameters {
where
S: serde::Serializer,
{
// select all requried parameters
// select all required parameters
let required: Vec<&String> = self
.properties
.iter()
@ -60,6 +61,7 @@ impl Serialize for FunctionParameters {
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
map.serialize_entry("type", &self.properties_type)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
@ -113,7 +115,7 @@ pub enum ParameterType {
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
#[serde(rename = "string")]
String,
#[serde(rename = "list")]
List,
@ -189,7 +191,7 @@ pub struct ToolCall {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: Option<HashMap<String, Value>>,
pub arguments: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]

View file

@ -19,6 +19,37 @@ pub struct Configuration {
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub agents: HashMap<String, Agent>,
pub tools: HashMap<String, Tool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub name: String,
pub description: String,
pub default_input_modes: Option<Vec<String>>,
pub default_output_modes: Option<Vec<String>>,
pub skills: Option<Vec<Skill>>,
pub model: String,
pub agent_orchestrator_prompt: Option<String>,
pub system_prompt: Option<String>,
pub tools: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Skill {
pub id: String,
pub name: String,
pub description: String,
pub examples: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub endpoint: Option<EndpointDetails>,
pub parameters: Option<Vec<Parameter>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -260,7 +291,48 @@ impl From<&PromptTarget> for ChatCompletionTool {
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters { properties },
parameters: FunctionParameters {
properties,
properties_type: "object".to_string(),
},
},
}
}
}
// convert Tool to ChatCompletionTool
impl From<&Tool> for ChatCompletionTool {
fn from(val: &Tool) -> Self {
let properties: HashMap<String, FunctionParameter> = match val.parameters {
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
format: entity.format.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
ChatCompletionTool {
tool_type: crate::api::open_ai::ToolType::Function,
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters {
properties,
properties_type: "object".to_string(),
},
},
}
}

View file

@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
pub const CHAT_COMPLETIONS_PATH: [&str; 2] =
["/v1/chat/completions", "/openai/v1/chat/completions"];
pub const HEALTHZ_PATH: &str = "/healthz";
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";

View file

@ -371,7 +371,6 @@ impl StreamContext {
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
.path
@ -382,9 +381,10 @@ impl StreamContext {
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
//TODO: fixme: adilhafeez hack
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
tool_params,
&None,
&prompt_target_params,
&http_method,
) {
@ -777,18 +777,19 @@ impl StreamContext {
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
let content = model_server_response
.choices.first()
.choices
.first()
.and_then(|choice| choice.message.content.as_ref());
let content_has_value = content.is_some() && !content.unwrap().is_empty();
let tool_calls = model_server_response
.choices.first()
.choices
.first()
.and_then(|choice| choice.message.tool_calls.as_ref());
// intent was matched if content has some value or tool_calls is empty
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
}

View file

@ -0,0 +1,76 @@
version: v0.1
listeners:
ingress_traffic:
address: 0.0.0.0
port: 10000
message_format: openai
timeout: 30s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider_interface: openai
model: gpt-4o
endpoints:
frankfurther_api:
endpoint: api.frankfurter.dev
protocol: https
tools:
- name: get_exchange_rate
endpoint:
name: frankfurther_api
path: /v1/latest?base={currency_from}&symbols={currency_to}
params:
- name: currency_from
description: currency symbol to convert from
type: str
default: USD
- name: currency_to
description: currency symbol to convert to
type: str
default: EUR
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
agents:
- name: Currency Exchange Agent
description: Helps with exchange rates for currencies
default_input_modes:
- text
- text/plain
default_output_modes:
- text
- text/plain
skills:
- id: convert_currency
name: Currency Exchange Rates Tool
description: Helps with exchange values between various currencies
examples:
- What is exchange rate between USD and GBP?
capabilities:
streaming: true
push_notifications: true
model: gpt-4o
system_prompt: |
You are a specialized assistant for currency conversions.
Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates.
If the user asks about anything other than currency conversion or exchange rates,
politely state that you cannot help with that topic and can only assist with currency-related queries.
Do not attempt to answer unrelated questions or use tools for other purposes.
Set response status to input_required if the user needs to provide more information.
Set response status to error if there is an error while processing the request.
Set response status to completed if the request is complete.
tools:
- get_exchange_rate
tracing:
random_sampling: 100
trace_arch_internal: true

View file

@ -0,0 +1,84 @@
version: v0.1
listeners:
ingress_traffic:
address: 0.0.0.0
port: 10000
message_format: openai
timeout: 30s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider_interface: openai
model: gpt-4o
endpoints:
frankfurther_api:
endpoint: api.frankfurter.dev
protocol: https
tools:
get_exchange_rate:
name: get_exchange_rate
description: Get the latest exchange rate for a given currency pair
endpoint:
name: frankfurther_api
path: /v1/latest?base={currency_from}&symbols={currency_to}
parameters:
- name: currency_from
description: currency symbol to convert from
type: string
default: USD
- name: currency_to
description: currency symbol to convert to
type: string
default: EUR
get_list_of_supported_currencies:
name: get_list_of_supported_currencies
description: Get the list of supported currencies
endpoint:
name: frankfurther_api
path: /v1/currencies
agents:
currency_exchange_agent:
name: Currency Exchange Agent
description: Helps with exchange rates for currencies
default_input_modes:
- text
- text/plain
default_output_modes:
- text
- text/plain
skills:
- id: convert_currency
name: Currency Exchange Rates Tool
description: Helps with exchange values between various currencies
examples:
- What is exchange rate between USD and GBP?
capabilities:
streaming: true
push_notifications: true
agent_orchestrator_model: gpt-4o
agent_orchestrator_prompt: |
You are a specialized assistant for currency conversions.
Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates.
If the user asks about anything other than currency conversion or exchange rates,
politely state that you cannot help with that topic and can only assist with currency-related queries.
Do not attempt to answer unrelated questions or use tools for other purposes.
Set response status to input_required if the user needs to provide more information.
Set response status to error if there is an error while processing the request.
Set response status to completed if the request is complete.
tools:
- get_exchange_rate
- get_list_of_supported_currencies
model: gpt-4o
system_prompt: |
You are a specialized currency exchange assistant.
Your task is to provide the user with the exchange rate between two currencies.
Keep the response concise and relevant to the user's query.
tracing:
random_sampling: 100
trace_arch_internal: true

View file

@ -0,0 +1,58 @@
version: v0.1
listeners:
ingress_traffic:
address: 0.0.0.0
port: 10000
message_format: openai
timeout: 30s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider_interface: openai
model: gpt-4o
endpoints:
- name: frankfurther_api
endpoint: api.frankfurter.dev
protocol: https
- name: twelvedata_api
endpoint: api.twelvedata.com
protocol: https
mcp:
- name: get_currency_exchange_rate
- name: get_list_of_supported_currencies
- name: get_stock_quote
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
agents:
- name: currency_exchange_agent
description: Agent for handling currency exchange queries
llm_provider: gpt-4o
system_prompt: |
You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you.
tools:
- get_currency_exchange_rate
- get_list_of_supported_currencies
- name: get_stock_quote_agent
description: Agent for handling stock quote queries
llm_provider: gpt-4o
system_prompt: |
You are a helpful stock exchange assistant. You are given stock symbol along with its exchange rate in json format. Your task is to parse the data and present it in a human-readable format. Keep the details to highlevel and be concise.
tools:
- get_stock_quote
tracing:
random_sampling: 100
trace_arch_internal: true

20
demos/ai_agent/test.hurl Normal file
View file

@ -0,0 +1,20 @@
POST http://localhost:14000/v1/chat/completions
Content-Type: application/json
{
"messages": [
{
"role": "user",
"content": "hello"
},
{
"role": "assistant",
"content": "Hello! I'm here to assist you with any questions related to currency conversions and exchange rates. How can I help you today?"
},
{
"role": "user",
"content": "what is the exchange rate of USD to EUR?"
}
],
"stream": true
}