Send back developer error correctly (#195)

This commit is contained in:
Adil Hafeez 2024-10-18 13:14:18 -07:00 committed by GitHub
parent 28421353fd
commit 1719b7d5f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 13 deletions

View file

@ -316,10 +316,9 @@ impl HttpContext for StreamContext {
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
Err(_e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
return Action::Continue;
}
};

View file

@ -33,6 +33,7 @@ use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use common::stats::IncrementingMetric;
@ -70,11 +71,12 @@ pub enum ServerError {
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
Upstream {
authority: String,
host: String,
path: String,
status: String,
body: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),
@ -149,7 +151,6 @@ impl StreamContext {
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
@ -164,6 +165,7 @@ impl StreamContext {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
debug!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -234,6 +236,7 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing zero shot classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -263,6 +266,7 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -276,6 +280,7 @@ impl StreamContext {
match serde_json::from_slice(&body) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
debug!("error deserializing hallucination response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -339,6 +344,7 @@ impl StreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
debug!("error deserializing zero shot classification response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -450,6 +456,7 @@ impl StreamContext {
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
@ -465,6 +472,7 @@ impl StreamContext {
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
debug!("prompt target not found: {}", prompt_target_name);
return self.send_server_error(
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
@ -537,6 +545,7 @@ impl StreamContext {
msg_body
}
Err(e) => {
debug!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
@ -569,6 +578,7 @@ impl StreamContext {
callout_context.prompt_target_name = Some(prompt_target.name);
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
@ -580,6 +590,7 @@ impl StreamContext {
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
debug!("error deserializing arch_fc response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -693,6 +704,7 @@ impl StreamContext {
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing hallucination classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -789,13 +801,15 @@ impl StreamContext {
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
debug!("upstream error response: {}", http_status);
return self.send_server_error(
ServerError::Upstream {
authority: callout_context.upstream_cluster.unwrap(),
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status,
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
},
None,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
} else {
@ -893,6 +907,7 @@ impl StreamContext {
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
debug!("jailbreak detected: {}", msg);
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
@ -916,6 +931,7 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
@ -952,6 +968,7 @@ impl StreamContext {
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -985,6 +1002,7 @@ impl StreamContext {
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
debug!("error deserializing default target response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -1259,9 +1277,8 @@ impl HttpContext for StreamContext {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
return Action::Continue;
}
};

View file

@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,