mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
making sure that we convert the raw bytes to the correct provider type upstream
This commit is contained in:
parent
c55979307e
commit
d4dfbe600f
2 changed files with 110 additions and 49 deletions
|
|
@ -8,39 +8,6 @@ pub enum ProviderRequestType {
|
||||||
MessagesRequest(MessagesRequest),
|
MessagesRequest(MessagesRequest),
|
||||||
//add more request types here
|
//add more request types here
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
|
||||||
type Error = std::io::Error;
|
|
||||||
|
|
||||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
|
||||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
|
||||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
|
||||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
|
||||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse request based on endpoint and provider information
|
|
||||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|
||||||
type Error = std::io::Error;
|
|
||||||
|
|
||||||
fn try_from((bytes, endpoint): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
|
||||||
// Use SupportedApi to determine the appropriate request type
|
|
||||||
match endpoint {
|
|
||||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
|
||||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
|
||||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
|
||||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
|
||||||
}
|
|
||||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
|
||||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
|
||||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
|
||||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ProviderRequest: Send + Sync {
|
pub trait ProviderRequest: Send + Sync {
|
||||||
/// Extract the model name from the request
|
/// Extract the model name from the request
|
||||||
fn model(&self) -> &str;
|
fn model(&self) -> &str;
|
||||||
|
|
@ -61,6 +28,74 @@ pub trait ProviderRequest: Send + Sync {
|
||||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse request based on api
|
||||||
|
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn try_from((bytes, the_api_type): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||||
|
// Use SupportedApi to determine the appropriate request type
|
||||||
|
match the_api_type {
|
||||||
|
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||||
|
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||||
|
}
|
||||||
|
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||||
|
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<(&ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||||
|
type Error = ProviderRequestError;
|
||||||
|
|
||||||
|
fn try_from((r, target_api): (&ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||||
|
match (r, target_api) {
|
||||||
|
// Same API - no conversion needed, just clone the reference
|
||||||
|
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req.clone()))
|
||||||
|
}
|
||||||
|
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||||
|
Ok(ProviderRequestType::MessagesRequest(messages_req.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cross-API conversion - cloning is necessary for transformation
|
||||||
|
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||||
|
let messages_req = MessagesRequest::try_from(chat_req.clone())
|
||||||
|
.map_err(|e| ProviderRequestError {
|
||||||
|
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||||
|
source: Some(Box::new(e))
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||||
|
}
|
||||||
|
|
||||||
|
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||||
|
let chat_req = ChatCompletionsRequest::try_from(messages_req.clone())
|
||||||
|
.map_err(|e| ProviderRequestError {
|
||||||
|
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||||
|
source: Some(Box::new(e))
|
||||||
|
})?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ProviderRequest for ProviderRequestType {
|
impl ProviderRequest for ProviderRequestType {
|
||||||
fn model(&self) -> &str {
|
fn model(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
|
fn read_raw_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
|
||||||
if self.streaming_response {
|
if self.streaming_response {
|
||||||
let chunk_start = 0;
|
let chunk_start = 0;
|
||||||
let chunk_size = body_size;
|
let chunk_size = body_size;
|
||||||
|
|
@ -583,9 +583,10 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut deserialized_body = match self.resolved_api.as_ref() {
|
//We need to deserialize the request body based on the resolved API
|
||||||
Some(resolved_api) => {
|
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
|
||||||
match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) {
|
Some(the_client_api) => {
|
||||||
|
match ProviderRequestType::try_from((&body_bytes[..], the_client_api)) {
|
||||||
Ok(deserialized) => deserialized,
|
Ok(deserialized) => deserialized,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!(
|
debug!(
|
||||||
|
|
@ -620,7 +621,7 @@ impl HttpContext for StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Store the original model for logging
|
// Store the original model for logging
|
||||||
let model_requested = deserialized_body.model().to_string();
|
let model_requested = deserialized_client_request.model().to_string();
|
||||||
|
|
||||||
// Apply model name resolution logic using the trait method
|
// Apply model name resolution logic using the trait method
|
||||||
let resolved_model = match model_name {
|
let resolved_model = match model_name {
|
||||||
|
|
@ -646,10 +647,10 @@ impl HttpContext for StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set the resolved model using the trait method
|
// Set the resolved model using the trait method
|
||||||
deserialized_body.set_model(resolved_model.clone());
|
deserialized_client_request.set_model(resolved_model.clone());
|
||||||
|
|
||||||
// Extract user message for tracing
|
// Extract user message for tracing
|
||||||
self.user_message = deserialized_body.get_recent_user_message();
|
self.user_message = deserialized_client_request.get_recent_user_message();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||||
|
|
@ -659,10 +660,10 @@ impl HttpContext for StreamContext {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use provider interface for streaming detection and setup
|
// Use provider interface for streaming detection and setup
|
||||||
self.streaming_response = deserialized_body.is_streaming();
|
self.streaming_response = deserialized_client_request.is_streaming();
|
||||||
|
|
||||||
// Use provider interface for text extraction (after potential mutation)
|
// Use provider interface for text extraction (after potential mutation)
|
||||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
let input_tokens_str = deserialized_client_request.extract_messages_text();
|
||||||
// enforce ratelimits on ingress
|
// enforce ratelimits on ingress
|
||||||
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
|
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
|
|
@ -674,19 +675,44 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert chat completion request to llm provider specific request using provider interface
|
// Convert chat completion request to llm provider specific request using provider interface
|
||||||
let deserialized_body_bytes = match deserialized_body.to_bytes() {
|
let serialized_body_bytes_upstream = match self.resolved_api.as_ref() {
|
||||||
Ok(bytes) => bytes,
|
Some(upstream) => {
|
||||||
Err(e) => {
|
match ProviderRequestType::try_from((&deserialized_client_request, upstream)) {
|
||||||
warn!("Failed to serialize request body: {}", e);
|
Ok(request) => match request.to_bytes() {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to serialize request body: {}", e);
|
||||||
|
self.send_server_error(
|
||||||
|
ServerError::LogicError(format!(
|
||||||
|
"Request serialization error: {}",
|
||||||
|
e
|
||||||
|
)),
|
||||||
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to create provider request: {}", e);
|
||||||
|
self.send_server_error(
|
||||||
|
ServerError::LogicError(format!("Provider request error: {}", e)),
|
||||||
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!("No upstream API resolved");
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
ServerError::LogicError("No upstream API resolved".into()),
|
||||||
Some(StatusCode::BAD_REQUEST),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
);
|
);
|
||||||
return Action::Pause;
|
return Action::Pause;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream);
|
||||||
Action::Continue
|
Action::Continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -734,7 +760,7 @@ impl HttpContext for StreamContext {
|
||||||
return Action::Continue;
|
return Action::Continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let body = match self.read_response_body(body_size) {
|
let body = match self.read_raw_response_body(body_size) {
|
||||||
Ok(bytes) => bytes,
|
Ok(bytes) => bytes,
|
||||||
Err(action) => return action,
|
Err(action) => return action,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue