mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Add prem support for a2a agents
This commit is contained in:
parent
2e346143dd
commit
299f183e66
23 changed files with 2544 additions and 16 deletions
|
|
@ -4,8 +4,7 @@ RUN rustup -v target add wasm32-wasip1
|
|||
WORKDIR /arch
|
||||
COPY crates .
|
||||
|
||||
RUN cd prompt_gateway && cargo build --release --target wasm32-wasip1
|
||||
RUN cd llm_gateway && cargo build --release --target wasm32-wasip1
|
||||
RUN cargo build --release --target wasm32-wasip1
|
||||
|
||||
# copy built filter into envoy image
|
||||
FROM docker.io/envoyproxy/envoy:v1.32-latest as envoy
|
||||
|
|
@ -17,6 +16,7 @@ RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && r
|
|||
|
||||
COPY --from=builder /arch/target/wasm32-wasip1/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasip1/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasip1/release/agent_gateway.wasm /etc/envoy/proxy-wasm-plugins/agent_gateway.wasm
|
||||
COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy
|
||||
WORKDIR /app
|
||||
COPY arch/requirements.txt .
|
||||
|
|
@ -29,4 +29,4 @@ RUN pip install requests
|
|||
RUN touch /var/log/envoy.log
|
||||
|
||||
# ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --log-level trace 2>&1 | tee /var/log/envoy.log"]
|
||||
ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log/envoy.log"]
|
||||
ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log"]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ type: object
|
|||
properties:
|
||||
version:
|
||||
type: string
|
||||
tools:
|
||||
type: object
|
||||
agents:
|
||||
type: object
|
||||
listeners:
|
||||
type: object
|
||||
additionalProperties: false
|
||||
|
|
|
|||
|
|
@ -452,6 +452,125 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
|
||||
- name: ingress_traffic_agent
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
port_value: 14000
|
||||
traffic_direction: INBOUND
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
generate_request_id: true
|
||||
tracing:
|
||||
provider:
|
||||
name: envoy.tracers.opentelemetry
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
|
||||
grpc_service:
|
||||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: ingress_traffic_agent
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
stat_prefix: ingress_traffic_agent
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
access_log:
|
||||
- name: envoy.access_loggers.file
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/access_ingress_agent.log"
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
- name: local_service
|
||||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
{% for provider in arch_llm_providers %}
|
||||
# if endpoint is set then use custom cluster for upstream llm
|
||||
{% if provider.endpoint %}
|
||||
{% set llm_cluster_name = provider.name %}
|
||||
{% else %}
|
||||
{% set llm_cluster_name = provider.provider_interface %}
|
||||
{% endif %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: {{ llm_cluster_name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ llm_cluster_name }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
http_filters:
|
||||
- name: envoy.filters.http.compressor
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.compressor.v3.Compressor
|
||||
compressor_library:
|
||||
name: compress
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.compression.gzip.compressor.v3.Gzip
|
||||
memory_level: 3
|
||||
window_bits: 10
|
||||
- name: envoy.filters.http.wasm_agent
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
root_id: agent_gateway
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
{{ arch_config | indent(32) }}
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
local:
|
||||
filename: "/etc/envoy/proxy-wasm-plugins/agent_gateway.wasm"
|
||||
- name: envoy.filters.http.wasm_llm
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
root_id: llm_gateway
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
{{ arch_llm_config | indent(32) }}
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
local:
|
||||
filename: "/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm"
|
||||
- name: envoy.filters.http.decompressor
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.decompressor.v3.Decompressor
|
||||
decompressor_library:
|
||||
name: decompress
|
||||
typed_config:
|
||||
"@type": "type.googleapis.com/envoy.extensions.compression.gzip.decompressor.v3.Gzip"
|
||||
window_bits: 9
|
||||
chunk_size: 8192
|
||||
# If this ratio is set too low, then body data will not be decompressed completely.
|
||||
max_inflate_ratio: 1000
|
||||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
|
||||
clusters:
|
||||
- name: openai
|
||||
connect_timeout: 0.5s
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ def docker_start_archgw_detached(
|
|||
port_mappings = [
|
||||
f"{prompt_gateway_port}:{prompt_gateway_port}",
|
||||
f"{llm_gateway_port}:{llm_gateway_port}",
|
||||
"14000:14000",
|
||||
"9901:19901",
|
||||
]
|
||||
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
|
||||
|
|
@ -56,7 +57,7 @@ def docker_start_archgw_detached(
|
|||
volume_mappings = [
|
||||
f"{logs_path_abs}:/var/log:rw",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
"/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
|
|
|
|||
23
crates/Cargo.lock
generated
23
crates/Cargo.lock
generated
|
|
@ -20,6 +20,29 @@ dependencies = [
|
|||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_gateway"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acap",
|
||||
"common",
|
||||
"derivative",
|
||||
"governor",
|
||||
"http",
|
||||
"log",
|
||||
"md5",
|
||||
"pretty_assertions",
|
||||
"proxy-wasm",
|
||||
"proxy-wasm-test-framework",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"serial_test",
|
||||
"sha2",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.3.8"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["llm_gateway", "prompt_gateway", "common"]
|
||||
members = ["llm_gateway", "prompt_gateway", "agent_gateway", "common"]
|
||||
|
|
|
|||
29
crates/agent_gateway/Cargo.toml
Normal file
29
crates/agent_gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "agent_gateway"
|
||||
version = "0.1.0"
|
||||
authors = ["Katanemo Inc <info@katanemo.com>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
proxy-wasm = "0.2.1"
|
||||
log = "0.4"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
common = { path = "../common" }
|
||||
http = "1.1.0"
|
||||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.64"
|
||||
derivative = "2.2.0"
|
||||
sha2 = "0.10.8"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
serial_test = "3.1.1"
|
||||
pretty_assertions = "1.4.1"
|
||||
66
crates/agent_gateway/src/context.rs
Normal file
66
crates/agent_gateway/src/context.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use common::errors::ServerError;
|
||||
use common::stats::IncrementingMetric;
|
||||
use http::StatusCode;
|
||||
use log::warn;
|
||||
use proxy_wasm::traits::Context;
|
||||
|
||||
use crate::stream_context::{ResponseHandlerType, StreamContext};
|
||||
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self
|
||||
.callouts
|
||||
.get_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
match StatusCode::from_str(http_status.as_str()) {
|
||||
Ok(status_code) => {
|
||||
if !status_code.is_success() {
|
||||
let server_error = ServerError::Upstream {
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
};
|
||||
warn!("received non 2xx code: {:?}", server_error);
|
||||
return self.send_server_error(
|
||||
server_error,
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// invalid status code (status code non numeric)
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(format!("invalid status code: {}", http_status)),
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// :status header not found
|
||||
warn!("missing :status header");
|
||||
}
|
||||
|
||||
#[cfg_attr(any(), rustfmt::skip)]
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||
}
|
||||
}
|
||||
}
|
||||
121
crates/agent_gateway/src/filter_context.rs
Normal file
121
crates/agent_gateway/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::{
|
||||
Agent, Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tool, Tracing,
|
||||
};
|
||||
use common::http::Client;
|
||||
use common::stats::Gauge;
|
||||
use log::trace;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<Metrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
agents: Rc<HashMap<String, Agent>>,
|
||||
tools: Rc<HashMap<String, Tool>>,
|
||||
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(Metrics::new()),
|
||||
system_prompt: Rc::new(None),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
endpoints: Rc::new(None),
|
||||
tracing: Rc::new(None),
|
||||
agents: Rc::new(HashMap::new()),
|
||||
tools: Rc::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
.expect("Arch config cannot be empty");
|
||||
|
||||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
|
||||
};
|
||||
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
let mut prompt_targets = HashMap::new();
|
||||
for pt in config.prompt_targets.unwrap_or_default() {
|
||||
prompt_targets.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.endpoints = Rc::new(config.endpoints);
|
||||
self.agents = Rc::new(config.agents);
|
||||
self.tools = Rc::new(config.tools);
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
}
|
||||
|
||||
self.tracing = Rc::new(config.tracing);
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
trace!(
|
||||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.endpoints),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(&self.tracing),
|
||||
Rc::clone(&self.agents),
|
||||
Rc::clone(&self.tools),
|
||||
)))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
440
crates/agent_gateway/src/http_context.rs
Normal file
440
crates/agent_gateway/src/http_context.rs
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
|
||||
use common::{
|
||||
api::open_ai::{self, ArchState, ChatCompletionTool, ChatCompletionsRequest, Message},
|
||||
consts::{
|
||||
ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER,
|
||||
CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER,
|
||||
SYSTEM_ROLE, TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
pii::obfuscate_auth_header,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::{traits::HttpContext, types::Action};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
|
||||
// Server's generally throw away requests whose body length do not match the Content-Length header.
|
||||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
// get endpoint that has agent_orchestrator set to true
|
||||
if let Some(endpoints) = self.endpoints.as_ref() {
|
||||
if endpoints.len() == 1 {
|
||||
let (name, _) = endpoints.iter().next().unwrap();
|
||||
info!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name);
|
||||
self.set_http_request_header(ARCH_ROUTING_HEADER, Some(name));
|
||||
} else {
|
||||
warn!("Need single endpoint when use_agent_orchestrator is set");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(
|
||||
"Need single endpoint when use_agent_orchestrator is set"
|
||||
.to_string(),
|
||||
),
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let request_path = self.get_http_request_header(":path").unwrap_or_default();
|
||||
if request_path == HEALTHZ_PATH {
|
||||
self.send_http_response(200, vec![], None);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str());
|
||||
|
||||
// check if agent name is in the request header
|
||||
// if not, check if there is only one agent in the config
|
||||
// if so, use that agent
|
||||
// if there are multiple agents in the config, return an error
|
||||
if let Some(agent_header_value) = self.get_http_request_header("x-agent-name") {
|
||||
if let Some(agent) = self.agents.as_ref().get(&agent_header_value) {
|
||||
self.agent = Some(agent.clone());
|
||||
} else {
|
||||
warn!("Agent not found in config");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Agent {} not found in config",
|
||||
agent_header_value
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
} else if self.agents.as_ref().len() == 1 {
|
||||
let (name, agent) = self.agents.iter().next().unwrap();
|
||||
info!("Setting agent to {}", name);
|
||||
self.agent = Some(agent.clone());
|
||||
} else {
|
||||
warn!("Multiple agents found in config and no agent name in request header");
|
||||
self.send_http_response(
|
||||
400,
|
||||
vec![],
|
||||
Some(
|
||||
"Multiple agents found in config and no agent name in request header"
|
||||
.as_bytes(),
|
||||
),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"on_http_request_headers S[{}] req_headers={:?}",
|
||||
self.context_id,
|
||||
obfuscate_auth_header(&mut self.get_http_request_headers())
|
||||
);
|
||||
|
||||
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
|
||||
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.request_body_size = body_size;
|
||||
|
||||
debug!(
|
||||
"on_http_request_body S[{}] body_size={}",
|
||||
self.context_id, body_size
|
||||
);
|
||||
|
||||
let body_bytes = match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => body_bytes,
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("request body: {}", String::from_utf8_lossy(&body_bytes));
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: ChatCompletionsRequest = match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
if metadata.contains_key(X_ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
|
||||
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
|
||||
Some(arch_state)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
|
||||
let last_user_prompt: &open_ai::Message = match deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.last()
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
warn!("No messages in the request body");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
let mut tool_calls = Vec::new();
|
||||
if let Some(agent) = self.agent.as_ref() {
|
||||
if let Some(tools) = agent.tools.as_ref() {
|
||||
for tool in tools {
|
||||
if let Some(tool) = self.tools.as_ref().get(tool) {
|
||||
info!("tool: {:?}", tool);
|
||||
let tool_chat_completion_tool: ChatCompletionTool = tool.into();
|
||||
info!("tool_chat_completion_tool: {:?}", tool_chat_completion_tool);
|
||||
tool_calls.push(tool_chat_completion_tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut metadata = deserialized_body.metadata.clone();
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.optimize_context_window.unwrap_or_default() {
|
||||
if metadata.is_none() {
|
||||
metadata = Some(HashMap::new());
|
||||
}
|
||||
metadata
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.insert("optimize_context_window".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<Message> = match self.agent.as_ref().unwrap().agent_orchestrator_prompt {
|
||||
Some(ref agent_orchestrator_prompt) => {
|
||||
let mut messages = Vec::new();
|
||||
messages.push(Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(agent_orchestrator_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
messages.extend(deserialized_body.messages.clone());
|
||||
messages
|
||||
}
|
||||
None => deserialized_body.messages.clone(),
|
||||
};
|
||||
|
||||
let arch_fc_chat_completion_request = ChatCompletionsRequest {
|
||||
messages,
|
||||
metadata,
|
||||
//HACK: adilhafeez: enable streaming for agent orchestrator
|
||||
stream: false,
|
||||
model: deserialized_body.model.clone(),
|
||||
stream_options: deserialized_body.stream_options.clone(),
|
||||
tools: Some(tool_calls),
|
||||
};
|
||||
|
||||
self.chat_completions_request = Some(deserialized_body);
|
||||
|
||||
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
self.send_server_error(ServerError::Serialization(error), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
info!("on_http_request_body: sending request to model server");
|
||||
debug!("request body: {}", json_data);
|
||||
|
||||
let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, "openai"),
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
("content-type", "application/json"),
|
||||
(":authority", "openai"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
"arch_listener_llm",
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
|
||||
upstream_cluster_path: Some("/function_calling".to_string()),
|
||||
agent: self.agent.clone(),
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
warn!("http_call failed: {:?}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
|
||||
Action::Pause
|
||||
}
|
||||
|
||||
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
debug!(
|
||||
"on_http_response_headers recv [S={}] headers={:?}",
|
||||
self.context_id,
|
||||
self.get_http_response_headers()
|
||||
);
|
||||
// delete content-lenght header let envoy calculate it, because we modify the response body
|
||||
// that would result in a different content-length
|
||||
self.set_http_response_header("content-length", None);
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
debug!(
|
||||
"on_http_response_body: recv [S={}] bytes={} end_stream={}",
|
||||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !self.is_chat_completions_request {
|
||||
info!("non-gpt request");
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
if self.time_to_first_token.is_none() {
|
||||
self.time_to_first_token = Some(
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos(),
|
||||
);
|
||||
}
|
||||
|
||||
if end_of_stream && body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
let body = if self.streaming_response {
|
||||
let streaming_chunk = match self.get_http_response_body(0, body_size) {
|
||||
Some(chunk) => chunk,
|
||||
None => {
|
||||
warn!(
|
||||
"response body empty, chunk_start: {}, chunk_size: {}",
|
||||
0, body_size
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if streaming_chunk.len() != body_size {
|
||||
warn!(
|
||||
"chunk size mismatch: read: {} != requested: {}",
|
||||
streaming_chunk.len(),
|
||||
body_size
|
||||
);
|
||||
}
|
||||
|
||||
streaming_chunk
|
||||
} else {
|
||||
info!("non streaming response bytes read: 0:{}", body_size);
|
||||
match self.get_http_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
warn!("non streaming response body empty");
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let body_utf8 = match String::from_utf8(body) {
|
||||
Ok(body_utf8) => body_utf8,
|
||||
Err(e) => {
|
||||
info!("could not convert to utf8: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if self.streaming_response {
|
||||
debug!("streaming response");
|
||||
|
||||
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
let chunks = vec![
|
||||
// ChatCompletionStreamResponse::new(
|
||||
// self.arch_fc_response.clone(),
|
||||
// Some(ASSISTANT_ROLE.to_string()),
|
||||
// Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
// None,
|
||||
// ),
|
||||
// ChatCompletionStreamResponse::new(
|
||||
// self.tool_call_response.clone(),
|
||||
// Some(TOOL_ROLE.to_string()),
|
||||
// Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
// None,
|
||||
// ),
|
||||
];
|
||||
|
||||
let mut response_str = open_ai::to_server_events(chunks);
|
||||
// append the original response from the model to the stream
|
||||
response_str.push_str(&body_utf8);
|
||||
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||
self.tool_calls = None;
|
||||
}
|
||||
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not deserialize response, sending data as it is: {}",
|
||||
e
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
info!("archgw <= developer: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
}
|
||||
17
crates/agent_gateway/src/lib.rs
Normal file
17
crates/agent_gateway/src/lib.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod context;
|
||||
mod filter_context;
|
||||
mod http_context;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
mod tools;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
14
crates/agent_gateway/src/metrics.rs
Normal file
14
crates/agent_gateway/src/metrics.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
use common::stats::Gauge;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Metrics {
|
||||
pub active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
pub fn new() -> Metrics {
|
||||
Metrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
528
crates/agent_gateway/src/stream_context.rs
Normal file
528
crates/agent_gateway/src/stream_context.rs
Normal file
|
|
@ -0,0 +1,528 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::tools::compute_request_path_body;
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Agent, Endpoint, Overrides, Tool, Tracing};
|
||||
use common::consts::{
|
||||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
use common::stats::Gauge;
|
||||
use derivative::Derivative;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_yaml::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ResponseHandlerType {
|
||||
ArchFC,
|
||||
FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Clone, Derivative)]
|
||||
#[derivative(Debug)]
|
||||
pub struct StreamCallContext {
|
||||
pub response_handler_type: ResponseHandlerType,
|
||||
pub user_message: Option<String>,
|
||||
pub prompt_target_name: Option<String>,
|
||||
#[derivative(Debug = "ignore")]
|
||||
pub request_body: ChatCompletionsRequest,
|
||||
pub similarity_scores: Option<Vec<(String, f64)>>,
|
||||
pub upstream_cluster: Option<String>,
|
||||
pub upstream_cluster_path: Option<String>,
|
||||
pub agent: Option<Agent>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<Metrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
pub context_id: u32,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub tool_call_response: Option<String>,
|
||||
pub arch_state: Option<Vec<ArchState>>,
|
||||
pub request_body_size: usize,
|
||||
pub user_prompt: Option<Message>,
|
||||
pub streaming_response: bool,
|
||||
pub is_chat_completions_request: bool,
|
||||
pub chat_completions_request: Option<ChatCompletionsRequest>,
|
||||
pub request_id: Option<String>,
|
||||
pub start_upstream_llm_request_time: u128,
|
||||
pub time_to_first_token: Option<u128>,
|
||||
pub traceparent: Option<String>,
|
||||
pub agents: Rc<HashMap<String, Agent>>,
|
||||
pub agent: Option<Agent>,
|
||||
pub tools: Rc<HashMap<String, Tool>>,
|
||||
pub _tracing: Rc<Option<Tracing>>,
|
||||
pub arch_fc_response: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<Metrics>,
|
||||
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
agents: Rc<HashMap<String, Agent>>,
|
||||
tools: Rc<HashMap<String, Tool>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
endpoints,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
chat_completions_request: None,
|
||||
tool_calls: None,
|
||||
tool_call_response: None,
|
||||
arch_state: None,
|
||||
request_body_size: 0,
|
||||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
is_chat_completions_request: false,
|
||||
overrides,
|
||||
request_id: None,
|
||||
traceparent: None,
|
||||
_tracing: tracing,
|
||||
start_upstream_llm_request_time: 0,
|
||||
time_to_first_token: None,
|
||||
arch_fc_response: None,
|
||||
agents,
|
||||
tools,
|
||||
agent: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.as_u16()
|
||||
.into(),
|
||||
vec![],
|
||||
Some(format!("{error}").as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
fn _trace_arch_internal(&self) -> bool {
|
||||
match self._tracing.as_ref() {
|
||||
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
|
||||
Some(trace_arch_internal) => *trace_arch_internal,
|
||||
None => false,
|
||||
},
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arch_fc_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
) {
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
info!("on_http_call_response: model server response received");
|
||||
debug!("response body: {}", body_str);
|
||||
|
||||
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"error deserializing llm response: {}, body: {}",
|
||||
e, body_str
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
//TODO: try to avoid clone
|
||||
let message = model_server_response
|
||||
.choices
|
||||
.first()
|
||||
.map(|choice| choice.message.clone())
|
||||
.unwrap();
|
||||
|
||||
self.tool_calls = message.tool_calls;
|
||||
|
||||
if self.tool_calls.as_ref().is_some() && self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||
warn!(
|
||||
"multiple tool calls not supported yet, tool_calls count found: {}",
|
||||
self.tool_calls.as_ref().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
// this means llm model didn't need additional data from tool calls and is ready to respond back to user
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
self.arch_fc_response.clone(),
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
),
|
||||
None,
|
||||
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
body_str
|
||||
};
|
||||
|
||||
self.tool_calls = None;
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![],
|
||||
Some(direct_response_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
// update prompt target name from the tool call response
|
||||
callout_context.prompt_target_name =
|
||||
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
||||
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
// Construct messages early to avoid mutable borrow conflicts
|
||||
|
||||
let tool_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let tool = self.tools.get(&tool_name).unwrap().clone();
|
||||
let tool_params = self.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.arguments
|
||||
.clone();
|
||||
let endpoint_details = tool.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
.path
|
||||
.as_ref()
|
||||
.unwrap_or(&String::from("/"))
|
||||
.to_string();
|
||||
|
||||
let http_method = endpoint_details.method.clone().unwrap_or_default();
|
||||
let prompt_target_params = tool.parameters.clone().unwrap_or_default();
|
||||
|
||||
let mut tool_params_json: Option<HashMap<String, Value>> = None;
|
||||
|
||||
if let Some(params) = tool_params.as_ref() {
|
||||
match serde_json::from_str::<HashMap<String, Value>>(params.as_str()) {
|
||||
Ok(params_json) => tool_params_json = Some(params_json),
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"error deserializing tool params: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8(params.as_bytes().to_vec()).unwrap()
|
||||
);
|
||||
return self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
//TODO: fixme hack adilhafeez
|
||||
let (path, api_call_body) = match compute_request_path_body(
|
||||
&endpoint_path,
|
||||
&tool_params_json,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
) {
|
||||
Ok((path, body)) => (path, body),
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("error computing api request path or body: {}", e),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
debug!("on_http_call_response: api call body {:?}", api_call_body);
|
||||
|
||||
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let http_method_str = http_method.to_string();
|
||||
let mut headers: HashMap<_, _> = [
|
||||
(ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()),
|
||||
(":method", &http_method_str),
|
||||
(":path", &path),
|
||||
(":authority", endpoint_details.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
|
||||
}
|
||||
|
||||
if self.traceparent.is_some() {
|
||||
headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap());
|
||||
}
|
||||
|
||||
// override http headers that are set in the prompt target
|
||||
let http_headers = endpoint_details.http_headers.clone().unwrap_or_default();
|
||||
for (key, value) in http_headers.iter() {
|
||||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&path,
|
||||
headers.into_iter().collect(),
|
||||
api_call_body.as_deref().map(|s| s.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
info!(
|
||||
"on_http_call_response: dispatching api call to developer endpoint: {}, path: {}, method: {}",
|
||||
endpoint_details.name, path, http_method_str
|
||||
);
|
||||
|
||||
callout_context.upstream_cluster = Some(endpoint_details.name.to_owned());
|
||||
callout_context.upstream_cluster_path = Some(path.to_owned());
|
||||
callout_context.agent = self.agent.clone();
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let http_status = self
|
||||
.get_http_call_response_header(":status")
|
||||
.unwrap_or(StatusCode::OK.as_str().to_string());
|
||||
info!(
|
||||
"on_http_call_response: developer api call response received: status code: {}",
|
||||
http_status
|
||||
);
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
warn!(
|
||||
"api server responded with non 2xx status code: {}",
|
||||
http_status
|
||||
);
|
||||
return self.send_server_error(
|
||||
ServerError::Upstream {
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
},
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
self.tool_call_response = Some(String::from_utf8(body).unwrap());
|
||||
debug!(
|
||||
"response body: {}",
|
||||
self.tool_call_response.as_ref().unwrap()
|
||||
);
|
||||
|
||||
let mut messages = self.construct_llm_messages(&callout_context);
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
None => {
|
||||
return self.send_server_error(
|
||||
ServerError::NoMessagesFound {
|
||||
why: "no user messages found".to_string(),
|
||||
},
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let final_prompt = format!(
|
||||
"{}\ncontext: {}",
|
||||
user_message.content.unwrap(),
|
||||
self.tool_call_response.as_ref().unwrap()
|
||||
);
|
||||
|
||||
// add original user prompt
|
||||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(final_prompt),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
});
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
info!("on_http_call_response: sending request to upstream llm");
|
||||
debug!("request body: {}", llm_request_str);
|
||||
|
||||
self.start_upstream_llm_request_time = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
|
||||
self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec<Message> {
|
||||
messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
!(m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()))
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
|
||||
if let Some(agent) = callout_context.agent.as_ref() {
|
||||
if let Some(system_prompt) = agent.system_prompt.as_ref() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
||||
);
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
|
||||
|
||||
use crate::stream_context::check_intent_matched;
|
||||
|
||||
#[test]
|
||||
fn test_intent_matched() {
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(!check_intent_matched(&model_server_response));
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(check_intent_matched(&model_server_response));
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "1".to_string(),
|
||||
function: common::api::open_ai::FunctionCallDetail {
|
||||
name: "test".to_string(),
|
||||
arguments: None,
|
||||
},
|
||||
tool_type: common::api::open_ai::ToolType::Function,
|
||||
}]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(check_intent_matched(&model_server_response));
|
||||
}
|
||||
}
|
||||
162
crates/agent_gateway/src/tools.rs
Normal file
162
crates/agent_gateway/src/tools.rs
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
use common::configuration::{HttpMethod, Parameter};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_yaml::Value;
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashMap<String, String> {
|
||||
if tool_params.is_none() {
|
||||
return HashMap::new();
|
||||
}
|
||||
tool_params
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
Value::Number(n) => (key.clone(), n.to_string()),
|
||||
Value::String(s) => (key.clone(), s.clone()),
|
||||
Value::Bool(b) => (key.clone(), b.to_string()),
|
||||
Value::Null => todo!(),
|
||||
Value::Sequence(_) => todo!(),
|
||||
Value::Mapping(_) => todo!(),
|
||||
Value::Tagged(_) => todo!(),
|
||||
})
|
||||
.collect::<HashMap<String, String>>()
|
||||
}
|
||||
|
||||
pub fn compute_request_path_body(
|
||||
endpoint_path: &str,
|
||||
tool_params: &Option<HashMap<String, Value>>,
|
||||
prompt_target_params: &[Parameter],
|
||||
http_method: &HttpMethod,
|
||||
) -> Result<(String, Option<String>), String> {
|
||||
let tool_url_params = filter_tool_params(tool_params);
|
||||
let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path(
|
||||
endpoint_path,
|
||||
&tool_url_params,
|
||||
prompt_target_params,
|
||||
)?;
|
||||
|
||||
let (path, body) = match http_method {
|
||||
HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None),
|
||||
HttpMethod::Post => {
|
||||
let mut additional_params = additional_params;
|
||||
if !query_string.is_empty() {
|
||||
query_string.split("&").for_each(|param| {
|
||||
let mut parts = param.split("=");
|
||||
let key = parts.next().unwrap();
|
||||
let value = parts.next().unwrap();
|
||||
additional_params.insert(key.to_string(), value.to_string());
|
||||
});
|
||||
}
|
||||
let body = serde_json::to_string(&additional_params).unwrap();
|
||||
(path_with_params, Some(body))
|
||||
}
|
||||
};
|
||||
|
||||
Ok((path, body))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::configuration::{HttpMethod, Parameter};
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
|
||||
let tool_params = serde_yaml::from_str(
|
||||
r#"
|
||||
cluster_name: test1
|
||||
hello: hello world
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body_empty_params() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
|
||||
let tool_params = serde_yaml::from_str(r#"{}"#).unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/?country=US"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body_override_default_val() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
|
||||
let tool_params = serde_yaml::from_str(
|
||||
r#"
|
||||
country: UK
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/?country=UK"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
}
|
||||
690
crates/agent_gateway/tests/integration.rs
Normal file
690
crates/agent_gateway/tests/integration.rs
Normal file
|
|
@ -0,0 +1,690 @@
|
|||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
|
||||
};
|
||||
use common::configuration::Configuration;
|
||||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use serde_yaml::Value;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
fn wasm_module() -> String {
|
||||
let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm");
|
||||
assert!(
|
||||
wasm_file.exists(),
|
||||
"Run `cargo build --release --target=wasm32-wasip1` first"
|
||||
);
|
||||
wasm_file.to_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/v1/chat/completions"))
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||
.returning(None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
||||
.returning(None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
|
||||
.returning(None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/function_calling"),
|
||||
("content-type", "application/json"),
|
||||
(":authority", "model_server"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "30000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
Some(5000),
|
||||
)
|
||||
.returning(Some(1))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||
let filter_context = 1;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(filter_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
fn default_config() -> &'static str {
|
||||
r#"
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider_interface: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.0
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /weather
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
ratelimits:
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let incomplete_chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
incomplete_chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn prompt_gateway_request_to_llm_gateway() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
|
||||
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
||||
let config_str = serde_json::to_string(&config).unwrap();
|
||||
|
||||
let filter_context = setup_filter(&mut module, &config_str);
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: Some(HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)])),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: {
|
||||
let mut map: HashMap<String, String> = HashMap::new();
|
||||
map.insert("function_latency".to_string(), "0.0".to_string());
|
||||
Some(map)
|
||||
},
|
||||
};
|
||||
|
||||
let expected_body = "{\"city\":\"seattle\"}";
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-arch-upstream", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "30000"),
|
||||
(":path", "/weather"),
|
||||
(":method", "POST"),
|
||||
(":authority", "api_server"),
|
||||
]),
|
||||
Some(expected_body),
|
||||
None,
|
||||
Some(5000),
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("hello from fake llm gateway".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_response_body(
|
||||
http_context,
|
||||
chat_completion_response_str.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
|
||||
.returning(Some(chat_completion_response_str.as_str()))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn prompt_gateway_request_no_intent_match() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
|
||||
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
||||
let config_str = serde_json::to_string(&config).unwrap();
|
||||
|
||||
let filter_context = setup_filter(&mut module, &config_str);
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("no default prompt target found, forwarding request to upstream llm"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn arch_config_default_target() -> &'static str {
|
||||
r#"
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider_interface: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.0
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /weather
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
- name: default_target
|
||||
default: true
|
||||
description: This is the default target for all unmatched prompts.
|
||||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /default_target
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful assistant! Summarize the user's request and provide a helpful response.
|
||||
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||
# if true arch will forward the response to the default LLM
|
||||
auto_llm_dispatch_on_response: false
|
||||
|
||||
ratelimits:
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn prompt_gateway_request_no_intent_match_default_target() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let mut config: Configuration = serde_yaml::from_str(arch_config_default_target()).unwrap();
|
||||
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
||||
let config_str = serde_json::to_string(&config).unwrap();
|
||||
|
||||
let filter_context = setup_filter(&mut module, &config_str);
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("default prompt target found, forwarding request to default prompt target"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
("x-arch-upstream", "weather_forecast_service"),
|
||||
(":path", "/default_target"),
|
||||
(":authority", "weather_forecast_service"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "30000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
Some(5000),
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
|
|
@ -43,6 +42,8 @@ pub struct FunctionDefinition {
|
|||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub properties_type: String,
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ impl Serialize for FunctionParameters {
|
|||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
// select all required parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
|
|
@ -60,6 +61,7 @@ impl Serialize for FunctionParameters {
|
|||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
map.serialize_entry("type", &self.properties_type)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
|
|
@ -113,7 +115,7 @@ pub enum ParameterType {
|
|||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
#[serde(rename = "string")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
|
|
@ -189,7 +191,7 @@ pub struct ToolCall {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: Option<HashMap<String, Value>>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,37 @@ pub struct Configuration {
|
|||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
pub agents: HashMap<String, Agent>,
|
||||
pub tools: HashMap<String, Tool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub default_input_modes: Option<Vec<String>>,
|
||||
pub default_output_modes: Option<Vec<String>>,
|
||||
pub skills: Option<Vec<Skill>>,
|
||||
pub model: String,
|
||||
pub agent_orchestrator_prompt: Option<String>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Skill {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub examples: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -260,7 +291,48 @@ impl From<&PromptTarget> for ChatCompletionTool {
|
|||
function: FunctionDefinition {
|
||||
name: val.name.clone(),
|
||||
description: val.description.clone(),
|
||||
parameters: FunctionParameters { properties },
|
||||
parameters: FunctionParameters {
|
||||
properties,
|
||||
properties_type: "object".to_string(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convert Tool to ChatCompletionTool
|
||||
impl From<&Tool> for ChatCompletionTool {
|
||||
fn from(val: &Tool) -> Self {
|
||||
let properties: HashMap<String, FunctionParameter> = match val.parameters {
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
format: entity.format.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
|
||||
ChatCompletionTool {
|
||||
tool_type: crate::api::open_ai::ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: val.name.clone(),
|
||||
description: val.description.clone(),
|
||||
parameters: FunctionParameters {
|
||||
properties,
|
||||
properties_type: "object".to_string(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
|
|||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
|
||||
pub const CHAT_COMPLETIONS_PATH: [&str; 2] =
|
||||
["/v1/chat/completions", "/openai/v1/chat/completions"];
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
|
||||
|
|
|
|||
|
|
@ -371,7 +371,6 @@ impl StreamContext {
|
|||
|
||||
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
|
||||
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
.path
|
||||
|
|
@ -382,9 +381,10 @@ impl StreamContext {
|
|||
let http_method = endpoint_details.method.clone().unwrap_or_default();
|
||||
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
|
||||
|
||||
//TODO: fixme: adilhafeez hack
|
||||
let (path, api_call_body) = match compute_request_path_body(
|
||||
&endpoint_path,
|
||||
tool_params,
|
||||
&None,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
) {
|
||||
|
|
@ -777,18 +777,19 @@ impl StreamContext {
|
|||
|
||||
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
|
||||
let content = model_server_response
|
||||
.choices.first()
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref());
|
||||
|
||||
let content_has_value = content.is_some() && !content.unwrap().is_empty();
|
||||
|
||||
let tool_calls = model_server_response
|
||||
.choices.first()
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.tool_calls.as_ref());
|
||||
|
||||
// intent was matched if content has some value or tool_calls is empty
|
||||
|
||||
|
||||
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
|
||||
}
|
||||
|
||||
|
|
|
|||
76
demos/ai_agent/agent_config copy.yaml
Normal file
76
demos/ai_agent/agent_config copy.yaml
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
version: v0.1
|
||||
|
||||
|
||||
listeners:
|
||||
ingress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-4o
|
||||
|
||||
endpoints:
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev
|
||||
protocol: https
|
||||
|
||||
tools:
|
||||
- name: get_exchange_rate
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base={currency_from}&symbols={currency_to}
|
||||
params:
|
||||
- name: currency_from
|
||||
description: currency symbol to convert from
|
||||
type: str
|
||||
default: USD
|
||||
- name: currency_to
|
||||
description: currency symbol to convert to
|
||||
type: str
|
||||
default: EUR
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
agents:
|
||||
- name: Currency Exchange Agent
|
||||
description: Helps with exchange rates for currencies
|
||||
default_input_modes:
|
||||
- text
|
||||
- text/plain
|
||||
default_output_modes:
|
||||
- text
|
||||
- text/plain
|
||||
skills:
|
||||
- id: convert_currency
|
||||
name: Currency Exchange Rates Tool
|
||||
description: Helps with exchange values between various currencies
|
||||
examples:
|
||||
- What is exchange rate between USD and GBP?
|
||||
capabilities:
|
||||
streaming: true
|
||||
push_notifications: true
|
||||
model: gpt-4o
|
||||
system_prompt: |
|
||||
You are a specialized assistant for currency conversions.
|
||||
Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates.
|
||||
If the user asks about anything other than currency conversion or exchange rates,
|
||||
politely state that you cannot help with that topic and can only assist with currency-related queries.
|
||||
Do not attempt to answer unrelated questions or use tools for other purposes.
|
||||
Set response status to input_required if the user needs to provide more information.
|
||||
Set response status to error if there is an error while processing the request.
|
||||
Set response status to completed if the request is complete.
|
||||
tools:
|
||||
- get_exchange_rate
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
84
demos/ai_agent/agent_config.yaml
Normal file
84
demos/ai_agent/agent_config.yaml
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
version: v0.1
|
||||
|
||||
|
||||
listeners:
|
||||
ingress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-4o
|
||||
|
||||
endpoints:
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev
|
||||
protocol: https
|
||||
|
||||
tools:
|
||||
get_exchange_rate:
|
||||
name: get_exchange_rate
|
||||
description: Get the latest exchange rate for a given currency pair
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base={currency_from}&symbols={currency_to}
|
||||
parameters:
|
||||
- name: currency_from
|
||||
description: currency symbol to convert from
|
||||
type: string
|
||||
default: USD
|
||||
- name: currency_to
|
||||
description: currency symbol to convert to
|
||||
type: string
|
||||
default: EUR
|
||||
get_list_of_supported_currencies:
|
||||
name: get_list_of_supported_currencies
|
||||
description: Get the list of supported currencies
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/currencies
|
||||
|
||||
agents:
|
||||
currency_exchange_agent:
|
||||
name: Currency Exchange Agent
|
||||
description: Helps with exchange rates for currencies
|
||||
default_input_modes:
|
||||
- text
|
||||
- text/plain
|
||||
default_output_modes:
|
||||
- text
|
||||
- text/plain
|
||||
skills:
|
||||
- id: convert_currency
|
||||
name: Currency Exchange Rates Tool
|
||||
description: Helps with exchange values between various currencies
|
||||
examples:
|
||||
- What is exchange rate between USD and GBP?
|
||||
capabilities:
|
||||
streaming: true
|
||||
push_notifications: true
|
||||
agent_orchestrator_model: gpt-4o
|
||||
agent_orchestrator_prompt: |
|
||||
You are a specialized assistant for currency conversions.
|
||||
Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates.
|
||||
If the user asks about anything other than currency conversion or exchange rates,
|
||||
politely state that you cannot help with that topic and can only assist with currency-related queries.
|
||||
Do not attempt to answer unrelated questions or use tools for other purposes.
|
||||
Set response status to input_required if the user needs to provide more information.
|
||||
Set response status to error if there is an error while processing the request.
|
||||
Set response status to completed if the request is complete.
|
||||
tools:
|
||||
- get_exchange_rate
|
||||
- get_list_of_supported_currencies
|
||||
model: gpt-4o
|
||||
system_prompt: |
|
||||
You are a specialized currency exchange assistant.
|
||||
Your task is to provide the user with the exchange rate between two currencies.
|
||||
Keep the response concise and relevant to the user's query.
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
58
demos/ai_agent/agent_config_mcp.yaml
Normal file
58
demos/ai_agent/agent_config_mcp.yaml
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
version: v0.1
|
||||
|
||||
listeners:
|
||||
ingress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-4o
|
||||
|
||||
endpoints:
|
||||
|
||||
- name: frankfurther_api
|
||||
endpoint: api.frankfurter.dev
|
||||
protocol: https
|
||||
|
||||
- name: twelvedata_api
|
||||
endpoint: api.twelvedata.com
|
||||
protocol: https
|
||||
|
||||
mcp:
|
||||
|
||||
- name: get_currency_exchange_rate
|
||||
- name: get_list_of_supported_currencies
|
||||
- name: get_stock_quote
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
agents:
|
||||
- name: currency_exchange_agent
|
||||
description: Agent for handling currency exchange queries
|
||||
llm_provider: gpt-4o
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you.
|
||||
tools:
|
||||
- get_currency_exchange_rate
|
||||
- get_list_of_supported_currencies
|
||||
|
||||
- name: get_stock_quote_agent
|
||||
description: Agent for handling stock quote queries
|
||||
llm_provider: gpt-4o
|
||||
system_prompt: |
|
||||
You are a helpful stock exchange assistant. You are given stock symbol along with its exchange rate in json format. Your task is to parse the data and present it in a human-readable format. Keep the details to highlevel and be concise.
|
||||
tools:
|
||||
- get_stock_quote
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
20
demos/ai_agent/test.hurl
Normal file
20
demos/ai_agent/test.hurl
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
POST http://localhost:14000/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm here to assist you with any questions related to currency conversions and exchange rates. How can I help you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the exchange rate of USD to EUR?"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue