diff --git a/arch/docker-compose.dev.yaml b/arch/docker-compose.dev.yaml index b019f52f..8c8ce464 100644 --- a/arch/docker-compose.dev.yaml +++ b/arch/docker-compose.dev.yaml @@ -5,28 +5,9 @@ services: - "10000:10000" - "19901:9901" volumes: - - ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_config.yaml}:/config/arch_config.yaml + - ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml - /etc/ssl/cert.pem:/etc/ssl/cert.pem - ./envoy.template.dev.yaml:/config/envoy.template.yaml - ./target/wasm32-wasi/release/intelligent_prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm - depends_on: - model_server: - condition: service_healthy env_file: - stage.env - - model_server: - image: model_server:latest - ports: - - "18081:80" - healthcheck: - test: ["CMD", "curl" ,"http://localhost/healthz"] - interval: 5s - retries: 20 - volumes: - - ~/.cache/huggingface:/root/.cache/huggingface - environment: - - OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal} - - OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M - - MODE=${MODE:-cloud} - - FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1} diff --git a/arch/envoy.template.dev.yaml b/arch/envoy.template.dev.yaml index 6b9d82e1..9e83cf44 100644 --- a/arch/envoy.template.dev.yaml +++ b/arch/envoy.template.dev.yaml @@ -13,7 +13,7 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager stat_prefix: arch_ingress_http - codec_type: HTTP1 + codec_type: AUTO scheme_header_transformation: scheme_to_overwrite: https access_log: @@ -72,11 +72,6 @@ static_resources: type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN - # typed_extension_protocol_options: - # envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - # "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - # explicit_http_config: - # http2_protocol_options: {} load_assignment: cluster_name: openai endpoints: @@ -129,7 +124,7 @@ static_resources: address: socket_address: address: host.docker.internal - port_value: 8000 + port_value: 51000 hostname: "model_server" - name: mistral_7b_instruct connect_timeout: 5s @@ -159,7 +154,7 @@ static_resources: address: socket_address: address: host.docker.internal - port_value: 8000 + port_value: 51000 hostname: "arch_fc" {% for _, cluster in arch_clusters.items() %} - name: {{ cluster.name }} diff --git a/arch/src/consts.rs b/arch/src/consts.rs index 5cf0478e..fe02a876 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -1,6 +1,7 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli"; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; +pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; @@ -13,3 +14,4 @@ pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; pub const ARCH_STATE_HEADER: &str = "x-arch-state"; +pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 9ca8e5fd..5f5b9e4d 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -1,7 +1,8 @@ use crate::consts::{ - ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, - ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, + ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, + DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, + DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; @@ -17,12 +18,13 @@ use proxy_wasm::traits::*; use proxy_wasm::types::*; use public_types::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, - ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, + ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, + Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, }; use public_types::common_types::{ - EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, - ZeroShotClassificationRequest, ZeroShotClassificationResponse, + EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, + PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, + ZeroShotClassificationResponse, }; use public_types::configuration::LlmProvider; use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; @@ -37,22 +39,24 @@ use std::num::NonZero; use std::rc::Rc; use std::time::Duration; -#[derive(Debug)] +#[derive(Debug, Clone)] enum ResponseHandlerType { GetEmbeddings, FunctionResolver, FunctionCall, ZeroShotIntent, + HallucinationDetect, ArchGuard, DefaultTarget, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct StreamCallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target_name: Option, request_body: ChatCompletionsRequest, + tool_calls: Option>, similarity_scores: Option>, upstream_cluster: Option, upstream_cluster_path: Option, @@ -310,6 +314,69 @@ impl StreamContext { } } + fn hallucination_classification_resp_handler( + &mut self, + body: Vec, + callout_context: StreamCallContext, + ) { + let hallucination_response: HallucinationClassificationResponse = + match serde_json::from_slice(&body) { + Ok(hallucination_response) => hallucination_response, + Err(e) => { + return self.send_server_error(ServerError::Deserialization(e), None); + } + }; + let mut keys_with_low_score: Vec = Vec::new(); + for (key, value) in &hallucination_response.params_scores { + if *value < DEFAULT_HALLUCINATED_THRESHOLD { + debug!( + "hallucination detected: score for {} : {} is less than threshold {}", + key, value, DEFAULT_HALLUCINATED_THRESHOLD + ); + keys_with_low_score.push(key.clone().to_string()); + } + } + + if !keys_with_low_score.is_empty() { + let response = + "It seems I’m missing some information. Could you provide the following details: " + .to_string() + + &keys_with_low_score.join(", ") + + " ?"; + let message = Message { + role: SYSTEM_ROLE.to_string(), + content: Some(response), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + }; + + let chat_completion_response = ChatCompletionsResponse { + choices: vec![Choice { + message, + index: 0, + finish_reason: "done".to_string(), + }], + usage: None, + model: ARCH_FC_MODEL_NAME.to_string(), + metadata: None, + }; + + debug!("hallucination response: {:?}", chat_completion_response); + self.send_http_response( + StatusCode::OK.as_u16().into(), + vec![("Powered-By", "Katanemo")], + Some( + serde_json::to_string(&chat_completion_response) + .unwrap() + .as_bytes(), + ), + ); + } else { + // not a hallucination, resume the flow + self.schedule_api_call_request(callout_context); + } + } + fn zero_shot_intent_detection_resp_handler( &mut self, body: Vec, @@ -565,6 +632,9 @@ impl StreamContext { let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); + // TODO CO: pass nli check + // If hallucination, pass chat template to check parameters + // extract all tool names let tool_names: Vec = tool_calls .iter() @@ -581,10 +651,11 @@ impl StreamContext { String::from(ARCH_MESSAGES_KEY), serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), ); + let tools_call_name = tool_calls[0].function.name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); + callout_context.tool_calls = Some(tool_calls.clone()); debug!( "prompt_target_name: {}, tool_name(s): {:?}", @@ -592,6 +663,81 @@ impl StreamContext { ); debug!("tool_params: {}", tool_params_json_str); + if model_resp.message.tool_calls.is_some() + && !model_resp.message.tool_calls.as_ref().unwrap().is_empty() + { + use serde_json::Value; + let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); + let tool_params_dict: HashMap = match v.as_object() { + Some(obj) => obj + .iter() + .filter_map(|(key, value)| { + value + .as_str() + .map(|str_value| (key.clone(), str_value.to_string())) + }) + .collect(), + None => HashMap::new(), // Return an empty HashMap if v is not an object + }; + + let hallucination_classification_request = HallucinationClassificationRequest { + prompt: callout_context.user_message.as_ref().unwrap().clone(), + model: String::from(DEFAULT_INTENT_MODEL), + parameters: tool_params_dict, + }; + + let json_data: String = + match serde_json::to_string(&hallucination_classification_request) { + Ok(json_data) => json_data, + Err(error) => { + return self.send_server_error(ServerError::Serialization(error), None); + } + }; + let call_args = CallArgs::new( + MODEL_SERVER_NAME, + "/hallucination", + vec![ + (":method", "POST"), + (":path", "/hallucination"), + (":authority", MODEL_SERVER_NAME), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ); + callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect; + + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); + } + } else { + self.schedule_api_call_request(callout_context); + } + } + + fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { + let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0] + .function + .name + .clone(); + + let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); + + //HACK: for now we only support one tool call, we will support multiple tool calls in the future + let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0] + .function + .arguments + .clone(); + tool_params.insert( + String::from(ARCH_MESSAGES_KEY), + serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), + ); + + let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); + let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); let call_args = CallArgs::new( @@ -612,8 +758,6 @@ impl StreamContext { callout_context.upstream_cluster_path = Some(path.clone()); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; - self.tool_calls = Some(tool_calls.clone()); - if let Err(e) = self.http_call(call_args, callout_context) { self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } @@ -806,6 +950,7 @@ impl StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, + tool_calls: None, }; if let Err(e) = self.http_call(call_args, call_context) { @@ -1009,6 +1154,7 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, + tool_calls: None, }; self.get_embeddings(callout_context); return Action::Pause; @@ -1057,6 +1203,7 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, + tool_calls: None, }; if let Err(e) = self.http_call(call_args, call_context) { @@ -1144,7 +1291,13 @@ impl HttpContext for StreamContext { } }; - self.response_tokens += chat_completions_response.usage.completion_tokens; + if chat_completions_response.usage.is_some() { + self.response_tokens += chat_completions_response + .usage + .as_ref() + .unwrap() + .completion_tokens; + } if let Some(tool_calls) = self.tool_calls.as_ref() { if !tool_calls.is_empty() { @@ -1239,6 +1392,9 @@ impl Context for StreamContext { ResponseHandlerType::ZeroShotIntent => { self.zero_shot_intent_detection_resp_handler(body, callout_context) } + ResponseHandlerType::HallucinationDetect => { + self.hallucination_classification_resp_handler(body, callout_context) + } ResponseHandlerType::FunctionResolver => { self.function_resolver_handler(body, callout_context) } diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index a7f7ae1d..9016b617 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -5,7 +5,7 @@ use proxy_wasm_test_framework::types::{ }; use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use public_types::common_types::PromptGuardResponse; +use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; use public_types::embeddings::{ create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, @@ -534,9 +534,9 @@ fn request_ratelimited() { normal_flow(&mut module, filter_context, http_context); let arch_fc_resp = ChatCompletionsResponse { - usage: Usage { + usage: Some(Usage { completion_tokens: 0, - }, + }), choices: vec![Choice { finish_reason: "test".to_string(), index: 0, @@ -572,6 +572,38 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/hallucination"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(5)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::None) + .unwrap(); + + let hallucatination_body = HallucinationClassificationResponse { + params_scores: HashMap::from([("city".to_string(), 0.99)]), + model: "nli-model".to_string(), + }; + + let body_text = serde_json::to_string(&hallucatination_body).unwrap(); + + module + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -585,14 +617,14 @@ fn request_ratelimited() { None, None, ) - .returning(Some(5)) + .returning(Some(6)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) @@ -642,9 +674,9 @@ fn request_not_ratelimited() { normal_flow(&mut module, filter_context, http_context); let arch_fc_resp = ChatCompletionsResponse { - usage: Usage { + usage: Some(Usage { completion_tokens: 0, - }, + }), choices: vec![Choice { finish_reason: "test".to_string(), index: 0, @@ -680,6 +712,43 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/hallucination"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(5)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::None) + .unwrap(); + + // hallucination should return that parameters were not halliucinated + // prompt: str + // parameters: dict + // model: str + + let hallucatination_body = HallucinationClassificationResponse { + params_scores: HashMap::from([("city".to_string(), 0.99)]), + model: "nli-model".to_string(), + }; + + let body_text = serde_json::to_string(&hallucatination_body).unwrap(); + + module + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -693,14 +762,14 @@ fn request_not_ratelimited() { None, None, ) - .returning(Some(5)) + .returning(Some(6)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index eba24016..63c034ae 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -49,12 +49,17 @@ def predict(message, state): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - log.debug("raw_response: ", raw_response.text) + log.info("raw_response: ", raw_response.text) response = raw_response.parse() # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request - arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None) + response_json = json.loads(raw_response.text) + arch_state = None + if response_json: + metadata = response_json.get('metadata', {}) + if metadata: + arch_state = metadata.get(ARCH_STATE_HEADER, None) if arch_state: state['arch_state'] = arch_state diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index a24ac507..fa381c0d 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -19,7 +19,7 @@ llm_providers: - name: open-ai-gpt-4 access_key: OPENAI_API_KEY provider: openai - model: gpt-4 + model: gpt-4o default: true system_prompt: | diff --git a/model_server/app/main.py b/model_server/app/main.py index 90066e7e..9f219eda 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -63,6 +63,7 @@ async def models(): @app.post("/embeddings") async def embedding(req: EmbeddingRequest, res: Response): + print(f"Embedding Call Start Time: {time.time()}") if req.model not in transformers: raise HTTPException(status_code=400, detail="unknown model: " + req.model) start = time.time() @@ -77,7 +78,7 @@ async def embedding(req: EmbeddingRequest, res: Response): "prompt_tokens": 0, "total_tokens": 0, } - + print(f"Embedding Call Complete Time: {time.time()}") return {"data": data, "model": req.model, "object": "list", "usage": usage} @@ -217,10 +218,10 @@ async def hallucination(req: HallucinationRequest, res: Response): ) result_score = result["scores"] result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)} + logger.info(f"hallucination result: {result_params}") return { "params_scores": result_params, - "raw_result": result, "model": req.model, } diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 4b338fc3..fb0f902c 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -227,7 +227,7 @@ pub mod open_ai { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsResponse { - pub usage: Usage, + pub usage: Option, pub choices: Vec, pub model: String, pub metadata: Option>, @@ -272,6 +272,19 @@ pub struct ZeroShotClassificationResponse { pub model: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HallucinationClassificationRequest { + pub prompt: String, + pub parameters: HashMap, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HallucinationClassificationResponse { + pub params_scores: HashMap, + pub model: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum PromptGuardTask { #[serde(rename = "jailbreak")]