diff --git a/crates/hermesllm/README.md b/crates/hermesllm/README.md index 541ca09f..dc0efa61 100644 --- a/crates/hermesllm/README.md +++ b/crates/hermesllm/README.md @@ -4,10 +4,11 @@ A Rust library for translating LLM (Large Language Model) API requests and respo ## Features -- Unified types for chat completions and model metadata across multiple LLM providers -- Builder-pattern API for constructing requests in an idiomatic Rust style -- Easy conversion between provider formats +- Unified traits for chat completions across multiple LLM providers +- Function-based API for runtime provider selection and conversion +- Direct trait implementations on concrete types (no wrapper types needed) - Streaming and non-streaming response support +- Type-safe provider identification and conversion ## Supported Providers @@ -32,29 +33,147 @@ _Replace the path with the appropriate location if using as a workspace member o ## Usage -Construct a chat completion request using the builder pattern: +### Basic Request/Response Handling ```rust -use hermesllm::{create_provider, ProviderId}; -use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::{ProviderId, try_request_from_bytes, try_response_from_bytes, ConversionMode}; -let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())]) - .build() - .expect("Failed to build OpenAIRequest"); +// Parse a request from raw bytes with provider-specific handling +let provider_id = ProviderId::OpenAI; +let request_bytes = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello!"}]}"#; +let request = try_request_from_bytes(request_bytes.as_bytes(), &provider_id)?; -// Create a provider and convert request to bytes -let provider = create_provider(ProviderId::OpenAI); -let bytes = serde_json::to_vec(&request)?; -let parsed_request = provider.try_request_from_bytes(&bytes)?; +// Work with the request using trait methods +println!("Model: {}", request.model()); +println!("Is streaming: {}", request.is_streaming()); +if let Some(user_msg) = request.extract_user_message() { + println!("User message: {}", user_msg); +} +``` + +### Building Requests with the Builder Pattern + +```rust +use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent}; + +// Build a request using the builder pattern +let request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("You are a helpful assistant".to_string()), + ..Default::default() + }, + Message { + role: Role::User, + content: MessageContent::Text("What is the capital of France?".to_string()), + ..Default::default() + } + ], + temperature: Some(0.7), + max_tokens: Some(150), + ..Default::default() +}; + +// Convert to provider-specific format +let provider_bytes = request.to_provider_bytes(ConversionMode::Compatible)?; +``` + +### Handling Responses + +```rust +// Parse responses from provider +let response_bytes = /* response JSON from LLM provider */; +let response = try_response_from_bytes(&response_bytes, &provider_id, ConversionMode::Compatible)?; + +// Extract usage information +if let Some((prompt_tokens, completion_tokens, total_tokens)) = response.extract_usage_counts() { + println!("Token usage: {}/{}/{}", prompt_tokens, completion_tokens, total_tokens); +} +``` + +### Streaming Responses + +```rust +// Handle streaming responses +let stream_data = /* SSE stream data */; +let streaming_iter = try_streaming_from_bytes(&stream_data, &provider_id, ConversionMode::Compatible)?; + +for chunk_result in streaming_iter { + match chunk_result { + Ok(chunk) => { + if let Some(content) = chunk.content_delta() { + print!("{}", content); + } + if chunk.is_final() { + println!("\nStream completed"); + break; + } + } + Err(e) => { + eprintln!("Streaming error: {}", e); + break; + } + } +} +``` + +### Provider Compatibility + +```rust +use hermesllm::{ProviderId, has_compatible_api, supported_apis}; + +// Check if a provider supports a specific API +let provider_id = ProviderId::Groq; +if has_compatible_api(&provider_id, "/v1/chat/completions") { + println!("Groq supports chat completions API"); +} + +// Get all supported APIs for a provider +let apis = supported_apis(&provider_id); +println!("Groq supports: {:?}", apis); + +// Runtime provider selection +let provider_name = "mistral"; // Could come from config or request header +let provider_id = ProviderId::from(provider_name); ``` ## API Overview -- `Provider`: Enum listing all supported LLM providers. -- `ChatCompletionsRequest`: Builder-pattern struct for creating chat completion requests. -- `ChatCompletionsResponse`: Struct for parsing responses. -- Streaming support via `SseChatCompletionIter`. -- Error handling via `OpenAIError`. +### Core Functions +- `try_request_from_bytes()`: Parse requests from bytes with provider-specific handling +- `try_response_from_bytes()`: Parse responses from bytes with provider-specific handling +- `try_streaming_from_bytes()`: Create streaming response iterators +- `has_compatible_api()`: Check API compatibility for providers +- `supported_apis()`: Get supported API endpoints for providers + +### Core Types +- `ProviderId`: Enum for identifying providers (OpenAI, Mistral, Groq, etc.) +- `ConversionMode`: Controls conversion behavior (Compatible, Passthrough) + +### Traits +- `ProviderRequest`: Common interface for all request types +- `ProviderResponse`: Common interface for all response types +- `ProviderStreamResponse`: Interface for streaming response chunks +- `ProviderStreamResponseIter`: Iterator trait for streaming responses +- `TokenUsage`: Interface for token usage information + +### Concrete Types +- `ChatCompletionsRequest`: OpenAI-compatible chat completion requests +- `ChatCompletionsResponse`: OpenAI-compatible chat completion responses +- `SseChatCompletionIter`: Streaming response iterator for SSE format + +## Architecture + +This library uses a function-based approach instead of traditional trait objects to enable: + +- **Dynamic Provider Selection**: Runtime provider selection based on request headers or configuration +- **No Wrapper Types**: Direct trait implementations on concrete types like `ChatCompletionsRequest` +- **Type Erasure**: Functions return `Box` for polymorphic usage +- **Parameterized Conversion**: `TryFrom<(&[u8], &ProviderId)>` pattern for provider-specific parsing + +The function-based design solves trait object limitations while maintaining clean abstractions and runtime flexibility. ## Contributing diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index efe49a0e..8a18a441 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -433,6 +433,10 @@ impl ProviderRequest for ChatCompletionsRequest { &self.model } + fn set_model(&mut self, model: String) { + self.model = model; + } + fn is_streaming(&self) -> bool { self.stream.unwrap_or_default() } diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index c7c25034..2689ac8b 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -64,6 +64,9 @@ pub trait ProviderRequest: Send + Sync { /// Extract the model name from the request fn model(&self) -> &str; + /// Set the model name for the request + fn set_model(&mut self, model: String); + /// Check if this is a streaming request fn is_streaming(&self) -> bool; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index a248a5e8..5552cc54 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -315,30 +315,54 @@ impl HttpContext for StreamContext { } }; - // TODO: For now, we'll work with the concrete ChatCompletionsRequest type - // In the future, this could be made more generic using trait objects - let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), None => None, }; - let _use_agent_orchestrator = match self.overrides.as_ref() { + let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, }; - // Use the provider interface methods for cleaner interaction - let model_requested = deserialized_body.model().to_string(); // Convert to owned string + // Store the original model for logging + let model_requested = deserialized_body.model().to_string(); + + // Apply model name resolution logic using the trait method + let resolved_model = match model_name { + Some(model_name) => model_name.clone(), + None => { + if use_agent_orchestrator { + "agent_orchestrator".to_string() + } else { + self.send_server_error( + ServerError::BadRequest { + why: format!( + "No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", + model_requested, + self.llm_provider().name, + self.llm_provider().model + ), + }, + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } + } + }; + + // Set the resolved model using the trait method + deserialized_body.set_model(resolved_model); // Extract user message for tracing self.user_message = deserialized_body.extract_user_message(); info!( - "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", + "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}, final model: {}", self.llm_provider().name, model_requested, model_name.unwrap_or(&"None".to_string()), + deserialized_body.model(), ); // Use provider interface for streaming detection and setup @@ -595,13 +619,13 @@ impl HttpContext for StreamContext { } Err(e) => { warn!("Error processing streaming chunk: {}", e); + return Action::Continue; } } } } Err(e) => { warn!("Failed to parse streaming response: {}", e); - return Action::Continue; } } } else {