saving changes, although we will need a small re-factor after this as well

This commit is contained in:
Salman Paracha 2025-08-09 11:19:23 -07:00
parent 203fc8f9a9
commit 63f23efda4
10 changed files with 414 additions and 259 deletions

View file

@ -10,9 +10,7 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::{
Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage,
};
use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -76,8 +74,8 @@ impl StreamContext {
.expect("the provider should be set when asked for it")
}
fn get_provider_instance(&self) -> ProviderInstance {
self.llm_provider().create_provider_instance()
fn get_provider(&self) -> Provider {
self.llm_provider().create_provider()
}
fn select_llm_provider(&mut self) {
@ -295,9 +293,9 @@ impl HttpContext for StreamContext {
}
};
let provider_instance = self.get_provider_instance();
let provider = self.get_provider();
let mut deserialized_body = match provider_instance.parse_request(&body_bytes) {
let mut deserialized_body = match provider.parse_request(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
@ -356,10 +354,11 @@ impl HttpContext for StreamContext {
}
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider)
let deserialized_body_bytes = match deserialized_body
.to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible)
{
Ok(bytes) => bytes,
Err(e) => {
@ -529,42 +528,16 @@ impl HttpContext for StreamContext {
}
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
let _provider_id = ProviderId::from(llm_provider_str.as_str());
if self.streaming_response {
// Use the provider instance to parse streaming response
let provider_instance = self.get_provider_instance();
// TODO: Implement streaming response parsing with new provider structure
warn!(
"Streaming response parsing not yet fully implemented with new provider structure"
);
let streaming_events =
match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) {
Ok(events) => events,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
return Action::Continue;
}
};
for event_result in streaming_events {
match event_result {
Ok(event) => {
if let Some(usage) = event.usage() {
self.response_tokens += usage.completion_tokens();
}
}
Err(e) => {
warn!("error in response event: {}", e);
continue;
}
}
}
// Compute TTFT if not already recorded
// For now, just compute TTFT and continue
if self.ttft_duration.is_none() {
// if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
@ -584,9 +557,9 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
let provider_instance = self.get_provider_instance();
let response = match provider_instance.parse_response(&body, &hermes_llm_provider) {
Ok(de) => de,
let provider = self.get_provider();
let _response = match provider.parse_response(&body, ConversionMode::Compatible) {
Ok(response_box) => response_box,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
@ -606,9 +579,9 @@ impl HttpContext for StreamContext {
}
};
if let Some(usage) = response.usage() {
self.response_tokens += usage.completion_tokens();
}
// TODO: Extract usage information from the response box
// For now, we'll skip this until we have a better way to handle Any types
warn!("Response token counting not yet implemented with new provider structure");
}
debug!(