mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more refactoring changes, getting close
This commit is contained in:
parent
63f23efda4
commit
58028bb7ae
21 changed files with 542 additions and 217 deletions
|
|
@ -13,14 +13,14 @@
|
|||
//!
|
||||
//! ```rust
|
||||
//! use hermesllm::apis::{
|
||||
//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! MessagesMessageContent, MessagesSystemPrompt,
|
||||
//! };
|
||||
//! use hermesllm::clients::TransformError;
|
||||
//! use std::convert::TryInto;
|
||||
//!
|
||||
//! // Transform Anthropic to OpenAI
|
||||
//! let anthropic_req = AnthropicMessagesRequest {
|
||||
//! let anthropic_req = MessagesRequest {
|
||||
//! model: "claude-3-sonnet".to_string(),
|
||||
//! system: None,
|
||||
//! messages: vec![],
|
||||
|
|
|
|||
|
|
@ -46,18 +46,18 @@ mod tests {
|
|||
#[test]
|
||||
fn test_provider_instance_creation() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.has_compatible_api("/v1/embeddings"));
|
||||
assert!(provider.interface().has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.interface().has_compatible_api("/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversion_mode() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
|
||||
let compatible_mode = provider.get_interface(false);
|
||||
let compatible_mode = provider.interface().get_interface(false);
|
||||
assert!(matches!(compatible_mode, ConversionMode::Compatible));
|
||||
|
||||
let passthrough_mode = provider.get_interface(true);
|
||||
let passthrough_mode = provider.interface().get_interface(true);
|
||||
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
6
crates/hermesllm/src/providers/arch/mod.rs
Normal file
6
crates/hermesllm/src/providers/arch/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
//! Arch provider implementation
|
||||
//!
|
||||
//! Arch uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::ArchProvider;
|
||||
40
crates/hermesllm/src/providers/arch/provider.rs
Normal file
40
crates/hermesllm/src/providers/arch/provider.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
//! Arch provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Arch provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ArchProvider;
|
||||
|
||||
impl ProviderInterface for ArchProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
7
crates/hermesllm/src/providers/claude/mod.rs
Normal file
7
crates/hermesllm/src/providers/claude/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
//! Claude provider implementation
|
||||
//!
|
||||
//! Claude will use a different API format in the future (/v1/messages)
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::ClaudeProvider;
|
||||
48
crates/hermesllm/src/providers/claude/provider.rs
Normal file
48
crates/hermesllm/src/providers/claude/provider.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
//! Claude provider implementation
|
||||
//!
|
||||
//! TODO: Implement Claude-specific API format (/v1/messages) when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Claude provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeProvider;
|
||||
|
||||
impl ProviderInterface for ClaudeProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Claude API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/v1/messages")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when Claude API is fully implemented
|
||||
vec!["/v1/messages"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
6
crates/hermesllm/src/providers/deepseek/mod.rs
Normal file
6
crates/hermesllm/src/providers/deepseek/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
//! Deepseek provider implementation
|
||||
//!
|
||||
//! Deepseek uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::DeepseekProvider;
|
||||
40
crates/hermesllm/src/providers/deepseek/provider.rs
Normal file
40
crates/hermesllm/src/providers/deepseek/provider.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
//! Deepseek provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Deepseek provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeepseekProvider;
|
||||
|
||||
impl ProviderInterface for DeepseekProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
7
crates/hermesllm/src/providers/gemini/mod.rs
Normal file
7
crates/hermesllm/src/providers/gemini/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
//! Gemini provider implementation
|
||||
//!
|
||||
//! Gemini will use a different API format in the future
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GeminiProvider;
|
||||
48
crates/hermesllm/src/providers/gemini/provider.rs
Normal file
48
crates/hermesllm/src/providers/gemini/provider.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
//! Gemini provider implementation
|
||||
//!
|
||||
//! TODO: Implement Gemini-specific API format when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Gemini provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiProvider;
|
||||
|
||||
impl ProviderInterface for GeminiProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Gemini API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/v1/models")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when Gemini API is fully implemented
|
||||
vec!["/v1/models"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
7
crates/hermesllm/src/providers/github/mod.rs
Normal file
7
crates/hermesllm/src/providers/github/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
//! GitHub provider implementation
|
||||
//!
|
||||
//! GitHub will use a different API format in the future (/models)
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GitHubProvider;
|
||||
48
crates/hermesllm/src/providers/github/provider.rs
Normal file
48
crates/hermesllm/src/providers/github/provider.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
//! GitHub provider implementation
|
||||
//!
|
||||
//! TODO: Implement GitHub-specific API format (/models) when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// GitHub provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitHubProvider;
|
||||
|
||||
impl ProviderInterface for GitHubProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when GitHub API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/models")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when GitHub API is fully implemented
|
||||
vec!["/models"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
6
crates/hermesllm/src/providers/groq/mod.rs
Normal file
6
crates/hermesllm/src/providers/groq/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
//! Groq provider implementation
|
||||
//!
|
||||
//! Groq uses OpenAI-compatible API format but with different endpoints
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GroqProvider;
|
||||
43
crates/hermesllm/src/providers/groq/provider.rs
Normal file
43
crates/hermesllm/src/providers/groq/provider.rs
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
//! Groq provider implementation
|
||||
//!
|
||||
//! This module contains the Groq provider that handles Groq API format requests.
|
||||
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Groq provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GroqProvider;
|
||||
|
||||
impl ProviderInterface for GroqProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/openai/v1/chat/completions"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
//! Provider interface trait definitions
|
||||
//!
|
||||
//! This module defines the core traits that all LLM providers must implement.
|
||||
//! The interface is designed around v1/chat/completions API for simplicity.
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
/// Conversion mode for provider requests/responses
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ConversionMode {
|
||||
/// Compatible: Convert between different provider formats to ensure compatibility
|
||||
Compatible,
|
||||
/// Passthrough: Pass requests/responses through with minimal modification
|
||||
Passthrough,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Error type for provider operations
|
||||
pub trait ProviderError: Error + Send + Sync + 'static {}
|
||||
|
||||
/// Request type that can be converted to/from provider-specific formats
|
||||
pub trait ProviderRequest: Sized {
|
||||
type Error: ProviderError;
|
||||
|
||||
/// Parse request from raw bytes (typically JSON)
|
||||
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to bytes for sending to upstream API
|
||||
fn to_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_text(&self) -> String;
|
||||
}
|
||||
|
||||
/// Response type that can be converted to/from provider-specific formats
|
||||
pub trait ProviderResponse: Sized {
|
||||
type Error: ProviderError;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes (typically JSON)
|
||||
fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to bytes for sending to client
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
}
|
||||
|
||||
/// Streaming response chunk
|
||||
pub trait StreamChunk: Sized {
|
||||
type Error: ProviderError;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse chunk from a line of streaming data
|
||||
fn from_line(line: &str, mode: ConversionMode) -> Result<Option<Self>, Self::Error>;
|
||||
|
||||
/// Convert to line for sending to client
|
||||
fn to_line(&self) -> Result<String, Self::Error>;
|
||||
|
||||
/// Get usage information if available (usually only in final chunk)
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
|
||||
/// Check if this is the final chunk in the stream
|
||||
fn is_final(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Main provider interface
|
||||
pub trait LLMProvider {
|
||||
type Request: ProviderRequest;
|
||||
type Response: ProviderResponse;
|
||||
type StreamChunk: StreamChunk;
|
||||
type Error: ProviderError;
|
||||
|
||||
/// Create a new instance of this provider
|
||||
fn new() -> Self;
|
||||
|
||||
/// Get the supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
|
||||
/// Check if the provider supports v1/chat/completions API
|
||||
fn supports_chat_completions(&self) -> bool {
|
||||
self.supported_apis().contains(&"/v1/chat/completions")
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<Self::Request, Self::Error>;
|
||||
|
||||
/// Parse a response from raw bytes
|
||||
fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<Self::Response, Self::Error>;
|
||||
|
||||
/// Parse streaming response chunks from raw data
|
||||
fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result<Option<Self::StreamChunk>, Self::Error>;
|
||||
}
|
||||
6
crates/hermesllm/src/providers/mistral/mod.rs
Normal file
6
crates/hermesllm/src/providers/mistral/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
//! Mistral provider implementation
|
||||
//!
|
||||
//! Mistral uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::MistralProvider;
|
||||
40
crates/hermesllm/src/providers/mistral/provider.rs
Normal file
40
crates/hermesllm/src/providers/mistral/provider.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
//! Mistral provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
|
||||
/// Mistral provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MistralProvider;
|
||||
|
||||
impl ProviderInterface for MistralProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5,10 +5,24 @@
|
|||
|
||||
pub mod traits;
|
||||
pub mod openai;
|
||||
pub mod groq;
|
||||
pub mod mistral;
|
||||
pub mod deepseek;
|
||||
pub mod arch;
|
||||
pub mod gemini;
|
||||
pub mod claude;
|
||||
pub mod github;
|
||||
|
||||
// Re-export the main interfaces
|
||||
pub use traits::*;
|
||||
pub use openai::OpenAIProvider;
|
||||
pub use groq::GroqProvider;
|
||||
pub use mistral::MistralProvider;
|
||||
pub use deepseek::DeepseekProvider;
|
||||
pub use arch::ArchProvider;
|
||||
pub use gemini::GeminiProvider;
|
||||
pub use claude::ClaudeProvider;
|
||||
pub use github::GitHubProvider;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
|
|
@ -81,27 +95,29 @@ impl ProviderId {
|
|||
}
|
||||
|
||||
/// Enum for dynamic dispatch of provider instances
|
||||
/// For now, most providers use OpenAI-compatible format
|
||||
pub enum Provider {
|
||||
OpenAI(OpenAIProvider, ProviderId),
|
||||
// TODO: Add specific implementations when providers have different APIs
|
||||
// Mistral(MistralProvider, ProviderId),
|
||||
// Groq(GroqProvider, ProviderId),
|
||||
// etc.
|
||||
Groq(GroqProvider, ProviderId),
|
||||
Mistral(MistralProvider, ProviderId),
|
||||
Deepseek(DeepseekProvider, ProviderId),
|
||||
Arch(ArchProvider, ProviderId),
|
||||
Gemini(GeminiProvider, ProviderId),
|
||||
Claude(ClaudeProvider, ProviderId),
|
||||
GitHub(GitHubProvider, ProviderId),
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// Create a provider instance from a provider ID
|
||||
pub fn new(id: ProviderId) -> Self {
|
||||
match id {
|
||||
// For now, all providers that support v1/chat/completions use OpenAI format
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => {
|
||||
Provider::OpenAI(OpenAIProvider, id)
|
||||
}
|
||||
// TODO: Implement specific providers when they have different APIs
|
||||
ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now
|
||||
}
|
||||
ProviderId::OpenAI => Provider::OpenAI(OpenAIProvider, id),
|
||||
ProviderId::Groq => Provider::Groq(GroqProvider, id),
|
||||
ProviderId::Mistral => Provider::Mistral(MistralProvider, id),
|
||||
ProviderId::Deepseek => Provider::Deepseek(DeepseekProvider, id),
|
||||
ProviderId::Arch => Provider::Arch(ArchProvider, id),
|
||||
ProviderId::Gemini => Provider::Gemini(GeminiProvider, id),
|
||||
ProviderId::Claude => Provider::Claude(ClaudeProvider, id),
|
||||
ProviderId::GitHub => Provider::GitHub(GitHubProvider, id),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -109,66 +125,27 @@ impl Provider {
|
|||
pub fn id(&self) -> ProviderId {
|
||||
match self {
|
||||
Provider::OpenAI(_, id) => *id,
|
||||
Provider::Groq(_, id) => *id,
|
||||
Provider::Mistral(_, id) => *id,
|
||||
Provider::Deepseek(_, id) => *id,
|
||||
Provider::Arch(_, id) => *id,
|
||||
Provider::Gemini(_, id) => *id,
|
||||
Provider::Claude(_, id) => *id,
|
||||
Provider::GitHub(_, id) => *id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
pub fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
/// Get the provider interface implementation
|
||||
pub fn interface(&self) -> &dyn ProviderInterface {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the interface implementation for this provider
|
||||
pub fn get_interface(&self, passthrough: bool) -> ConversionMode {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.get_interface(passthrough),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes - returns the concrete OpenAI request type for now
|
||||
pub fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a response from raw bytes - returns the concrete OpenAI response type for now
|
||||
pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::providers::traits::ProviderResponse;
|
||||
|
||||
let provider_id = self.id();
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a request to bytes for sending to upstream API
|
||||
pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
|
||||
let provider_id = self.id();
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
Provider::OpenAI(provider, _) => provider,
|
||||
Provider::Groq(provider, _) => provider,
|
||||
Provider::Mistral(provider, _) => provider,
|
||||
Provider::Deepseek(provider, _) => provider,
|
||||
Provider::Arch(provider, _) => provider,
|
||||
Provider::Gemini(provider, _) => provider,
|
||||
Provider::Claude(provider, _) => provider,
|
||||
Provider::GitHub(provider, _) => provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,11 +68,6 @@ impl Iterator for OpenAIStreamingResponse {
|
|||
}
|
||||
|
||||
impl ProviderInterface for OpenAIProvider {
|
||||
type Request = ChatCompletionsRequest;
|
||||
type Response = ChatCompletionsResponse;
|
||||
type StreamingResponse = OpenAIStreamingResponse;
|
||||
type Usage = Usage;
|
||||
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
api_path == "/v1/chat/completions"
|
||||
}
|
||||
|
|
@ -80,6 +75,30 @@ impl ProviderInterface for OpenAIProvider {
|
|||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderResponse;
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -131,6 +150,9 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
}
|
||||
}
|
||||
|
||||
// Implement the helper trait for stream context integration
|
||||
impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {}
|
||||
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
|
|
|
|||
|
|
@ -54,6 +54,31 @@ pub trait ProviderResponse: Sized {
|
|||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait for stream context integration
|
||||
pub trait StreamContextHelpers: ProviderRequest {
|
||||
/// Get the model name for routing and metrics
|
||||
fn get_model_for_routing(&self) -> String {
|
||||
self.extract_model().to_string()
|
||||
}
|
||||
|
||||
/// Get text for token counting and rate limiting
|
||||
fn get_text_for_tokenization(&self) -> String {
|
||||
self.extract_messages_text()
|
||||
}
|
||||
|
||||
/// Prepare for streaming by setting appropriate options
|
||||
fn prepare_for_streaming(&mut self) {
|
||||
if self.is_streaming() {
|
||||
self.set_streaming_options();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for streaming response chunks
|
||||
|
|
@ -75,11 +100,6 @@ pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> +
|
|||
|
||||
/// Main provider interface trait
|
||||
pub trait ProviderInterface {
|
||||
type Request: ProviderRequest;
|
||||
type Response: ProviderResponse;
|
||||
type StreamingResponse: StreamingResponse;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool;
|
||||
|
||||
|
|
@ -93,6 +113,47 @@ pub trait ProviderInterface {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes - returns concrete ChatCompletionsRequest
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Parse a response from raw bytes - returns concrete ChatCompletionsResponse
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Convert a request to bytes for sending to upstream API
|
||||
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Extract model name from request for routing (convenience method for stream_context)
|
||||
fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
use ProviderRequest;
|
||||
request.extract_model().to_string()
|
||||
}
|
||||
|
||||
/// Check if request is streaming (convenience method for stream_context)
|
||||
fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
|
||||
use ProviderRequest;
|
||||
request.is_streaming()
|
||||
}
|
||||
|
||||
/// Prepare request for streaming (convenience method for stream_context)
|
||||
fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
|
||||
use ProviderRequest;
|
||||
if request.is_streaming() {
|
||||
request.set_streaming_options();
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract text for tokenization (convenience method for stream_context)
|
||||
fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
use ProviderRequest;
|
||||
request.extract_messages_text()
|
||||
}
|
||||
|
||||
/// Extract usage information from response (convenience method for stream_context)
|
||||
fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
use ProviderResponse;
|
||||
response.extract_usage_counts()
|
||||
}
|
||||
|
||||
/// Get supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -295,7 +295,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let provider = self.get_provider();
|
||||
|
||||
let mut deserialized_body = match provider.parse_request(&body_bytes) {
|
||||
let mut deserialized_body = match provider.interface().parse_request(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -310,8 +310,8 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific
|
||||
// This could be made generic by adding a trait method later
|
||||
// TODO: For now, we'll work with the concrete ChatCompletionsRequest type
|
||||
// In the future, this could be made more generic using trait objects
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
|
|
@ -323,9 +323,10 @@ impl HttpContext for StreamContext {
|
|||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.extract_model().to_string();
|
||||
// Note: We can't directly modify the model field through the trait,
|
||||
// this would need to be handled differently in a full generic implementation
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = provider
|
||||
.interface()
|
||||
.extract_model_from_request(&deserialized_body);
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -334,15 +335,21 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
if deserialized_body.is_streaming() {
|
||||
// Use provider interface for streaming detection and setup
|
||||
if provider
|
||||
.interface()
|
||||
.is_request_streaming(&deserialized_body)
|
||||
{
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.is_streaming() {
|
||||
deserialized_body.set_streaming_options();
|
||||
provider
|
||||
.interface()
|
||||
.prepare_request_for_streaming(&mut deserialized_body);
|
||||
}
|
||||
|
||||
// only use the tokens from the messages, excluding the metadata and json tags
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// Use provider interface for text extraction
|
||||
let input_tokens_str = provider
|
||||
.interface()
|
||||
.extract_text_for_tokenization(&deserialized_body);
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
|
|
@ -354,12 +361,14 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body
|
||||
.to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible)
|
||||
{
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match provider.interface().request_to_bytes(
|
||||
&deserialized_body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
|
|
@ -558,8 +567,12 @@ impl HttpContext for StreamContext {
|
|||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider = self.get_provider();
|
||||
let _response = match provider.parse_response(&body, ConversionMode::Compatible) {
|
||||
Ok(response_box) => response_box,
|
||||
let response = match provider.interface().parse_response(
|
||||
&body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
|
|
@ -579,9 +592,18 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: Extract usage information from the response box
|
||||
// For now, we'll skip this until we have a better way to handle Any types
|
||||
warn!("Response token counting not yet implemented with new provider structure");
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
provider.interface().extract_usage_from_response(&response)
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue