From cc730d8b3dee6abeb60d55034ee2ba4ab6e03ffb Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 27 Jan 2025 15:44:41 -0800 Subject: [PATCH] fix logs a bit more and refactor code --- crates/llm_gateway/src/stream_context.rs | 7 +++- crates/prompt_gateway/src/context.rs | 42 ++++++++++++++------- crates/prompt_gateway/src/http_context.rs | 4 +- crates/prompt_gateway/src/stream_context.rs | 8 ++-- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index a0714e80..68a5a67c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -517,8 +517,11 @@ impl HttpContext for StreamContext { let chat_completions_response: ChatCompletionsResponse = match serde_json::from_str(body_utf8.as_str()) { Ok(de) => de, - Err(_e) => { - debug!("invalid response: {}", body_utf8); + Err(err) => { + debug!( + "non chat-completion compliant response received err: {}, body: {}", + err, body_utf8 + ); return Action::Continue; } }; diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs index c70937d7..5e8fcf0a 100644 --- a/crates/prompt_gateway/src/context.rs +++ b/crates/prompt_gateway/src/context.rs @@ -27,20 +27,36 @@ impl Context for StreamContext { .get_http_call_response_body(0, body_size) .unwrap_or_default(); - let http_status = self - .get_http_call_response_header(":status") - .unwrap_or(StatusCode::OK.as_str().to_string()); - if http_status != StatusCode::OK.as_str() { - let server_error = ServerError::Upstream { - host: callout_context.upstream_cluster.unwrap(), - path: callout_context.upstream_cluster_path.unwrap(), - status: http_status.clone(), - body: String::from_utf8(body).unwrap(), - }; - warn!("filter received non 2xx code: {:?}", server_error); + if let Some(http_status) = self.get_http_call_response_header(":status") { + match StatusCode::from_str(http_status.as_str()) { + Ok(status_code) => { + if !status_code.is_success() { + let server_error = ServerError::Upstream { + host: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), + }; + warn!("received non 2xx code: {:?}", server_error); + return self.send_server_error( + server_error, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); + } + } + Err(_) => { + // invalid status code (status code non numeric) + return self.send_server_error( + ServerError::LogicError(format!("invalid status code: {}", http_status)), + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); + } + } + } else { + // :status header not found return self.send_server_error( - server_error, - Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ServerError::LogicError("missing :status header".to_string()), + Some(StatusCode::BAD_GATEWAY), ); } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index f105fd7e..58a725b4 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -79,7 +79,7 @@ impl HttpContext for StreamContext { }; debug!( - "developer => archgw: {}", + "developer => archgw request body: {}", String::from_utf8_lossy(&body_bytes) ); @@ -152,7 +152,7 @@ impl HttpContext for StreamContext { } }; - debug!("archgw => archfc: {}", json_data); + debug!("archgw => modelserver request body: {}", json_data); let mut headers = vec![ (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 678eedab..a0cc7a88 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -125,13 +125,13 @@ impl StreamContext { mut callout_context: StreamCallContext, ) { let body_str = String::from_utf8(body).unwrap(); - debug!("archgw <= archfc response: {}", body_str); + debug!("archgw <= modelserver response body: {}", body_str); let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { warn!( - "error deserializing archfc response: {}, body: {}", + "error deserializing modelserver response: {}, body: {}", e, body_str ); return self.send_server_error(ServerError::Deserialization(e), None); @@ -141,7 +141,7 @@ impl StreamContext { let arch_fc_response = match model_server_response { ModelServerResponse::ChatCompletionsResponse(response) => response, ModelServerResponse::ModelServerErrorResponse(response) => { - debug!("archgw <= archfc error response: {}", response.result); + debug!("archgw <= modelserver error response: {}", response.result); if response.result == "No intent matched" { if let Some(default_prompt_target) = self .prompt_targets @@ -344,7 +344,7 @@ impl StreamContext { ); debug!( - "archgw => api call, endpoint: {}{}, body: {}", + "archgw => developer api call endpoint: {}, path: {}, body: {}", endpoint.name.as_str(), path, tool_params_json_str