mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
Introduce hermesllm library to handle llm message translation (#501)
This commit is contained in:
parent
96b583c819
commit
6c53510f49
33 changed files with 1693 additions and 690 deletions
|
|
@ -1,8 +1,4 @@
|
|||
use crate::metrics::Metrics;
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
ContentType, Message, StreamOptions,
|
||||
};
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::consts::{
|
||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
|
|
@ -14,6 +10,11 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
|
||||
use hermesllm::providers::openai::types::{
|
||||
ChatCompletionsResponse, ContentType, Message, StreamOptions,
|
||||
};
|
||||
use hermesllm::Provider;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -201,14 +202,15 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
if let Some(routing_header_value) = routing_header_value.as_ref() {
|
||||
let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
|
||||
|
||||
if routing_header_value.is_some() && !routing_header_value.as_ref().unwrap().is_empty() {
|
||||
let routing_header_value = routing_header_value.as_ref().unwrap();
|
||||
info!("routing header already set: {}", routing_header_value);
|
||||
self.llm_provider = Some(Rc::new(LlmProvider {
|
||||
name: routing_header_value.to_string(),
|
||||
|
|
@ -284,27 +286,17 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
}
|
||||
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.user_message = deserialized_body
|
||||
.messages
|
||||
|
|
@ -348,17 +340,12 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
chat_completion_request_str
|
||||
);
|
||||
|
||||
if deserialized_body.stream {
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
if deserialized_body.stream.unwrap_or_default()
|
||||
&& deserialized_body.stream_options.is_none()
|
||||
{
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
});
|
||||
|
|
@ -387,7 +374,20 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
|
@ -542,58 +542,33 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let body_utf8 = match String::from_utf8(body) {
|
||||
Ok(body_utf8) => body_utf8,
|
||||
Err(e) => {
|
||||
warn!("could not convert to utf8: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
if body_utf8 == "data: [DONE]\n" {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
let chat_completions_chunk_response_events =
|
||||
match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) {
|
||||
Ok(response) => response,
|
||||
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"invalid streaming response: body str: {}, {:?}",
|
||||
body_utf8, e
|
||||
);
|
||||
warn!("could not parse response: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completions_chunk_response_events.events.is_empty() {
|
||||
warn!(
|
||||
"couldn't parse any streaming events: body str: {}",
|
||||
body_utf8
|
||||
);
|
||||
return Action::Continue;
|
||||
for event in chat_completions_chunk_response_events {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let Some(usage) = event.usage.as_ref() {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("error in response event: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let model = chat_completions_chunk_response_events
|
||||
.events
|
||||
.first()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||
|
||||
let token_count =
|
||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||
{
|
||||
Ok(token_count) => token_count,
|
||||
Err(e) => {
|
||||
warn!("could not get token count: {:?}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
self.response_tokens += token_count;
|
||||
|
||||
// Compute TTFT if not already recorded
|
||||
if self.ttft_duration.is_none() {
|
||||
// if let Some(start_time) = self.start_time {
|
||||
|
|
@ -616,24 +591,26 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_str(body_utf8.as_str()) {
|
||||
let chat_completions_response =
|
||||
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(de) => de,
|
||||
Err(err) => {
|
||||
info!(
|
||||
"non chat-completion compliant response received err: {}, body: {}",
|
||||
err, body_utf8
|
||||
Err(e) => {
|
||||
warn!("could not parse response: {}", e);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::OpenAIPError(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completions_response.usage.is_some() {
|
||||
self.response_tokens += chat_completions_response
|
||||
.usage
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.completion_tokens;
|
||||
if let Some(usage) = chat_completions_response.usage {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue