mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
updating the implementation of /v1/chat/completions to use the generic provider interfaces
This commit is contained in:
parent
93ff4d7b1f
commit
203fc8f9a9
8 changed files with 441 additions and 89 deletions
|
|
@ -177,6 +177,18 @@ impl Display for LlmProviderType {
|
|||
}
|
||||
}
|
||||
|
||||
impl LlmProviderType {
|
||||
/// Create a ProviderInstance from this LlmProviderType
|
||||
/// This is the main method for stream_context to get provider-specific interfaces
|
||||
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
|
||||
use hermesllm::ProviderInstance;
|
||||
|
||||
// For now, all providers use OpenAI-compatible APIs
|
||||
// TODO: Return specific provider instances when implementing different APIs
|
||||
ProviderInstance::from_name(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
|
|
@ -252,6 +264,14 @@ impl Display for LlmProvider {
|
|||
}
|
||||
}
|
||||
|
||||
impl LlmProvider {
|
||||
/// Create a ProviderInstance for this LlmProvider
|
||||
/// This is a convenience method that delegates to the provider_interface
|
||||
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
|
||||
self.provider_interface.create_provider_instance()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,11 @@ pub mod providers;
|
|||
pub mod apis;
|
||||
pub mod clients;
|
||||
|
||||
// Re-export important traits
|
||||
pub use providers::traits::*;
|
||||
pub use providers::openai::provider::OpenAIProvider;
|
||||
pub use providers::provider_enum::ProviderInstance;
|
||||
|
||||
|
||||
use std::fmt::Display;
|
||||
pub enum Provider {
|
||||
|
|
@ -34,6 +39,45 @@ impl From<&str> for Provider {
|
|||
}
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// Get the API endpoint path for this provider
|
||||
pub fn api_path(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::OpenAI => "/v1/chat/completions",
|
||||
Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint
|
||||
Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path
|
||||
Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path
|
||||
Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API
|
||||
Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API
|
||||
Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint
|
||||
Provider::Github => "/models", // TODO: Update with correct GitHub models path
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider uses OpenAI-compatible API format
|
||||
pub fn uses_openai_format(&self) -> bool {
|
||||
match self {
|
||||
Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true,
|
||||
Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a provider implementation instance for this provider
|
||||
pub fn create_provider_instance(&self) -> ProviderInstance {
|
||||
match self {
|
||||
Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider),
|
||||
Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API
|
||||
Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API
|
||||
Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API
|
||||
Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API
|
||||
// TODO: Implement specific providers for these when they have different APIs
|
||||
Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Provider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
pub mod openai;
|
||||
pub mod traits;
|
||||
pub mod provider_enum;
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod builder;
|
||||
pub mod types;
|
||||
pub mod provider;
|
||||
|
|
|
|||
171
crates/hermesllm/src/providers/openai/provider.rs
Normal file
171
crates/hermesllm/src/providers/openai/provider.rs
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
//! OpenAI provider interface implementations
|
||||
|
||||
use crate::apis::openai::*;
|
||||
use crate::providers::traits::*;
|
||||
use crate::Provider;
|
||||
|
||||
// Simple error type for OpenAI API operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIApiError {
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("UTF-8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Invalid streaming data: {0}")]
|
||||
InvalidStreamingData(String),
|
||||
#[error("Request conversion error: {0}")]
|
||||
RequestConversionError(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OpenAI Provider Definition
|
||||
// ============================================================================
|
||||
|
||||
pub struct OpenAIProvider;
|
||||
|
||||
// Create a concrete streaming response type to avoid lifetime issues
|
||||
pub struct OpenAIStreamingResponse {
|
||||
lines: Vec<String>,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl OpenAIStreamingResponse {
|
||||
fn new(data: String) -> Self {
|
||||
let lines: Vec<String> = data.lines().map(|s| s.to_string()).collect();
|
||||
Self {
|
||||
lines,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for OpenAIStreamingResponse {
|
||||
type Item = Result<ChatCompletionsStreamResponse, OpenAIApiError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while self.current_index < self.lines.len() {
|
||||
let line = &self.lines[self.current_index];
|
||||
self.current_index += 1;
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
let data = data.trim();
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue; // Skip ping messages
|
||||
}
|
||||
|
||||
return Some(
|
||||
serde_json::from_str::<ChatCompletionsStreamResponse>(data).map_err(|e| {
|
||||
OpenAIApiError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for OpenAIProvider {
|
||||
type Request = ChatCompletionsRequest;
|
||||
type Response = ChatCompletionsResponse;
|
||||
type StreamingResponse = OpenAIStreamingResponse;
|
||||
type Usage = Usage;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trait Implementations for OpenAI Types
|
||||
// ============================================================================
|
||||
|
||||
impl ProviderRequest for ChatCompletionsRequest {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, _provider: Provider) -> Result<Vec<u8>, Self::Error> {
|
||||
Ok(serde_json::to_vec(self)?)
|
||||
}
|
||||
|
||||
fn extract_model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&mut self) {
|
||||
if self.stream_options.is_none() {
|
||||
self.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => {
|
||||
parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_tokens as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ChatCompletionsResponse {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn usage(&self) -> Option<&Self::Usage> {
|
||||
Some(&self.usage)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamChunk for ChatCompletionsStreamResponse {
|
||||
type Usage = Usage;
|
||||
|
||||
fn usage(&self) -> Option<&Self::Usage> {
|
||||
self.usage.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for OpenAIStreamingResponse {
|
||||
type Error = OpenAIApiError;
|
||||
type Chunk = ChatCompletionsStreamResponse;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
}
|
||||
67
crates/hermesllm/src/providers/provider_enum.rs
Normal file
67
crates/hermesllm/src/providers/provider_enum.rs
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
use crate::providers::traits::*;
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
|
||||
/// Enum that wraps all possible providers for dynamic dispatch
|
||||
pub enum ProviderInstance {
|
||||
OpenAI(OpenAIProvider),
|
||||
// TODO: Add other providers as they are implemented
|
||||
// Anthropic(AnthropicProvider),
|
||||
// Mistral(MistralProvider),
|
||||
// etc.
|
||||
}
|
||||
|
||||
impl ProviderInstance {
|
||||
/// Creates a provider from a provider name string
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
match name.to_lowercase().as_str() {
|
||||
"openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => {
|
||||
ProviderInstance::OpenAI(OpenAIProvider)
|
||||
}
|
||||
// TODO: Add other providers when implemented
|
||||
// "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider),
|
||||
// "mistral" => ProviderInstance::Mistral(MistralProvider),
|
||||
_ => {
|
||||
// Default to OpenAI for unknown providers
|
||||
ProviderInstance::OpenAI(OpenAIProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse request from bytes using the appropriate provider
|
||||
pub fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse response from bytes using the appropriate provider
|
||||
pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse streaming response from bytes using the appropriate provider
|
||||
pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<OpenAIStreamingResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match self {
|
||||
ProviderInstance::OpenAI(_) => {
|
||||
OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
// TODO: Add other provider cases when implemented
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ProviderInstance {
|
||||
type Request = ChatCompletionsRequest;
|
||||
type Response = ChatCompletionsResponse;
|
||||
type StreamingResponse = OpenAIStreamingResponse;
|
||||
type Usage = Usage;
|
||||
}
|
||||
74
crates/hermesllm/src/providers/traits.rs
Normal file
74
crates/hermesllm/src/providers/traits.rs
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
//! Provider traits for generic request/response handling
|
||||
//!
|
||||
//! This module defines the core traits that enable provider-agnostic
|
||||
//! handling of LLM requests and responses in the gateway.
|
||||
|
||||
use std::error::Error;
|
||||
use crate::Provider;
|
||||
|
||||
/// Trait for provider-specific request types
|
||||
pub trait ProviderRequest: Sized {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
|
||||
/// Parse request from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Convert to provider-specific format
|
||||
fn to_provider_bytes(&self, provider: Provider) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn extract_model(&self) -> &str;
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self) -> String;
|
||||
}
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Trait for provider-specific response types
|
||||
pub trait ProviderResponse: Sized {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
}
|
||||
|
||||
/// Trait for streaming response chunks
|
||||
pub trait StreamChunk {
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
}
|
||||
|
||||
/// Trait for streaming response iterators
|
||||
pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> + Sized {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type Chunk: StreamChunk;
|
||||
|
||||
/// Parse streaming response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
|
||||
}
|
||||
|
||||
/// Main provider interface trait
|
||||
pub trait ProviderInterface {
|
||||
type Request: ProviderRequest;
|
||||
type Response: ProviderResponse;
|
||||
type StreamingResponse: StreamingResponse;
|
||||
type Usage: TokenUsage;
|
||||
}
|
||||
|
|
@ -10,11 +10,9 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
|
||||
use hermesllm::providers::openai::types::{
|
||||
ChatCompletionsResponse, ContentType, Message, StreamOptions,
|
||||
use hermesllm::{
|
||||
Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage,
|
||||
};
|
||||
use hermesllm::Provider;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -41,7 +39,6 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
user_message: Option<Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
}
|
||||
|
|
@ -69,7 +66,6 @@ impl StreamContext {
|
|||
ttft_duration: None,
|
||||
traceparent: None,
|
||||
ttft_time: None,
|
||||
user_message: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
}
|
||||
|
|
@ -80,6 +76,10 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider_instance(&self) -> ProviderInstance {
|
||||
self.llm_provider().create_provider_instance()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
|
|
@ -295,52 +295,39 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
|
||||
let provider_instance = self.get_provider_instance();
|
||||
|
||||
let mut deserialized_body = match provider_instance.parse_request(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.user_message = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user")
|
||||
.last()
|
||||
.cloned();
|
||||
// TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific
|
||||
// This could be made generic by adding a trait method later
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
let _use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
let model_requested = deserialized_body.extract_model().to_string();
|
||||
// Note: We can't directly modify the model field through the trait,
|
||||
// this would need to be handled differently in a full generic implementation
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -349,32 +336,17 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
if deserialized_body.is_streaming() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.stream.unwrap_or_default()
|
||||
&& deserialized_body.stream_options.is_none()
|
||||
{
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
});
|
||||
if deserialized_body.is_streaming() {
|
||||
deserialized_body.set_streaming_options();
|
||||
}
|
||||
|
||||
// only use the tokens from the messages, excluding the metadata and json tags
|
||||
let input_tokens_str = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " "
|
||||
+ m.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text(String::new()))
|
||||
.to_string()
|
||||
.as_str()
|
||||
});
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
|
||||
{
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
|
|
@ -387,11 +359,15 @@ impl HttpContext for StreamContext {
|
|||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
|
||||
let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider)
|
||||
{
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
|
@ -484,12 +460,6 @@ impl HttpContext for StreamContext {
|
|||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
if let Some(user_message) = self.user_message.as_ref() {
|
||||
if let Some(prompt) = user_message.content.as_ref() {
|
||||
llm_span
|
||||
.add_attribute("user_prompt".to_string(), prompt.to_string());
|
||||
}
|
||||
}
|
||||
llm_span.add_attribute(
|
||||
"model".to_string(),
|
||||
self.llm_provider().name.to_string(),
|
||||
|
|
@ -562,8 +532,11 @@ impl HttpContext for StreamContext {
|
|||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
let chat_completions_chunk_response_events =
|
||||
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
// Use the provider instance to parse streaming response
|
||||
let provider_instance = self.get_provider_instance();
|
||||
|
||||
let streaming_events =
|
||||
match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
|
@ -575,11 +548,11 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
for event in chat_completions_chunk_response_events {
|
||||
match event {
|
||||
for event_result in streaming_events {
|
||||
match event_result {
|
||||
Ok(event) => {
|
||||
if let Some(usage) = event.usage.as_ref() {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
if let Some(usage) = event.usage() {
|
||||
self.response_tokens += usage.completion_tokens();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -611,30 +584,30 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response =
|
||||
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::OpenAIPError(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
let provider_instance = self.get_provider_instance();
|
||||
let response = match provider_instance.parse_response(&body, &hermes_llm_provider) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(usage) = chat_completions_response.usage {
|
||||
self.response_tokens += usage.completion_tokens;
|
||||
if let Some(usage) = response.usage() {
|
||||
self.response_tokens += usage.completion_tokens();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue