mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
making sure that we convert the raw bytes to the correct provider type upstream
This commit is contained in:
parent
c55979307e
commit
d4dfbe600f
2 changed files with 110 additions and 49 deletions
|
|
@ -8,39 +8,6 @@ pub enum ProviderRequestType {
|
|||
MessagesRequest(MessagesRequest),
|
||||
//add more request types here
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request based on endpoint and provider information
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, endpoint): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match endpoint {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
|
@ -61,6 +28,74 @@ pub trait ProviderRequest: Send + Sync {
|
|||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request based on api
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, the_api_type): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match the_api_type {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((r, target_api): (&ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (r, target_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req.clone()))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req.clone()))
|
||||
}
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req.clone())
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req.clone())
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
fn model(&self) -> &str {
|
||||
match self {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue