mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
updating the implementation of /v1/chat/completions to use the generi… (#548)
* updating the implementation of /v1/chat/completions to use the generic provider interfaces * saving changes, although we will need a small re-factor after this as well * more refactoring changes, getting close * more refactoring changes to avoid unecessary re-direction and duplication * more clean up * more refactoring * more refactoring to clean code and make stream_context.rs work * removing unecessary trait implemenations * some more clean-up * fixed bugs * fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types * refactored changes to support enum dispatch * removed the dependency on try_streaming_from_bytes into a try_from trait implementation * updated readme based on new usage * updated code based on code review comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
parent
1fdde8181a
commit
89ab51697a
22 changed files with 1044 additions and 972 deletions
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use bytes::Bytes;
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -93,7 +93,7 @@ pub async fn chat_completions(
|
|||
chat_completion_request.metadata.and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.and_then(|value| value.as_str().map(String::from))
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
|
|
@ -105,9 +105,7 @@ pub async fn chat_completions(
|
|||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.as_ref().map_or("None".to_string(), |content| {
|
||||
content.to_string().replace('\n', "\\n")
|
||||
})
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{IntoModels, LlmProvider};
|
||||
use hermesllm::providers::openai::types::Models;
|
||||
use hermesllm::apis::openai::Models;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::{
|
|||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -153,9 +153,7 @@ impl RouterService {
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.router_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ use std::collections::HashMap;
|
|||
|
||||
use common::{
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 {
|
|||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
|
|
@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 {
|
|||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text("".to_string()))
|
||||
.to_string()
|
||||
.len()
|
||||
/ TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 {
|
|||
, selected_messsage_count,
|
||||
messages_vec.len()
|
||||
);
|
||||
if message.role == USER_ROLE {
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
|
|
@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 {
|
|||
|
||||
// ensure that first and last selected message is from user
|
||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||
if first_message.role != USER_ROLE {
|
||||
if first_message.role != Role::User {
|
||||
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||
if last_message.role != USER_ROLE {
|
||||
if last_message.role != Role::User {
|
||||
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
|
|
@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 {
|
|||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: Some(ContentType::Text(
|
||||
message.content.as_ref().unwrap().to_string(),
|
||||
)),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
|
@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(ContentType::Text(router_message)),
|
||||
role: USER_ROLE.to_string(),
|
||||
content: MessageContent::Text(router_message),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
..Default::default()
|
||||
|
|
@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": { "location": "Tokyo" }
|
||||
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
|
||||
use hermesllm::apis::openai::{ModelDetail, ModelObject, Models};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
|
@ -177,6 +177,14 @@ impl Display for LlmProviderType {
|
|||
}
|
||||
}
|
||||
|
||||
impl LlmProviderType {
|
||||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
|
|
@ -252,6 +260,14 @@ impl Display for LlmProvider {
|
|||
}
|
||||
}
|
||||
|
||||
impl LlmProvider {
|
||||
/// Get the ProviderId for this LlmProvider
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
self.provider_interface.to_provider_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use proxy_wasm::types::Status;
|
||||
|
||||
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
use hermesllm::providers::openai::types::OpenAIError;
|
||||
use hermesllm::apis::openai::OpenAIError;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
|
|
|
|||
|
|
@ -1,63 +1,145 @@
|
|||
# hermesllm
|
||||
|
||||
A Rust library for translating LLM (Large Language Model) API requests and responses between Mistral, Groq, Gemini, Deepseek, OpenAI, and other provider-compliant formats.
|
||||
A Rust library for handling LLM (Large Language Model) API requests and responses with unified abstractions across multiple providers.
|
||||
|
||||
## Features
|
||||
|
||||
- Unified types for chat completions and model metadata across multiple LLM providers
|
||||
- Builder-pattern API for constructing requests in an idiomatic Rust style
|
||||
- Easy conversion between provider formats
|
||||
- Streaming and non-streaming response support
|
||||
- Unified request/response types with provider-specific parsing
|
||||
- Support for both streaming and non-streaming responses
|
||||
- Type-safe provider identification
|
||||
- OpenAI-compatible API structure with extensible provider support
|
||||
|
||||
## Supported Providers
|
||||
|
||||
- Mistral
|
||||
- Deepseek
|
||||
- Groq
|
||||
- Gemini
|
||||
- OpenAI
|
||||
- Mistral
|
||||
- Groq
|
||||
- Deepseek
|
||||
- Gemini
|
||||
- Claude
|
||||
- Github
|
||||
- GitHub
|
||||
|
||||
## Installation
|
||||
|
||||
Add the following to your `Cargo.toml`:
|
||||
Add to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
hermesllm = { git = "https://github.com/katanemo/archgw", subdir = "crates/hermesllm" }
|
||||
hermesllm = { path = "../hermesllm" } # or appropriate path in workspace
|
||||
```
|
||||
|
||||
_Replace the path with the appropriate location if using as a workspace member or published crate._
|
||||
|
||||
## Usage
|
||||
|
||||
Construct a chat completion request using the builder pattern:
|
||||
### Basic Request Parsing
|
||||
|
||||
```rust
|
||||
use hermesllm::Provider;
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
use hermesllm::providers::{ProviderRequestType, ProviderRequest, ProviderId};
|
||||
|
||||
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
// Parse request from JSON bytes
|
||||
let request_bytes = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello!"}]}"#;
|
||||
|
||||
// Convert to bytes for a specific provider
|
||||
let bytes = request.to_bytes(Provider::OpenAI)?;
|
||||
// Parse with provider context
|
||||
let request = ProviderRequestType::try_from((request_bytes.as_bytes(), &ProviderId::OpenAI))?;
|
||||
|
||||
// Access request properties
|
||||
println!("Model: {}", request.model());
|
||||
println!("User message: {:?}", request.get_recent_user_message());
|
||||
println!("Is streaming: {}", request.is_streaming());
|
||||
```
|
||||
|
||||
## API Overview
|
||||
### Working with Responses
|
||||
|
||||
- `Provider`: Enum listing all supported LLM providers.
|
||||
- `ChatCompletionsRequest`: Builder-pattern struct for creating chat completion requests.
|
||||
- `ChatCompletionsResponse`: Struct for parsing responses.
|
||||
- Streaming support via `SseChatCompletionIter`.
|
||||
- Error handling via `OpenAIError`.
|
||||
```rust
|
||||
use hermesllm::providers::{ProviderResponseType, ProviderResponse};
|
||||
|
||||
## Contributing
|
||||
// Parse response from provider
|
||||
let response_bytes = /* JSON response from LLM */;
|
||||
let response = ProviderResponseType::try_from((response_bytes, ProviderId::OpenAI))?;
|
||||
|
||||
Contributions are welcome! Please open issues or pull requests for bug fixes, new features, or provider integrations.
|
||||
// Extract token usage
|
||||
if let Some((prompt, completion, total)) = response.extract_usage_counts() {
|
||||
println!("Tokens used: {}/{}/{}", prompt, completion, total);
|
||||
}
|
||||
```
|
||||
|
||||
### Handling Streaming Responses
|
||||
|
||||
```rust
|
||||
use hermesllm::providers::{ProviderStreamResponseIter, ProviderStreamResponse};
|
||||
|
||||
// Create streaming iterator from SSE data
|
||||
let sse_data = /* Server-Sent Events data */;
|
||||
let mut stream = ProviderStreamResponseIter::try_from((sse_data, &ProviderId::OpenAI))?;
|
||||
|
||||
// Process streaming chunks
|
||||
for chunk_result in stream {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
print!("{}", content);
|
||||
}
|
||||
if chunk.is_final() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => eprintln!("Stream error: {}", e),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Provider Compatibility
|
||||
|
||||
```rust
|
||||
use hermesllm::providers::{ProviderId, has_compatible_api, supported_apis};
|
||||
|
||||
// Check API compatibility
|
||||
let provider = ProviderId::Groq;
|
||||
if has_compatible_api(&provider, "/v1/chat/completions") {
|
||||
println!("Provider supports chat completions");
|
||||
}
|
||||
|
||||
// List supported APIs
|
||||
let apis = supported_apis(&provider);
|
||||
println!("Supported APIs: {:?}", apis);
|
||||
```
|
||||
|
||||
## Core Types
|
||||
|
||||
### Provider Types
|
||||
- `ProviderId` - Enum identifying supported providers (OpenAI, Mistral, Groq, etc.)
|
||||
- `ProviderRequestType` - Enum wrapping provider-specific request types
|
||||
- `ProviderResponseType` - Enum wrapping provider-specific response types
|
||||
- `ProviderStreamResponseIter` - Iterator for streaming response chunks
|
||||
|
||||
### Traits
|
||||
- `ProviderRequest` - Common interface for all request types
|
||||
- `ProviderResponse` - Common interface for all response types
|
||||
- `ProviderStreamResponse` - Interface for streaming response chunks
|
||||
- `TokenUsage` - Interface for token usage information
|
||||
|
||||
### OpenAI API Types
|
||||
- `ChatCompletionsRequest` - Chat completion request structure
|
||||
- `ChatCompletionsResponse` - Chat completion response structure
|
||||
- `Message`, `Role`, `MessageContent` - Message building blocks
|
||||
|
||||
## Architecture
|
||||
|
||||
The library uses a type-safe enum-based approach that:
|
||||
|
||||
- **Provides Type Safety**: All provider operations are checked at compile time
|
||||
- **Enables Runtime Provider Selection**: Provider can be determined from request headers or config
|
||||
- **Maintains Clean Abstractions**: Common traits hide provider-specific details
|
||||
- **Supports Extensibility**: New providers can be added by extending the enums
|
||||
|
||||
All requests are parsed into a common `ProviderRequestType` enum which implements the `ProviderRequest` trait, allowing uniform access to request properties regardless of the underlying provider format.
|
||||
|
||||
## Examples
|
||||
|
||||
See the `src/lib.rs` tests for complete working examples of:
|
||||
- Parsing requests with provider context
|
||||
- Handling streaming responses
|
||||
- Working with token usage information
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the terms of the [MIT License](../LICENSE).
|
||||
This project is licensed under the MIT License.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,13 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
|
||||
use super::ApiDefinition;
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -115,8 +121,8 @@ pub enum Role {
|
|||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub content: MessageContent,
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
pub name: Option<String>,
|
||||
/// Tool calls made by the assistant (only present for assistant role)
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -124,8 +130,6 @@ pub struct Message {
|
|||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ResponseMessage {
|
||||
|
|
@ -170,6 +174,28 @@ pub enum MessageContent {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageContent::Text(text) => write!(f, "{}", text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Individual content part within a message (text or image)
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
|
|
@ -424,6 +450,239 @@ pub struct StreamOptions {
|
|||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelDetail {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: usize,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelObject {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Models {
|
||||
pub object: ModelObject,
|
||||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("UTF-8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Invalid streaming data: {0}")]
|
||||
InvalidStreamingData(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("invalid streaming data err {source}, data: {data}")]
|
||||
InvalidStreamingData {
|
||||
source: serde_json::Error,
|
||||
data: String,
|
||||
},
|
||||
#[error("unsupported provider: {provider}")]
|
||||
UnsupportedProvider { provider: String },
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsResponse
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of TokenUsage for OpenAI Usage type
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_tokens as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderRequest for ChatCompletionsRequest
|
||||
impl ProviderRequest for ChatCompletionsRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(_) => None, // No user message in parts
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||
impl ProviderResponse for ChatCompletionsResponse {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
Some(&self.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((
|
||||
self.usage.prompt_tokens(),
|
||||
self.usage.completion_tokens(),
|
||||
self.usage.total_tokens(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI SSE STREAMING ITERATOR
|
||||
// ============================================================================
|
||||
|
||||
/// OpenAI-specific SSE streaming iterator
|
||||
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
|
||||
pub struct OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
sse_stream: SseStreamIter<I>,
|
||||
}
|
||||
|
||||
impl<I> OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
|
||||
Self { sse_stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.sse_stream.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..]; // Remove "data: " prefix
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip ping messages (usually from other providers, but handle gracefully)
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI-specific parsing of ChatCompletionsStreamResponse
|
||||
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||
Ok(response) => return Some(Ok(Box::new(response))),
|
||||
Err(e) => return Some(Err(Box::new(
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.as_deref())
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
self.choices
|
||||
.first()
|
||||
.map(|choice| choice.finish_reason.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@
|
|||
//!
|
||||
//! ```rust
|
||||
//! use hermesllm::apis::{
|
||||
//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! MessagesMessageContent, MessagesSystemPrompt,
|
||||
//! };
|
||||
//! use hermesllm::clients::TransformError;
|
||||
//! use std::convert::TryInto;
|
||||
//!
|
||||
//! // Transform Anthropic to OpenAI
|
||||
//! let anthropic_req = AnthropicMessagesRequest {
|
||||
//! let anthropic_req = MessagesRequest {
|
||||
//! model: "claude-3-sonnet".to_string(),
|
||||
//! system: None,
|
||||
//! messages: vec![],
|
||||
|
|
|
|||
|
|
@ -5,77 +5,90 @@ pub mod providers;
|
|||
pub mod apis;
|
||||
pub mod clients;
|
||||
|
||||
|
||||
use std::fmt::Display;
|
||||
pub enum Provider {
|
||||
Arch,
|
||||
Mistral,
|
||||
Deepseek,
|
||||
Groq,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Claude,
|
||||
Github,
|
||||
}
|
||||
|
||||
impl From<&str> for Provider {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"arch" => Provider::Arch,
|
||||
"mistral" => Provider::Mistral,
|
||||
"deepseek" => Provider::Deepseek,
|
||||
"groq" => Provider::Groq,
|
||||
"gemini" => Provider::Gemini,
|
||||
"openai" => Provider::OpenAI,
|
||||
"claude" => Provider::Claude,
|
||||
"github" => Provider::Github,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Provider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Provider::Arch => write!(f, "Arch"),
|
||||
Provider::Mistral => write!(f, "Mistral"),
|
||||
Provider::Deepseek => write!(f, "Deepseek"),
|
||||
Provider::Groq => write!(f, "Groq"),
|
||||
Provider::Gemini => write!(f, "Gemini"),
|
||||
Provider::OpenAI => write!(f, "OpenAI"),
|
||||
Provider::Claude => write!(f, "Claude"),
|
||||
Provider::Github => write!(f, "Github"),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Re-export important types and traits
|
||||
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage};
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::adapters::{has_compatible_api, supported_apis};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn openai_builder() {
|
||||
let request =
|
||||
ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.n(1)
|
||||
.max_tokens(100)
|
||||
.stream(false)
|
||||
.stop(vec!["\n".to_string()])
|
||||
.presence_penalty(0.0)
|
||||
.frequency_penalty(0.0)
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
fn test_provider_id_conversion() {
|
||||
assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI);
|
||||
assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral);
|
||||
assert_eq!(ProviderId::from("groq"), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
|
||||
}
|
||||
|
||||
assert_eq!(request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.top_p, Some(0.9));
|
||||
assert_eq!(request.n, Some(1));
|
||||
assert_eq!(request.max_tokens, Some(100));
|
||||
assert_eq!(request.stream, Some(false));
|
||||
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.frequency_penalty, Some(0.0));
|
||||
#[test]
|
||||
fn test_provider_api_compatibility() {
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
assert!(!has_compatible_api(&ProviderId::OpenAI, "/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_supported_apis() {
|
||||
let apis = supported_apis(&ProviderId::OpenAI);
|
||||
assert!(apis.contains(&"/v1/chat/completions"));
|
||||
|
||||
// Test that provider supports the expected API endpoints
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_request_parsing() {
|
||||
// Test with a sample JSON request
|
||||
let json_request = r#"{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: Result<ProviderRequestType, std::io::Error> = ProviderRequestType::try_from(json_request.as_bytes());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let request = result.unwrap();
|
||||
assert_eq!(request.model(), "gpt-4");
|
||||
assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
|
||||
let mut streaming_response = result.unwrap();
|
||||
|
||||
// Test that we can iterate over chunks - it's just an iterator now!
|
||||
let first_chunk = streaming_response.next();
|
||||
assert!(first_chunk.is_some());
|
||||
|
||||
let chunk_result = first_chunk.unwrap();
|
||||
assert!(chunk_result.is_ok());
|
||||
|
||||
let chunk = chunk_result.unwrap();
|
||||
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||
assert!(!chunk.is_final());
|
||||
|
||||
// Test that stream ends properly
|
||||
let final_chunk = streaming_response.next();
|
||||
assert!(final_chunk.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
39
crates/hermesllm/src/providers/adapters.rs
Normal file
39
crates/hermesllm/src/providers/adapters.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AdapterType {
|
||||
OpenAICompatible,
|
||||
// Future: Claude, Gemini, etc.
|
||||
}
|
||||
|
||||
/// Provider adapter configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderConfig {
|
||||
pub supported_apis: &'static [&'static str],
|
||||
pub adapter_type: AdapterType,
|
||||
}
|
||||
|
||||
/// Check if provider has compatible API
|
||||
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||
let config = get_provider_config(provider_id);
|
||||
config.supported_apis.iter().any(|&supported| supported == api_path)
|
||||
}
|
||||
|
||||
/// Get supported APIs for provider
|
||||
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||
let config = get_provider_config(provider_id);
|
||||
config.supported_apis.to_vec()
|
||||
}
|
||||
|
||||
/// Get provider configuration
|
||||
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
|
||||
match provider_id {
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
ProviderConfig {
|
||||
supported_apis: &["/v1/chat/completions"],
|
||||
adapter_type: AdapterType::OpenAICompatible,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
45
crates/hermesllm/src/providers/id.rs
Normal file
45
crates/hermesllm/src/providers/id.rs
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ProviderId {
|
||||
OpenAI,
|
||||
Mistral,
|
||||
Deepseek,
|
||||
Groq,
|
||||
Gemini,
|
||||
Claude,
|
||||
GitHub,
|
||||
Arch,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"openai" => ProviderId::OpenAI,
|
||||
"mistral" => ProviderId::Mistral,
|
||||
"deepseek" => ProviderId::Deepseek,
|
||||
"groq" => ProviderId::Groq,
|
||||
"gemini" => ProviderId::Gemini,
|
||||
"claude" => ProviderId::Claude,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ProviderId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProviderId::OpenAI => write!(f, "OpenAI"),
|
||||
ProviderId::Mistral => write!(f, "Mistral"),
|
||||
ProviderId::Deepseek => write!(f, "Deepseek"),
|
||||
ProviderId::Groq => write!(f, "Groq"),
|
||||
ProviderId::Gemini => write!(f, "Gemini"),
|
||||
ProviderId::Claude => write!(f, "Claude"),
|
||||
ProviderId::GitHub => write!(f, "GitHub"),
|
||||
ProviderId::Arch => write!(f, "Arch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1 +1,14 @@
|
|||
pub mod openai;
|
||||
//! Provider implementations for different LLM APIs
|
||||
//!
|
||||
//! This module contains provider-specific implementations that handle
|
||||
//! request/response conversion for different LLM service APIs.
|
||||
//!
|
||||
pub mod id;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod adapters;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||
pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use adapters::*;
|
||||
|
|
|
|||
|
|
@ -1,114 +0,0 @@
|
|||
use serde_json::Value;
|
||||
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenAIRequestBuilder {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
n: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
stream: Option<bool>,
|
||||
stop: Option<Vec<String>>,
|
||||
presence_penalty: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
stream_options: Option<StreamOptions>,
|
||||
tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
impl OpenAIRequestBuilder {
|
||||
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
messages,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: None,
|
||||
stream: None,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream_options: None,
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn n(mut self, n: u32) -> Self {
|
||||
self.n = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop(mut self, stop: Vec<String>) -> Self {
|
||||
self.stop = Some(stop);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
|
||||
self.presence_penalty = Some(presence_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
|
||||
self.frequency_penalty = Some(frequency_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream_options(mut self, include_usage: bool) -> Self {
|
||||
self.stream = Some(true);
|
||||
self.stream_options = Some(StreamOptions { include_usage });
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<Value>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
|
||||
let request = ChatCompletionsRequest {
|
||||
model: self.model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
n: self.n,
|
||||
max_tokens: self.max_tokens,
|
||||
stream: self.stream,
|
||||
stop: self.stop,
|
||||
presence_penalty: self.presence_penalty,
|
||||
frequency_penalty: self.frequency_penalty,
|
||||
stream_options: self.stream_options,
|
||||
tools: self.tools,
|
||||
metadata: None,
|
||||
};
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new(model, messages)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
pub mod builder;
|
||||
pub mod types;
|
||||
|
|
@ -1,563 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::convert::TryFrom;
|
||||
use std::str;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::Provider;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("invalid streaming data err {source}, data: {data}")]
|
||||
InvalidStreamingData {
|
||||
source: serde_json::Error,
|
||||
data: String,
|
||||
},
|
||||
#[error("unsupported provider: {provider}")]
|
||||
UnsupportedProvider { provider: String },
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MultiPartContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MultiPartContent {
|
||||
pub text: Option<String>,
|
||||
pub image_url: Option<ImageUrl>,
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: MultiPartContentType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentType {
|
||||
Text(String),
|
||||
MultiPart(Vec<MultiPartContent>),
|
||||
}
|
||||
|
||||
impl Display for ContentType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentType::Text(text) => write!(f, "{}", text),
|
||||
ContentType::MultiPart(multi_part) => {
|
||||
let text_parts: Vec<String> = multi_part
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
if part.content_type == MultiPartContentType::Text {
|
||||
part.text.clone()
|
||||
} else if part.content_type == MultiPartContentType::ImageUrl {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
} else {
|
||||
panic!("Unsupported content type: {:?}", part.content_type);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(content: String) -> Self {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: Some(ContentType::Text(content)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub n: Option<u32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
pub tools: Option<Vec<Value>>,
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
// Use input.provider as needed, if necessary
|
||||
serde_json::from_slice(input.0).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
|
||||
match provider {
|
||||
Provider::OpenAI
|
||||
| Provider::Arch
|
||||
| Provider::Deepseek
|
||||
| Provider::Mistral
|
||||
| Provider::Groq
|
||||
| Provider::Gemini
|
||||
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
|
||||
_ => Err(OpenAIError::UnsupportedProvider {
|
||||
provider: provider.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeltaMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct StreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: DeltaMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamChoice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub struct SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
lines: I,
|
||||
}
|
||||
|
||||
impl<I> SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<ChatCompletionStreamResponse>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
let line = line.as_ref();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
let data = data.trim();
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue; // Skip ping messages - that is usually from anthropic
|
||||
}
|
||||
|
||||
return Some(
|
||||
serde_json::from_str::<ChatCompletionStreamResponse>(data).map_err(|e| {
|
||||
OpenAIError::InvalidStreamingData {
|
||||
source: e,
|
||||
data: data.to_string(),
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
let s = std::str::from_utf8(input.0)?;
|
||||
// Use input.provider as needed
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(bytes: &'a [u8]) -> Result<Self> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelDetail {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: usize,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelObject {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Models {
|
||||
pub object: ModelObject,
|
||||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_content_type_display() {
|
||||
let text_content = ContentType::Text("Hello, world!".to_string());
|
||||
assert_eq!(text_content.to_string(), "Hello, world!");
|
||||
|
||||
let multi_part_content = ContentType::MultiPart(vec![
|
||||
MultiPartContent {
|
||||
text: Some("This is a text part.".to_string()),
|
||||
content_type: MultiPartContentType::Text,
|
||||
image_url: None,
|
||||
},
|
||||
MultiPartContent {
|
||||
text: Some("https://example.com/image.png".to_string()),
|
||||
content_type: MultiPartContentType::ImageUrl,
|
||||
image_url: None,
|
||||
},
|
||||
]);
|
||||
assert_eq!(multi_part_content.to_string(), "This is a text part.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type_array() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hello world"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(
|
||||
multi_part_content[0].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[1].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_image_content() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"stream": true,
|
||||
"model": "openai/gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "describe this photo pls"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "openai/gpt-4o");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(
|
||||
multi_part_content[0].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("describe this photo pls".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[1].content_type,
|
||||
MultiPartContentType::ImageUrl
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[1].image_url,
|
||||
Some(ImageUrl {
|
||||
url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(),
|
||||
})
|
||||
);
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_streaming() {
|
||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
||||
data: [DONE]"#;
|
||||
|
||||
let iter = SseChatCompletionIter::new(json_data.lines());
|
||||
|
||||
println!("Testing SSE Streaming");
|
||||
for item in iter {
|
||||
match item {
|
||||
Ok(response) => {
|
||||
println!("Received response: {:?}", response);
|
||||
if response.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for choice in response.choices {
|
||||
if let Some(content) = choice.delta.content {
|
||||
println!("Content: {}", content);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error parsing JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_streaming_try_from_bytes() {
|
||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
||||
data: [DONE]"#;
|
||||
|
||||
let iter = SseChatCompletionIter::try_from(json_data.as_bytes())
|
||||
.expect("Failed to create SSE iterator");
|
||||
|
||||
println!("Testing SSE Streaming");
|
||||
for item in iter {
|
||||
match item {
|
||||
Ok(response) => {
|
||||
println!("Received response: {:?}", response);
|
||||
if response.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for choice in response.choices {
|
||||
if let Some(content) = choice.delta.content {
|
||||
println!("Content: {}", content);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error parsing JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_chat_completions_request() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "None",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in seattle"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
} "#;
|
||||
|
||||
let _chat_completions_request: ChatCompletionsRequest =
|
||||
ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes())
|
||||
.expect("Failed to parse ChatCompletionsRequest");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_claude() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"type": "ping"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"Hello!"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" How can I assist you today? Whether"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" you have a question, need information"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":", or just want to chat about"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" something, I'm here to help. What woul"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"d you like to talk about?"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes());
|
||||
|
||||
assert!(iter.is_ok(), "Failed to create SSE iterator");
|
||||
let iter: SseChatCompletionIter<str::Lines<'_>> = iter.unwrap();
|
||||
|
||||
let all_text: Vec<String> = iter
|
||||
.map(|item| {
|
||||
let response = item.expect("Failed to parse response");
|
||||
response
|
||||
.choices
|
||||
.into_iter()
|
||||
.filter_map(|choice| choice.delta.content)
|
||||
.map(|content| content.to_string())
|
||||
.collect::<String>()
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
all_text.len(),
|
||||
8,
|
||||
"Expected 8 chunks of text, but got {}",
|
||||
all_text.len()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
all_text.join(""),
|
||||
"Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?"
|
||||
);
|
||||
}
|
||||
}
|
||||
115
crates/hermesllm/src/providers/request.rs
Normal file
115
crates/hermesllm/src/providers/request.rs
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use super::{ProviderId, get_provider_config, AdapterType};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
//MessagesRequest(MessagesRequest),
|
||||
//add more request types here
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
// Future: handle other adapter types like Claude
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Set the model name for the request
|
||||
fn set_model(&mut self, model: String);
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self) -> String;
|
||||
|
||||
/// Extract the user message for tracing/logging purposes
|
||||
fn get_recent_user_message(&self) -> Option<String>;
|
||||
|
||||
/// Convert the request to bytes for transmission
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
fn model(&self) -> &str {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.is_streaming(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderRequestError {
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProviderRequestError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider request error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ProviderRequestError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
167
crates/hermesllm/src/providers/response.rs
Normal file
167
crates/hermesllm/src/providers/response.rs
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::OpenAISseIter;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::providers::adapters::{get_provider_config, AdapterType};
|
||||
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
//MessagesResponse(MessagesResponse),
|
||||
}
|
||||
|
||||
pub enum ProviderStreamResponseIter {
|
||||
ChatCompletionsStream(OpenAISseIter<std::vec::IntoIter<String>>),
|
||||
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(&provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response))
|
||||
}
|
||||
// Future: handle other adapter types like Claude
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
|
||||
// Parse SSE (Server-Sent Events) streaming data - protocol layer
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
// Delegate to OpenAI-specific iterator implementation
|
||||
let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
// Future: AdapterType::Claude => {
|
||||
// let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
// let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
|
||||
// Ok(ProviderStreamResponseIter::MessagesStream(iter))
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Iterator for ProviderStreamResponseIter {
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(),
|
||||
// Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderStreamResponse: Send + Sync {
|
||||
/// Get the content delta for this chunk
|
||||
fn content_delta(&self) -> Option<&str>;
|
||||
|
||||
/// Check if this is the final chunk in the stream
|
||||
fn is_final(&self) -> bool;
|
||||
|
||||
/// Get role information if available
|
||||
fn role(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// GENERIC SSE STREAMING ITERATOR (Container Only)
|
||||
// ============================================================================
|
||||
|
||||
/// Generic SSE (Server-Sent Events) streaming iterator container
|
||||
/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules
|
||||
pub struct SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub lines: I,
|
||||
}
|
||||
|
||||
impl<I> SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl ProviderResponse for ProviderResponseType {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Send + Sync for the enum to match the original trait requirements
|
||||
unsafe impl Send for ProviderStreamResponseIter {}
|
||||
unsafe impl Sync for ProviderStreamResponseIter {}
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderResponseError {
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
|
||||
impl fmt::Display for ProviderResponseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider response error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ProviderResponseError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
|
@ -10,11 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
|
||||
use hermesllm::providers::openai::types::{
|
||||
ChatCompletionsResponse, ContentType, Message, StreamOptions,
|
||||
use hermesllm::providers::response::ProviderStreamResponseIter;
|
||||
use hermesllm::{
|
||||
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
|
||||
};
|
||||
use hermesllm::Provider;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -41,9 +40,9 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
user_message: Option<Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -69,9 +68,9 @@ impl StreamContext {
|
|||
ttft_duration: None,
|
||||
traceparent: None,
|
||||
ttft_time: None,
|
||||
user_message: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -80,6 +79,10 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider_id(&self) -> ProviderId {
|
||||
self.llm_provider().to_provider_id()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
|
|
@ -295,24 +298,23 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
self.user_message = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user")
|
||||
.last()
|
||||
.cloned();
|
||||
let mut deserialized_body =
|
||||
match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
|
|
@ -324,24 +326,38 @@ impl HttpContext for StreamContext {
|
|||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
deserialized_body.model = match model_name {
|
||||
// Store the original model for logging
|
||||
let model_requested = deserialized_body.model().to_string();
|
||||
|
||||
// Apply model name resolution logic using the trait method
|
||||
let resolved_model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
ServerError::BadRequest {
|
||||
why: format!(
|
||||
"No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}",
|
||||
model_requested,
|
||||
self.llm_provider().name,
|
||||
self.llm_provider().model
|
||||
),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Set the resolved model using the trait method
|
||||
deserialized_body.set_model(resolved_model.clone());
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.get_recent_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
self.llm_provider().name,
|
||||
|
|
@ -349,32 +365,13 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.stream.unwrap_or_default()
|
||||
&& deserialized_body.stream_options.is_none()
|
||||
{
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
});
|
||||
}
|
||||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = deserialized_body.is_streaming();
|
||||
|
||||
// only use the tokens from the messages, excluding the metadata and json tags
|
||||
let input_tokens_str = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " "
|
||||
+ m.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text(String::new()))
|
||||
.to_string()
|
||||
.as_str()
|
||||
});
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
|
||||
{
|
||||
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
|
|
@ -383,15 +380,15 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
|
@ -484,17 +481,16 @@ impl HttpContext for StreamContext {
|
|||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
if let Some(user_message) = self.user_message.as_ref() {
|
||||
if let Some(prompt) = user_message.content.as_ref() {
|
||||
llm_span
|
||||
.add_attribute("user_prompt".to_string(), prompt.to_string());
|
||||
}
|
||||
}
|
||||
llm_span.add_attribute(
|
||||
"model".to_string(),
|
||||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span
|
||||
.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
|
|
@ -558,62 +554,69 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
let chat_completions_chunk_response_events =
|
||||
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
debug!("processing streaming response");
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Compute TTFT on first chunk
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for event in chat_completions_chunk_response_events {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let Some(usage) = event.usage.as_ref() {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
// For streaming responses, we handle token counting differently
|
||||
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||
// Token counting for streaming responses typically happens with final usage chunk
|
||||
if chunk.is_final() {
|
||||
// For now, we'll implement basic token estimation
|
||||
// In a complete implementation, the final chunk would contain usage information
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
|
||||
// For now, estimate tokens from content delta
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
// Rough estimation: ~4 characters per token
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("error in response event: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute TTFT if not already recorded
|
||||
if self.ttft_duration.is_none() {
|
||||
// if let Some(start_time) = self.start_time {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response =
|
||||
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(de) => de,
|
||||
let provider_id = self.get_provider_id();
|
||||
let response: ProviderResponseType =
|
||||
match ProviderResponseType::try_from((&body[..], provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
|
|
@ -626,15 +629,24 @@ impl HttpContext for StreamContext {
|
|||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::OpenAIPError(e),
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(usage) = chat_completions_response.usage {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ fn wasm_module() -> String {
|
|||
wasm_file.exists(),
|
||||
"Run `cargo build --release --target=wasm32-wasip1` first"
|
||||
);
|
||||
wasm_file.to_str().unwrap().to_string()
|
||||
wasm_file.to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||
|
|
@ -267,17 +267,12 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 13"))
|
||||
.expect_metric_record("input_sequence_length", 13)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=13"#))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -386,11 +381,11 @@ fn llm_gateway_request_not_ratelimited() {
|
|||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -433,11 +428,11 @@ fn llm_gateway_override_model_name() {
|
|||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -483,8 +478,8 @@ fn llm_gateway_override_use_default_model() {
|
|||
Some(LogLevel::Info),
|
||||
Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
|
|
@ -530,11 +525,11 @@ fn llm_gateway_override_use_model_name_none() {
|
|||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue