mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
more refactoring to clean code and make stream_context.rs work
This commit is contained in:
parent
d4ca70d177
commit
df3aa17d67
23 changed files with 545 additions and 1321 deletions
|
|
@ -10,10 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::traits::{
|
||||
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
|
||||
use hermesllm::{
|
||||
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, ConversionMode,
|
||||
ProviderId,
|
||||
};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -79,8 +79,8 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider(&self) -> Provider {
|
||||
self.llm_provider().create_provider()
|
||||
fn get_provider_id(&self) -> ProviderId {
|
||||
self.llm_provider().to_provider_id()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
|
|
@ -298,9 +298,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let provider = self.get_provider();
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
|
||||
let mut deserialized_body = match try_request_from_bytes(&body_bytes, &provider_id) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -329,10 +329,10 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
|
||||
let model_requested = deserialized_body.model().to_string(); // Convert to owned string
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = provider.extract_user_message(&deserialized_body);
|
||||
self.user_message = deserialized_body.extract_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -342,15 +342,15 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = provider.is_streaming(&deserialized_body);
|
||||
self.streaming_response = deserialized_body.is_streaming();
|
||||
|
||||
// Set streaming options if needed
|
||||
if self.streaming_response {
|
||||
provider.set_streaming_options(&mut deserialized_body);
|
||||
deserialized_body.set_streaming_options();
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
|
|
@ -365,21 +365,18 @@ impl HttpContext for StreamContext {
|
|||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match provider.to_provider_bytes(
|
||||
&deserialized_body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let deserialized_body_bytes =
|
||||
match deserialized_body.to_provider_bytes(ConversionMode::Compatible) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||
|
||||
|
|
@ -550,16 +547,9 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Parse streaming response using OpenAI-compatible format
|
||||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider = self.get_provider();
|
||||
let provider_id =
|
||||
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
match StreamingResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider_id,
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
match try_streaming_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
|
|
@ -587,14 +577,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Extract usage information if available
|
||||
if let Some(usage) = chunk.usage() {
|
||||
let completion_tokens = usage.completion_tokens();
|
||||
self.response_tokens += completion_tokens;
|
||||
debug!(
|
||||
"Streaming chunk completion tokens: {}",
|
||||
completion_tokens
|
||||
);
|
||||
// For streaming responses, we handle token counting differently
|
||||
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||
// Token counting for streaming responses typically happens with final usage chunk
|
||||
if chunk.is_final() {
|
||||
// For now, we'll implement basic token estimation
|
||||
// In a complete implementation, the final chunk would contain usage information
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
|
||||
// For now, estimate tokens from content delta
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
// Rough estimation: ~4 characters per token
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -605,40 +601,37 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider = self.get_provider();
|
||||
let response = match ProviderResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
let provider_id = self.get_provider_id();
|
||||
let response =
|
||||
match try_response_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
provider.extract_usage_counts(&response)
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue