mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
transformations are working. Now need to add some tests next
This commit is contained in:
parent
0aa9243093
commit
e73a9eb61c
6 changed files with 182 additions and 207 deletions
|
|
@ -31,14 +31,6 @@ pub enum SupportedApi {
|
|||
}
|
||||
|
||||
impl SupportedApi {
|
||||
/// Determine if a request/response conversion is required for the given model string
|
||||
pub fn requires_conversion_for_model(&self, model: &str) -> bool {
|
||||
use crate::providers::adapters::is_claude_family;
|
||||
match self {
|
||||
SupportedApi::Anthropic(AnthropicApi::Messages) => !is_claude_family(model),
|
||||
SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => is_claude_family(model),
|
||||
}
|
||||
}
|
||||
/// Create a SupportedApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
|
|
@ -60,14 +52,6 @@ impl SupportedApi {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the API family name
|
||||
pub fn api_family(&self) -> &'static str {
|
||||
match self {
|
||||
SupportedApi::OpenAI(_) => "openai",
|
||||
SupportedApi::Anthropic(_) => "anthropic",
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the target endpoint for a given provider
|
||||
/// For /v1/messages: if provider is Anthropic, use /v1/messages; otherwise use /v1/chat/completions
|
||||
pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str {
|
||||
|
|
@ -83,23 +67,6 @@ impl SupportedApi {
|
|||
_ => self.endpoint()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if request conversion is required for the given provider
|
||||
/// True if we need to convert between Anthropic and OpenAI formats
|
||||
pub fn requires_conversion(&self, provider: &str) -> bool {
|
||||
match self {
|
||||
SupportedApi::Anthropic(AnthropicApi::Messages) => {
|
||||
// If provider is not Anthropic/Claude, we need to convert to OpenAI format
|
||||
!(provider.to_lowercase().contains("anthropic") ||
|
||||
provider.to_lowercase().contains("claude"))
|
||||
}
|
||||
SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => {
|
||||
// If provider is Anthropic/Claude but request is OpenAI format, need conversion
|
||||
provider.to_lowercase().contains("anthropic") ||
|
||||
provider.to_lowercase().contains("claude")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -66,28 +66,30 @@ mod tests {
|
|||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
let api = SupportedApi::OpenAI(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &api, &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
|
||||
let mut streaming_response = result.unwrap();
|
||||
let mut streaming_response = result.unwrap();
|
||||
|
||||
// Test that we can iterate over chunks - it's just an iterator now!
|
||||
let first_chunk = streaming_response.next();
|
||||
assert!(first_chunk.is_some());
|
||||
// Test that we can iterate over chunks - it's just an iterator now!
|
||||
let first_chunk = streaming_response.next();
|
||||
assert!(first_chunk.is_some());
|
||||
|
||||
let chunk_result = first_chunk.unwrap();
|
||||
assert!(chunk_result.is_ok());
|
||||
let chunk_result = first_chunk.unwrap();
|
||||
assert!(chunk_result.is_ok());
|
||||
|
||||
let chunk = chunk_result.unwrap();
|
||||
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||
assert!(!chunk.is_final());
|
||||
let chunk = chunk_result.unwrap();
|
||||
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||
assert!(!chunk.is_final());
|
||||
|
||||
// Test that stream ends properly
|
||||
// Test that stream ends properly
|
||||
let final_chunk = streaming_response.next();
|
||||
assert!(final_chunk.is_none());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
use std::fmt::Display;
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
use crate::apis::{OpenAIApi, AnthropicApi};
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -29,6 +31,22 @@ impl From<&str> for ProviderId {
|
|||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(&self, client_api: &SupportedApi) -> SupportedApi {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Claude, SupportedApi::Anthropic(_)) => client_api.clone(),
|
||||
// Claude/Anthropic providers can also support OpenAI chat completions by mapping to Anthropic Messages
|
||||
(ProviderId::Claude, SupportedApi::OpenAI(OpenAIApi::ChatCompletions)) => SupportedApi::Anthropic(AnthropicApi::Messages),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::Anthropic(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::OpenAI(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ProviderId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
use super::{ProviderId, get_provider_config, AdapterType};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
pub enum ProviderRequestType {
|
||||
|
|
@ -22,53 +21,23 @@ impl TryFrom<&[u8]> for ProviderRequestType {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
|
||||
/// Parse request based on endpoint and provider information
|
||||
impl TryFrom<(&[u8], &SupportedApi)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
fn try_from((bytes, endpoint): (&[u8], &SupportedApi)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match endpoint {
|
||||
SupportedApi::OpenAI(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
AdapterType::AnthropicCompatible => {
|
||||
// For Anthropic providers, try to parse as MessagesRequest first, fallback to ChatCompletionsRequest
|
||||
if let Ok(messages_request) = MessagesRequest::try_from(bytes) {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
} else {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request based on endpoint and provider information
|
||||
impl TryFrom<(&[u8], &str, &ProviderId)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, endpoint, provider_id): (&[u8], &str, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
if let Some(api) = SupportedApi::from_endpoint(endpoint) {
|
||||
match api {
|
||||
SupportedApi::OpenAI(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedApi::Anthropic(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to provider-based parsing for unsupported endpoints
|
||||
Self::try_from((bytes, provider_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::OpenAISseIter;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::providers::adapters::{get_provider_config, AdapterType};
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
//MessagesResponse(MessagesResponse),
|
||||
|
|
@ -16,51 +20,50 @@ pub enum ProviderStreamResponseIter {
|
|||
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
||||
|
||||
// --- Response transformation logic for client API compatibility ---
|
||||
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(&provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
match (&upstream_api, client_api) {
|
||||
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response))
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
AdapterType::AnthropicCompatible => {
|
||||
// TODO: Implement MessagesResponse parsing for Anthropic-compatible providers
|
||||
todo!("AnthropicCompatible response parsing not yet implemented");
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
// If you add a MessagesResponse variant, return it here. For now, just error or serialize as needed.
|
||||
Err(std::io::Error::new(std::io::ErrorKind::Other, "Anthropic response variant not implemented"))
|
||||
}
|
||||
_ => Err(std::io::Error::new(std::io::ErrorKind::Other, "Unsupported response transformation")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
|
||||
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderStreamResponseIter {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
|
||||
// Parse SSE (Server-Sent Events) streaming data - protocol layer
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
// Delegate to OpenAI-specific iterator implementation
|
||||
let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
match (&upstream_api, client_api) {
|
||||
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
AdapterType::AnthropicCompatible => {
|
||||
// TODO: Implement Anthropic streaming support
|
||||
todo!("AnthropicCompatible streaming not yet implemented");
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
// TODO: Implement streaming transformation from OpenAI to Anthropic
|
||||
Err("Anthropic streaming response variant not implemented".into())
|
||||
}
|
||||
_ => Err("Unsupported streaming response transformation".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Iterator for ProviderStreamResponseIter {
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
|
|||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::clients::endpoints::SupportedApi;
|
||||
use hermesllm::providers::response::ProviderStreamResponseIter;
|
||||
use hermesllm::{
|
||||
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
|
||||
};
|
||||
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -33,6 +31,8 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
supported_api: Option<SupportedApi>,
|
||||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedApi>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
|
|
@ -62,6 +62,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
supported_api: None,
|
||||
resolved_api: None,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
request_id: None,
|
||||
|
|
@ -223,6 +224,16 @@ impl HttpContext for StreamContext {
|
|||
let supported_api = SupportedApi::from_endpoint(&request_path);
|
||||
self.supported_api = supported_api;
|
||||
|
||||
// Determine the resolved (upstream) API using provider compatibility
|
||||
if let (Some(api), Some(provider)) =
|
||||
(self.supported_api.as_ref(), self.llm_provider.as_ref())
|
||||
{
|
||||
let provider_id = provider.to_provider_id();
|
||||
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
|
||||
} else {
|
||||
self.resolved_api = None;
|
||||
}
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
|
|
@ -340,22 +351,26 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let provider_id = self.get_provider_id();
|
||||
let request_path = self.get_http_request_header(":path").unwrap_or_default();
|
||||
|
||||
let mut deserialized_body = match ProviderRequestType::try_from((
|
||||
&body_bytes[..],
|
||||
request_path.as_str(),
|
||||
&provider_id,
|
||||
)) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
let mut deserialized_body = match self.resolved_api.as_ref() {
|
||||
Some(resolved_api) => {
|
||||
match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||
ServerError::LogicError("No resolved API for provider".to_string()),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
|
|
@ -603,99 +618,100 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
let provider_id = self.get_provider_id();
|
||||
let supported_api = self.supported_api.as_ref();
|
||||
|
||||
if self.streaming_response {
|
||||
debug!("processing streaming response");
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Compute TTFT on first chunk
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
match (supported_api, self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderStreamResponseIter::try_from((
|
||||
&body[..],
|
||||
supported_api,
|
||||
&provider_id,
|
||||
)) {
|
||||
Ok(mut streaming_response) => {
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!("on_http_response_body: time to first token: {}ms", duration_ms);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
if chunk.is_final() {
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming responses, we handle token counting differently
|
||||
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||
// Token counting for streaming responses typically happens with final usage chunk
|
||||
if chunk.is_final() {
|
||||
// For now, we'll implement basic token estimation
|
||||
// In a complete implementation, the final chunk would contain usage information
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
|
||||
// For now, estimate tokens from content delta
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
// Rough estimation: ~4 characters per token
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider_id = self.get_provider_id();
|
||||
let response: ProviderResponseType =
|
||||
match ProviderResponseType::try_from((&body[..], provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
match (supported_api, self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
|
||||
Ok(response) => match serde_json::to_vec(&response) {
|
||||
Ok(bytes) => {
|
||||
self.set_http_response_body(0, bytes.len(), &bytes);
|
||||
}
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Response serialization error: {}",
|
||||
e
|
||||
)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue