updating the implementation of /v1/chat/completions to use the generi… (#548)

* updating the implementation of /v1/chat/completions to use the generic provider interfaces

* saving changes, although we will need a small re-factor after this as well

* more refactoring changes, getting close

* more refactoring changes to avoid unecessary re-direction and duplication

* more clean up

* more refactoring

* more refactoring to clean code and make stream_context.rs work

* removing unecessary trait implemenations

* some more clean-up

* fixed bugs

* fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types

* refactored changes to support enum dispatch

* removed the dependency on try_streaming_from_bytes into a try_from trait implementation

* updated readme based on new usage

* updated code based on code review comments

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
Salman Paracha 2025-08-20 12:55:29 -07:00 committed by GitHub
parent 1fdde8181a
commit 89ab51697a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1044 additions and 972 deletions

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use hermesllm::apis::openai::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
@ -93,7 +93,7 @@ pub async fn chat_completions(
chat_completion_request.metadata.and_then(|metadata| {
metadata
.get("archgw_preference_config")
.and_then(|value| value.as_str().map(String::from))
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
@ -105,9 +105,7 @@ pub async fn chat_completions(
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.as_ref().map_or("None".to_string(), |content| {
content.to_string().replace('\n', "\\n")
})
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;

View file

@ -1,6 +1,6 @@
use bytes::Bytes;
use common::configuration::{IntoModels, LlmProvider};
use hermesllm::providers::openai::types::Models;
use hermesllm::apis::openai::Models;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Response, StatusCode};
use serde_json;

View file

@ -98,7 +98,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service = Arc::clone(&router_service);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let llm_provider_endpoint = llm_provider_endpoint.clone();
let llm_providers = llm_providers.clone();

View file

@ -4,7 +4,7 @@ use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
use thiserror::Error;
use tracing::{debug, info, warn};
@ -153,9 +153,7 @@ impl RouterService {
return Ok(None);
}
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.router_model
.parse_response(content, &usage_preferences)?;

View file

@ -1,5 +1,5 @@
use common::configuration::ModelUsagePreference;
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]

View file

@ -2,9 +2,8 @@ use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 {
// when role == tool its tool call response
let messages_vec = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
.filter(|m| {
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
})
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 {
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message
.content
.as_ref()
.unwrap_or(&ContentType::Text("".to_string()))
.to_string()
.len()
/ TOKEN_LENGTH_DIVISOR;
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
debug!(
@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 {
, selected_messsage_count,
messages_vec.len()
);
if message.role == USER_ROLE {
if message.role == Role::User {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list_reversed.push(message);
}
@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 {
// ensure that first and last selected message is from user
if let Some(first_message) = selected_messages_list_reversed.first() {
if first_message.role != USER_ROLE {
if first_message.role != Role::User {
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
}
}
if let Some(last_message) = selected_messages_list_reversed.last() {
if last_message.role != USER_ROLE {
if last_message.role != Role::User {
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
}
}
@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 {
Message {
role: message.role.clone(),
// we can unwrap here because we have already filtered out messages without content
content: Some(ContentType::Text(
message.content.as_ref().unwrap().to_string(),
)),
content: MessageContent::Text(message.content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
})
.collect::<Vec<Message>>();
@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 {
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(ContentType::Text(router_message)),
role: USER_ROLE.to_string(),
content: MessageContent::Text(router_message),
role: Role::User,
name: None,
tool_calls: None,
tool_call_id: None,
}],
temperature: Some(0.01),
..Default::default()
@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y
}]);
let req = router.generate_request(&conversation, &usage_preferences);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]
@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y
},
{
"role": "assistant",
"content": null,
"content": "",
"tool_calls": [
{
"id": "toolcall-abc123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": { "location": "Tokyo" }
"arguments": "{ \"location\": \"Tokyo\" }"
}
}
]
@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, &None);
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
let prompt = req.messages[0].content.to_string();
assert_eq!(expected_prompt, prompt.to_string());
assert_eq!(expected_prompt, prompt);
}
#[test]