diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 655f76ff..bd2fba5e 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -316,10 +316,9 @@ impl HttpContext for StreamContext { let chat_completions_response: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(de) => de, - Err(e) => { + Err(_e) => { debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + return Action::Continue; } }; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 602f1629..da4d344f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -33,6 +33,7 @@ use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use std::str::FromStr; use std::time::Duration; use common::stats::IncrementingMetric; @@ -70,11 +71,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), @@ -149,7 +151,6 @@ impl StreamContext { } fn send_server_error(&self, error: ServerError, override_status_code: Option) { - debug!("server error occurred: {}", error); self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -164,6 +165,7 @@ impl StreamContext { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { + debug!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -234,6 +236,7 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing zero shot classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -263,6 +266,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -276,6 +280,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(hallucination_response) => hallucination_response, Err(e) => { + debug!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -339,6 +344,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { + debug!("error deserializing zero shot classification response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -450,6 +456,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -465,6 +472,7 @@ impl StreamContext { let prompt_target = match self.prompt_targets.get(&prompt_target_name) { Some(prompt_target) => prompt_target.clone(), None => { + debug!("prompt target not found: {}", prompt_target_name); return self.send_server_error( ServerError::LogicError(format!( "Prompt target not found: {prompt_target_name}" @@ -537,6 +545,7 @@ impl StreamContext { msg_body } Err(e) => { + debug!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -569,6 +578,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } } @@ -580,6 +590,7 @@ impl StreamContext { let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { + debug!("error deserializing arch_fc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -693,6 +704,7 @@ impl StreamContext { match serde_json::to_string(&hallucination_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing hallucination classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -789,13 +801,15 @@ impl StreamContext { ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { + debug!("upstream error response: {}", http_status); return self.send_server_error( ServerError::Upstream { - authority: callout_context.upstream_cluster.unwrap(), + host: callout_context.upstream_cluster.unwrap(), path: callout_context.upstream_cluster_path.unwrap(), - status: http_status, + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), }, - None, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), ); } } else { @@ -893,6 +907,7 @@ impl StreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); + debug!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -916,6 +931,7 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing get embeddings request: {}", error); return self.send_server_error(ServerError::Deserialization(error), None); } }; @@ -952,6 +968,7 @@ impl StreamContext { }; if let Err(e) = self.http_call(call_args, call_context) { + debug!("error dispatching get embeddings request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -985,6 +1002,7 @@ impl StreamContext { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { + debug!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1259,9 +1277,8 @@ impl HttpContext for StreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e); + return Action::Continue; } }; diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 5f27adc3..0338f23b 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None,