making sure that we convert the raw bytes to the correct provider type upstream

This commit is contained in:
Salman Paracha 2025-09-02 16:19:45 -07:00
parent c55979307e
commit d4dfbe600f
2 changed files with 110 additions and 49 deletions

View file

@ -289,7 +289,7 @@ impl StreamContext {
}
}
fn read_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
fn read_raw_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
@ -583,9 +583,10 @@ impl HttpContext for StreamContext {
}
};
let mut deserialized_body = match self.resolved_api.as_ref() {
Some(resolved_api) => {
match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) {
//We need to deserialize the request body based on the resolved API
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
Some(the_client_api) => {
match ProviderRequestType::try_from((&body_bytes[..], the_client_api)) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
@ -620,7 +621,7 @@ impl HttpContext for StreamContext {
};
// Store the original model for logging
let model_requested = deserialized_body.model().to_string();
let model_requested = deserialized_client_request.model().to_string();
// Apply model name resolution logic using the trait method
let resolved_model = match model_name {
@ -646,10 +647,10 @@ impl HttpContext for StreamContext {
};
// Set the resolved model using the trait method
deserialized_body.set_model(resolved_model.clone());
deserialized_client_request.set_model(resolved_model.clone());
// Extract user message for tracing
self.user_message = deserialized_body.get_recent_user_message();
self.user_message = deserialized_client_request.get_recent_user_message();
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
@ -659,10 +660,10 @@ impl HttpContext for StreamContext {
);
// Use provider interface for streaming detection and setup
self.streaming_response = deserialized_body.is_streaming();
self.streaming_response = deserialized_client_request.is_streaming();
// Use provider interface for text extraction (after potential mutation)
let input_tokens_str = deserialized_body.extract_messages_text();
let input_tokens_str = deserialized_client_request.extract_messages_text();
// enforce ratelimits on ingress
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
self.send_server_error(
@ -674,19 +675,44 @@ impl HttpContext for StreamContext {
}
// Convert chat completion request to llm provider specific request using provider interface
let deserialized_body_bytes = match deserialized_body.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
let serialized_body_bytes_upstream = match self.resolved_api.as_ref() {
Some(upstream) => {
match ProviderRequestType::try_from((&deserialized_client_request, upstream)) {
Ok(request) => match request.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
self.send_server_error(
ServerError::LogicError(format!(
"Request serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
Err(e) => {
warn!("Failed to create provider request: {}", e);
self.send_server_error(
ServerError::LogicError(format!("Provider request error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
warn!("No upstream API resolved");
self.send_server_error(
ServerError::LogicError(format!("Request serialization error: {}", e)),
ServerError::LogicError("No upstream API resolved".into()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream);
Action::Continue
}
@ -734,7 +760,7 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let body = match self.read_response_body(body_size) {
let body = match self.read_raw_response_body(body_size) {
Ok(bytes) => bytes,
Err(action) => return action,
};