From 299f183e667a33f891309e5e9cfc50bebdc7dc04 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 25 Apr 2025 00:57:13 -0700 Subject: [PATCH] Add prem support for a2a agents --- arch/Dockerfile | 6 +- arch/arch_config_schema.yaml | 4 + arch/envoy.template.yaml | 119 ++++ arch/tools/cli/docker_cli.py | 3 +- crates/Cargo.lock | 23 + crates/Cargo.toml | 2 +- crates/agent_gateway/Cargo.toml | 29 + crates/agent_gateway/src/context.rs | 66 ++ crates/agent_gateway/src/filter_context.rs | 121 ++++ crates/agent_gateway/src/http_context.rs | 440 +++++++++++++ crates/agent_gateway/src/lib.rs | 17 + crates/agent_gateway/src/metrics.rs | 14 + crates/agent_gateway/src/stream_context.rs | 528 +++++++++++++++ crates/agent_gateway/src/tools.rs | 162 +++++ crates/agent_gateway/tests/integration.rs | 690 ++++++++++++++++++++ crates/common/src/api/open_ai.rs | 10 +- crates/common/src/configuration.rs | 74 ++- crates/common/src/consts.rs | 3 +- crates/prompt_gateway/src/stream_context.rs | 11 +- demos/ai_agent/agent_config copy.yaml | 76 +++ demos/ai_agent/agent_config.yaml | 84 +++ demos/ai_agent/agent_config_mcp.yaml | 58 ++ demos/ai_agent/test.hurl | 20 + 23 files changed, 2544 insertions(+), 16 deletions(-) create mode 100644 crates/agent_gateway/Cargo.toml create mode 100644 crates/agent_gateway/src/context.rs create mode 100644 crates/agent_gateway/src/filter_context.rs create mode 100644 crates/agent_gateway/src/http_context.rs create mode 100644 crates/agent_gateway/src/lib.rs create mode 100644 crates/agent_gateway/src/metrics.rs create mode 100644 crates/agent_gateway/src/stream_context.rs create mode 100644 crates/agent_gateway/src/tools.rs create mode 100644 crates/agent_gateway/tests/integration.rs create mode 100644 demos/ai_agent/agent_config copy.yaml create mode 100644 demos/ai_agent/agent_config.yaml create mode 100644 demos/ai_agent/agent_config_mcp.yaml create mode 100644 demos/ai_agent/test.hurl diff --git a/arch/Dockerfile b/arch/Dockerfile index 7f933da5..f0854d37 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -4,8 +4,7 @@ RUN rustup -v target add wasm32-wasip1 WORKDIR /arch COPY crates . -RUN cd prompt_gateway && cargo build --release --target wasm32-wasip1 -RUN cd llm_gateway && cargo build --release --target wasm32-wasip1 +RUN cargo build --release --target wasm32-wasip1 # copy built filter into envoy image FROM docker.io/envoyproxy/envoy:v1.32-latest as envoy @@ -17,6 +16,7 @@ RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && r COPY --from=builder /arch/target/wasm32-wasip1/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm COPY --from=builder /arch/target/wasm32-wasip1/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm +COPY --from=builder /arch/target/wasm32-wasip1/release/agent_gateway.wasm /etc/envoy/proxy-wasm-plugins/agent_gateway.wasm COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy WORKDIR /app COPY arch/requirements.txt . @@ -29,4 +29,4 @@ RUN pip install requests RUN touch /var/log/envoy.log # ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --log-level trace 2>&1 | tee /var/log/envoy.log"] -ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log/envoy.log"] +ENTRYPOINT ["sh","-c", "python config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log/envoy.log"] diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 59276589..02014a45 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -3,6 +3,10 @@ type: object properties: version: type: string + tools: + type: object + agents: + type: object listeners: type: object additionalProperties: false diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index cac17187..623d20de 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -452,6 +452,125 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + - name: ingress_traffic_agent + address: + socket_address: + address: 0.0.0.0 + port_value: 14000 + traffic_direction: INBOUND + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + {% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %} + generate_request_id: true + tracing: + provider: + name: envoy.tracers.opentelemetry + typed_config: + "@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig + grpc_service: + envoy_grpc: + cluster_name: opentelemetry_collector + timeout: 0.250s + service_name: ingress_traffic_agent + random_sampling: + value: {{ arch_tracing.random_sampling }} + {% endif %} + stat_prefix: ingress_traffic_agent + codec_type: AUTO + scheme_header_transformation: + scheme_to_overwrite: https + access_log: + - name: envoy.access_loggers.file + typed_config: + "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog + path: "/var/log/access_ingress_agent.log" + route_config: + name: local_routes + virtual_hosts: + - name: local_service + domains: + - "*" + routes: + {% for provider in arch_llm_providers %} + # if endpoint is set then use custom cluster for upstream llm + {% if provider.endpoint %} + {% set llm_cluster_name = provider.name %} + {% else %} + {% set llm_cluster_name = provider.provider_interface %} + {% endif %} + - match: + prefix: "/" + headers: + - name: "x-arch-llm-provider" + string_match: + exact: {{ llm_cluster_name }} + route: + auto_host_rewrite: true + cluster: {{ llm_cluster_name }} + timeout: 60s + {% endfor %} + http_filters: + - name: envoy.filters.http.compressor + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.compressor.v3.Compressor + compressor_library: + name: compress + typed_config: + "@type": type.googleapis.com/envoy.extensions.compression.gzip.compressor.v3.Gzip + memory_level: 3 + window_bits: 10 + - name: envoy.filters.http.wasm_agent + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_config" + root_id: agent_gateway + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + {{ arch_config | indent(32) }} + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/agent_gateway.wasm" + - name: envoy.filters.http.wasm_llm + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_config" + root_id: llm_gateway + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + {{ arch_llm_config | indent(32) }} + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm" + - name: envoy.filters.http.decompressor + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.decompressor.v3.Decompressor + decompressor_library: + name: decompress + typed_config: + "@type": "type.googleapis.com/envoy.extensions.compression.gzip.decompressor.v3.Gzip" + window_bits: 9 + chunk_size: 8192 + # If this ratio is set too low, then body data will not be decompressed completely. + max_inflate_ratio: 1000 + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: - name: openai connect_timeout: 0.5s diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index 6edfb8dc..d4264b35 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -49,6 +49,7 @@ def docker_start_archgw_detached( port_mappings = [ f"{prompt_gateway_port}:{prompt_gateway_port}", f"{llm_gateway_port}:{llm_gateway_port}", + "14000:14000", "9901:19901", ] port_mappings_args = [item for port in port_mappings for item in ("-p", port)] @@ -56,7 +57,7 @@ def docker_start_archgw_detached( volume_mappings = [ f"{logs_path_abs}:/var/log:rw", f"{arch_config_file}:/app/arch_config.yaml:ro", - # "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", + "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", ] volume_mappings_args = [ item for volume in volume_mappings for item in ("-v", volume) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index b585ef6e..313c2599 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -20,6 +20,29 @@ dependencies = [ "gimli", ] +[[package]] +name = "agent_gateway" +version = "0.1.0" +dependencies = [ + "acap", + "common", + "derivative", + "governor", + "http", + "log", + "md5", + "pretty_assertions", + "proxy-wasm", + "proxy-wasm-test-framework", + "rand", + "serde", + "serde_json", + "serde_yaml", + "serial_test", + "sha2", + "thiserror", +] + [[package]] name = "ahash" version = "0.3.8" diff --git a/crates/Cargo.toml b/crates/Cargo.toml index 3ba99280..c170f2b7 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -1,3 +1,3 @@ [workspace] resolver = "2" -members = ["llm_gateway", "prompt_gateway", "common"] +members = ["llm_gateway", "prompt_gateway", "agent_gateway", "common"] diff --git a/crates/agent_gateway/Cargo.toml b/crates/agent_gateway/Cargo.toml new file mode 100644 index 00000000..6af7a927 --- /dev/null +++ b/crates/agent_gateway/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "agent_gateway" +version = "0.1.0" +authors = ["Katanemo Inc "] +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm = "0.2.1" +log = "0.4" +serde = { version = "1.0", features = ["derive"] } +serde_yaml = "0.9.34" +serde_json = "1.0" +md5 = "0.7.0" +common = { path = "../common" } +http = "1.1.0" +governor = { version = "0.6.3", default-features = false, features = ["no_std"]} +acap = "0.3.0" +rand = "0.8.5" +thiserror = "1.0.64" +derivative = "2.2.0" +sha2 = "0.10.8" + +[dev-dependencies] +proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } +serial_test = "3.1.1" +pretty_assertions = "1.4.1" diff --git a/crates/agent_gateway/src/context.rs b/crates/agent_gateway/src/context.rs new file mode 100644 index 00000000..8a3d2c89 --- /dev/null +++ b/crates/agent_gateway/src/context.rs @@ -0,0 +1,66 @@ +use std::str::FromStr; + +use common::errors::ServerError; +use common::stats::IncrementingMetric; +use http::StatusCode; +use log::warn; +use proxy_wasm::traits::Context; + +use crate::stream_context::{ResponseHandlerType, StreamContext}; + +impl Context for StreamContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_context = self + .callouts + .get_mut() + .remove(&token_id) + .expect("invalid token_id"); + self.metrics.active_http_calls.increment(-1); + + let body = self + .get_http_call_response_body(0, body_size) + .unwrap_or_default(); + + if let Some(http_status) = self.get_http_call_response_header(":status") { + match StatusCode::from_str(http_status.as_str()) { + Ok(status_code) => { + if !status_code.is_success() { + let server_error = ServerError::Upstream { + host: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), + }; + warn!("received non 2xx code: {:?}", server_error); + return self.send_server_error( + server_error, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); + } + } + Err(_) => { + // invalid status code (status code non numeric) + return self.send_server_error( + ServerError::LogicError(format!("invalid status code: {}", http_status)), + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); + } + } + } else { + // :status header not found + warn!("missing :status header"); + } + + #[cfg_attr(any(), rustfmt::skip)] + match callout_context.response_handler_type { + ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), + ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context), + } + } +} diff --git a/crates/agent_gateway/src/filter_context.rs b/crates/agent_gateway/src/filter_context.rs new file mode 100644 index 00000000..7320584a --- /dev/null +++ b/crates/agent_gateway/src/filter_context.rs @@ -0,0 +1,121 @@ +use crate::metrics::Metrics; +use crate::stream_context::StreamContext; +use common::configuration::{ + Agent, Configuration, Endpoint, Overrides, PromptGuards, PromptTarget, Tool, Tracing, +}; +use common::http::Client; +use common::stats::Gauge; +use log::trace; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(Debug)] +pub struct FilterCallContext {} + +#[derive(Debug)] +pub struct FilterContext { + metrics: Rc, + // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. + callouts: RefCell>, + overrides: Rc>, + system_prompt: Rc>, + prompt_targets: Rc>, + agents: Rc>, + tools: Rc>, + endpoints: Rc>>, + prompt_guards: Rc, + tracing: Rc>, +} + +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { + callouts: RefCell::new(HashMap::new()), + metrics: Rc::new(Metrics::new()), + system_prompt: Rc::new(None), + prompt_targets: Rc::new(HashMap::new()), + overrides: Rc::new(None), + prompt_guards: Rc::new(PromptGuards::default()), + endpoints: Rc::new(None), + tracing: Rc::new(None), + agents: Rc::new(HashMap::new()), + tools: Rc::new(HashMap::new()), + } + } +} + +impl Client for FilterContext { + type CallContext = FilterCallContext; + + fn callouts(&self) -> &RefCell> { + &self.callouts + } + + fn active_http_calls(&self) -> &Gauge { + &self.metrics.active_http_calls + } +} + +impl Context for FilterContext {} + +// RootContext allows the Rust code to reach into the Envoy Config +impl RootContext for FilterContext { + fn on_configure(&mut self, _: usize) -> bool { + let config_bytes = self + .get_plugin_configuration() + .expect("Arch config cannot be empty"); + + let config: Configuration = match serde_yaml::from_slice(&config_bytes) { + Ok(config) => config, + Err(err) => panic!("Invalid arch config \"{:?}\"", err), + }; + + self.overrides = Rc::new(config.overrides); + + let mut prompt_targets = HashMap::new(); + for pt in config.prompt_targets.unwrap_or_default() { + prompt_targets.insert(pt.name.clone(), pt.clone()); + } + self.system_prompt = Rc::new(config.system_prompt); + self.prompt_targets = Rc::new(prompt_targets); + self.endpoints = Rc::new(config.endpoints); + self.agents = Rc::new(config.agents); + self.tools = Rc::new(config.tools); + + if let Some(prompt_guards) = config.prompt_guards { + self.prompt_guards = Rc::new(prompt_guards) + } + + self.tracing = Rc::new(config.tracing); + + true + } + + fn create_http_context(&self, context_id: u32) -> Option> { + trace!( + "||| create_http_context called with context_id: {:?} |||", + context_id + ); + + Some(Box::new(StreamContext::new( + context_id, + Rc::clone(&self.metrics), + Rc::clone(&self.endpoints), + Rc::clone(&self.overrides), + Rc::clone(&self.tracing), + Rc::clone(&self.agents), + Rc::clone(&self.tools), + ))) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } + + fn on_vm_start(&mut self, _: usize) -> bool { + true + } +} diff --git a/crates/agent_gateway/src/http_context.rs b/crates/agent_gateway/src/http_context.rs new file mode 100644 index 00000000..8d122a90 --- /dev/null +++ b/crates/agent_gateway/src/http_context.rs @@ -0,0 +1,440 @@ +use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext}; +use common::{ + api::open_ai::{self, ArchState, ChatCompletionTool, ChatCompletionsRequest, Message}, + consts::{ + ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER, + CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, + SYSTEM_ROLE, TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER, + }, + errors::ServerError, + http::{CallArgs, Client}, + pii::obfuscate_auth_header, +}; +use http::StatusCode; +use log::{debug, info, warn}; +use proxy_wasm::{traits::HttpContext, types::Action}; +use serde_json::Value; +use std::{ + collections::HashMap, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +// HttpContext is the trait that allows the Rust code to interact with HTTP objects. +impl HttpContext for StreamContext { + // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto + // the lifecycle of the http request and response. + fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. + // Server's generally throw away requests whose body length do not match the Content-Length header. + // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could + // manipulate the body in benign ways e.g., compression. + self.set_http_request_header("content-length", None); + + if let Some(overrides) = self.overrides.as_ref() { + if overrides.use_agent_orchestrator.unwrap_or_default() { + // get endpoint that has agent_orchestrator set to true + if let Some(endpoints) = self.endpoints.as_ref() { + if endpoints.len() == 1 { + let (name, _) = endpoints.iter().next().unwrap(); + info!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name); + self.set_http_request_header(ARCH_ROUTING_HEADER, Some(name)); + } else { + warn!("Need single endpoint when use_agent_orchestrator is set"); + self.send_server_error( + ServerError::LogicError( + "Need single endpoint when use_agent_orchestrator is set" + .to_string(), + ), + None, + ); + } + } + } + } + + let request_path = self.get_http_request_header(":path").unwrap_or_default(); + if request_path == HEALTHZ_PATH { + self.send_http_response(200, vec![], None); + return Action::Continue; + } + + self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); + + // check if agent name is in the request header + // if not, check if there is only one agent in the config + // if so, use that agent + // if there are multiple agents in the config, return an error + if let Some(agent_header_value) = self.get_http_request_header("x-agent-name") { + if let Some(agent) = self.agents.as_ref().get(&agent_header_value) { + self.agent = Some(agent.clone()); + } else { + warn!("Agent not found in config"); + self.send_server_error( + ServerError::LogicError(format!( + "Agent {} not found in config", + agent_header_value + )), + None, + ); + return Action::Pause; + } + } else if self.agents.as_ref().len() == 1 { + let (name, agent) = self.agents.iter().next().unwrap(); + info!("Setting agent to {}", name); + self.agent = Some(agent.clone()); + } else { + warn!("Multiple agents found in config and no agent name in request header"); + self.send_http_response( + 400, + vec![], + Some( + "Multiple agents found in config and no agent name in request header" + .as_bytes(), + ), + ); + return Action::Pause; + } + + debug!( + "on_http_request_headers S[{}] req_headers={:?}", + self.context_id, + obfuscate_auth_header(&mut self.get_http_request_headers()) + ); + + self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); + self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); + + Action::Continue + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + // Let the client send the gateway all the data before sending to the LLM_provider. + // TODO: consider a streaming API. + + if !end_of_stream { + return Action::Pause; + } + + if body_size == 0 { + return Action::Continue; + } + + self.request_body_size = body_size; + + debug!( + "on_http_request_body S[{}] body_size={}", + self.context_id, body_size + ); + + let body_bytes = match self.get_http_request_body(0, body_size) { + Some(body_bytes) => body_bytes, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + + debug!("request body: {}", String::from_utf8_lossy(&body_bytes)); + + // Deserialize body into spec. + // Currently OpenAI API. + let deserialized_body: ChatCompletionsRequest = match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; + + self.arch_state = match deserialized_body.metadata { + Some(ref metadata) => { + if metadata.contains_key(X_ARCH_STATE_HEADER) { + let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone(); + let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); + Some(arch_state) + } else { + None + } + } + None => None, + }; + + self.streaming_response = deserialized_body.stream; + + let last_user_prompt: &open_ai::Message = match deserialized_body + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .last() + { + Some(content) => content, + None => { + warn!("No messages in the request body"); + return Action::Continue; + } + }; + + self.user_prompt = Some(last_user_prompt.clone()); + + let mut tool_calls = Vec::new(); + if let Some(agent) = self.agent.as_ref() { + if let Some(tools) = agent.tools.as_ref() { + for tool in tools { + if let Some(tool) = self.tools.as_ref().get(tool) { + info!("tool: {:?}", tool); + let tool_chat_completion_tool: ChatCompletionTool = tool.into(); + info!("tool_chat_completion_tool: {:?}", tool_chat_completion_tool); + tool_calls.push(tool_chat_completion_tool); + } + } + } + } + + let mut metadata = deserialized_body.metadata.clone(); + + if let Some(overrides) = self.overrides.as_ref() { + if overrides.optimize_context_window.unwrap_or_default() { + if metadata.is_none() { + metadata = Some(HashMap::new()); + } + metadata + .as_mut() + .unwrap() + .insert("optimize_context_window".to_string(), "true".to_string()); + } + } + + let messages: Vec = match self.agent.as_ref().unwrap().agent_orchestrator_prompt { + Some(ref agent_orchestrator_prompt) => { + let mut messages = Vec::new(); + messages.push(Message { + role: SYSTEM_ROLE.to_string(), + content: Some(agent_orchestrator_prompt.clone()), + model: None, + tool_calls: None, + tool_call_id: None, + }); + messages.extend(deserialized_body.messages.clone()); + messages + } + None => deserialized_body.messages.clone(), + }; + + let arch_fc_chat_completion_request = ChatCompletionsRequest { + messages, + metadata, + //HACK: adilhafeez: enable streaming for agent orchestrator + stream: false, + model: deserialized_body.model.clone(), + stream_options: deserialized_body.stream_options.clone(), + tools: Some(tool_calls), + }; + + self.chat_completions_request = Some(deserialized_body); + + let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) { + Ok(json_data) => json_data, + Err(error) => { + self.send_server_error(ServerError::Serialization(error), None); + return Action::Pause; + } + }; + + info!("on_http_request_body: sending request to model server"); + debug!("request body: {}", json_data); + + let timeout_str = MODEL_SERVER_REQUEST_TIMEOUT_MS.to_string(); + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, "openai"), + (":method", "POST"), + (":path", "/v1/chat/completions"), + ("content-type", "application/json"), + (":authority", "openai"), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), + ]; + + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + + if self.traceparent.is_some() { + headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); + } + + let call_args = CallArgs::new( + "arch_listener_llm", + "/v1/chat/completions", + headers, + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ); + + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchFC, + user_message: self.user_prompt.as_ref().unwrap().content.clone(), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()), + upstream_cluster_path: Some("/function_calling".to_string()), + agent: self.agent.clone(), + }; + + if let Err(e) = self.http_call(call_args, call_context) { + warn!("http_call failed: {:?}", e); + self.send_server_error(ServerError::HttpDispatch(e), None); + } + + Action::Pause + } + + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!( + "on_http_response_headers recv [S={}] headers={:?}", + self.context_id, + self.get_http_response_headers() + ); + // delete content-lenght header let envoy calculate it, because we modify the response body + // that would result in a different content-length + self.set_http_response_header("content-length", None); + Action::Continue + } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + debug!( + "on_http_response_body: recv [S={}] bytes={} end_stream={}", + self.context_id, body_size, end_of_stream + ); + + if !self.is_chat_completions_request { + info!("non-gpt request"); + return Action::Continue; + } + + if self.time_to_first_token.is_none() { + self.time_to_first_token = Some( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(), + ); + } + + if end_of_stream && body_size == 0 { + return Action::Continue; + } + + let body = if self.streaming_response { + let streaming_chunk = match self.get_http_response_body(0, body_size) { + Some(chunk) => chunk, + None => { + warn!( + "response body empty, chunk_start: {}, chunk_size: {}", + 0, body_size + ); + return Action::Continue; + } + }; + + if streaming_chunk.len() != body_size { + warn!( + "chunk size mismatch: read: {} != requested: {}", + streaming_chunk.len(), + body_size + ); + } + + streaming_chunk + } else { + info!("non streaming response bytes read: 0:{}", body_size); + match self.get_http_response_body(0, body_size) { + Some(body) => body, + None => { + warn!("non streaming response body empty"); + return Action::Continue; + } + } + }; + + let body_utf8 = match String::from_utf8(body) { + Ok(body_utf8) => body_utf8, + Err(e) => { + info!("could not convert to utf8: {}", e); + return Action::Continue; + } + }; + + if self.streaming_response { + debug!("streaming response"); + + if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { + let chunks = vec![ + // ChatCompletionStreamResponse::new( + // self.arch_fc_response.clone(), + // Some(ASSISTANT_ROLE.to_string()), + // Some(ARCH_FC_MODEL_NAME.to_string()), + // None, + // ), + // ChatCompletionStreamResponse::new( + // self.tool_call_response.clone(), + // Some(TOOL_ROLE.to_string()), + // Some(ARCH_FC_MODEL_NAME.to_string()), + // None, + // ), + ]; + + let mut response_str = open_ai::to_server_events(chunks); + // append the original response from the model to the stream + response_str.push_str(&body_utf8); + self.set_http_response_body(0, body_size, response_str.as_bytes()); + self.tool_calls = None; + } + } else if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + let mut data = match serde_json::from_str(&body_utf8) { + Ok(data) => data, + Err(e) => { + warn!( + "could not deserialize response, sending data as it is: {}", + e + ); + return Action::Continue; + } + }; + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); + } + + let data_serialized = serde_json::to_string(&data).unwrap(); + info!("archgw <= developer: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } + + debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream); + + Action::Continue + } +} diff --git a/crates/agent_gateway/src/lib.rs b/crates/agent_gateway/src/lib.rs new file mode 100644 index 00000000..7e7a24f9 --- /dev/null +++ b/crates/agent_gateway/src/lib.rs @@ -0,0 +1,17 @@ +use filter_context::FilterContext; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +mod context; +mod filter_context; +mod http_context; +mod metrics; +mod stream_context; +mod tools; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Trace); + proxy_wasm::set_root_context(|_| -> Box { + Box::new(FilterContext::new()) + }); +}} diff --git a/crates/agent_gateway/src/metrics.rs b/crates/agent_gateway/src/metrics.rs new file mode 100644 index 00000000..ff891636 --- /dev/null +++ b/crates/agent_gateway/src/metrics.rs @@ -0,0 +1,14 @@ +use common::stats::Gauge; + +#[derive(Copy, Clone, Debug)] +pub struct Metrics { + pub active_http_calls: Gauge, +} + +impl Metrics { + pub fn new() -> Metrics { + Metrics { + active_http_calls: Gauge::new(String::from("active_http_calls")), + } + } +} diff --git a/crates/agent_gateway/src/stream_context.rs b/crates/agent_gateway/src/stream_context.rs new file mode 100644 index 00000000..99d46f0f --- /dev/null +++ b/crates/agent_gateway/src/stream_context.rs @@ -0,0 +1,528 @@ +use crate::metrics::Metrics; +use crate::tools::compute_request_path_body; +use common::api::open_ai::{ + to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, + ChatCompletionsResponse, Message, ToolCall, +}; +use common::configuration::{Agent, Endpoint, Overrides, Tool, Tracing}; +use common::consts::{ + API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, + TRACE_PARENT_HEADER, USER_ROLE, +}; +use common::errors::ServerError; +use common::http::{CallArgs, Client}; +use common::stats::Gauge; +use derivative::Derivative; +use http::StatusCode; +use log::{debug, info, warn}; +use proxy_wasm::traits::*; +use serde_yaml::Value; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::str::FromStr; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +#[derive(Debug, Clone)] +pub enum ResponseHandlerType { + ArchFC, + FunctionCall, +} + +#[derive(Clone, Derivative)] +#[derivative(Debug)] +pub struct StreamCallContext { + pub response_handler_type: ResponseHandlerType, + pub user_message: Option, + pub prompt_target_name: Option, + #[derivative(Debug = "ignore")] + pub request_body: ChatCompletionsRequest, + pub similarity_scores: Option>, + pub upstream_cluster: Option, + pub upstream_cluster_path: Option, + pub agent: Option, +} + +pub struct StreamContext { + pub endpoints: Rc>>, + pub overrides: Rc>, + pub metrics: Rc, + pub callouts: RefCell>, + pub context_id: u32, + pub tool_calls: Option>, + pub tool_call_response: Option, + pub arch_state: Option>, + pub request_body_size: usize, + pub user_prompt: Option, + pub streaming_response: bool, + pub is_chat_completions_request: bool, + pub chat_completions_request: Option, + pub request_id: Option, + pub start_upstream_llm_request_time: u128, + pub time_to_first_token: Option, + pub traceparent: Option, + pub agents: Rc>, + pub agent: Option, + pub tools: Rc>, + pub _tracing: Rc>, + pub arch_fc_response: Option, +} + +impl StreamContext { + pub fn new( + context_id: u32, + metrics: Rc, + endpoints: Rc>>, + overrides: Rc>, + tracing: Rc>, + agents: Rc>, + tools: Rc>, + ) -> Self { + StreamContext { + context_id, + metrics, + endpoints, + callouts: RefCell::new(HashMap::new()), + chat_completions_request: None, + tool_calls: None, + tool_call_response: None, + arch_state: None, + request_body_size: 0, + streaming_response: false, + user_prompt: None, + is_chat_completions_request: false, + overrides, + request_id: None, + traceparent: None, + _tracing: tracing, + start_upstream_llm_request_time: 0, + time_to_first_token: None, + arch_fc_response: None, + agents, + tools, + agent: None, + } + } + + pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { + self.send_http_response( + override_status_code + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + .as_u16() + .into(), + vec![], + Some(format!("{error}").as_bytes()), + ); + } + + fn _trace_arch_internal(&self) -> bool { + match self._tracing.as_ref() { + Some(tracing) => match tracing.trace_arch_internal.as_ref() { + Some(trace_arch_internal) => *trace_arch_internal, + None => false, + }, + None => false, + } + } + + pub fn arch_fc_response_handler( + &mut self, + body: Vec, + mut callout_context: StreamCallContext, + ) { + let body_str = String::from_utf8(body).unwrap(); + info!("on_http_call_response: model server response received"); + debug!("response body: {}", body_str); + + let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { + Ok(arch_fc_response) => arch_fc_response, + Err(e) => { + warn!( + "error deserializing llm response: {}, body: {}", + e, body_str + ); + return self.send_server_error(ServerError::Deserialization(e), None); + } + }; + + //TODO: try to avoid clone + let message = model_server_response + .choices + .first() + .map(|choice| choice.message.clone()) + .unwrap(); + + self.tool_calls = message.tool_calls; + + if self.tool_calls.as_ref().is_some() && self.tool_calls.as_ref().unwrap().len() > 1 { + warn!( + "multiple tool calls not supported yet, tool_calls count found: {}", + self.tool_calls.as_ref().unwrap().len() + ); + } + + if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() { + // this means llm model didn't need additional data from tool calls and is ready to respond back to user + + let direct_response_str = if self.streaming_response { + let chunks = vec![ + ChatCompletionStreamResponse::new( + self.arch_fc_response.clone(), + Some(ASSISTANT_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_string()), + None, + ), + ChatCompletionStreamResponse::new( + Some( + model_server_response.choices[0] + .message + .content + .as_ref() + .unwrap() + .clone(), + ), + None, + Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())), + None, + ), + ]; + + to_server_events(chunks) + } else { + body_str + }; + + self.tool_calls = None; + return self.send_http_response( + StatusCode::OK.as_u16().into(), + vec![], + Some(direct_response_str.as_bytes()), + ); + } + + // update prompt target name from the tool call response + callout_context.prompt_target_name = + Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone()); + + self.schedule_api_call_request(callout_context); + } + + fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { + // Construct messages early to avoid mutable borrow conflicts + + let tool_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); + let tool = self.tools.get(&tool_name).unwrap().clone(); + let tool_params = self.tool_calls.as_ref().unwrap()[0] + .function + .arguments + .clone(); + let endpoint_details = tool.endpoint.as_ref().unwrap(); + let endpoint_path: String = endpoint_details + .path + .as_ref() + .unwrap_or(&String::from("/")) + .to_string(); + + let http_method = endpoint_details.method.clone().unwrap_or_default(); + let prompt_target_params = tool.parameters.clone().unwrap_or_default(); + + let mut tool_params_json: Option> = None; + + if let Some(params) = tool_params.as_ref() { + match serde_json::from_str::>(params.as_str()) { + Ok(params_json) => tool_params_json = Some(params_json), + Err(e) => { + log::warn!( + "error deserializing tool params: {}, body str: {}", + e, + String::from_utf8(params.as_bytes().to_vec()).unwrap() + ); + return self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + } + }; + } + + //TODO: fixme hack adilhafeez + let (path, api_call_body) = match compute_request_path_body( + &endpoint_path, + &tool_params_json, + &prompt_target_params, + &http_method, + ) { + Ok((path, body)) => (path, body), + Err(e) => { + return self.send_server_error( + ServerError::BadRequest { + why: format!("error computing api request path or body: {}", e), + }, + Some(StatusCode::BAD_REQUEST), + ); + } + }; + + debug!("on_http_call_response: api call body {:?}", api_call_body); + + let timeout_str = API_REQUEST_TIMEOUT_MS.to_string(); + + let http_method_str = http_method.to_string(); + let mut headers: HashMap<_, _> = [ + (ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()), + (":method", &http_method_str), + (":path", &path), + (":authority", endpoint_details.name.as_str()), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), + ] + .into_iter() + .collect(); + + if self.request_id.is_some() { + headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()); + } + + if self.traceparent.is_some() { + headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()); + } + + // override http headers that are set in the prompt target + let http_headers = endpoint_details.http_headers.clone().unwrap_or_default(); + for (key, value) in http_headers.iter() { + headers.insert(key.as_str(), value.as_str()); + } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + &path, + headers.into_iter().collect(), + api_call_body.as_deref().map(|s| s.as_bytes()), + vec![], + Duration::from_secs(5), + ); + + info!( + "on_http_call_response: dispatching api call to developer endpoint: {}, path: {}, method: {}", + endpoint_details.name, path, http_method_str + ); + + callout_context.upstream_cluster = Some(endpoint_details.name.to_owned()); + callout_context.upstream_cluster_path = Some(path.to_owned()); + callout_context.agent = self.agent.clone(); + callout_context.response_handler_type = ResponseHandlerType::FunctionCall; + + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); + } + } + + pub fn api_call_response_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + let http_status = self + .get_http_call_response_header(":status") + .unwrap_or(StatusCode::OK.as_str().to_string()); + info!( + "on_http_call_response: developer api call response received: status code: {}", + http_status + ); + if http_status != StatusCode::OK.as_str() { + warn!( + "api server responded with non 2xx status code: {}", + http_status + ); + return self.send_server_error( + ServerError::Upstream { + host: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), + }, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); + } + self.tool_call_response = Some(String::from_utf8(body).unwrap()); + debug!( + "response body: {}", + self.tool_call_response.as_ref().unwrap() + ); + + let mut messages = self.construct_llm_messages(&callout_context); + + let user_message = match messages.pop() { + Some(user_message) => user_message, + None => { + return self.send_server_error( + ServerError::NoMessagesFound { + why: "no user messages found".to_string(), + }, + None, + ); + } + }; + + let final_prompt = format!( + "{}\ncontext: {}", + user_message.content.unwrap(), + self.tool_call_response.as_ref().unwrap() + ); + + // add original user prompt + messages.push({ + Message { + role: USER_ROLE.to_string(), + content: Some(final_prompt), + model: None, + tool_calls: None, + tool_call_id: None, + } + }); + + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { + model: callout_context.request_body.model, + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + metadata: None, + }; + + let llm_request_str = match serde_json::to_string(&chat_completions_request) { + Ok(json_string) => json_string, + Err(e) => { + return self.send_server_error(ServerError::Serialization(e), None); + } + }; + info!("on_http_call_response: sending request to upstream llm"); + debug!("request body: {}", llm_request_str); + + self.start_upstream_llm_request_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + + self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes()); + self.resume_http_request(); + } + + fn filter_out_arch_messages(&self, messages: &[Message]) -> Vec { + messages + .iter() + .filter(|m| { + !(m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())) + }) + .cloned() + .collect() + } + + fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec { + let mut messages: Vec = Vec::new(); + + if let Some(agent) = callout_context.agent.as_ref() { + if let Some(system_prompt) = agent.system_prompt.as_ref() { + let system_prompt_message = Message { + role: SYSTEM_ROLE.to_string(), + content: Some(system_prompt.clone()), + model: None, + tool_calls: None, + tool_call_id: None, + }; + messages.push(system_prompt_message); + } + } + + messages.append( + &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()), + ); + messages + } +} + +impl Client for StreamContext { + type CallContext = StreamCallContext; + + fn callouts(&self) -> &RefCell> { + &self.callouts + } + + fn active_http_calls(&self) -> &Gauge { + &self.metrics.active_http_calls + } +} + +#[cfg(test)] +mod test { + use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall}; + + use crate::stream_context::check_intent_matched; + + #[test] + fn test_intent_matched() { + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(!check_intent_matched(&model_server_response)); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("hello".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(check_intent_matched(&model_server_response)); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![ToolCall { + id: "1".to_string(), + function: common::api::open_ai::FunctionCallDetail { + name: "test".to_string(), + arguments: None, + }, + tool_type: common::api::open_ai::ToolType::Function, + }]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(check_intent_matched(&model_server_response)); + } +} diff --git a/crates/agent_gateway/src/tools.rs b/crates/agent_gateway/src/tools.rs new file mode 100644 index 00000000..c909a2dd --- /dev/null +++ b/crates/agent_gateway/src/tools.rs @@ -0,0 +1,162 @@ +use common::configuration::{HttpMethod, Parameter}; +use std::collections::HashMap; + +use serde_yaml::Value; + +// only add params that are of string, number and bool type +pub fn filter_tool_params(tool_params: &Option>) -> HashMap { + if tool_params.is_none() { + return HashMap::new(); + } + tool_params + .as_ref() + .unwrap() + .iter() + .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) + .map(|(key, value)| match value { + Value::Number(n) => (key.clone(), n.to_string()), + Value::String(s) => (key.clone(), s.clone()), + Value::Bool(b) => (key.clone(), b.to_string()), + Value::Null => todo!(), + Value::Sequence(_) => todo!(), + Value::Mapping(_) => todo!(), + Value::Tagged(_) => todo!(), + }) + .collect::>() +} + +pub fn compute_request_path_body( + endpoint_path: &str, + tool_params: &Option>, + prompt_target_params: &[Parameter], + http_method: &HttpMethod, +) -> Result<(String, Option), String> { + let tool_url_params = filter_tool_params(tool_params); + let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path( + endpoint_path, + &tool_url_params, + prompt_target_params, + )?; + + let (path, body) = match http_method { + HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None), + HttpMethod::Post => { + let mut additional_params = additional_params; + if !query_string.is_empty() { + query_string.split("&").for_each(|param| { + let mut parts = param.split("="); + let key = parts.next().unwrap(); + let value = parts.next().unwrap(); + additional_params.insert(key.to_string(), value.to_string()); + }); + } + let body = serde_json::to_string(&additional_params).unwrap(); + (path_with_params, Some(body)) + } + }; + + Ok((path, body)) +} + +#[cfg(test)] +mod test { + use common::configuration::{HttpMethod, Parameter}; + + #[test] + fn test_compute_request_path_body() { + let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}"; + let tool_params = serde_yaml::from_str( + r#" + cluster_name: test1 + hello: hello world + "#, + ) + .unwrap(); + let prompt_target_params = vec![Parameter { + name: "country".to_string(), + parameter_type: None, + description: "test target".to_string(), + required: None, + enum_values: None, + default: Some("US".to_string()), + in_path: None, + format: None, + }]; + let http_method = HttpMethod::Get; + let (path, body) = super::compute_request_path_body( + endpoint_path, + &tool_params, + &prompt_target_params, + &http_method, + ) + .unwrap(); + assert_eq!( + path, + "/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US" + ); + assert_eq!(body, None); + } + + #[test] + fn test_compute_request_path_body_empty_params() { + let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/"; + let tool_params = serde_yaml::from_str(r#"{}"#).unwrap(); + let prompt_target_params = vec![Parameter { + name: "country".to_string(), + parameter_type: None, + description: "test target".to_string(), + required: None, + enum_values: None, + default: Some("US".to_string()), + in_path: None, + format: None, + }]; + let http_method = HttpMethod::Get; + let (path, body) = super::compute_request_path_body( + endpoint_path, + &tool_params, + &prompt_target_params, + &http_method, + ) + .unwrap(); + assert_eq!( + path, + "/cluster.open-cluster-management.io/v1/managedclusters/?country=US" + ); + assert_eq!(body, None); + } + + #[test] + fn test_compute_request_path_body_override_default_val() { + let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/"; + let tool_params = serde_yaml::from_str( + r#" + country: UK + "#, + ) + .unwrap(); + let prompt_target_params = vec![Parameter { + name: "country".to_string(), + parameter_type: None, + description: "test target".to_string(), + required: None, + enum_values: None, + default: Some("US".to_string()), + in_path: None, + format: None, + }]; + let http_method = HttpMethod::Get; + let (path, body) = super::compute_request_path_body( + endpoint_path, + &tool_params, + &prompt_target_params, + &http_method, + ) + .unwrap(); + assert_eq!( + path, + "/cluster.open-cluster-management.io/v1/managedclusters/?country=UK" + ); + assert_eq!(body, None); + } +} diff --git a/crates/agent_gateway/tests/integration.rs b/crates/agent_gateway/tests/integration.rs new file mode 100644 index 00000000..91b36c01 --- /dev/null +++ b/crates/agent_gateway/tests/integration.rs @@ -0,0 +1,690 @@ +use common::api::open_ai::{ + ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, +}; +use common::configuration::Configuration; +use http::StatusCode; +use proxy_wasm_test_framework::tester::{self, Tester}; +use proxy_wasm_test_framework::types::{ + Action, BufferType, LogLevel, MapType, MetricType, ReturnType, +}; +use serde_yaml::Value; +use serial_test::serial; +use std::collections::HashMap; +use std::path::Path; + +fn wasm_module() -> String { + let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm"); + assert!( + wasm_file.exists(), + "Run `cargo build --release --target=wasm32-wasip1` first" + ); + wasm_file.to_str().unwrap().to_string() +} + +fn request_headers_expectations(module: &mut Tester, http_context: i32) { + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_log(Some(LogLevel::Debug), None) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/v1/chat/completions")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) + .returning(None) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) + .returning(None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + +fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + request_headers_expectations(module, http_context); + + // Request Body + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ +}"; + + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_log(Some(LogLevel::Debug), None) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + ("x-arch-upstream", "model_server"), + (":method", "POST"), + (":path", "/function_calling"), + ("content-type", "application/json"), + (":authority", "model_server"), + ("x-envoy-upstream-rq-timeout-ms", "30000"), + ]), + None, + None, + Some(5000), + ) + .returning(Some(1)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); +} + +fn setup_filter(module: &mut Tester, config: &str) -> i32 { + let filter_context = 1; + + module + .call_proxy_on_context_create(filter_context, 0) + .expect_metric_creation(MetricType::Gauge, "active_http_calls") + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); + + filter_context +} + +fn default_config() -> &'static str { + r#" +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +endpoints: + api_server: + endpoint: api_server:80 + connect_timeout: 0.005s + +llm_providers: + - name: open-ai-gpt-4 + provider_interface: openai + access_key: secret_key + model: gpt-4 + default: true + +overrides: + # confidence threshold for prompt target intent matching + prompt_target_intent_matching_threshold: 0.0 + +system_prompt: | + You are a helpful assistant. + +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." + +prompt_targets: + - name: weather_forecast + description: This function provides realtime weather forecast information for a given city. + parameters: + - name: city + required: true + description: The city for which the weather forecast is requested. + - name: days + description: The number of days for which the weather forecast is requested. + - name: units + description: The units in which the weather forecast is requested. + endpoint: + name: api_server + path: /weather + http_method: POST + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + +ratelimits: + - model: gpt-4 + selector: + key: selector-key + value: selector-value + limit: + tokens: 1 + unit: minute +"# +} + +#[test] +#[serial] +fn prompt_gateway_successful_request_to_open_ai_chat_completions() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); + + // Setup HTTP Stream + let http_context = 2; + + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + request_headers_expectations(&mut module, http_context); + + // Request Body + let chat_completions_request_body = "\ + {\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ + }"; + + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_log(Some(LogLevel::Debug), None) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call(Some("arch_internal"), None, None, None, None) + .returning(Some(4)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); +} + +#[test] +#[serial] +fn prompt_gateway_bad_request_to_open_ai_chat_completions() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); + + // Setup HTTP Stream + let http_context = 2; + + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + request_headers_expectations(&mut module, http_context); + + // Request Body + let incomplete_chat_completions_request_body = "\ + {\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ]\ + }"; + + module + .call_proxy_on_request_body( + http_context, + incomplete_chat_completions_request_body.len() as i32, + true, + ) + .expect_log(Some(LogLevel::Debug), None) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(incomplete_chat_completions_request_body)) + .expect_log(Some(LogLevel::Debug), None) + .expect_send_local_response( + Some(StatusCode::BAD_REQUEST.as_u16().into()), + None, + None, + None, + ) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); +} + +#[test] +#[serial] +fn prompt_gateway_request_to_llm_gateway() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); + + let filter_context = setup_filter(&mut module, &config_str); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "system".to_string(), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: Some(HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )])), + }, + }]), + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: { + let mut map: HashMap = HashMap::new(); + map.insert("function_latency".to_string(), "0.0".to_string()); + Some(map) + }, + }; + + let expected_body = "{\"city\":\"seattle\"}"; + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + ("x-envoy-max-retries", "3"), + ("x-arch-upstream", "api_server"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "30000"), + (":path", "/weather"), + (":method", "POST"), + (":authority", "api_server"), + ]), + Some(expected_body), + None, + Some(5000), + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + let body_text = String::from("test body"); + module + .call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) + .returning(Some("200")) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + let chat_completion_response = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "assistant".to_string(), + content: Some("hello from fake llm gateway".to_string()), + model: None, + tool_calls: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap(); + module + .call_proxy_on_response_body( + http_context, + chat_completion_response_str.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) + .returning(Some(chat_completion_response_str.as_str())) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + +#[test] +#[serial] +fn prompt_gateway_request_no_intent_match() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); + + let filter_context = setup_filter(&mut module, &config_str); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "assistant".to_string(), + content: None, + tool_calls: None, + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) + .expect_log( + Some(LogLevel::Info), + Some("no default prompt target found, forwarding request to upstream llm"), + ) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Info), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::None) + .unwrap(); +} + +fn arch_config_default_target() -> &'static str { + r#" +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +endpoints: + api_server: + endpoint: api_server:80 + connect_timeout: 0.005s + +llm_providers: + - name: open-ai-gpt-4 + provider_interface: openai + access_key: secret_key + model: gpt-4 + default: true + +overrides: + # confidence threshold for prompt target intent matching + prompt_target_intent_matching_threshold: 0.0 + +system_prompt: | + You are a helpful assistant. + +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." + +prompt_targets: + - name: weather_forecast + description: This function provides realtime weather forecast information for a given city. + parameters: + - name: city + required: true + description: The city for which the weather forecast is requested. + - name: days + description: The number of days for which the weather forecast is requested. + - name: units + description: The units in which the weather forecast is requested. + endpoint: + name: api_server + path: /weather + http_method: POST + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + + - name: default_target + default: true + description: This is the default target for all unmatched prompts. + endpoint: + name: weather_forecast_service + path: /default_target + http_method: POST + system_prompt: | + You are a helpful assistant! Summarize the user's request and provide a helpful response. + # if it is set to false arch will send response that it received from this prompt target to the user + # if true arch will forward the response to the default LLM + auto_llm_dispatch_on_response: false + +ratelimits: + - model: gpt-4 + selector: + key: selector-key + value: selector-value + limit: + tokens: 1 + unit: minute +"# +} + +#[test] +#[serial] +fn prompt_gateway_request_no_intent_match_default_target() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(arch_config_default_target()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); + + let filter_context = setup_filter(&mut module, &config_str); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "system".to_string(), + content: None, + tool_calls: None, + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) + .expect_log( + Some(LogLevel::Info), + Some("default prompt target found, forwarding request to default prompt target"), + ) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + (":method", "POST"), + ("x-arch-upstream", "weather_forecast_service"), + (":path", "/default_target"), + (":authority", "weather_forecast_service"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "30000"), + ]), + None, + None, + Some(5000), + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::None) + .unwrap(); +} diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index d71b0d58..99a1d840 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -1,6 +1,5 @@ use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; use serde::{ser::SerializeMap, Deserialize, Serialize}; -use serde_yaml::Value; use std::{ collections::{HashMap, VecDeque}, fmt::Display, @@ -43,6 +42,8 @@ pub struct FunctionDefinition { #[derive(Debug, Clone, Deserialize)] pub struct FunctionParameters { + #[serde(rename = "type")] + pub properties_type: String, pub properties: HashMap, } @@ -51,7 +52,7 @@ impl Serialize for FunctionParameters { where S: serde::Serializer, { - // select all requried parameters + // select all required parameters let required: Vec<&String> = self .properties .iter() @@ -60,6 +61,7 @@ impl Serialize for FunctionParameters { .collect(); let mut map = serializer.serialize_map(Some(2))?; map.serialize_entry("properties", &self.properties)?; + map.serialize_entry("type", &self.properties_type)?; if !required.is_empty() { map.serialize_entry("required", &required)?; } @@ -113,7 +115,7 @@ pub enum ParameterType { Float, #[serde(rename = "bool")] Bool, - #[serde(rename = "str")] + #[serde(rename = "string")] String, #[serde(rename = "list")] List, @@ -189,7 +191,7 @@ pub struct ToolCall { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionCallDetail { pub name: String, - pub arguments: Option>, + pub arguments: Option, } #[derive(Debug, Deserialize, Serialize)] diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 2065b1aa..3e72d4cf 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -19,6 +19,37 @@ pub struct Configuration { pub ratelimits: Option>, pub tracing: Option, pub mode: Option, + pub agents: HashMap, + pub tools: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Agent { + pub name: String, + pub description: String, + pub default_input_modes: Option>, + pub default_output_modes: Option>, + pub skills: Option>, + pub model: String, + pub agent_orchestrator_prompt: Option, + pub system_prompt: Option, + pub tools: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Skill { + pub id: String, + pub name: String, + pub description: String, + pub examples: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub endpoint: Option, + pub parameters: Option>, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -260,7 +291,48 @@ impl From<&PromptTarget> for ChatCompletionTool { function: FunctionDefinition { name: val.name.clone(), description: val.description.clone(), - parameters: FunctionParameters { properties }, + parameters: FunctionParameters { + properties, + properties_type: "object".to_string(), + }, + }, + } + } +} + +// convert Tool to ChatCompletionTool +impl From<&Tool> for ChatCompletionTool { + fn from(val: &Tool) -> Self { + let properties: HashMap = match val.parameters { + Some(ref entities) => { + let mut properties: HashMap = HashMap::new(); + for entity in entities.iter() { + let param = FunctionParameter { + parameter_type: ParameterType::from( + entity.parameter_type.clone().unwrap_or("str".to_string()), + ), + description: entity.description.clone(), + required: entity.required, + enum_values: entity.enum_values.clone(), + default: entity.default.clone(), + format: entity.format.clone(), + }; + properties.insert(entity.name.clone(), param); + } + properties + } + None => HashMap::new(), + }; + + ChatCompletionTool { + tool_type: crate::api::open_ai::ToolType::Function, + function: FunctionDefinition { + name: val.name.clone(), + description: val.description.clone(), + parameters: FunctionParameters { + properties, + properties_type: "object".to_string(), + }, }, } } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index e58bebde..4bcacc62 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"]; +pub const CHAT_COMPLETIONS_PATH: [&str; 2] = + ["/v1/chat/completions", "/openai/v1/chat/completions"]; pub const HEALTHZ_PATH: &str = "/healthz"; pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 1cd2fa86..25b25ecc 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -371,7 +371,6 @@ impl StreamContext { let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments; let endpoint_details = prompt_target.endpoint.as_ref().unwrap(); let endpoint_path: String = endpoint_details .path @@ -382,9 +381,10 @@ impl StreamContext { let http_method = endpoint_details.method.clone().unwrap_or_default(); let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default(); + //TODO: fixme: adilhafeez hack let (path, api_call_body) = match compute_request_path_body( &endpoint_path, - tool_params, + &None, &prompt_target_params, &http_method, ) { @@ -777,18 +777,19 @@ impl StreamContext { fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { let content = model_server_response - .choices.first() + .choices + .first() .and_then(|choice| choice.message.content.as_ref()); let content_has_value = content.is_some() && !content.unwrap().is_empty(); let tool_calls = model_server_response - .choices.first() + .choices + .first() .and_then(|choice| choice.message.tool_calls.as_ref()); // intent was matched if content has some value or tool_calls is empty - content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty()) } diff --git a/demos/ai_agent/agent_config copy.yaml b/demos/ai_agent/agent_config copy.yaml new file mode 100644 index 00000000..c6eafde6 --- /dev/null +++ b/demos/ai_agent/agent_config copy.yaml @@ -0,0 +1,76 @@ +version: v0.1 + + +listeners: + ingress_traffic: + address: 0.0.0.0 + port: 10000 + message_format: openai + timeout: 30s + +llm_providers: + - name: gpt-4o + access_key: $OPENAI_API_KEY + provider_interface: openai + model: gpt-4o + +endpoints: + frankfurther_api: + endpoint: api.frankfurter.dev + protocol: https + +tools: + - name: get_exchange_rate + endpoint: + name: frankfurther_api + path: /v1/latest?base={currency_from}&symbols={currency_to} + params: + - name: currency_from + description: currency symbol to convert from + type: str + default: USD + - name: currency_to + description: currency symbol to convert to + type: str + default: EUR + +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange. + +agents: + - name: Currency Exchange Agent + description: Helps with exchange rates for currencies + default_input_modes: + - text + - text/plain + default_output_modes: + - text + - text/plain + skills: + - id: convert_currency + name: Currency Exchange Rates Tool + description: Helps with exchange values between various currencies + examples: + - What is exchange rate between USD and GBP? + capabilities: + streaming: true + push_notifications: true + model: gpt-4o + system_prompt: | + You are a specialized assistant for currency conversions. + Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates. + If the user asks about anything other than currency conversion or exchange rates, + politely state that you cannot help with that topic and can only assist with currency-related queries. + Do not attempt to answer unrelated questions or use tools for other purposes. + Set response status to input_required if the user needs to provide more information. + Set response status to error if there is an error while processing the request. + Set response status to completed if the request is complete. + tools: + - get_exchange_rate + +tracing: + random_sampling: 100 + trace_arch_internal: true diff --git a/demos/ai_agent/agent_config.yaml b/demos/ai_agent/agent_config.yaml new file mode 100644 index 00000000..52f62400 --- /dev/null +++ b/demos/ai_agent/agent_config.yaml @@ -0,0 +1,84 @@ +version: v0.1 + + +listeners: + ingress_traffic: + address: 0.0.0.0 + port: 10000 + message_format: openai + timeout: 30s + +llm_providers: + - name: gpt-4o + access_key: $OPENAI_API_KEY + provider_interface: openai + model: gpt-4o + +endpoints: + frankfurther_api: + endpoint: api.frankfurter.dev + protocol: https + +tools: + get_exchange_rate: + name: get_exchange_rate + description: Get the latest exchange rate for a given currency pair + endpoint: + name: frankfurther_api + path: /v1/latest?base={currency_from}&symbols={currency_to} + parameters: + - name: currency_from + description: currency symbol to convert from + type: string + default: USD + - name: currency_to + description: currency symbol to convert to + type: string + default: EUR + get_list_of_supported_currencies: + name: get_list_of_supported_currencies + description: Get the list of supported currencies + endpoint: + name: frankfurther_api + path: /v1/currencies + +agents: + currency_exchange_agent: + name: Currency Exchange Agent + description: Helps with exchange rates for currencies + default_input_modes: + - text + - text/plain + default_output_modes: + - text + - text/plain + skills: + - id: convert_currency + name: Currency Exchange Rates Tool + description: Helps with exchange values between various currencies + examples: + - What is exchange rate between USD and GBP? + capabilities: + streaming: true + push_notifications: true + agent_orchestrator_model: gpt-4o + agent_orchestrator_prompt: | + You are a specialized assistant for currency conversions. + Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates. + If the user asks about anything other than currency conversion or exchange rates, + politely state that you cannot help with that topic and can only assist with currency-related queries. + Do not attempt to answer unrelated questions or use tools for other purposes. + Set response status to input_required if the user needs to provide more information. + Set response status to error if there is an error while processing the request. + Set response status to completed if the request is complete. + tools: + - get_exchange_rate + - get_list_of_supported_currencies + model: gpt-4o + system_prompt: | + You are a specialized currency exchange assistant. + Your task is to provide the user with the exchange rate between two currencies. + Keep the response concise and relevant to the user's query. +tracing: + random_sampling: 100 + trace_arch_internal: true diff --git a/demos/ai_agent/agent_config_mcp.yaml b/demos/ai_agent/agent_config_mcp.yaml new file mode 100644 index 00000000..54dd43d7 --- /dev/null +++ b/demos/ai_agent/agent_config_mcp.yaml @@ -0,0 +1,58 @@ +version: v0.1 + +listeners: + ingress_traffic: + address: 0.0.0.0 + port: 10000 + message_format: openai + timeout: 30s + +llm_providers: + - name: gpt-4o + access_key: $OPENAI_API_KEY + provider_interface: openai + model: gpt-4o + +endpoints: + + - name: frankfurther_api + endpoint: api.frankfurter.dev + protocol: https + + - name: twelvedata_api + endpoint: api.twelvedata.com + protocol: https + +mcp: + + - name: get_currency_exchange_rate + - name: get_list_of_supported_currencies + - name: get_stock_quote + +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange. + +agents: + - name: currency_exchange_agent + description: Agent for handling currency exchange queries + llm_provider: gpt-4o + system_prompt: | + You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you. + tools: + - get_currency_exchange_rate + - get_list_of_supported_currencies + + - name: get_stock_quote_agent + description: Agent for handling stock quote queries + llm_provider: gpt-4o + system_prompt: | + You are a helpful stock exchange assistant. You are given stock symbol along with its exchange rate in json format. Your task is to parse the data and present it in a human-readable format. Keep the details to highlevel and be concise. + tools: + - get_stock_quote + +tracing: + random_sampling: 100 + trace_arch_internal: true diff --git a/demos/ai_agent/test.hurl b/demos/ai_agent/test.hurl new file mode 100644 index 00000000..215de034 --- /dev/null +++ b/demos/ai_agent/test.hurl @@ -0,0 +1,20 @@ +POST http://localhost:14000/v1/chat/completions +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "hello" + }, + { + "role": "assistant", + "content": "Hello! I'm here to assist you with any questions related to currency conversions and exchange rates. How can I help you today?" + }, + { + "role": "user", + "content": "what is the exchange rate of USD to EUR?" + } + ], + "stream": true +}