From b18c8d313cc85b3f0d9c304cbb0ed1e74ae592d8 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 17 Oct 2024 20:52:04 -0700 Subject: [PATCH] Send back developer error correctly --- crates/llm_gateway/src/llm_stream_context.rs | 4 +-- .../src/prompt_stream_context.rs | 36 ++++++++++++++----- demos/function_calling/api_server/app/main.py | 4 ++- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/llm_stream_context.rs index 6c585a72..ea6357f3 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/llm_stream_context.rs @@ -335,8 +335,8 @@ impl HttpContext for LlmGatewayStreamContext { Ok(de) => de, Err(e) => { debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + // self.send_server_error(ServerError::Deserialization(e), None); + return Action::Continue; } }; diff --git a/crates/prompt_gateway/src/prompt_stream_context.rs b/crates/prompt_gateway/src/prompt_stream_context.rs index d208f5e8..0055ba92 100644 --- a/crates/prompt_gateway/src/prompt_stream_context.rs +++ b/crates/prompt_gateway/src/prompt_stream_context.rs @@ -32,6 +32,7 @@ use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use std::str::FromStr; use std::time::Duration; use common::stats::IncrementingMetric; @@ -69,11 +70,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), @@ -149,7 +151,6 @@ impl PromptStreamContext { } fn send_server_error(&self, error: ServerError, override_status_code: Option) { - debug!("server error occurred: {}", error); self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -164,6 +165,7 @@ impl PromptStreamContext { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { + debug!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -234,6 +236,7 @@ impl PromptStreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing zero shot classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -263,6 +266,7 @@ impl PromptStreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -276,6 +280,7 @@ impl PromptStreamContext { match serde_json::from_slice(&body) { Ok(hallucination_response) => hallucination_response, Err(e) => { + debug!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -339,6 +344,7 @@ impl PromptStreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { + debug!("error deserializing zero shot classification response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -450,6 +456,7 @@ impl PromptStreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -465,6 +472,7 @@ impl PromptStreamContext { let prompt_target = match self.prompt_targets.get(&prompt_target_name) { Some(prompt_target) => prompt_target.clone(), None => { + debug!("prompt target not found: {}", prompt_target_name); return self.send_server_error( ServerError::LogicError(format!( "Prompt target not found: {prompt_target_name}" @@ -537,6 +545,7 @@ impl PromptStreamContext { msg_body } Err(e) => { + debug!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -569,6 +578,7 @@ impl PromptStreamContext { callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } } @@ -580,6 +590,7 @@ impl PromptStreamContext { let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { + debug!("error deserializing arch_fc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -693,6 +704,7 @@ impl PromptStreamContext { match serde_json::to_string(&hallucination_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing hallucination classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -789,13 +801,15 @@ impl PromptStreamContext { ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { + debug!("upstream error response: {}", http_status); return self.send_server_error( ServerError::Upstream { - authority: callout_context.upstream_cluster.unwrap(), + host: callout_context.upstream_cluster.unwrap(), path: callout_context.upstream_cluster_path.unwrap(), - status: http_status, + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), }, - None, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), ); } } else { @@ -893,6 +907,7 @@ impl PromptStreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); + debug!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -916,6 +931,7 @@ impl PromptStreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing get embeddings request: {}", error); return self.send_server_error(ServerError::Deserialization(error), None); } }; @@ -952,6 +968,7 @@ impl PromptStreamContext { }; if let Err(e) = self.http_call(call_args, call_context) { + debug!("error dispatching get embeddings request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -985,6 +1002,7 @@ impl PromptStreamContext { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { + debug!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1260,9 +1278,9 @@ impl HttpContext for PromptStreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e); + // self.send_server_error(ServerError::Deserialization(e), None); + return Action::Continue; } }; diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index 041a921d..b692ef70 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -1,3 +1,4 @@ +from fastapi import FastAPI, HTTPException import json import random from fastapi import FastAPI, Response @@ -45,7 +46,8 @@ async def weather(req: WeatherRequest, res: Response): } ) - return weather_forecast + raise HTTPException(status_code=404, detail="some error") + # return weather_forecast class InsuranceClaimDetailsRequest(BaseModel):