mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 16:22:42 +02:00
more fixes
This commit is contained in:
parent
3e7f7be838
commit
2923930944
7 changed files with 8 additions and 32 deletions
|
|
@ -14,6 +14,7 @@ derivative = "2.2.0"
|
||||||
thiserror = "1.0.64"
|
thiserror = "1.0.64"
|
||||||
tiktoken-rs = "0.5.9"
|
tiktoken-rs = "0.5.9"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
serde_json = "1.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
pretty_assertions = "1.4.1"
|
pretty_assertions = "1.4.1"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::stats::{Gauge, IncrementingMetric};
|
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
|
||||||
use derivative::Derivative;
|
use derivative::Derivative;
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use proxy_wasm::{traits::Context, types::Status};
|
use proxy_wasm::{traits::Context, types::Status};
|
||||||
|
|
@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum ClientError {
|
|
||||||
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
|
|
||||||
DispatchError {
|
|
||||||
upstream_name: String,
|
|
||||||
path: String,
|
|
||||||
internal_status: Status,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Client: Context {
|
pub trait Client: Context {
|
||||||
type CallContext: Debug;
|
type CallContext: Debug;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,4 @@ pub mod ratelimit;
|
||||||
pub mod routing;
|
pub mod routing;
|
||||||
pub mod stats;
|
pub mod stats;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
|
pub mod errors;
|
||||||
|
|
|
||||||
1
crates/llm_gateway/Cargo.lock
generated
1
crates/llm_gateway/Cargo.lock
generated
|
|
@ -228,6 +228,7 @@ dependencies = [
|
||||||
"proxy-wasm",
|
"proxy-wasm",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ use common::consts::{
|
||||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
|
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
|
||||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
|
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
|
||||||
};
|
};
|
||||||
|
use common::errors::ServerError;
|
||||||
use common::llm_providers::LlmProviders;
|
use common::llm_providers::LlmProviders;
|
||||||
use common::ratelimit::Header;
|
use common::ratelimit::Header;
|
||||||
use common::{ratelimit, routing, tokenizer};
|
use common::{ratelimit, routing, tokenizer};
|
||||||
|
|
@ -22,25 +23,12 @@ use std::rc::Rc;
|
||||||
|
|
||||||
use common::stats::IncrementingMetric;
|
use common::stats::IncrementingMetric;
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum ServerError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Deserialization(serde_json::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
LogicError(String),
|
|
||||||
#[error(transparent)]
|
|
||||||
ExceededRatelimit(ratelimit::Error),
|
|
||||||
#[error("{why}")]
|
|
||||||
BadRequest { why: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct StreamContext {
|
pub struct StreamContext {
|
||||||
context_id: u32,
|
context_id: u32,
|
||||||
metrics: Rc<WasmMetrics>,
|
metrics: Rc<WasmMetrics>,
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
tool_call_response: Option<String>,
|
tool_call_response: Option<String>,
|
||||||
arch_state: Option<Vec<ArchState>>,
|
arch_state: Option<Vec<ArchState>>,
|
||||||
request_body_size: usize,
|
|
||||||
ratelimit_selector: Option<Header>,
|
ratelimit_selector: Option<Header>,
|
||||||
streaming_response: bool,
|
streaming_response: bool,
|
||||||
user_prompt: Option<Message>,
|
user_prompt: Option<Message>,
|
||||||
|
|
@ -53,7 +41,6 @@ pub struct StreamContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamContext {
|
impl StreamContext {
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
|
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
|
||||||
StreamContext {
|
StreamContext {
|
||||||
context_id,
|
context_id,
|
||||||
|
|
@ -62,7 +49,6 @@ impl StreamContext {
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
tool_call_response: None,
|
tool_call_response: None,
|
||||||
arch_state: None,
|
arch_state: None,
|
||||||
request_body_size: 0,
|
|
||||||
ratelimit_selector: None,
|
ratelimit_selector: None,
|
||||||
streaming_response: false,
|
streaming_response: false,
|
||||||
user_prompt: None,
|
user_prompt: None,
|
||||||
|
|
@ -198,8 +184,6 @@ impl HttpContext for StreamContext {
|
||||||
return Action::Continue;
|
return Action::Continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.request_body_size = body_size;
|
|
||||||
|
|
||||||
// Deserialize body into spec.
|
// Deserialize body into spec.
|
||||||
// Currently OpenAI API.
|
// Currently OpenAI API.
|
||||||
let mut deserialized_body: ChatCompletionsRequest =
|
let mut deserialized_body: ChatCompletionsRequest =
|
||||||
|
|
@ -225,7 +209,6 @@ impl HttpContext for StreamContext {
|
||||||
return Action::Pause;
|
return Action::Pause;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.is_chat_completions_request = true;
|
|
||||||
|
|
||||||
// remove metadata from the request body
|
// remove metadata from the request body
|
||||||
deserialized_body.metadata = None;
|
deserialized_body.metadata = None;
|
||||||
|
|
|
||||||
1
crates/prompt_gateway/Cargo.lock
generated
1
crates/prompt_gateway/Cargo.lock
generated
|
|
@ -228,6 +228,7 @@ dependencies = [
|
||||||
"proxy-wasm",
|
"proxy-wasm",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,8 @@ use common::consts::{
|
||||||
use common::embeddings::{
|
use common::embeddings::{
|
||||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||||
};
|
};
|
||||||
use common::http::{CallArgs, Client, ClientError};
|
use common::errors::ClientError;
|
||||||
|
use common::http::{CallArgs, Client};
|
||||||
use common::stats::Gauge;
|
use common::stats::Gauge;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
|
|
@ -103,7 +104,6 @@ pub struct StreamContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamContext {
|
impl StreamContext {
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
context_id: u32,
|
context_id: u32,
|
||||||
metrics: Rc<WasmMetrics>,
|
metrics: Rc<WasmMetrics>,
|
||||||
|
|
@ -1094,7 +1094,6 @@ impl HttpContext for StreamContext {
|
||||||
return Action::Pause;
|
return Action::Pause;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.is_chat_completions_request = true;
|
|
||||||
|
|
||||||
self.arch_state = match deserialized_body.metadata {
|
self.arch_state = match deserialized_body.metadata {
|
||||||
Some(ref metadata) => {
|
Some(ref metadata) => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue