mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
more refactoring changes to avoid unecessary re-direction and duplication
This commit is contained in:
parent
58028bb7ae
commit
9c09a18fd0
12 changed files with 809 additions and 225 deletions
|
|
@ -10,6 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::apis::openai::{ContentPart, MessageContent};
|
||||
use hermesllm::providers::traits::{
|
||||
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
|
||||
};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -39,6 +43,7 @@ pub struct StreamContext {
|
|||
request_body_sent_time: Option<u128>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -66,6 +71,7 @@ impl StreamContext {
|
|||
ttft_time: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -295,7 +301,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let provider = self.get_provider();
|
||||
|
||||
let mut deserialized_body = match provider.interface().parse_request(&body_bytes) {
|
||||
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -324,9 +330,29 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = provider
|
||||
.interface()
|
||||
.extract_model_from_request(&deserialized_body);
|
||||
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(parts) => {
|
||||
// Extract text from content parts, ignoring images
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => None,
|
||||
})
|
||||
.collect();
|
||||
if text_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_parts.join(" "))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -336,20 +362,15 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
if provider
|
||||
.interface()
|
||||
.is_request_streaming(&deserialized_body)
|
||||
{
|
||||
self.streaming_response = true;
|
||||
provider
|
||||
.interface()
|
||||
.prepare_request_for_streaming(&mut deserialized_body);
|
||||
self.streaming_response = provider.is_streaming(&deserialized_body);
|
||||
|
||||
// Set streaming options if needed
|
||||
if self.streaming_response {
|
||||
provider.set_streaming_options(&mut deserialized_body);
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction
|
||||
let input_tokens_str = provider
|
||||
.interface()
|
||||
.extract_text_for_tokenization(&deserialized_body);
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
|
|
@ -364,7 +385,7 @@ impl HttpContext for StreamContext {
|
|||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match provider.interface().request_to_bytes(
|
||||
let deserialized_body_bytes = match provider.to_provider_bytes(
|
||||
&deserialized_body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
|
|
@ -473,6 +494,11 @@ impl HttpContext for StreamContext {
|
|||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span
|
||||
.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
|
|
@ -540,36 +566,74 @@ impl HttpContext for StreamContext {
|
|||
let _provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
// TODO: Implement streaming response parsing with new provider structure
|
||||
warn!(
|
||||
"Streaming response parsing not yet fully implemented with new provider structure"
|
||||
);
|
||||
debug!("processing streaming response");
|
||||
|
||||
// For now, just compute TTFT and continue
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
// Parse streaming response using OpenAI-compatible format
|
||||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider = self.get_provider();
|
||||
let provider_id =
|
||||
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
|
||||
|
||||
match StreamingResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider_id,
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Compute TTFT on first chunk
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage information if available
|
||||
if let Some(usage) = chunk.usage() {
|
||||
let completion_tokens = usage.completion_tokens();
|
||||
self.response_tokens += completion_tokens;
|
||||
debug!(
|
||||
"Streaming chunk completion tokens: {}",
|
||||
completion_tokens
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider = self.get_provider();
|
||||
let response = match provider.interface().parse_response(
|
||||
let response = match ProviderResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
provider.id(),
|
||||
&provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(response) => response,
|
||||
|
|
@ -594,7 +658,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
provider.interface().extract_usage_from_response(&response)
|
||||
provider.extract_usage_counts(&response)
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue