Code refactor and some improvements - see description (#194)

This commit is contained in:
Adil Hafeez 2024-10-18 12:53:44 -07:00 committed by GitHub
parent aa30353c85
commit c6ba28dfcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 100 additions and 115 deletions

View file

@ -229,20 +229,6 @@ mod test {
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
assert_eq!(config.version, "v0.1");
let open_ai_provider = config
.llm_providers
.iter()
.find(|p| p.name.to_lowercase() == "openai")
.unwrap();
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
assert_eq!(
open_ai_provider.access_key,
Some("OPENAI_API_KEY".to_string())
);
assert_eq!(open_ai_provider.model, "gpt-4o");
assert_eq!(open_ai_provider.default, Some(true));
assert_eq!(open_ai_provider.stream, Some(true));
let prompt_guards = config.prompt_guards.as_ref().unwrap();
let input_guards = &prompt_guards.input_guards;
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();

View file

@ -12,7 +12,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const REQUEST_ID_HEADER: &str = "x-request-id";

View file

@ -0,0 +1,39 @@
use proxy_wasm::types::Status;
use crate::ratelimit;
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
HttpDispatch(ClientError),
#[error(transparent)]
Deserialization(serde_json::Error),
#[error(transparent)]
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
Upstream {
authority: String,
path: String,
status: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),
#[error("{why}")]
NoMessagesFound { why: String },
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
}

View file

@ -1,4 +1,4 @@
use crate::stats::{Gauge, IncrementingMetric};
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
use derivative::Derivative;
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> {
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
pub trait Client: Context {
type CallContext: Debug;

View file

@ -10,3 +10,4 @@ pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod errors;