From 2923930944fd77aeeece1c93243a7e86584cdeea Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 17 Oct 2024 18:15:16 -0700 Subject: [PATCH] more fixes --- crates/common/Cargo.toml | 1 + crates/common/src/http.rs | 12 +----------- crates/common/src/lib.rs | 1 + crates/llm_gateway/Cargo.lock | 1 + crates/llm_gateway/src/stream_context.rs | 19 +------------------ crates/prompt_gateway/Cargo.lock | 1 + crates/prompt_gateway/src/stream_context.rs | 5 ++--- 7 files changed, 8 insertions(+), 32 deletions(-) diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index a362da9c..4651c610 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -14,6 +14,7 @@ derivative = "2.2.0" thiserror = "1.0.64" tiktoken-rs = "0.5.9" rand = "0.8.5" +serde_json = "1.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 21380b0f..842818e2 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,4 @@ -use crate::stats::{Gauge, IncrementingMetric}; +use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; @@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> { } } -#[derive(thiserror::Error, Debug)] -pub enum ClientError { - #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] - DispatchError { - upstream_name: String, - path: String, - internal_status: Status, - }, -} - pub trait Client: Context { type CallContext: Debug; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 27a51803..c23443ca 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -10,3 +10,4 @@ pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; +pub mod errors; diff --git a/crates/llm_gateway/Cargo.lock b/crates/llm_gateway/Cargo.lock index 35182863..19ce3747 100644 --- a/crates/llm_gateway/Cargo.lock +++ b/crates/llm_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index e1790552..655f76ff 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -8,6 +8,7 @@ use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE, }; +use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::ratelimit::Header; use common::{ratelimit, routing, tokenizer}; @@ -22,25 +23,12 @@ use std::rc::Rc; use common::stats::IncrementingMetric; -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - Deserialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), - #[error("{why}")] - BadRequest { why: String }, -} - pub struct StreamContext { context_id: u32, metrics: Rc, tool_calls: Option>, tool_call_response: Option, arch_state: Option>, - request_body_size: usize, ratelimit_selector: Option
, streaming_response: bool, user_prompt: Option, @@ -53,7 +41,6 @@ pub struct StreamContext { } impl StreamContext { - #[allow(clippy::too_many_arguments)] pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { StreamContext { context_id, @@ -62,7 +49,6 @@ impl StreamContext { tool_calls: None, tool_call_response: None, arch_state: None, - request_body_size: 0, ratelimit_selector: None, streaming_response: false, user_prompt: None, @@ -198,8 +184,6 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.request_body_size = body_size; - // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -225,7 +209,6 @@ impl HttpContext for StreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; // remove metadata from the request body deserialized_body.metadata = None; diff --git a/crates/prompt_gateway/Cargo.lock b/crates/prompt_gateway/Cargo.lock index 63de3b3f..7679b301 100644 --- a/crates/prompt_gateway/Cargo.lock +++ b/crates/prompt_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 97c67974..602f1629 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -21,7 +21,8 @@ use common::consts::{ use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::http::{CallArgs, Client, ClientError}; +use common::errors::ClientError; +use common::http::{CallArgs, Client}; use common::stats::Gauge; use http::StatusCode; use log::{debug, info, warn}; @@ -103,7 +104,6 @@ pub struct StreamContext { } impl StreamContext { - #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, metrics: Rc, @@ -1094,7 +1094,6 @@ impl HttpContext for StreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; self.arch_state = match deserialized_body.metadata { Some(ref metadata) => {