mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
saving changes, although we will need a small re-factor after this as well
This commit is contained in:
parent
203fc8f9a9
commit
63f23efda4
10 changed files with 414 additions and 259 deletions
|
|
@ -178,14 +178,13 @@ impl Display for LlmProviderType {
|
|||
}
|
||||
|
||||
impl LlmProviderType {
|
||||
/// Create a ProviderInstance from this LlmProviderType
|
||||
/// Create a Provider from this LlmProviderType
|
||||
/// This is the main method for stream_context to get provider-specific interfaces
|
||||
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
|
||||
use hermesllm::ProviderInstance;
|
||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
||||
use hermesllm::{ProviderId, Provider};
|
||||
|
||||
// For now, all providers use OpenAI-compatible APIs
|
||||
// TODO: Return specific provider instances when implementing different APIs
|
||||
ProviderInstance::from_name(&self.to_string())
|
||||
let provider_id = ProviderId::from(self.to_string().as_str());
|
||||
Provider::new(provider_id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -265,10 +264,10 @@ impl Display for LlmProvider {
|
|||
}
|
||||
|
||||
impl LlmProvider {
|
||||
/// Create a ProviderInstance for this LlmProvider
|
||||
/// Create a Provider for this LlmProvider
|
||||
/// This is a convenience method that delegates to the provider_interface
|
||||
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
|
||||
self.provider_interface.create_provider_instance()
|
||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
||||
self.provider_interface.create_provider()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,121 +5,59 @@ pub mod providers;
|
|||
pub mod apis;
|
||||
pub mod clients;
|
||||
|
||||
// Re-export important traits
|
||||
pub use providers::traits::*;
|
||||
pub use providers::openai::provider::OpenAIProvider;
|
||||
pub use providers::provider_enum::ProviderInstance;
|
||||
|
||||
|
||||
use std::fmt::Display;
|
||||
pub enum Provider {
|
||||
Arch,
|
||||
Mistral,
|
||||
Deepseek,
|
||||
Groq,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Claude,
|
||||
Github,
|
||||
}
|
||||
|
||||
impl From<&str> for Provider {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"arch" => Provider::Arch,
|
||||
"mistral" => Provider::Mistral,
|
||||
"deepseek" => Provider::Deepseek,
|
||||
"groq" => Provider::Groq,
|
||||
"gemini" => Provider::Gemini,
|
||||
"openai" => Provider::OpenAI,
|
||||
"claude" => Provider::Claude,
|
||||
"github" => Provider::Github,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// Get the API endpoint path for this provider
|
||||
pub fn api_path(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::OpenAI => "/v1/chat/completions",
|
||||
Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint
|
||||
Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path
|
||||
Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path
|
||||
Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API
|
||||
Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API
|
||||
Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint
|
||||
Provider::Github => "/models", // TODO: Update with correct GitHub models path
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider uses OpenAI-compatible API format
|
||||
pub fn uses_openai_format(&self) -> bool {
|
||||
match self {
|
||||
Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true,
|
||||
Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a provider implementation instance for this provider
|
||||
pub fn create_provider_instance(&self) -> ProviderInstance {
|
||||
match self {
|
||||
Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider),
|
||||
Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API
|
||||
Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API
|
||||
Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API
|
||||
Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API
|
||||
// TODO: Implement specific providers for these when they have different APIs
|
||||
Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Provider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Provider::Arch => write!(f, "Arch"),
|
||||
Provider::Mistral => write!(f, "Mistral"),
|
||||
Provider::Deepseek => write!(f, "Deepseek"),
|
||||
Provider::Groq => write!(f, "Groq"),
|
||||
Provider::Gemini => write!(f, "Gemini"),
|
||||
Provider::OpenAI => write!(f, "OpenAI"),
|
||||
Provider::Claude => write!(f, "Claude"),
|
||||
Provider::Github => write!(f, "Github"),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Re-export important types and traits
|
||||
pub use providers::{
|
||||
ProviderId, Provider, ConversionMode,
|
||||
ProviderInterface, ProviderRequest, ProviderResponse,
|
||||
TokenUsage, StreamChunk, StreamingResponse,
|
||||
OpenAIProvider
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn openai_builder() {
|
||||
let request =
|
||||
ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.n(1)
|
||||
.max_tokens(100)
|
||||
.stream(false)
|
||||
.stop(vec!["\n".to_string()])
|
||||
.presence_penalty(0.0)
|
||||
.frequency_penalty(0.0)
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
fn test_provider_id_conversion() {
|
||||
assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI);
|
||||
assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral);
|
||||
assert_eq!(ProviderId::from("groq"), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
|
||||
}
|
||||
|
||||
assert_eq!(request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.top_p, Some(0.9));
|
||||
assert_eq!(request.n, Some(1));
|
||||
assert_eq!(request.max_tokens, Some(100));
|
||||
assert_eq!(request.stream, Some(false));
|
||||
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.frequency_penalty, Some(0.0));
|
||||
#[test]
|
||||
fn test_provider_api_paths() {
|
||||
assert_eq!(ProviderId::OpenAI.api_path(), "/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Groq.api_path(), "/openai/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Mistral.api_path(), "/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Arch.api_path(), "/v1/chat/completions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_openai_format_support() {
|
||||
assert!(ProviderId::OpenAI.supports_openai_format());
|
||||
assert!(ProviderId::Groq.supports_openai_format());
|
||||
assert!(ProviderId::Mistral.supports_openai_format());
|
||||
assert!(ProviderId::Arch.supports_openai_format());
|
||||
assert!(!ProviderId::Gemini.supports_openai_format());
|
||||
assert!(!ProviderId::Claude.supports_openai_format());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_instance_creation() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.has_compatible_api("/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversion_mode() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
|
||||
let compatible_mode = provider.get_interface(false);
|
||||
assert!(matches!(compatible_mode, ConversionMode::Compatible));
|
||||
|
||||
let passthrough_mode = provider.get_interface(true);
|
||||
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
109
crates/hermesllm/src/providers/interface.rs
Normal file
109
crates/hermesllm/src/providers/interface.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
//! Provider interface trait definitions
|
||||
//!
|
||||
//! This module defines the core traits that all LLM providers must implement.
|
||||
//! The interface is designed around v1/chat/completions API for simplicity.
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
/// Conversion mode for provider requests/responses
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ConversionMode {
|
||||
/// Compatible: Convert between different provider formats to ensure compatibility
|
||||
Compatible,
|
||||
/// Passthrough: Pass requests/responses through with minimal modification
|
||||
Passthrough,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Error type for provider operations
|
||||
pub trait ProviderError: Error + Send + Sync + 'static {}
|
||||
|
||||
/// Request type that can be converted to/from provider-specific formats
|
||||
pub trait ProviderRequest: Sized {
|
||||
type Error: ProviderError;
|
||||
|
||||
/// Parse request from raw bytes (typically JSON)
|
||||
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to bytes for sending to upstream API
|
||||
fn to_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_text(&self) -> String;
|
||||
}
|
||||
|
||||
/// Response type that can be converted to/from provider-specific formats
|
||||
pub trait ProviderResponse: Sized {
|
||||
type Error: ProviderError;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes (typically JSON)
|
||||
fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to bytes for sending to client
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
}
|
||||
|
||||
/// Streaming response chunk
|
||||
pub trait StreamChunk: Sized {
|
||||
type Error: ProviderError;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse chunk from a line of streaming data
|
||||
fn from_line(line: &str, mode: ConversionMode) -> Result<Option<Self>, Self::Error>;
|
||||
|
||||
/// Convert to line for sending to client
|
||||
fn to_line(&self) -> Result<String, Self::Error>;
|
||||
|
||||
/// Get usage information if available (usually only in final chunk)
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
|
||||
/// Check if this is the final chunk in the stream
|
||||
fn is_final(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Main provider interface
|
||||
pub trait LLMProvider {
|
||||
type Request: ProviderRequest;
|
||||
type Response: ProviderResponse;
|
||||
type StreamChunk: StreamChunk;
|
||||
type Error: ProviderError;
|
||||
|
||||
/// Create a new instance of this provider
|
||||
fn new() -> Self;
|
||||
|
||||
/// Get the supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
|
||||
/// Check if the provider supports v1/chat/completions API
|
||||
fn supports_chat_completions(&self) -> bool {
|
||||
self.supported_apis().contains(&"/v1/chat/completions")
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<Self::Request, Self::Error>;
|
||||
|
||||
/// Parse a response from raw bytes
|
||||
fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<Self::Response, Self::Error>;
|
||||
|
||||
/// Parse streaming response chunks from raw data
|
||||
fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result<Option<Self::StreamChunk>, Self::Error>;
|
||||
}
|
||||
|
|
@ -1,3 +1,174 @@
|
|||
pub mod openai;
|
||||
//! Provider implementations for different LLM APIs
|
||||
//!
|
||||
//! This module contains provider-specific implementations that handle
|
||||
//! request/response conversion for different LLM service APIs.
|
||||
|
||||
pub mod traits;
|
||||
pub mod provider_enum;
|
||||
pub mod openai;
|
||||
|
||||
// Re-export the main interfaces
|
||||
pub use traits::*;
|
||||
pub use openai::OpenAIProvider;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ProviderId {
|
||||
OpenAI,
|
||||
Mistral,
|
||||
Deepseek,
|
||||
Groq,
|
||||
Gemini,
|
||||
Claude,
|
||||
GitHub,
|
||||
Arch,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"openai" => ProviderId::OpenAI,
|
||||
"mistral" => ProviderId::Mistral,
|
||||
"deepseek" => ProviderId::Deepseek,
|
||||
"groq" => ProviderId::Groq,
|
||||
"gemini" => ProviderId::Gemini,
|
||||
"claude" => ProviderId::Claude,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ProviderId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProviderId::OpenAI => write!(f, "OpenAI"),
|
||||
ProviderId::Mistral => write!(f, "Mistral"),
|
||||
ProviderId::Deepseek => write!(f, "Deepseek"),
|
||||
ProviderId::Groq => write!(f, "Groq"),
|
||||
ProviderId::Gemini => write!(f, "Gemini"),
|
||||
ProviderId::Claude => write!(f, "Claude"),
|
||||
ProviderId::GitHub => write!(f, "GitHub"),
|
||||
ProviderId::Arch => write!(f, "Arch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Get the API endpoint path for this provider
|
||||
pub fn api_path(&self) -> &'static str {
|
||||
match self {
|
||||
ProviderId::OpenAI => "/v1/chat/completions",
|
||||
ProviderId::Groq => "/openai/v1/chat/completions",
|
||||
ProviderId::Gemini => "/v1/models", // TODO: Update when Gemini API is implemented
|
||||
ProviderId::Claude => "/v1/messages", // TODO: Update when Claude API is implemented
|
||||
ProviderId::Mistral => "/v1/chat/completions",
|
||||
ProviderId::Deepseek => "/v1/chat/completions",
|
||||
ProviderId::GitHub => "/models", // TODO: Update when GitHub models API is implemented
|
||||
ProviderId::Arch => "/v1/chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider supports OpenAI v1/chat/completions API format
|
||||
pub fn supports_openai_format(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum for dynamic dispatch of provider instances
|
||||
/// For now, most providers use OpenAI-compatible format
|
||||
pub enum Provider {
|
||||
OpenAI(OpenAIProvider, ProviderId),
|
||||
// TODO: Add specific implementations when providers have different APIs
|
||||
// Mistral(MistralProvider, ProviderId),
|
||||
// Groq(GroqProvider, ProviderId),
|
||||
// etc.
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// Create a provider instance from a provider ID
|
||||
pub fn new(id: ProviderId) -> Self {
|
||||
match id {
|
||||
// For now, all providers that support v1/chat/completions use OpenAI format
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => {
|
||||
Provider::OpenAI(OpenAIProvider, id)
|
||||
}
|
||||
// TODO: Implement specific providers when they have different APIs
|
||||
ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the provider ID
|
||||
pub fn id(&self) -> ProviderId {
|
||||
match self {
|
||||
Provider::OpenAI(_, id) => *id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
pub fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the interface implementation for this provider
|
||||
pub fn get_interface(&self, passthrough: bool) -> ConversionMode {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.get_interface(passthrough),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes - returns the concrete OpenAI request type for now
|
||||
pub fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a response from raw bytes - returns the concrete OpenAI response type for now
|
||||
pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::providers::traits::ProviderResponse;
|
||||
|
||||
let provider_id = self.id();
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a request to bytes for sending to upstream API
|
||||
pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
Provider::OpenAI(_, _) => {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
|
||||
let provider_id = self.id();
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
pub mod builder;
|
||||
pub mod types;
|
||||
pub mod provider;
|
||||
|
||||
// Re-export the main provider
|
||||
pub use provider::OpenAIProvider;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
use crate::apis::openai::*;
|
||||
use crate::providers::traits::*;
|
||||
use crate::Provider;
|
||||
|
||||
// Simple error type for OpenAI API operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -73,6 +72,14 @@ impl ProviderInterface for OpenAIProvider {
|
|||
type Response = ChatCompletionsResponse;
|
||||
type StreamingResponse = OpenAIStreamingResponse;
|
||||
type Usage = Usage;
|
||||
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
api_path == "/v1/chat/completions"
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -87,7 +94,7 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, _provider: Provider) -> Result<Vec<u8>, Self::Error> {
|
||||
fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
Ok(serde_json::to_vec(self)?)
|
||||
}
|
||||
|
||||
|
|
@ -142,7 +149,7 @@ impl ProviderResponse for ChatCompletionsResponse {
|
|||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
|
@ -164,7 +171,7 @@ impl StreamingResponse for OpenAIStreamingResponse {
|
|||
type Error = OpenAIApiError;
|
||||
type Chunk = ChatCompletionsStreamResponse;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use std::convert::TryFrom;
|
|||
use std::str;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::Provider;
|
||||
use crate::providers::ProviderId;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
|
|
@ -144,28 +144,26 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
|
||||
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
|
||||
// Use input.provider as needed, if necessary
|
||||
serde_json::from_slice(input.0).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
|
||||
pub fn to_bytes(&self, provider: ProviderId) -> Result<Vec<u8>> {
|
||||
match provider {
|
||||
Provider::OpenAI
|
||||
| Provider::Arch
|
||||
| Provider::Deepseek
|
||||
| Provider::Mistral
|
||||
| Provider::Groq
|
||||
| Provider::Gemini
|
||||
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
|
||||
_ => Err(OpenAIError::UnsupportedProvider {
|
||||
provider: provider.to_string(),
|
||||
}),
|
||||
ProviderId::OpenAI
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::Claude
|
||||
| ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -262,10 +260,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
|
||||
let s = std::str::from_utf8(input.0)?;
|
||||
// Use input.provider as needed
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
|
|
|
|||
|
|
@ -1,67 +0,0 @@
|
|||
use crate::providers::traits::*;
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
|
||||
/// Enum that wraps all possible providers for dynamic dispatch
|
||||
pub enum ProviderInstance {
|
||||
OpenAI(OpenAIProvider),
|
||||
// TODO: Add other providers as they are implemented
|
||||
// Anthropic(AnthropicProvider),
|
||||
// Mistral(MistralProvider),
|
||||
// etc.
|
||||
}
|
||||
|
||||
impl ProviderInstance {
|
||||
/// Creates a provider from a provider name string
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
match name.to_lowercase().as_str() {
|
||||
"openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => {
|
||||
ProviderInstance::OpenAI(OpenAIProvider)
|
||||
}
|
||||
// TODO: Add other providers when implemented
|
||||
// "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider),
|
||||
// "mistral" => ProviderInstance::Mistral(MistralProvider),
|
||||
_ => {
|
||||
// Default to OpenAI for unknown providers
|
||||
ProviderInstance::OpenAI(OpenAIProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request from bytes using the appropriate provider
|
||||
pub fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse response from bytes using the appropriate provider
|
||||
pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse streaming response from bytes using the appropriate provider
|
||||
pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<OpenAIStreamingResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ProviderInstance {
|
||||
type Request = ChatCompletionsRequest;
|
||||
type Response = ChatCompletionsResponse;
|
||||
type StreamingResponse = OpenAIStreamingResponse;
|
||||
type Usage = Usage;
|
||||
}
|
||||
|
|
@ -4,7 +4,15 @@
|
|||
//! handling of LLM requests and responses in the gateway.
|
||||
|
||||
use std::error::Error;
|
||||
use crate::Provider;
|
||||
|
||||
/// Conversion mode for provider requests/responses
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ConversionMode {
|
||||
/// Compatible: Convert between different provider formats to ensure compatibility
|
||||
Compatible,
|
||||
/// Passthrough: Pass requests/responses through with minimal modification
|
||||
Passthrough,
|
||||
}
|
||||
|
||||
/// Trait for provider-specific request types
|
||||
pub trait ProviderRequest: Sized {
|
||||
|
|
@ -14,7 +22,7 @@ pub trait ProviderRequest: Sized {
|
|||
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to provider-specific format
|
||||
fn to_provider_bytes(&self, provider: Provider) -> Result<Vec<u8>, Self::Error>;
|
||||
fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn extract_model(&self) -> &str;
|
||||
|
|
@ -42,7 +50,7 @@ pub trait ProviderResponse: Sized {
|
|||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
|
||||
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
|
|
@ -62,7 +70,7 @@ pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> +
|
|||
type Chunk: StreamChunk;
|
||||
|
||||
/// Parse streaming response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
|
||||
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Main provider interface trait
|
||||
|
|
@ -71,4 +79,20 @@ pub trait ProviderInterface {
|
|||
type Response: ProviderResponse;
|
||||
type StreamingResponse: StreamingResponse;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool;
|
||||
|
||||
/// Get the interface implementation for this provider
|
||||
/// passthrough: if true, use provider-specific format; if false, use compatible format
|
||||
fn get_interface(&self, passthrough: bool) -> ConversionMode {
|
||||
if passthrough {
|
||||
ConversionMode::Passthrough
|
||||
} else {
|
||||
ConversionMode::Compatible
|
||||
}
|
||||
}
|
||||
|
||||
/// Get supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::{
|
||||
Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage,
|
||||
};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -76,8 +74,8 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider_instance(&self) -> ProviderInstance {
|
||||
self.llm_provider().create_provider_instance()
|
||||
fn get_provider(&self) -> Provider {
|
||||
self.llm_provider().create_provider()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
|
|
@ -295,9 +293,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let provider_instance = self.get_provider_instance();
|
||||
let provider = self.get_provider();
|
||||
|
||||
let mut deserialized_body = match provider_instance.parse_request(&body_bytes) {
|
||||
let mut deserialized_body = match provider.parse_request(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -356,10 +354,11 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider)
|
||||
let deserialized_body_bytes = match deserialized_body
|
||||
.to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible)
|
||||
{
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
|
|
@ -529,42 +528,16 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
let _provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
// Use the provider instance to parse streaming response
|
||||
let provider_instance = self.get_provider_instance();
|
||||
// TODO: Implement streaming response parsing with new provider structure
|
||||
warn!(
|
||||
"Streaming response parsing not yet fully implemented with new provider structure"
|
||||
);
|
||||
|
||||
let streaming_events =
|
||||
match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
for event_result in streaming_events {
|
||||
match event_result {
|
||||
Ok(event) => {
|
||||
if let Some(usage) = event.usage() {
|
||||
self.response_tokens += usage.completion_tokens();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("error in response event: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute TTFT if not already recorded
|
||||
// For now, just compute TTFT and continue
|
||||
if self.ttft_duration.is_none() {
|
||||
// if let Some(start_time) = self.start_time {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
|
|
@ -584,9 +557,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider_instance = self.get_provider_instance();
|
||||
let response = match provider_instance.parse_response(&body, &hermes_llm_provider) {
|
||||
Ok(de) => de,
|
||||
let provider = self.get_provider();
|
||||
let _response = match provider.parse_response(&body, ConversionMode::Compatible) {
|
||||
Ok(response_box) => response_box,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
|
|
@ -606,9 +579,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if let Some(usage) = response.usage() {
|
||||
self.response_tokens += usage.completion_tokens();
|
||||
}
|
||||
// TODO: Extract usage information from the response box
|
||||
// For now, we'll skip this until we have a better way to handle Any types
|
||||
warn!("Response token counting not yet implemented with new provider structure");
|
||||
}
|
||||
|
||||
debug!(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue