diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 26348378..c8167d89 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -38,31 +38,6 @@ mod tests { assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions")); } - #[test] - fn test_provider_request_parsing() { - // Test with a sample JSON request - let json_request = r#"{ - "model": "gpt-4", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello!" - } - ] - }"#; - - let result: Result = ProviderRequestType::try_from(json_request.as_bytes()); - assert!(result.is_ok()); - - let request = result.unwrap(); - assert_eq!(request.model(), "gpt-4"); - assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string())); - } - #[test] fn test_provider_streaming_response() { // Test streaming response parsing with sample SSE data diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 69da73dd..b3f46ee1 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -28,74 +28,6 @@ pub trait ProviderRequest: Send + Sync { fn to_bytes(&self) -> Result, ProviderRequestError>; } - -impl TryFrom<&[u8]> for ProviderRequestType { - type Error = std::io::Error; - - // if passing bytes without provider id we assume the request is in OpenAI format - fn try_from(bytes: &[u8]) -> Result { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } -} - -/// Parse request based on api -impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { - type Error = std::io::Error; - - fn try_from((bytes, the_api_type): (&[u8], &SupportedAPIs)) -> Result { - // Use SupportedApi to determine the appropriate request type - match the_api_type { - SupportedAPIs::OpenAIChatCompletions(_) => { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } - SupportedAPIs::AnthropicMessagesAPI(_) => { - let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::MessagesRequest(messages_request)) - } - } - } -} - -impl TryFrom<(&ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { - type Error = ProviderRequestError; - - fn try_from((r, target_api): (&ProviderRequestType, &SupportedAPIs)) -> Result { - match (r, target_api) { - // Same API - no conversion needed, just clone the reference - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => { - Ok(ProviderRequestType::ChatCompletionsRequest(chat_req.clone())) - } - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { - Ok(ProviderRequestType::MessagesRequest(messages_req.clone())) - } - - // Cross-API conversion - cloning is necessary for transformation - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let messages_req = MessagesRequest::try_from(chat_req.clone()) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), - source: Some(Box::new(e)) - })?; - Ok(ProviderRequestType::MessagesRequest(messages_req)) - } - - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => { - let chat_req = ChatCompletionsRequest::try_from(messages_req.clone()) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), - source: Some(Box::new(e)) - })?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) - } - } - } -} - impl ProviderRequest for ProviderRequestType { fn model(&self) -> &str { match self { @@ -140,6 +72,64 @@ impl ProviderRequest for ProviderRequestType { } } +/// Parse the client API from a byte slice. +impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { + type Error = std::io::Error; + + fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result { + // Use SupportedApi to determine the appropriate request type + match client_api { + SupportedAPIs::OpenAIChatCompletions(_) => { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + SupportedAPIs::AnthropicMessagesAPI(_) => { + let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::MessagesRequest(messages_request)) + } + } + } +} + +/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs) +impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { + type Error = ProviderRequestError; + + fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result { + match (request, upstream_api) { + // Same API - no conversion needed, just clone the reference + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) + } + (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + Ok(ProviderRequestType::MessagesRequest(messages_req)) + } + + // Cross-API conversion - cloning is necessary for transformation + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let messages_req = MessagesRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::MessagesRequest(messages_req)) + } + + (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + let chat_req = ChatCompletionsRequest::try_from(messages_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) + } + } + } +} + + /// Error types for provider operations #[derive(Debug)] @@ -161,8 +151,6 @@ impl Error for ProviderRequestError { } -// ...existing code... - #[cfg(test)] mod tests { use super::*; @@ -184,7 +172,8 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let result = ProviderRequestType::try_from(bytes.as_slice()); + let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions); + let result = ProviderRequestType::try_from((bytes.as_slice(), &api)); assert!(result.is_ok()); match result.unwrap() { ProviderRequestType::ChatCompletionsRequest(r) => { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index b84d615c..46ef1ecf 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -811,7 +811,7 @@ impl HttpContext for StreamContext { self.request_identifier(), self.client_api, upstream ); - match ProviderRequestType::try_from((&deserialized_client_request, upstream)) { + match ProviderRequestType::try_from((deserialized_client_request, upstream)) { Ok(request) => { debug!( "[ARCHGW_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",