mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 14:52:42 +02:00
updating the implementation of /v1/chat/completions to use the generi… (#548)
* updating the implementation of /v1/chat/completions to use the generic provider interfaces * saving changes, although we will need a small re-factor after this as well * more refactoring changes, getting close * more refactoring changes to avoid unecessary re-direction and duplication * more clean up * more refactoring * more refactoring to clean code and make stream_context.rs work * removing unecessary trait implemenations * some more clean-up * fixed bugs * fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types * refactored changes to support enum dispatch * removed the dependency on try_streaming_from_bytes into a try_from trait implementation * updated readme based on new usage * updated code based on code review comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
parent
1fdde8181a
commit
89ab51697a
22 changed files with 1044 additions and 972 deletions
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use bytes::Bytes;
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -93,7 +93,7 @@ pub async fn chat_completions(
|
|||
chat_completion_request.metadata.and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.and_then(|value| value.as_str().map(String::from))
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
|
|
@ -105,9 +105,7 @@ pub async fn chat_completions(
|
|||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.as_ref().map_or("None".to_string(), |content| {
|
||||
content.to_string().replace('\n', "\\n")
|
||||
})
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{IntoModels, LlmProvider};
|
||||
use hermesllm::providers::openai::types::Models;
|
||||
use hermesllm::apis::openai::Models;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::{
|
|||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -153,9 +153,7 @@ impl RouterService {
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.router_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ use std::collections::HashMap;
|
|||
|
||||
use common::{
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 {
|
|||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
|
|
@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 {
|
|||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text("".to_string()))
|
||||
.to_string()
|
||||
.len()
|
||||
/ TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 {
|
|||
, selected_messsage_count,
|
||||
messages_vec.len()
|
||||
);
|
||||
if message.role == USER_ROLE {
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
|
|
@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 {
|
|||
|
||||
// ensure that first and last selected message is from user
|
||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||
if first_message.role != USER_ROLE {
|
||||
if first_message.role != Role::User {
|
||||
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||
if last_message.role != USER_ROLE {
|
||||
if last_message.role != Role::User {
|
||||
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
|
|
@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 {
|
|||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: Some(ContentType::Text(
|
||||
message.content.as_ref().unwrap().to_string(),
|
||||
)),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
|
@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(ContentType::Text(router_message)),
|
||||
role: USER_ROLE.to_string(),
|
||||
content: MessageContent::Text(router_message),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
..Default::default()
|
||||
|
|
@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": { "location": "Tokyo" }
|
||||
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue