mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
updated the stream_context to update response bytes
This commit is contained in:
parent
9f6d2464f6
commit
e7238fb7fd
4 changed files with 314 additions and 234 deletions
|
|
@ -22,10 +22,10 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
|||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn chat_completions(
|
||||
pub async fn chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
llm_provider_endpoint: String,
|
||||
full_qualified_llm_provider_url: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
|
|
@ -152,7 +152,7 @@ pub async fn chat_completions(
|
|||
|
||||
debug!(
|
||||
"sending request to llm provider: {}, with model hint: {}",
|
||||
llm_provider_endpoint, model_name
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
|
|
@ -174,7 +174,7 @@ pub async fn chat_completions(
|
|||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(llm_provider_endpoint)
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(chat_request_parsed_bytes)
|
||||
.send()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use brightstaff::handlers::chat_completions::chat_completions;
|
||||
use brightstaff::handlers::chat_completions::chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::consts::CHAT_COMPLETIONS_PATH;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
|
|
@ -67,10 +68,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
|
||||
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
info!("llm provider endpoint: {}", llm_provider_endpoint);
|
||||
info!("llm provider url: {}", llm_provider_url);
|
||||
info!("listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
|
|
@ -88,7 +89,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
llm_provider_endpoint.clone(),
|
||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
));
|
||||
|
|
@ -99,19 +100,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, "/v1/chat/completions") => {
|
||||
chat_completions(req, router_service, llm_provider_endpoint)
|
||||
(&Method::POST, "/v1/chat/completions" | "/v1/messages") => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,14 @@
|
|||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
|
||||
use common::consts::{
|
||||
|
|
@ -13,16 +24,6 @@ use common::{ratelimit, routing, tokenizer};
|
|||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter};
|
||||
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
|
|
@ -30,6 +31,7 @@ pub struct StreamContext {
|
|||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
/// The API that is requested by the client (before compatibility mapping)
|
||||
client_api: Option<SupportedAPIs>,
|
||||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedAPIs>,
|
||||
|
|
@ -191,6 +193,270 @@ impl StreamContext {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// === Helper methods extracted from on_http_response_body (no behavior change) ===
|
||||
#[inline]
|
||||
fn record_ttft_if_needed(&mut self) {
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fn handle_end_of_stream_metrics_and_traces(&mut self, current_time: SystemTime) {
|
||||
// All streaming responses end with bytes=0 and end_stream=true
|
||||
// Record the latency for the request
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
// Convert the duration to milliseconds
|
||||
let duration_ms = duration.as_millis();
|
||||
info!("on_http_response_body: request latency: {}ms", duration_ms);
|
||||
// Record the latency to the latency histogram
|
||||
self.metrics.request_latency.record(duration_ms as u64);
|
||||
|
||||
if self.response_tokens > 0 {
|
||||
// Compute the time per output token
|
||||
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
||||
|
||||
// Record the time per output token
|
||||
self.metrics.time_per_output_token.record(tpot);
|
||||
|
||||
debug!(
|
||||
"time per token: {}ms, tokens per second: {}",
|
||||
tpot,
|
||||
1000 / tpot
|
||||
);
|
||||
// Record the tokens per second
|
||||
self.metrics.tokens_per_second.record(1000 / tpot);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
// Record the output sequence length
|
||||
self.metrics
|
||||
.output_sequence_length
|
||||
.record(self.response_tokens as u64);
|
||||
|
||||
if let Some(traceparent) = self.traceparent.as_ref() {
|
||||
let current_time_ns = current_time_ns();
|
||||
|
||||
match Traceparent::try_from(traceparent.to_string()) {
|
||||
Err(e) => {
|
||||
warn!("traceparent header is invalid: {}", e);
|
||||
}
|
||||
Ok(traceparent) => {
|
||||
let mut trace_data = common::tracing::TraceData::new();
|
||||
let mut llm_span = Span::new(
|
||||
"egress_traffic".to_string(),
|
||||
Some(traceparent.trace_id),
|
||||
Some(traceparent.parent_id),
|
||||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
llm_span
|
||||
.add_attribute("model".to_string(), self.llm_provider().name.to_string());
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
self.ttft_time.unwrap(),
|
||||
));
|
||||
trace_data.add_span(llm_span);
|
||||
}
|
||||
|
||||
self.traces_queue.lock().unwrap().push_back(trace_data);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn read_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
|
||||
if self.streaming_response {
|
||||
let chunk_start = 0;
|
||||
let chunk_size = body_size;
|
||||
debug!(
|
||||
"on_http_response_body: streaming response reading, {}..{}",
|
||||
chunk_start, chunk_size
|
||||
);
|
||||
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
|
||||
Some(chunk) => chunk,
|
||||
None => {
|
||||
warn!(
|
||||
"response body empty, chunk_start: {}, chunk_size: {}",
|
||||
chunk_start, chunk_size
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
|
||||
if streaming_chunk.len() != chunk_size {
|
||||
warn!(
|
||||
"chunk size mismatch: read: {} != requested: {}",
|
||||
streaming_chunk.len(),
|
||||
chunk_size
|
||||
);
|
||||
}
|
||||
Ok(streaming_chunk)
|
||||
} else {
|
||||
if body_size == 0 {
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
debug!("non streaming response bytes read: 0:{}", body_size);
|
||||
match self.get_http_response_body(0, body_size) {
|
||||
Some(body) => Ok(body),
|
||||
None => {
|
||||
warn!("non streaming response body empty");
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn debug_log_body(&self, body: &[u8]) {
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
debug!(
|
||||
"response data (converted to utf8): {}",
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
supported_api: SupportedAPIs,
|
||||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
debug!("processing streaming response");
|
||||
match (Some(supported_api), self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) {
|
||||
Ok(mut streaming_response) => {
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
if chunk.is_final() {
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
}
|
||||
}
|
||||
// NOTE:
|
||||
// We currently pass-through the original SSE bytes for streaming responses.
|
||||
// Non-streaming responses are parsed into ProviderResponseType and re-serialized to
|
||||
// normalize the payload to the client API. Doing the same for streaming would require
|
||||
// a streaming serializer that emits normalized SSE events for the target client API.
|
||||
// That doesn't exist yet in hermesllm; implementing it is a follow-up.
|
||||
// TODO(salmanap): Add a normalized SSE serializer in hermesllm and use it here so both
|
||||
// streaming and non-streaming paths perform the same compatibility mapping.
|
||||
// Until then, we keep behavior unchanged and forward upstream SSE as-is.
|
||||
// For consistency of the method contract, still return Vec<u8>.
|
||||
Ok(body.to_vec())
|
||||
}
|
||||
|
||||
fn handle_non_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
supported_api: SupportedAPIs,
|
||||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
debug!("non streaming response");
|
||||
|
||||
let response: ProviderResponseType =
|
||||
match (Some(&supported_api), self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderResponseType::try_from((body, supported_api, &provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
|
||||
// Serialize the normalized response back to JSON bytes
|
||||
match serde_json::to_vec(&response) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize normalized response: {}", e);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response serialization error: {}", e)),
|
||||
Some(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
);
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
|
|
@ -457,232 +723,44 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let current_time = get_current_time().unwrap();
|
||||
if end_of_stream && body_size == 0 {
|
||||
// All streaming responses end with bytes=0 and end_stream=true
|
||||
// Record the latency for the request
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
// Convert the duration to milliseconds
|
||||
let duration_ms = duration.as_millis();
|
||||
info!("on_http_response_body: request latency: {}ms", duration_ms);
|
||||
// Record the latency to the latency histogram
|
||||
self.metrics.request_latency.record(duration_ms as u64);
|
||||
|
||||
if self.response_tokens > 0 {
|
||||
// Compute the time per output token
|
||||
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
||||
|
||||
// Record the time per output token
|
||||
self.metrics.time_per_output_token.record(tpot);
|
||||
|
||||
debug!(
|
||||
"time per token: {}ms, tokens per second: {}",
|
||||
tpot,
|
||||
1000 / tpot
|
||||
);
|
||||
// Record the tokens per second
|
||||
self.metrics.tokens_per_second.record(1000 / tpot);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
// Record the output sequence length
|
||||
self.metrics
|
||||
.output_sequence_length
|
||||
.record(self.response_tokens as u64);
|
||||
|
||||
if let Some(traceparent) = self.traceparent.as_ref() {
|
||||
let current_time_ns = current_time_ns();
|
||||
|
||||
match Traceparent::try_from(traceparent.to_string()) {
|
||||
Err(e) => {
|
||||
warn!("traceparent header is invalid: {}", e);
|
||||
}
|
||||
Ok(traceparent) => {
|
||||
let mut trace_data = common::tracing::TraceData::new();
|
||||
let mut llm_span = Span::new(
|
||||
"egress_traffic".to_string(),
|
||||
Some(traceparent.trace_id),
|
||||
Some(traceparent.parent_id),
|
||||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
llm_span.add_attribute(
|
||||
"model".to_string(),
|
||||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span
|
||||
.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
self.ttft_time.unwrap(),
|
||||
));
|
||||
trace_data.add_span(llm_span);
|
||||
}
|
||||
|
||||
self.traces_queue.lock().unwrap().push_back(trace_data);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
self.handle_end_of_stream_metrics_and_traces(current_time);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
let body = if self.streaming_response {
|
||||
let chunk_start = 0;
|
||||
let chunk_size = body_size;
|
||||
debug!(
|
||||
"on_http_response_body: streaming response reading, {}..{}",
|
||||
chunk_start, chunk_size
|
||||
);
|
||||
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
|
||||
Some(chunk) => chunk,
|
||||
None => {
|
||||
warn!(
|
||||
"response body empty, chunk_start: {}, chunk_size: {}",
|
||||
chunk_start, chunk_size
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if streaming_chunk.len() != chunk_size {
|
||||
warn!(
|
||||
"chunk size mismatch: read: {} != requested: {}",
|
||||
streaming_chunk.len(),
|
||||
chunk_size
|
||||
);
|
||||
}
|
||||
streaming_chunk
|
||||
} else {
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
debug!("non streaming response bytes read: 0:{}", body_size);
|
||||
match self.get_http_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
warn!("non streaming response body empty");
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
let body = match self.read_response_body(body_size) {
|
||||
Ok(b) => b,
|
||||
Err(action) => return action,
|
||||
};
|
||||
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
debug!(
|
||||
"response data (converted to utf8): {}",
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
}
|
||||
self.debug_log_body(&body);
|
||||
|
||||
let provider_id = self.get_provider_id();
|
||||
let supported_api = self.client_api.as_ref();
|
||||
let supported_api_opt = self.client_api.clone();
|
||||
|
||||
if self.streaming_response {
|
||||
debug!("processing streaming response");
|
||||
match (supported_api, self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderStreamResponseIter::try_from((
|
||||
&body[..],
|
||||
supported_api,
|
||||
&provider_id,
|
||||
)) {
|
||||
Ok(mut streaming_response) => {
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!("on_http_response_body: time to first token: {}ms", duration_ms);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
if chunk.is_final() {
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
if let Some(supported_api) = supported_api_opt {
|
||||
match self.handle_streaming_response(&body, supported_api, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
// Write the normalized body back to the wire using the original body_size
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
}
|
||||
} else {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider_id = self.get_provider_id();
|
||||
let supported_api = self.client_api.as_ref();
|
||||
|
||||
let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
if let Some(supported_api) = supported_api_opt {
|
||||
match self.handle_non_streaming_response(&body, supported_api, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
// Write the normalized body back to the wire using the original body_size
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue