From 44872107a86c4125c0069c0ecb03fd963e5d935a Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 10 Dec 2024 15:12:31 -0800 Subject: [PATCH] integrate arch with model_server --- arch/envoy.template.yaml | 5 +- arch/tools/cli/config_generator.py | 1 + crates/common/src/api/open_ai.rs | 10 +- crates/common/src/configuration.rs | 113 ++- crates/common/src/lib.rs | 2 +- crates/common/src/path.rs | 5 +- crates/prompt_gateway/src/context.rs | 59 -- crates/prompt_gateway/src/embeddings.rs | 5 - crates/prompt_gateway/src/filter_context.rs | 210 +---- crates/prompt_gateway/src/http_context.rs | 79 +- crates/prompt_gateway/src/lib.rs | 1 - crates/prompt_gateway/src/stream_context.rs | 890 +------------------- e2e_tests/api_model_server.rest | 82 +- 13 files changed, 240 insertions(+), 1222 deletions(-) delete mode 100644 crates/prompt_gateway/src/embeddings.rs diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index d5c7b95e..b4d4b999 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -211,8 +211,7 @@ static_resources: domains: - "*" routes: - - {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %} + {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %} - match: prefix: "/" headers: @@ -449,7 +448,7 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext sni: api.mistral.ai - {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %} + {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %} - name: {{ internal_clustrer }} connect_timeout: 5s type: STRICT_DNS diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index c32dae16..2f28f537 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -90,6 +90,7 @@ def validate_and_render_schema(): rendered = template.render(data) print(ENVOY_CONFIG_FILE_RENDERED) + print(rendered) with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file: file.write(rendered) diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index 20b550ae..e96906fa 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest { pub metadata: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum ToolType { #[serde(rename = "function")] Function, @@ -165,8 +165,8 @@ pub struct Message { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Choice { - pub finish_reason: String, - pub index: usize, + pub finish_reason: Option, + pub index: Option, pub message: Message, } @@ -217,8 +217,8 @@ impl ChatCompletionsResponse { tool_calls: None, tool_call_id: None, }, - index: 0, - finish_reason: "done".to_string(), + index: Some(0), + finish_reason: Some("done".to_string()), }], usage: None, model: ARCH_FC_MODEL_NAME.to_string(), diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0d9bea80..9a56e863 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -2,6 +2,10 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; +use crate::api::open_ai::{ + ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType, +}; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -231,11 +235,46 @@ pub struct PromptTarget { pub auto_llm_dispatch_on_response: Option, } +// convert PromptTarget to ChatCompletionTool +impl Into for &PromptTarget { + fn into(self) -> ChatCompletionTool { + let properties: HashMap = match self.parameters { + Some(ref entities) => { + let mut properties: HashMap = HashMap::new(); + for entity in entities.iter() { + let param = FunctionParameter { + parameter_type: ParameterType::from( + entity.parameter_type.clone().unwrap_or("str".to_string()), + ), + description: entity.description.clone(), + required: entity.required, + enum_values: entity.enum_values.clone(), + default: entity.default.clone(), + }; + properties.insert(entity.name.clone(), param); + } + properties + } + None => HashMap::new(), + }; + + ChatCompletionTool { + tool_type: crate::api::open_ai::ToolType::Function, + function: FunctionDefinition { + name: self.name.clone(), + description: self.description.clone(), + parameters: FunctionParameters { properties }, + }, + } + } +} + #[cfg(test)] mod test { + use pretty_assertions::assert_eq; use std::fs; - use crate::configuration::GuardType; + use crate::{api::open_ai::ToolType, configuration::GuardType}; #[test] fn test_deserialize_configuration() { @@ -307,4 +346,76 @@ mod test { let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt); assert_eq!(*mode, super::GatewayMode::Prompt); } + + #[test] + fn test_tool_conversion() { + let ref_config = fs::read_to_string( + "../../docs/source/resources/includes/arch_config_full_reference.yaml", + ) + .expect("reference config file not found"); + let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); + let prompt_targets = &config.prompt_targets; + let prompt_target = prompt_targets + .as_ref() + .unwrap() + .iter() + .find(|p| p.name == "reboot_network_device") + .unwrap(); + let chat_completion_tool: super::ChatCompletionTool = prompt_target.into(); + assert_eq!(chat_completion_tool.tool_type, ToolType::Function); + assert_eq!(chat_completion_tool.function.name, "reboot_network_device"); + assert_eq!( + chat_completion_tool.function.description, + "Reboot a specific network device" + ); + assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2); + assert_eq!( + chat_completion_tool + .function + .parameters + .properties + .contains_key("device_id"), + true + ); + assert_eq!( + chat_completion_tool + .function + .parameters + .properties + .get("device_id") + .unwrap() + .parameter_type, + crate::api::open_ai::ParameterType::String + ); + assert_eq!( + chat_completion_tool + .function + .parameters + .properties + .get("device_id") + .unwrap() + .description, + "Identifier of the network device to reboot.".to_string() + ); + assert_eq!( + chat_completion_tool + .function + .parameters + .properties + .get("device_id") + .unwrap() + .required, + Some(true) + ); + assert_eq!( + chat_completion_tool + .function + .parameters + .properties + .get("confirmation") + .unwrap() + .parameter_type, + crate::api::open_ai::ParameterType::Bool + ); + } } diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index cd5238a3..a7c881c6 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -5,10 +5,10 @@ pub mod embeddings; pub mod errors; pub mod http; pub mod llm_providers; +pub mod path; pub mod pii; pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; pub mod tracing; -pub mod path; diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs index 2b289c9d..17dfe2ce 100644 --- a/crates/common/src/path.rs +++ b/crates/common/src/path.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; -pub fn replace_params_in_path(path: &str, params: &HashMap) -> Result { +pub fn replace_params_in_path( + path: &str, + params: &HashMap, +) -> Result { let mut result = String::new(); let mut in_param = false; let mut current_param = String::new(); diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs index 567f41eb..4dba8588 100644 --- a/crates/prompt_gateway/src/context.rs +++ b/crates/prompt_gateway/src/context.rs @@ -19,70 +19,11 @@ impl Context for StreamContext { .expect("invalid token_id"); self.metrics.active_http_calls.increment(-1); - /* - state transition - - graph LR - - on_http_request_body --> prompt received - prompt received --> get embeddings & arch guard - arch guard --> get embeddings - get embeddings --> zeroshot intent - - ┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐ - │ │ │ │ │ │ │ │ - │ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │ - │ │ │ │ │ │ │ │ - └──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘ - │ ▲ - │ │ - │ │ - │ ┌────────┴───────┐ - │ │ │ - └───────────►│ arch guard │ - │ │ - └────────────────┘ - - - continue from zeroshot intent - - graph LR - - zeroshot intent --> arch_fc - zeroshot intent --> default prompt target - arch_fc --> developer api call & hallucination check - hallucination check --> parameter gathering & developer api call - developer api call --> resume request to llm - - - ┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐ - │ │ │ │ │ │ │ │ - │ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │ - │ │ │ │ │ │ │ │ - └────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘ - │ │ ▲ - │ └─────────────┐ │ - │ │ │ - │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐ - │ │ │ │ │ │ │ │ - └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │ - │ │ │ │ │ │ - └───────────────────────┘ └─────────────────────┘ └───────────────────────┘ - - - using https://mermaid-ascii.art/ - */ - if let Some(body) = self.get_http_call_response_body(0, body_size) { #[cfg_attr(any(), rustfmt::skip)] match callout_context.response_handler_type { - ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context), - ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context), ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), - ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context), ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context), - ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context), } } else { self.send_server_error( diff --git a/crates/prompt_gateway/src/embeddings.rs b/crates/prompt_gateway/src/embeddings.rs deleted file mode 100644 index f2883682..00000000 --- a/crates/prompt_gateway/src/embeddings.rs +++ /dev/null @@ -1,5 +0,0 @@ -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub enum EmbeddingType { - Name, - Description, -} diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index 449be126..9780ed7d 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,35 +1,17 @@ -use crate::embeddings::EmbeddingType; use crate::metrics::Metrics; use crate::stream_context::StreamContext; use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing}; -use common::consts::ARCH_UPSTREAM_HOST_HEADER; -use common::consts::DEFAULT_EMBEDDING_MODEL; -use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; -use common::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; -use common::http::CallArgs; use common::http::Client; use common::stats::Gauge; -use common::stats::IncrementingMetric; -use http::StatusCode; -use log::{debug, info, trace, warn}; +use log::debug; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::rc::Rc; -use std::time::Duration; - -pub type EmbeddingTypeMap = HashMap>; -pub type EmbeddingsStore = HashMap; #[derive(Debug)] -pub struct FilterCallContext { - pub prompt_target_name: String, - pub embedding_type: EmbeddingType, -} +pub struct FilterCallContext {} #[derive(Debug)] pub struct FilterContext { @@ -40,9 +22,6 @@ pub struct FilterContext { system_prompt: Rc>, prompt_targets: Rc>, prompt_guards: Rc, - embeddings_store: Option>, - temp_embeddings_store: EmbeddingsStore, - active_embedding_calls_count: u32, tracing: Rc>, } @@ -55,131 +34,9 @@ impl FilterContext { prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - embeddings_store: Some(Rc::new(HashMap::new())), - temp_embeddings_store: HashMap::new(), - active_embedding_calls_count: 0, tracing: Rc::new(None), } } - - fn process_prompt_targets(&mut self) { - let prompt_target_description: Vec<(String, String)> = self - .prompt_targets - .iter() - .map(|(k, v)| (k.clone(), v.description.clone())) - .collect(); - - prompt_target_description - .iter() - .for_each(|(name, description)| { - self.schedule_embeddings_call(name, description, EmbeddingType::Description); - }); - } - - fn schedule_embeddings_call( - &mut self, - prompt_target_name: &str, - input: &str, - embedding_type: EmbeddingType, - ) { - let embeddings_input = CreateEmbeddingRequest { - input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - let json_data = serde_json::to_string(&embeddings_input).unwrap(); - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - vec![ - (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", EMBEDDINGS_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(60), - ); - - let call_context = crate::filter_context::FilterCallContext { - prompt_target_name: String::from(prompt_target_name), - embedding_type, - }; - - self.active_embedding_calls_count += 1; - if let Err(error) = self.http_call(call_args, call_context) { - panic!("{error}") - } - } - - fn embedding_response_handler( - &mut self, - embedding_type: EmbeddingType, - prompt_target_name: String, - body: Vec, - ) { - let prompt_target = self - .prompt_targets - .get(&prompt_target_name) - .unwrap_or_else(|| { - panic!( - "Received embeddings response for unknown prompt target name={}", - prompt_target_name - ) - }); - - if !body.is_empty() { - let mut embedding_response: CreateEmbeddingResponse = - match serde_json::from_slice(&body) { - Ok(response) => response, - Err(e) => { - panic!( - "Error deserializing embedding response. body: {:?}: {:?}", - String::from_utf8(body).unwrap(), - e - ); - } - }; - - let embeddings = embedding_response.data.remove(0).embedding; - debug!( - "Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}", - prompt_target.name, - prompt_target.description, - embedding_type - ); - - let entry = self.temp_embeddings_store.entry(prompt_target_name); - match entry { - Entry::Occupied(_) => { - entry.and_modify(|e| { - if let Entry::Vacant(e) = e.entry(embedding_type) { - e.insert(embeddings); - } else { - panic!( - "Duplicate {:?} for prompt target with name=\"{}\"", - &embedding_type, prompt_target.name - ) - } - }); - } - Entry::Vacant(_) => { - entry.or_insert(HashMap::from([(embedding_type, embeddings)])); - } - } - - if self.prompt_targets.len() == self.temp_embeddings_store.len() { - self.embeddings_store = - Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store))) - } - } - } } impl Client for FilterContext { @@ -194,46 +51,7 @@ impl Client for FilterContext { } } -impl Context for FilterContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - trace!( - "filter_context: on_http_call_response called with token_id: {:?}", - token_id - ); - let callout_data = self - .callouts - .borrow_mut() - .remove(&token_id) - .expect("invalid token_id"); - - self.active_embedding_calls_count -= 1; - self.metrics.active_http_calls.increment(-1); - let body_bytes = self.get_http_call_response_body(0, body_size).unwrap(); - - if let Some(status_code) = self.get_http_call_response_header(":status") { - if status_code == StatusCode::OK.as_str() { - self.embedding_response_handler( - callout_data.embedding_type, - callout_data.prompt_target_name, - body_bytes, - ); - } else { - warn!( - "Received non-200 status code: {} for callout with token_id: {}: body_str: {}", - status_code, - token_id, - String::from_utf8(body_bytes).unwrap() - ); - } - } - } -} +impl Context for FilterContext {} // RootContext allows the Rust code to reach into the Envoy Config impl RootContext for FilterContext { @@ -271,15 +89,12 @@ impl RootContext for FilterContext { context_id ); - let embedding_store = self.embeddings_store.as_ref().map(Rc::clone); Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), Rc::clone(&self.prompt_targets), - Rc::clone(&self.prompt_guards), Rc::clone(&self.overrides), - embedding_store, Rc::clone(&self.tracing), ))) } @@ -289,25 +104,6 @@ impl RootContext for FilterContext { } fn on_vm_start(&mut self, _: usize) -> bool { - self.set_tick_period(Duration::from_secs(1)); true } - - fn on_tick(&mut self) { - if self.embeddings_store.is_some() - && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() - { - info!("embeddings store initialized"); - self.set_tick_period(Duration::from_secs(0)); - } else { - if self.active_embedding_calls_count == 0 { - info!("retrieving embeddings from embedding server"); - self.process_prompt_targets(); - } else { - info!("waiting for embeddings store to be initialized"); - } - - self.set_tick_period(Duration::from_secs(5)); - } - } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 7508f852..7f4f010d 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -1,13 +1,12 @@ use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext}; use common::{ - api::{ - open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest}, - prompt_guard::{PromptGuardRequest, PromptGuardTask}, + api::open_ai::{ + self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, }, consts::{ ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER, - ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST, - HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, + MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, }, errors::ServerError, http::{CallArgs, Client}, @@ -35,11 +34,7 @@ impl HttpContext for StreamContext { let request_path = self.get_http_request_header(":path").unwrap_or_default(); if request_path == HEALTHZ_PATH { - if self.is_embedding_store_initialized() { - self.send_http_response(200, vec![], None); - } else { - self.send_http_response(503, vec![], None); - } + self.send_http_response(200, vec![], None); return Action::Continue; } @@ -138,43 +133,25 @@ impl HttpContext for StreamContext { self.user_prompt = Some(last_user_prompt.clone()); - let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); + // convert prompt targets to ChatCompletionTool + let tool_calls: Vec = self + .prompt_targets + .iter() + .map(|(_, pt)| pt.into()) + .collect(); - let prompt_guard_jailbreak_task = self - .prompt_guards - .input_guards - .contains_key(&common::configuration::GuardType::Jailbreak); + let arch_fc_chat_completion_request = ChatCompletionsRequest { + messages: deserialized_body.messages.clone(), + metadata: deserialized_body.metadata.clone(), + stream: deserialized_body.stream, + model: "--".to_string(), + stream_options: deserialized_body.stream_options.clone(), + tools: Some(tool_calls), + }; self.chat_completions_request = Some(deserialized_body); - if !prompt_guard_jailbreak_task { - debug!("Missing input guard. Making inline call to retrieve embeddings"); - let callout_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: user_message_str.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - - let get_prompt_guards_request = PromptGuardRequest { - input: self - .user_prompt - .as_ref() - .unwrap() - .content - .as_ref() - .unwrap() - .clone(), - task: PromptGuardTask::Jailbreak, - }; - - let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { + let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) { Ok(json_data) => json_data, Err(error) => { self.send_server_error(ServerError::Serialization(error), None); @@ -182,14 +159,14 @@ impl HttpContext for StreamContext { } }; + debug!("archgw => archfc: {}", json_data); + let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST), + (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), (":method", "POST"), - (":path", "/guard"), - (":authority", GUARD_INTERNAL_HOST), + (":path", "/function_calling"), ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), + (":authority", MODEL_SERVER_NAME), ]; if self.request_id.is_some() { @@ -202,14 +179,15 @@ impl HttpContext for StreamContext { let call_args = CallArgs::new( ARCH_INTERNAL_CLUSTER_NAME, - "/guard", + "/function_calling", headers, Some(json_data.as_bytes()), vec![], Duration::from_secs(5), ); + let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, + response_handler_type: ResponseHandlerType::ArchFC, user_message: self.user_prompt.as_ref().unwrap().content.clone(), prompt_target_name: None, request_body: self.chat_completions_request.as_ref().unwrap().clone(), @@ -219,6 +197,7 @@ impl HttpContext for StreamContext { }; if let Err(e) = self.http_call(call_args, call_context) { + debug!("http_call failed: {:?}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 9d828dac..1acd4d6d 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -3,7 +3,6 @@ use proxy_wasm::traits::*; use proxy_wasm::types::*; mod context; -mod embeddings; mod filter_context; mod http_context; mod metrics; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 8671af63..e02f96c5 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,36 +1,19 @@ -use crate::embeddings::EmbeddingType; -use crate::filter_context::EmbeddingsStore; use crate::metrics::Metrics; -use acap::cos; -use common::api::hallucination::{ - extract_messages_for_hallucination, HallucinationClassificationRequest, - HallucinationClassificationResponse, -}; use common::api::open_ai::{ - to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, - ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter, - FunctionParameters, Message, ParameterType, ToolCall, ToolType, + to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, + ChatCompletionsResponse, Message, ToolCall, }; -use common::api::prompt_guard::PromptGuardResponse; -use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse}; -use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing}; +use common::configuration::{Overrides, PromptTarget, Tracing}; use common::consts::{ - ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, - ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, - ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, - DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, - HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, - TRACE_PARENT_HEADER, USER_ROLE, ZEROSHOT_INTERNAL_HOST, -}; -use common::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, + MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; -use log::{debug, info, trace, warn}; +use log::{debug, warn}; use proxy_wasm::traits::*; use serde_yaml::Value; use std::cell::RefCell; @@ -41,13 +24,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Debug, Clone)] pub enum ResponseHandlerType { - Embeddings, ArchFC, FunctionCall, - ZeroShotIntent, - Hallucination, - ArchGuard, - DefaultTarget, } #[derive(Clone, Derivative)] @@ -66,8 +44,7 @@ pub struct StreamCallContext { pub struct StreamContext { system_prompt: Rc>, pub prompt_targets: Rc>, - pub embeddings_store: Option>, - overrides: Rc>, + _overrides: Rc>, pub metrics: Rc, pub callouts: RefCell>, pub context_id: u32, @@ -79,12 +56,11 @@ pub struct StreamContext { pub streaming_response: bool, pub is_chat_completions_request: bool, pub chat_completions_request: Option, - pub prompt_guards: Rc, pub request_id: Option, pub start_upstream_llm_request_time: u128, pub time_to_first_token: Option, pub traceparent: Option, - pub tracing: Rc>, + pub _tracing: Rc>, } impl StreamContext { @@ -94,9 +70,7 @@ impl StreamContext { metrics: Rc, system_prompt: Rc>, prompt_targets: Rc>, - prompt_guards: Rc, overrides: Rc>, - embeddings_store: Option>, tracing: Rc>, ) -> Self { StreamContext { @@ -104,7 +78,6 @@ impl StreamContext { metrics, system_prompt, prompt_targets, - embeddings_store, callouts: RefCell::new(HashMap::new()), chat_completions_request: None, tool_calls: None, @@ -114,32 +87,15 @@ impl StreamContext { streaming_response: false, user_prompt: None, is_chat_completions_request: false, - prompt_guards, - overrides, + _overrides: overrides, request_id: None, traceparent: None, - tracing, + _tracing: tracing, start_upstream_llm_request_time: 0, time_to_first_token: None, } } - fn embeddings_store(&self) -> &EmbeddingsStore { - self.embeddings_store.as_ref().unwrap() - } - - pub fn is_embedding_store_initialized(&self) -> bool { - if self.embeddings_store.as_ref().is_none() { - return false; - } - - if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() { - return true; - } - - false - } - pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { self.send_http_response( override_status_code @@ -151,190 +107,8 @@ impl StreamContext { ); } - pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { - let user_message = callout_context.user_message.unwrap(); - let get_embeddings_input = CreateEmbeddingRequest { - // Need to clone into input because user_message is used below. - input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) { - Ok(json_data) => json_data, - Err(error) => { - warn!("error serializing get embeddings request: {}", error); - return self.send_server_error(ServerError::Deserialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", EMBEDDINGS_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - if self.trace_arch_internal() && self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - headers, - Some(embeddings_request_str.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::Embeddings, - user_message: Some(user_message), - prompt_target_name: None, - request_body: callout_context.request_body, - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - }; - - debug!( - "archgw => get embeddings request: {}", - embeddings_request_str - ); - if let Err(e) = self.http_call(call_args, call_context) { - warn!("error dispatching get embeddings request: {}", e); - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - - pub fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { - let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { - Ok(embedding_response) => embedding_response, - Err(e) => { - warn!("error deserializing embedding response: {}", e); - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - let prompt_embeddings_vector = &embedding_response.data[0].embedding; - - trace!( - "embedding model: {}, vector length: {:?}", - embedding_response.model, - prompt_embeddings_vector.len() - ); - - let prompt_target_names = self - .prompt_targets - .iter() - // exclude default target - .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) - .map(|(name, _)| name.clone()) - .collect(); - - let similarity_scores: Vec<(String, f64)> = self - .prompt_targets - .iter() - // exclude default prompt target - .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) - .map(|(prompt_name, _)| { - let pte = match self.embeddings_store().get(prompt_name) { - Some(embeddings) => embeddings, - None => { - warn!( - "embeddings not found for prompt target name: {}", - prompt_name - ); - return (prompt_name.clone(), 0.0); - } - }; - - let description_embeddings = match pte.get(&EmbeddingType::Description) { - Some(embeddings) => embeddings, - None => { - warn!( - "description embeddings not found for prompt target name: {}", - prompt_name - ); - return (prompt_name.clone(), 0.0); - } - }; - let similarity_score_description = - cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings); - (prompt_name.clone(), similarity_score_description) - }) - .collect(); - - debug!( - "similarity scores based on description embeddings match: {:?}", - similarity_scores - ); - - callout_context.similarity_scores = Some(similarity_scores); - - let zero_shot_classification_request = ZeroShotClassificationRequest { - // Need to clone into input because user_message is used below. - input: callout_context.user_message.as_ref().unwrap().clone(), - model: String::from(DEFAULT_INTENT_MODEL), - labels: prompt_target_names, - }; - - let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - debug!( - "error serializing zero shot classification request: {}", - error - ); - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, ZEROSHOT_INTERNAL_HOST), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", ZEROSHOT_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - if self.trace_arch_internal() && self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/zeroshot", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; - - if let Err(e) = self.http_call(call_args, callout_context) { - warn!("error dispatching zero shot classification request: {}", e); - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - - fn trace_arch_internal(&self) -> bool { - match self.tracing.as_ref() { + fn _trace_arch_internal(&self) -> bool { + match self._tracing.as_ref() { Some(tracing) => match tracing.trace_arch_internal.as_ref() { Some(trace_arch_internal) => *trace_arch_internal, None => false, @@ -343,364 +117,7 @@ impl StreamContext { } } - pub fn hallucination_classification_resp_handler( - &mut self, - body: Vec, - callout_context: StreamCallContext, - ) { - let body_str = String::from_utf8(body).expect("could not convert body to string"); - debug!("archgw <= hallucination response: {}", body_str); - let hallucination_response: HallucinationClassificationResponse = - match serde_json::from_str(body_str.as_str()) { - Ok(hallucination_response) => hallucination_response, - Err(e) => { - warn!( - "error deserializing hallucination response: {}, body: {}", - e, - body_str.as_str() - ); - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - let mut keys_with_low_score: Vec = Vec::new(); - for (key, value) in &hallucination_response.params_scores { - if *value < DEFAULT_HALLUCINATED_THRESHOLD { - debug!( - "hallucination detected: score for {} : {} is less than threshold {}", - key, value, DEFAULT_HALLUCINATED_THRESHOLD - ); - keys_with_low_score.push(key.clone().to_string()); - } - } - - if !keys_with_low_score.is_empty() { - let response = - HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?"; - - let response_str = if self.streaming_response { - let chunks = vec![ - ChatCompletionStreamResponse::new( - None, - Some(ASSISTANT_ROLE.to_string()), - Some(ARCH_FC_MODEL_NAME.to_owned()), - None, - ), - ChatCompletionStreamResponse::new( - Some(response), - None, - Some(ARCH_FC_MODEL_NAME.to_owned()), - None, - ), - ]; - - to_server_events(chunks) - } else { - let chat_completion_response = ChatCompletionsResponse::new(response); - serde_json::to_string(&chat_completion_response).unwrap() - }; - debug!("hallucination response: {:?}", response_str); - // make sure on_http_response_body does not attach tool calls and tool response to the response - self.tool_calls = None; - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![], - Some(response_str.as_bytes()), - ); - } else { - // not a hallucination, resume the flow - self.schedule_api_call_request(callout_context); - } - } - - pub fn zero_shot_intent_detection_resp_handler( - &mut self, - body: Vec, - mut callout_context: StreamCallContext, - ) { - let zeroshot_intent_response: ZeroShotClassificationResponse = - match serde_json::from_slice(&body) { - Ok(zeroshot_response) => zeroshot_response, - Err(e) => { - warn!( - "error deserializing zero shot classification response: {}", - e - ); - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - trace!( - "zeroshot intent response: {}", - serde_json::to_string(&zeroshot_intent_response).unwrap() - ); - - let desc_emb_similarity_map: HashMap = callout_context - .similarity_scores - .clone() - .unwrap() - .into_iter() - .collect(); - - let pred_class_desc_emb_similarity = desc_emb_similarity_map - .get(&zeroshot_intent_response.predicted_class) - .unwrap(); - - let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7 - + pred_class_desc_emb_similarity * 0.3; - - debug!( - "similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}", - prompt_target_similarity_score, - zeroshot_intent_response.predicted_class_score, - pred_class_desc_emb_similarity, - callout_context.user_message.as_ref().unwrap() - ); - - let prompt_target_name = zeroshot_intent_response.predicted_class.clone(); - - // Check to see who responded to user message. This will help us identify if control should be passed to Arch FC or not. - // If the last message was from Arch FC, then Arch FC is handling the conversation (possibly for parameter collection). - let mut arch_assistant = false; - let messages = &callout_context.request_body.messages; - if messages.len() >= 2 { - let latest_assistant_message = &messages[messages.len() - 2]; - if let Some(model) = latest_assistant_message.model.as_ref() { - if model.contains(ARCH_MODEL_PREFIX) { - arch_assistant = true; - } - } - } else { - debug!("no assistant message found, probably first interaction"); - } - - // get prompt target similarity thresold from overrides - let prompt_target_intent_matching_threshold = match self.overrides.as_ref() { - Some(overrides) => match overrides.prompt_target_intent_matching_threshold { - Some(threshold) => threshold, - None => DEFAULT_PROMPT_TARGET_THRESHOLD, - }, - None => DEFAULT_PROMPT_TARGET_THRESHOLD, - }; - - // check to ensure that the prompt target similarity score is above the threshold - if prompt_target_similarity_score < prompt_target_intent_matching_threshold - || arch_assistant - { - debug!("intent score is low or arch assistant is handling the conversation"); - // if arch fc responded to the user message, then we don't need to check the similarity score - // it may be that arch fc is handling the conversation for parameter collection - if arch_assistant { - info!("arch fc is engaged in parameter collection"); - } else if let Some(default_prompt_target) = self - .prompt_targets - .values() - .find(|pt| pt.default.unwrap_or(false)) - { - debug!("default prompt target found, forwarding request to default prompt target"); - let endpoint = default_prompt_target.endpoint.clone().unwrap(); - let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); - - let upstream_endpoint = endpoint.name; - let mut params = HashMap::new(); - params.insert( - MESSAGES_KEY.to_string(), - callout_context.request_body.messages.clone(), - ); - let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); - - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), - (":path", &upstream_path), - (":authority", &upstream_endpoint), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - if self.trace_arch_internal() && self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - &upstream_path, - headers, - Some(arch_messages_json.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; - callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); - - if let Err(e) = self.http_call(call_args, callout_context) { - warn!("error dispatching default prompt target request: {}", e); - return self.send_server_error( - ServerError::HttpDispatch(e), - Some(StatusCode::BAD_REQUEST), - ); - } - return; - } else { - // if no default prompt target is found and similarity score is low send response to upstream llm - // removing tool calls and tool response - - let messages = self.filter_out_arch_messages(&callout_context); - - let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: callout_context.request_body.model, - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - - let llm_request_str = match serde_json::to_string(&chat_completions_request) { - Ok(json_string) => json_string, - Err(e) => { - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - debug!( - "archgw (low similarity score) => llm request: {}", - llm_request_str - ); - - self.set_http_request_body( - 0, - self.request_body_size, - &llm_request_str.into_bytes(), - ); - - self.resume_http_request(); - return; - } - } - - let prompt_target = self - .prompt_targets - .get(&prompt_target_name) - .expect("prompt target not found") - .clone(); - - let mut chat_completion_tools: Vec = Vec::new(); - for pt in self.prompt_targets.values() { - if pt.default.unwrap_or_default() { - continue; - } - // only extract entity names - let properties: HashMap = match pt.parameters { - // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. - Some(ref entities) => { - let mut properties: HashMap = HashMap::new(); - for entity in entities.iter() { - let param = FunctionParameter { - parameter_type: ParameterType::from( - entity.parameter_type.clone().unwrap_or("str".to_string()), - ), - description: entity.description.clone(), - required: entity.required, - enum_values: entity.enum_values.clone(), - default: entity.default.clone(), - }; - properties.insert(entity.name.clone(), param); - } - properties - } - None => HashMap::new(), - }; - let tools_parameters = FunctionParameters { properties }; - - chat_completion_tools.push({ - ChatCompletionTool { - tool_type: ToolType::Function, - function: FunctionDefinition { - name: pt.name.clone(), - description: pt.description.clone(), - parameters: tools_parameters, - }, - } - }); - } - - // archfc handler needs state so it can expand tool calls - let mut metadata = HashMap::new(); - metadata.insert( - ARCH_STATE_HEADER.to_string(), - serde_json::to_string(&self.arch_state).unwrap(), - ); - - let chat_completions = ChatCompletionsRequest { - model: self - .chat_completions_request - .as_ref() - .unwrap() - .model - .clone(), - messages: callout_context.request_body.messages.clone(), - tools: Some(chat_completion_tools), - stream: false, - stream_options: None, - metadata: Some(metadata), - }; - - let msg_body = match serde_json::to_string(&chat_completions) { - Ok(msg_body) => msg_body, - Err(e) => { - warn!("error serializing arch_fc request body: {}", e); - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); - - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST), - (":path", "/v1/chat/completions"), - (":authority", ARCH_FC_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - if self.trace_arch_internal() && self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/v1/chat/completions", - headers, - Some(msg_body.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::ArchFC; - callout_context.prompt_target_name = Some(prompt_target.name); - - debug!("archgw => archfc request: {}", msg_body); - if let Err(e) = self.http_call(call_args, callout_context) { - debug!("error dispatching arch_fc request: {}", e); - self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); - } - } - - pub fn arch_fc_response_handler( - &mut self, - body: Vec, - mut callout_context: StreamCallContext, - ) { + pub fn arch_fc_response_handler(&mut self, body: Vec, callout_context: StreamCallContext) { let body_str = String::from_utf8(body).unwrap(); debug!("archgw <= archfc response: {}", body_str); @@ -767,114 +184,7 @@ impl StreamContext { ); } - // TODO CO: pass nli check - let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); - let prompt_target = self - .prompt_targets - .get(&tools_call_name) - .expect("prompt target not found for tool call") - .clone(); - - debug!( - "prompt_target_name: {}, tool_name(s): {:?}", - prompt_target.name, - self.tool_calls - .as_ref() - .unwrap() - .iter() - .map(|tc| tc.function.name.clone()) - .collect::>(), - ); - - // If hallucination, pass chat template to check parameters - //HACK: for now we only support one tool call, we will support multiple tool calls in the future - - let mut tool_params = self.tool_calls.as_ref().unwrap()[0] - .function - .arguments - .clone(); - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - debug!( - "tool_params (without messages history): {}", - tool_params_json_str - ); - tool_params.insert( - String::from(MESSAGES_KEY), - serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), - ); - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - - use serde_json::Value; - let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); - let tool_params_dict: HashMap = match v.as_object() { - Some(obj) => obj - .iter() - .map(|(key, value)| { - // Convert each value to a string, regardless of its type - (key.clone(), value.to_string()) - }) - .collect(), - None => HashMap::new(), // Return an empty HashMap if v is not an object - }; - - let all_user_messages = - extract_messages_for_hallucination(&callout_context.request_body.messages); - let user_messages_str = all_user_messages.join(", "); - debug!("user messages: {}", user_messages_str); - - let hallucination_classification_request = HallucinationClassificationRequest { - prompt: user_messages_str, - model: String::from(DEFAULT_INTENT_MODEL), - parameters: tool_params_dict, - }; - - let hallucination_request_str: String = - match serde_json::to_string(&hallucination_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - debug!( - "error serializing hallucination classification request: {}", - error - ); - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", HALLUCINATION_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - if self.trace_arch_internal() && self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/hallucination", - headers, - Some(hallucination_request_str.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::Hallucination; - - debug!( - "archgw => hallucination request: {}", - hallucination_request_str - ); - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } + self.schedule_api_call_request(callout_context); } fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { @@ -1093,178 +403,6 @@ impl StreamContext { messages } - pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { - let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); - debug!( - "archgw <= archguard response: {:?}", - serde_json::to_string(&prompt_guard_resp) - ); - - if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { - //TODO: handle other scenarios like forward to error target - let msg = self - .prompt_guards - .jailbreak_on_exception_message() - .unwrap_or("refrain from discussing jailbreaking."); - info!("jailbreak detected: {}", msg); - - let response_str = if self.streaming_response { - let chunks = vec![ - ChatCompletionStreamResponse::new( - None, - Some(ASSISTANT_ROLE.to_string()), - Some(ARCH_FC_MODEL_NAME.to_owned()), - None, - ), - ChatCompletionStreamResponse::new( - Some(msg.to_string()), - None, - Some(ARCH_FC_MODEL_NAME.to_owned()), - None, - ), - ]; - - to_server_events(chunks) - } else { - let chat_completion_response = ChatCompletionsResponse::new(msg.to_string()); - serde_json::to_string(&chat_completion_response).unwrap() - }; - - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![], - Some(response_str.as_bytes()), - ); - - return self.send_server_error( - ServerError::Jailbreak(String::from(msg)), - Some(StatusCode::BAD_REQUEST), - ); - } - - self.get_embeddings(callout_context); - } - - pub fn default_target_handler(&self, body: Vec, mut callout_context: StreamCallContext) { - let prompt_target = self - .prompt_targets - .get(callout_context.prompt_target_name.as_ref().unwrap()) - .unwrap() - .clone(); - - // check if the default target should be dispatched to the LLM provider - if !prompt_target - .auto_llm_dispatch_on_response - .unwrap_or_default() - { - let default_target_response_str = if self.streaming_response { - let chat_completion_response = - match serde_json::from_slice::(&body) { - Ok(chat_completion_response) => chat_completion_response, - Err(e) => { - warn!( - "error deserializing default target response: {}, body str: {}", - e, - String::from_utf8(body).unwrap() - ); - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - let chunks = vec![ - ChatCompletionStreamResponse::new( - None, - Some(ASSISTANT_ROLE.to_string()), - Some(chat_completion_response.model.clone()), - None, - ), - ChatCompletionStreamResponse::new( - chat_completion_response.choices[0].message.content.clone(), - None, - Some(chat_completion_response.model.clone()), - None, - ), - ]; - - to_server_events(chunks) - } else { - String::from_utf8(body).unwrap() - }; - - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![], - Some(default_target_response_str.as_bytes()), - ); - return; - } - - let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { - Ok(chat_completions_resp) => chat_completions_resp, - Err(e) => { - warn!( - "error deserializing default target response: {}, body str: {}", - e, - String::from_utf8(body).unwrap() - ); - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - let mut messages = Vec::new(); - // add system prompt - match prompt_target.system_prompt.as_ref() { - None => {} - Some(system_prompt) => { - let system_prompt_message = Message { - role: SYSTEM_ROLE.to_string(), - content: Some(system_prompt.clone()), - model: None, - tool_calls: None, - tool_call_id: None, - }; - messages.push(system_prompt_message); - } - } - - messages.append(&mut callout_context.request_body.messages); - - let api_resp = chat_completions_resp.choices[0] - .message - .content - .as_ref() - .unwrap(); - - let user_message = messages.pop().unwrap(); - let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp); - messages.push(Message { - role: USER_ROLE.to_string(), - content: Some(message), - model: None, - tool_calls: None, - tool_call_id: None, - }); - - let chat_completion_request = ChatCompletionsRequest { - model: self - .chat_completions_request - .as_ref() - .unwrap() - .model - .clone(), - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - - let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); - debug!("archgw => (default target) llm request: {}", json_resp); - self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); - self.resume_http_request(); - } - pub fn generate_toll_call_message(&mut self) -> Message { Message { role: ASSISTANT_ROLE.to_string(), diff --git a/e2e_tests/api_model_server.rest b/e2e_tests/api_model_server.rest index e6ad1530..d92358d0 100644 --- a/e2e_tests/api_model_server.rest +++ b/e2e_tests/api_model_server.rest @@ -15,20 +15,76 @@ Content-Type: application/json ], "tools": [ { - "type": "function", - "function": { - "name": "weather_forecast", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "str" - }, - "days": { - "type": "int" + "id": "weather-112", + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State" + }, + "unit": { + "type": "str", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius" + }, + "days": { + "type": "str", + "description": "the number of days for the request." + } + }, + "required": ["location", "days"] + } + } + } + ] +} + +### talk to function calling endpoint +POST {{model_server_endpoint}}/function_calling HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + } + ], + "tools": [ + { + "id": "weather-112", + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State" + }, + "unit": { + "type": "str", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius" + }, + "days": { + "type": "str", + "description": "the number of days for the request." + } + }, + "required": ["location", "days"] } - }, - "required": ["city", "days"] } } }