mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed the dup cloning issue and cleaned up the ProviderRequestType enum and traits
This commit is contained in:
parent
c8b59aeda7
commit
e8881c7b8a
3 changed files with 61 additions and 97 deletions
|
|
@ -38,31 +38,6 @@ mod tests {
|
|||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_request_parsing() {
|
||||
// Test with a sample JSON request
|
||||
let json_request = r#"{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: Result<ProviderRequestType, std::io::Error> = ProviderRequestType::try_from(json_request.as_bytes());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let request = result.unwrap();
|
||||
assert_eq!(request.model(), "gpt-4");
|
||||
assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
|
|
|
|||
|
|
@ -28,74 +28,6 @@ pub trait ProviderRequest: Send + Sync {
|
|||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request based on api
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, the_api_type): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match the_api_type {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((r, target_api): (&ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (r, target_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req.clone()))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req.clone()))
|
||||
}
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req.clone())
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req.clone())
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
fn model(&self) -> &str {
|
||||
match self {
|
||||
|
|
@ -140,6 +72,64 @@ impl ProviderRequest for ProviderRequestType {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parse the client API from a byte slice.
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
|
||||
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
|
|
@ -161,8 +151,6 @@ impl Error for ProviderRequestError {
|
|||
}
|
||||
|
||||
|
||||
// ...existing code...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -184,7 +172,8 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let result = ProviderRequestType::try_from(bytes.as_slice());
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
|
|
|
|||
|
|
@ -811,7 +811,7 @@ impl HttpContext for StreamContext {
|
|||
self.request_identifier(), self.client_api, upstream
|
||||
);
|
||||
|
||||
match ProviderRequestType::try_from((&deserialized_client_request, upstream)) {
|
||||
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
|
||||
Ok(request) => {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue