fixed bugs

This commit is contained in:
Salman Paracha 2025-08-10 23:26:57 -07:00
parent e503934df2
commit df32c7e278
4 changed files with 176 additions and 26 deletions

View file

@ -4,10 +4,11 @@ A Rust library for translating LLM (Large Language Model) API requests and respo
## Features
- Unified types for chat completions and model metadata across multiple LLM providers
- Builder-pattern API for constructing requests in an idiomatic Rust style
- Easy conversion between provider formats
- Unified traits for chat completions across multiple LLM providers
- Function-based API for runtime provider selection and conversion
- Direct trait implementations on concrete types (no wrapper types needed)
- Streaming and non-streaming response support
- Type-safe provider identification and conversion
## Supported Providers
@ -32,29 +33,147 @@ _Replace the path with the appropriate location if using as a workspace member o
## Usage
Construct a chat completion request using the builder pattern:
### Basic Request/Response Handling
```rust
use hermesllm::{create_provider, ProviderId};
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use hermesllm::{ProviderId, try_request_from_bytes, try_response_from_bytes, ConversionMode};
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
.build()
.expect("Failed to build OpenAIRequest");
// Parse a request from raw bytes with provider-specific handling
let provider_id = ProviderId::OpenAI;
let request_bytes = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello!"}]}"#;
let request = try_request_from_bytes(request_bytes.as_bytes(), &provider_id)?;
// Create a provider and convert request to bytes
let provider = create_provider(ProviderId::OpenAI);
let bytes = serde_json::to_vec(&request)?;
let parsed_request = provider.try_request_from_bytes(&bytes)?;
// Work with the request using trait methods
println!("Model: {}", request.model());
println!("Is streaming: {}", request.is_streaming());
if let Some(user_msg) = request.extract_user_message() {
println!("User message: {}", user_msg);
}
```
### Building Requests with the Builder Pattern
```rust
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
// Build a request using the builder pattern
let request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
..Default::default()
},
Message {
role: Role::User,
content: MessageContent::Text("What is the capital of France?".to_string()),
..Default::default()
}
],
temperature: Some(0.7),
max_tokens: Some(150),
..Default::default()
};
// Convert to provider-specific format
let provider_bytes = request.to_provider_bytes(ConversionMode::Compatible)?;
```
### Handling Responses
```rust
// Parse responses from provider
let response_bytes = /* response JSON from LLM provider */;
let response = try_response_from_bytes(&response_bytes, &provider_id, ConversionMode::Compatible)?;
// Extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) = response.extract_usage_counts() {
println!("Token usage: {}/{}/{}", prompt_tokens, completion_tokens, total_tokens);
}
```
### Streaming Responses
```rust
// Handle streaming responses
let stream_data = /* SSE stream data */;
let streaming_iter = try_streaming_from_bytes(&stream_data, &provider_id, ConversionMode::Compatible)?;
for chunk_result in streaming_iter {
match chunk_result {
Ok(chunk) => {
if let Some(content) = chunk.content_delta() {
print!("{}", content);
}
if chunk.is_final() {
println!("\nStream completed");
break;
}
}
Err(e) => {
eprintln!("Streaming error: {}", e);
break;
}
}
}
```
### Provider Compatibility
```rust
use hermesllm::{ProviderId, has_compatible_api, supported_apis};
// Check if a provider supports a specific API
let provider_id = ProviderId::Groq;
if has_compatible_api(&provider_id, "/v1/chat/completions") {
println!("Groq supports chat completions API");
}
// Get all supported APIs for a provider
let apis = supported_apis(&provider_id);
println!("Groq supports: {:?}", apis);
// Runtime provider selection
let provider_name = "mistral"; // Could come from config or request header
let provider_id = ProviderId::from(provider_name);
```
## API Overview
- `Provider`: Enum listing all supported LLM providers.
- `ChatCompletionsRequest`: Builder-pattern struct for creating chat completion requests.
- `ChatCompletionsResponse`: Struct for parsing responses.
- Streaming support via `SseChatCompletionIter`.
- Error handling via `OpenAIError`.
### Core Functions
- `try_request_from_bytes()`: Parse requests from bytes with provider-specific handling
- `try_response_from_bytes()`: Parse responses from bytes with provider-specific handling
- `try_streaming_from_bytes()`: Create streaming response iterators
- `has_compatible_api()`: Check API compatibility for providers
- `supported_apis()`: Get supported API endpoints for providers
### Core Types
- `ProviderId`: Enum for identifying providers (OpenAI, Mistral, Groq, etc.)
- `ConversionMode`: Controls conversion behavior (Compatible, Passthrough)
### Traits
- `ProviderRequest`: Common interface for all request types
- `ProviderResponse`: Common interface for all response types
- `ProviderStreamResponse`: Interface for streaming response chunks
- `ProviderStreamResponseIter`: Iterator trait for streaming responses
- `TokenUsage`: Interface for token usage information
### Concrete Types
- `ChatCompletionsRequest`: OpenAI-compatible chat completion requests
- `ChatCompletionsResponse`: OpenAI-compatible chat completion responses
- `SseChatCompletionIter`: Streaming response iterator for SSE format
## Architecture
This library uses a function-based approach instead of traditional trait objects to enable:
- **Dynamic Provider Selection**: Runtime provider selection based on request headers or configuration
- **No Wrapper Types**: Direct trait implementations on concrete types like `ChatCompletionsRequest`
- **Type Erasure**: Functions return `Box<dyn ProviderRequest>` for polymorphic usage
- **Parameterized Conversion**: `TryFrom<(&[u8], &ProviderId)>` pattern for provider-specific parsing
The function-based design solves trait object limitations while maintaining clean abstractions and runtime flexibility.
## Contributing

View file

@ -433,6 +433,10 @@ impl ProviderRequest for ChatCompletionsRequest {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or_default()
}

View file

@ -64,6 +64,9 @@ pub trait ProviderRequest: Send + Sync {
/// Extract the model name from the request
fn model(&self) -> &str;
/// Set the model name for the request
fn set_model(&mut self, model: String);
/// Check if this is a streaming request
fn is_streaming(&self) -> bool;

View file

@ -315,30 +315,54 @@ impl HttpContext for StreamContext {
}
};
// TODO: For now, we'll work with the concrete ChatCompletionsRequest type
// In the future, this could be made more generic using trait objects
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
let _use_agent_orchestrator = match self.overrides.as_ref() {
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
// Use the provider interface methods for cleaner interaction
let model_requested = deserialized_body.model().to_string(); // Convert to owned string
// Store the original model for logging
let model_requested = deserialized_body.model().to_string();
// Apply model name resolution logic using the trait method
let resolved_model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!(
"No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}",
model_requested,
self.llm_provider().name,
self.llm_provider().model
),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
};
// Set the resolved model using the trait method
deserialized_body.set_model(resolved_model);
// Extract user message for tracing
self.user_message = deserialized_body.extract_user_message();
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}, final model: {}",
self.llm_provider().name,
model_requested,
model_name.unwrap_or(&"None".to_string()),
deserialized_body.model(),
);
// Use provider interface for streaming detection and setup
@ -595,13 +619,13 @@ impl HttpContext for StreamContext {
}
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
return Action::Continue;
}
}
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
return Action::Continue;
}
}
} else {