mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
updating the implementation of /v1/chat/completions to use the generi… (#548)
* updating the implementation of /v1/chat/completions to use the generic provider interfaces * saving changes, although we will need a small re-factor after this as well * more refactoring changes, getting close * more refactoring changes to avoid unecessary re-direction and duplication * more clean up * more refactoring * more refactoring to clean code and make stream_context.rs work * removing unecessary trait implemenations * some more clean-up * fixed bugs * fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types * refactored changes to support enum dispatch * removed the dependency on try_streaming_from_bytes into a try_from trait implementation * updated readme based on new usage * updated code based on code review comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
parent
1fdde8181a
commit
89ab51697a
22 changed files with 1044 additions and 972 deletions
|
|
@ -10,11 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
|
||||
use hermesllm::providers::openai::types::{
|
||||
ChatCompletionsResponse, ContentType, Message, StreamOptions,
|
||||
use hermesllm::providers::response::ProviderStreamResponseIter;
|
||||
use hermesllm::{
|
||||
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
|
||||
};
|
||||
use hermesllm::Provider;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -41,9 +40,9 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
user_message: Option<Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -69,9 +68,9 @@ impl StreamContext {
|
|||
ttft_duration: None,
|
||||
traceparent: None,
|
||||
ttft_time: None,
|
||||
user_message: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -80,6 +79,10 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider_id(&self) -> ProviderId {
|
||||
self.llm_provider().to_provider_id()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
|
|
@ -295,24 +298,23 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
self.user_message = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user")
|
||||
.last()
|
||||
.cloned();
|
||||
let mut deserialized_body =
|
||||
match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
|
|
@ -324,24 +326,38 @@ impl HttpContext for StreamContext {
|
|||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
deserialized_body.model = match model_name {
|
||||
// Store the original model for logging
|
||||
let model_requested = deserialized_body.model().to_string();
|
||||
|
||||
// Apply model name resolution logic using the trait method
|
||||
let resolved_model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
ServerError::BadRequest {
|
||||
why: format!(
|
||||
"No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}",
|
||||
model_requested,
|
||||
self.llm_provider().name,
|
||||
self.llm_provider().model
|
||||
),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Set the resolved model using the trait method
|
||||
deserialized_body.set_model(resolved_model.clone());
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.get_recent_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
self.llm_provider().name,
|
||||
|
|
@ -349,32 +365,13 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.stream.unwrap_or_default()
|
||||
&& deserialized_body.stream_options.is_none()
|
||||
{
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
});
|
||||
}
|
||||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = deserialized_body.is_streaming();
|
||||
|
||||
// only use the tokens from the messages, excluding the metadata and json tags
|
||||
let input_tokens_str = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " "
|
||||
+ m.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text(String::new()))
|
||||
.to_string()
|
||||
.as_str()
|
||||
});
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
|
||||
{
|
||||
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
|
|
@ -383,15 +380,15 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
|
@ -484,17 +481,16 @@ impl HttpContext for StreamContext {
|
|||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
if let Some(user_message) = self.user_message.as_ref() {
|
||||
if let Some(prompt) = user_message.content.as_ref() {
|
||||
llm_span
|
||||
.add_attribute("user_prompt".to_string(), prompt.to_string());
|
||||
}
|
||||
}
|
||||
llm_span.add_attribute(
|
||||
"model".to_string(),
|
||||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span
|
||||
.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
|
|
@ -558,62 +554,69 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
let chat_completions_chunk_response_events =
|
||||
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
debug!("processing streaming response");
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Compute TTFT on first chunk
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for event in chat_completions_chunk_response_events {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let Some(usage) = event.usage.as_ref() {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
// For streaming responses, we handle token counting differently
|
||||
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||
// Token counting for streaming responses typically happens with final usage chunk
|
||||
if chunk.is_final() {
|
||||
// For now, we'll implement basic token estimation
|
||||
// In a complete implementation, the final chunk would contain usage information
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
|
||||
// For now, estimate tokens from content delta
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
// Rough estimation: ~4 characters per token
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("error in response event: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute TTFT if not already recorded
|
||||
if self.ttft_duration.is_none() {
|
||||
// if let Some(start_time) = self.start_time {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response =
|
||||
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(de) => de,
|
||||
let provider_id = self.get_provider_id();
|
||||
let response: ProviderResponseType =
|
||||
match ProviderResponseType::try_from((&body[..], provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
|
|
@ -626,15 +629,24 @@ impl HttpContext for StreamContext {
|
|||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::OpenAIPError(e),
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(usage) = chat_completions_response.usage {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue