more clean up

This commit is contained in:
Salman Paracha 2025-08-09 21:52:31 -07:00
parent 9c09a18fd0
commit c4148a3d52
12 changed files with 107 additions and 213 deletions

View file

@ -60,4 +60,36 @@ mod tests {
// Test that provider supports the expected API endpoints
assert!(provider.has_compatible_api("/v1/chat/completions"));
}
#[test]
fn test_provider_extract_user_message() {
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent};
let provider = Provider::new(ProviderId::OpenAI);
// Test with text message
let request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: crate::apis::openai::Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: crate::apis::openai::Role::User,
content: MessageContent::Text("Hello, world!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
],
..Default::default()
};
let user_message = provider.extract_user_message(&request);
assert_eq!(user_message, Some("Hello, world!".to_string()));
}
}

View file

@ -40,6 +40,11 @@ impl ProviderRequest for ArchProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for ArchProvider {
@ -79,25 +84,4 @@ impl ProviderInterface for ArchProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -43,6 +43,11 @@ impl ProviderRequest for ClaudeProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for ClaudeProvider {
@ -84,28 +89,4 @@ impl ProviderInterface for ClaudeProvider {
// TODO: Update when Claude API is fully implemented
vec!["/v1/messages"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request parsing
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific response parsing
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request serialization
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -40,6 +40,11 @@ impl ProviderRequest for DeepseekProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for DeepseekProvider {
@ -79,25 +84,4 @@ impl ProviderInterface for DeepseekProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -43,6 +43,11 @@ impl ProviderRequest for GeminiProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for GeminiProvider {
@ -84,28 +89,4 @@ impl ProviderInterface for GeminiProvider {
// TODO: Update when Gemini API is fully implemented
vec!["/v1/models"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request parsing
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific response parsing
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request serialization
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -43,6 +43,11 @@ impl ProviderRequest for GitHubProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for GitHubProvider {
@ -84,28 +89,4 @@ impl ProviderInterface for GitHubProvider {
// TODO: Update when GitHub API is fully implemented
vec!["/models"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request parsing
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific response parsing
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request serialization
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -43,6 +43,11 @@ impl ProviderRequest for GroqProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for GroqProvider {
@ -82,25 +87,4 @@ impl ProviderInterface for GroqProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/openai/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -40,6 +40,11 @@ impl ProviderRequest for MistralProvider {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_user_message(&openai_provider, request)
}
}
impl ProviderResponse for MistralProvider {
@ -79,25 +84,4 @@ impl ProviderInterface for MistralProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -201,6 +201,19 @@ impl ProviderRequest for Provider {
Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request),
}
}
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String> {
match self {
Provider::OpenAI(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Groq(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Mistral(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Deepseek(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Arch(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Gemini(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::Claude(provider, _) => ProviderRequest::extract_user_message(provider, request),
Provider::GitHub(provider, _) => ProviderRequest::extract_user_message(provider, request),
}
}
}
impl ProviderResponse for Provider {

View file

@ -75,27 +75,6 @@ impl ProviderInterface for OpenAIProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}
// Direct trait implementations on OpenAIProvider
@ -142,6 +121,29 @@ impl ProviderRequest for OpenAIProvider {
}
})
}
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
request.messages.last().and_then(|msg| {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Parts(parts) => {
// Extract text from content parts, ignoring images
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => None,
})
.collect();
if text_parts.is_empty() {
None
} else {
Some(text_parts.join(" "))
}
}
}
})
}
}
impl ProviderResponse for OpenAIProvider {

View file

@ -35,6 +35,9 @@ pub trait ProviderRequest {
/// Extract text content from messages for token counting
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String;
/// Extract the user message for tracing/logging purposes
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String>;
}
/// Trait for token usage information
@ -85,19 +88,4 @@ pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingRespo
/// Get supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
/// Parse a request from raw bytes - delegates to ProviderRequest
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
ProviderRequest::try_from_bytes(self, bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
/// Parse a response from raw bytes - delegates to ProviderResponse
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
/// Convert a request to bytes - delegates to ProviderRequest
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
ProviderRequest::to_provider_bytes(self, request, provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}

View file

@ -10,7 +10,6 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::apis::openai::{ContentPart, MessageContent};
use hermesllm::providers::traits::{
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
};
@ -333,26 +332,7 @@ impl HttpContext for StreamContext {
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
// Extract user message for tracing
self.user_message = deserialized_body.messages.last().and_then(|msg| {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Parts(parts) => {
// Extract text from content parts, ignoring images
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => None,
})
.collect();
if text_parts.is_empty() {
None
} else {
Some(text_parts.join(" "))
}
}
}
});
self.user_message = provider.extract_user_message(&deserialized_body);
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",