mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed bugs
This commit is contained in:
parent
e503934df2
commit
df32c7e278
4 changed files with 176 additions and 26 deletions
|
|
@ -4,10 +4,11 @@ A Rust library for translating LLM (Large Language Model) API requests and respo
|
|||
|
||||
## Features
|
||||
|
||||
- Unified types for chat completions and model metadata across multiple LLM providers
|
||||
- Builder-pattern API for constructing requests in an idiomatic Rust style
|
||||
- Easy conversion between provider formats
|
||||
- Unified traits for chat completions across multiple LLM providers
|
||||
- Function-based API for runtime provider selection and conversion
|
||||
- Direct trait implementations on concrete types (no wrapper types needed)
|
||||
- Streaming and non-streaming response support
|
||||
- Type-safe provider identification and conversion
|
||||
|
||||
## Supported Providers
|
||||
|
||||
|
|
@ -32,29 +33,147 @@ _Replace the path with the appropriate location if using as a workspace member o
|
|||
|
||||
## Usage
|
||||
|
||||
Construct a chat completion request using the builder pattern:
|
||||
### Basic Request/Response Handling
|
||||
|
||||
```rust
|
||||
use hermesllm::{create_provider, ProviderId};
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
use hermesllm::{ProviderId, try_request_from_bytes, try_response_from_bytes, ConversionMode};
|
||||
|
||||
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
// Parse a request from raw bytes with provider-specific handling
|
||||
let provider_id = ProviderId::OpenAI;
|
||||
let request_bytes = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello!"}]}"#;
|
||||
let request = try_request_from_bytes(request_bytes.as_bytes(), &provider_id)?;
|
||||
|
||||
// Create a provider and convert request to bytes
|
||||
let provider = create_provider(ProviderId::OpenAI);
|
||||
let bytes = serde_json::to_vec(&request)?;
|
||||
let parsed_request = provider.try_request_from_bytes(&bytes)?;
|
||||
// Work with the request using trait methods
|
||||
println!("Model: {}", request.model());
|
||||
println!("Is streaming: {}", request.is_streaming());
|
||||
if let Some(user_msg) = request.extract_user_message() {
|
||||
println!("User message: {}", user_msg);
|
||||
}
|
||||
```
|
||||
|
||||
### Building Requests with the Builder Pattern
|
||||
|
||||
```rust
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
|
||||
// Build a request using the builder pattern
|
||||
let request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What is the capital of France?".to_string()),
|
||||
..Default::default()
|
||||
}
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(150),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Convert to provider-specific format
|
||||
let provider_bytes = request.to_provider_bytes(ConversionMode::Compatible)?;
|
||||
```
|
||||
|
||||
### Handling Responses
|
||||
|
||||
```rust
|
||||
// Parse responses from provider
|
||||
let response_bytes = /* response JSON from LLM provider */;
|
||||
let response = try_response_from_bytes(&response_bytes, &provider_id, ConversionMode::Compatible)?;
|
||||
|
||||
// Extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) = response.extract_usage_counts() {
|
||||
println!("Token usage: {}/{}/{}", prompt_tokens, completion_tokens, total_tokens);
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```rust
|
||||
// Handle streaming responses
|
||||
let stream_data = /* SSE stream data */;
|
||||
let streaming_iter = try_streaming_from_bytes(&stream_data, &provider_id, ConversionMode::Compatible)?;
|
||||
|
||||
for chunk_result in streaming_iter {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
print!("{}", content);
|
||||
}
|
||||
if chunk.is_final() {
|
||||
println!("\nStream completed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Streaming error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Provider Compatibility
|
||||
|
||||
```rust
|
||||
use hermesllm::{ProviderId, has_compatible_api, supported_apis};
|
||||
|
||||
// Check if a provider supports a specific API
|
||||
let provider_id = ProviderId::Groq;
|
||||
if has_compatible_api(&provider_id, "/v1/chat/completions") {
|
||||
println!("Groq supports chat completions API");
|
||||
}
|
||||
|
||||
// Get all supported APIs for a provider
|
||||
let apis = supported_apis(&provider_id);
|
||||
println!("Groq supports: {:?}", apis);
|
||||
|
||||
// Runtime provider selection
|
||||
let provider_name = "mistral"; // Could come from config or request header
|
||||
let provider_id = ProviderId::from(provider_name);
|
||||
```
|
||||
|
||||
## API Overview
|
||||
|
||||
- `Provider`: Enum listing all supported LLM providers.
|
||||
- `ChatCompletionsRequest`: Builder-pattern struct for creating chat completion requests.
|
||||
- `ChatCompletionsResponse`: Struct for parsing responses.
|
||||
- Streaming support via `SseChatCompletionIter`.
|
||||
- Error handling via `OpenAIError`.
|
||||
### Core Functions
|
||||
- `try_request_from_bytes()`: Parse requests from bytes with provider-specific handling
|
||||
- `try_response_from_bytes()`: Parse responses from bytes with provider-specific handling
|
||||
- `try_streaming_from_bytes()`: Create streaming response iterators
|
||||
- `has_compatible_api()`: Check API compatibility for providers
|
||||
- `supported_apis()`: Get supported API endpoints for providers
|
||||
|
||||
### Core Types
|
||||
- `ProviderId`: Enum for identifying providers (OpenAI, Mistral, Groq, etc.)
|
||||
- `ConversionMode`: Controls conversion behavior (Compatible, Passthrough)
|
||||
|
||||
### Traits
|
||||
- `ProviderRequest`: Common interface for all request types
|
||||
- `ProviderResponse`: Common interface for all response types
|
||||
- `ProviderStreamResponse`: Interface for streaming response chunks
|
||||
- `ProviderStreamResponseIter`: Iterator trait for streaming responses
|
||||
- `TokenUsage`: Interface for token usage information
|
||||
|
||||
### Concrete Types
|
||||
- `ChatCompletionsRequest`: OpenAI-compatible chat completion requests
|
||||
- `ChatCompletionsResponse`: OpenAI-compatible chat completion responses
|
||||
- `SseChatCompletionIter`: Streaming response iterator for SSE format
|
||||
|
||||
## Architecture
|
||||
|
||||
This library uses a function-based approach instead of traditional trait objects to enable:
|
||||
|
||||
- **Dynamic Provider Selection**: Runtime provider selection based on request headers or configuration
|
||||
- **No Wrapper Types**: Direct trait implementations on concrete types like `ChatCompletionsRequest`
|
||||
- **Type Erasure**: Functions return `Box<dyn ProviderRequest>` for polymorphic usage
|
||||
- **Parameterized Conversion**: `TryFrom<(&[u8], &ProviderId)>` pattern for provider-specific parsing
|
||||
|
||||
The function-based design solves trait object limitations while maintaining clean abstractions and runtime flexibility.
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -433,6 +433,10 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
&self.model
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ pub trait ProviderRequest: Send + Sync {
|
|||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Set the model name for the request
|
||||
fn set_model(&mut self, model: String);
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
|
|
|
|||
|
|
@ -315,30 +315,54 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: For now, we'll work with the concrete ChatCompletionsRequest type
|
||||
// In the future, this could be made more generic using trait objects
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let _use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = deserialized_body.model().to_string(); // Convert to owned string
|
||||
// Store the original model for logging
|
||||
let model_requested = deserialized_body.model().to_string();
|
||||
|
||||
// Apply model name resolution logic using the trait method
|
||||
let resolved_model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!(
|
||||
"No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}",
|
||||
model_requested,
|
||||
self.llm_provider().name,
|
||||
self.llm_provider().model
|
||||
),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Set the resolved model using the trait method
|
||||
deserialized_body.set_model(resolved_model);
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.extract_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}, final model: {}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
model_name.unwrap_or(&"None".to_string()),
|
||||
deserialized_body.model(),
|
||||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
|
|
@ -595,13 +619,13 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue