Introduce hermesllm library to handle llm message translation (#501)

This commit is contained in:
Adil Hafeez 2025-06-10 12:53:27 -07:00 committed by GitHub
parent 96b583c819
commit 6c53510f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1693 additions and 690 deletions

View file

@ -1,8 +1,4 @@
use crate::metrics::Metrics;
use common::api::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
ContentType, Message, StreamOptions,
};
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
@ -14,6 +10,11 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
use hermesllm::providers::openai::types::{
ChatCompletionsResponse, ContentType, Message, StreamOptions,
};
use hermesllm::Provider;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -201,14 +202,15 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
if let Some(routing_header_value) = routing_header_value.as_ref() {
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
if routing_header_value.is_some() && !routing_header_value.as_ref().unwrap().is_empty() {
let routing_header_value = routing_header_value.as_ref().unwrap();
info!("routing header already set: {}", routing_header_value);
self.llm_provider = Some(Rc::new(LlmProvider {
name: routing_header_value.to_string(),
@ -284,27 +286,17 @@ impl HttpContext for StreamContext {
}
};
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
"on_http_request_body: request body: {}",
String::from_utf8_lossy(&body_bytes)
);
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
for message in deserialized_body.messages.iter_mut() {
message.model = None;
}
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
"on_http_request_body: request body: {}",
String::from_utf8_lossy(&body_bytes)
);
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
return Action::Pause;
}
};
self.user_message = deserialized_body
.messages
@ -348,17 +340,12 @@ impl HttpContext for StreamContext {
model_name.unwrap_or(&"None".to_string()),
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
debug!(
"on_http_request_body: request body: {}",
chat_completion_request_str
);
if deserialized_body.stream {
if deserialized_body.stream.unwrap_or_default() {
self.streaming_response = true;
}
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
if deserialized_body.stream.unwrap_or_default()
&& deserialized_body.stream_options.is_none()
{
deserialized_body.stream_options = Some(StreamOptions {
include_usage: true,
});
@ -387,7 +374,20 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
return Action::Pause;
}
};
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
Action::Continue
}
@ -542,58 +542,33 @@ impl HttpContext for StreamContext {
}
};
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
warn!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
if self.streaming_response {
if body_utf8 == "data: [DONE]\n" {
return Action::Continue;
}
let chat_completions_chunk_response_events =
match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) {
Ok(response) => response,
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
Ok(events) => events,
Err(e) => {
warn!(
"invalid streaming response: body str: {}, {:?}",
body_utf8, e
);
warn!("could not parse response: {}", e);
return Action::Continue;
}
};
if chat_completions_chunk_response_events.events.is_empty() {
warn!(
"couldn't parse any streaming events: body str: {}",
body_utf8
);
return Action::Continue;
for event in chat_completions_chunk_response_events {
match event {
Ok(event) => {
if let Some(usage) = event.usage.as_ref() {
self.response_tokens += usage.completion_tokens;
}
}
Err(e) => {
warn!("error in response event: {}", e);
continue;
}
}
}
let model = chat_completions_chunk_response_events
.events
.first()
.unwrap()
.model
.clone();
let tokens_str = chat_completions_chunk_response_events.to_string();
let token_count =
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
{
Ok(token_count) => token_count,
Err(e) => {
warn!("could not get token count: {:?}", e);
return Action::Continue;
}
};
self.response_tokens += token_count;
// Compute TTFT if not already recorded
if self.ttft_duration.is_none() {
// if let Some(start_time) = self.start_time {
@ -616,24 +591,26 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(body_utf8.as_str()) {
let chat_completions_response =
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
Ok(de) => de,
Err(err) => {
info!(
"non chat-completion compliant response received err: {}, body: {}",
err, body_utf8
Err(e) => {
warn!("could not parse response: {}", e);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::OpenAIPError(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
};
if chat_completions_response.usage.is_some() {
self.response_tokens += chat_completions_response
.usage
.as_ref()
.unwrap()
.completion_tokens;
if let Some(usage) = chat_completions_response.usage {
self.response_tokens += usage.completion_tokens;
}
}