more fixes

This commit is contained in:
Adil Hafeez 2024-10-17 18:15:16 -07:00
parent 3e7f7be838
commit 2923930944
7 changed files with 8 additions and 32 deletions

View file

@ -14,6 +14,7 @@ derivative = "2.2.0"
thiserror = "1.0.64"
tiktoken-rs = "0.5.9"
rand = "0.8.5"
serde_json = "1.0"
[dev-dependencies]
pretty_assertions = "1.4.1"

View file

@ -1,4 +1,4 @@
use crate::stats::{Gauge, IncrementingMetric};
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
use derivative::Derivative;
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> {
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
pub trait Client: Context {
type CallContext: Debug;

View file

@ -10,3 +10,4 @@ pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod errors;

View file

@ -228,6 +228,7 @@ dependencies = [
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",

View file

@ -8,6 +8,7 @@ use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::{ratelimit, routing, tokenizer};
@ -22,25 +23,12 @@ use std::rc::Rc;
use common::stats::IncrementingMetric;
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
Deserialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
}
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
@ -53,7 +41,6 @@ pub struct StreamContext {
}
impl StreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
StreamContext {
context_id,
@ -62,7 +49,6 @@ impl StreamContext {
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
@ -198,8 +184,6 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.request_body_size = body_size;
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -225,7 +209,6 @@ impl HttpContext for StreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
// remove metadata from the request body
deserialized_body.metadata = None;

View file

@ -228,6 +228,7 @@ dependencies = [
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",

View file

@ -21,7 +21,8 @@ use common::consts::{
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::{CallArgs, Client, ClientError};
use common::errors::ClientError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use http::StatusCode;
use log::{debug, info, warn};
@ -103,7 +104,6 @@ pub struct StreamContext {
}
impl StreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
@ -1094,7 +1094,6 @@ impl HttpContext for StreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {