Send back developer error correctly

This commit is contained in:
Adil Hafeez 2024-10-17 20:52:04 -07:00
parent 6cd05572c4
commit b18c8d313c
3 changed files with 32 additions and 12 deletions

View file

@ -335,8 +335,8 @@ impl HttpContext for LlmGatewayStreamContext {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
// self.send_server_error(ServerError::Deserialization(e), None);
return Action::Continue;
}
};

View file

@ -32,6 +32,7 @@ use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use common::stats::IncrementingMetric;
@ -69,11 +70,12 @@ pub enum ServerError {
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
Upstream {
authority: String,
host: String,
path: String,
status: String,
body: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),
@ -149,7 +151,6 @@ impl PromptStreamContext {
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
@ -164,6 +165,7 @@ impl PromptStreamContext {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
debug!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -234,6 +236,7 @@ impl PromptStreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing zero shot classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -263,6 +266,7 @@ impl PromptStreamContext {
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -276,6 +280,7 @@ impl PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
debug!("error deserializing hallucination response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -339,6 +344,7 @@ impl PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
debug!("error deserializing zero shot classification response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -450,6 +456,7 @@ impl PromptStreamContext {
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
@ -465,6 +472,7 @@ impl PromptStreamContext {
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
debug!("prompt target not found: {}", prompt_target_name);
return self.send_server_error(
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
@ -537,6 +545,7 @@ impl PromptStreamContext {
msg_body
}
Err(e) => {
debug!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
@ -569,6 +578,7 @@ impl PromptStreamContext {
callout_context.prompt_target_name = Some(prompt_target.name);
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
@ -580,6 +590,7 @@ impl PromptStreamContext {
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
debug!("error deserializing arch_fc response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -693,6 +704,7 @@ impl PromptStreamContext {
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing hallucination classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -789,13 +801,15 @@ impl PromptStreamContext {
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
debug!("upstream error response: {}", http_status);
return self.send_server_error(
ServerError::Upstream {
authority: callout_context.upstream_cluster.unwrap(),
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status,
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
},
None,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
} else {
@ -893,6 +907,7 @@ impl PromptStreamContext {
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
debug!("jailbreak detected: {}", msg);
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
@ -916,6 +931,7 @@ impl PromptStreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
@ -952,6 +968,7 @@ impl PromptStreamContext {
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -985,6 +1002,7 @@ impl PromptStreamContext {
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
debug!("error deserializing default target response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -1260,9 +1278,9 @@ impl HttpContext for PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
// self.send_server_error(ServerError::Deserialization(e), None);
return Action::Continue;
}
};

View file

@ -1,3 +1,4 @@
from fastapi import FastAPI, HTTPException
import json
import random
from fastapi import FastAPI, Response
@ -45,7 +46,8 @@ async def weather(req: WeatherRequest, res: Response):
}
)
return weather_forecast
raise HTTPException(status_code=404, detail="some error")
# return weather_forecast
class InsuranceClaimDetailsRequest(BaseModel):