mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into adil/dashboard_update
This commit is contained in:
commit
e2d9b6a008
26 changed files with 320 additions and 257 deletions
6
.github/workflows/checks.yml
vendored
6
.github/workflows/checks.yml
vendored
|
|
@ -19,12 +19,12 @@ jobs:
|
|||
- name: Setup | Install wasm toolchain
|
||||
run: rustup target add wasm32-wasi
|
||||
|
||||
- name: Build wasm module for prompt_gateway
|
||||
run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi
|
||||
|
||||
- name: Run Tests on common crate
|
||||
run: cd crates/common && cargo test
|
||||
|
||||
- name: Build wasm module for prompt_gateway
|
||||
run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi
|
||||
|
||||
- name: Run Tests on prompt_gateway crate
|
||||
run: cd crates/prompt_gateway && cargo test
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ FROM envoyproxy/envoy:v1.31-latest as envoy
|
|||
|
||||
#Build config generator, so that we have a single build image for both Rust and Python
|
||||
FROM python:3-slim as arch
|
||||
COPY --from=builder /arch/prompt_gateway/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
COPY --from=builder /arch/llm_gateway/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy
|
||||
WORKDIR /config
|
||||
COPY arch/requirements.txt .
|
||||
|
|
|
|||
3
arch/tools/.vscode/settings.json
vendored
Normal file
3
arch/tools/.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
|
||||
}
|
||||
|
|
@ -5,19 +5,11 @@
|
|||
"path": "."
|
||||
},
|
||||
{
|
||||
"name": "common",
|
||||
"path": "crates/common"
|
||||
"name": "crates",
|
||||
"path": "crates"
|
||||
},
|
||||
{
|
||||
"name": "prompt_gateway",
|
||||
"path": "crates/prompt_gateway"
|
||||
},
|
||||
{
|
||||
"name": "llm_gateway",
|
||||
"path": "crates/llm_gateway"
|
||||
},
|
||||
{
|
||||
"name": "arch/tools",
|
||||
"name": "archgw_cli",
|
||||
"path": "arch/tools"
|
||||
},
|
||||
{
|
||||
|
|
@ -36,10 +28,15 @@
|
|||
"name": "demos/insurance_agent",
|
||||
"path": "./demos/insurance_agent",
|
||||
},
|
||||
{
|
||||
"name": "demos/function_calling/api_server",
|
||||
"path": "./demos/function_calling/api_server",
|
||||
},
|
||||
],
|
||||
"settings": {}
|
||||
"settings": {
|
||||
},
|
||||
"extensions": {
|
||||
"recommendations": [
|
||||
"ms-python.python",
|
||||
"ms-python.debugpy",
|
||||
"rust-lang.rust-analyzer",
|
||||
"humao.rest-client"
|
||||
]
|
||||
}
|
||||
}
|
||||
1
chatbot_ui/.vscode/launch.json
vendored
1
chatbot_ui/.vscode/launch.json
vendored
|
|
@ -5,6 +5,7 @@
|
|||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"python": "${workspaceFolder}/venv/bin/python",
|
||||
"name": "chatbot-ui",
|
||||
"cwd": "${workspaceFolder}/app",
|
||||
"type": "debugpy",
|
||||
|
|
|
|||
3
chatbot_ui/.vscode/settings.json
vendored
Normal file
3
chatbot_ui/.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
|
||||
}
|
||||
|
|
@ -2,14 +2,21 @@ import json
|
|||
import os
|
||||
from openai import OpenAI, DefaultHttpxClient
|
||||
import gradio as gr
|
||||
import logging as log
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
|
|
@ -23,23 +30,19 @@ def predict(message, state):
|
|||
state["history"] = []
|
||||
history = state.get("history")
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("history: ", history)
|
||||
log.info(f"history: {history}")
|
||||
|
||||
# Custom headers
|
||||
custom_headers = {
|
||||
"x-arch-deterministic-provider": "openai",
|
||||
}
|
||||
|
||||
metadata = None
|
||||
if "arch_state" in state:
|
||||
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
|
||||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(
|
||||
model="--",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
# metadata=metadata,
|
||||
extra_headers=custom_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -49,26 +52,35 @@ def predict(message, state):
|
|||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.info("raw_response: ", raw_response.text)
|
||||
log.error(f"raw_response: {raw_response.text}")
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
response_json = json.loads(raw_response.text)
|
||||
arch_state = None
|
||||
if response_json:
|
||||
metadata = response_json.get("metadata", {})
|
||||
if metadata:
|
||||
arch_state = metadata.get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state["arch_state"] = arch_state
|
||||
# load arch_state from metadata
|
||||
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
|
||||
# parse arch_state into json object
|
||||
arch_state = json.loads(arch_state_str)
|
||||
# load messages from arch_state
|
||||
arch_messages_str = arch_state.get("messages", "[]")
|
||||
# parse messages into json object
|
||||
arch_messages = json.loads(arch_messages_str)
|
||||
# append messages from arch gateway to history
|
||||
for message in arch_messages:
|
||||
history.append(message)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
history.append({"role": "assistant", "content": content, "model": response.model})
|
||||
|
||||
# for gradio UI we don't want to show raw tool calls and messages from developer application
|
||||
# so we're filtering those out
|
||||
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
|
||||
messages = [
|
||||
(history[i]["content"], history[i + 1]["content"])
|
||||
for i in range(0, len(history) - 1, 2)
|
||||
(history_view[i]["content"], history_view[i + 1]["content"])
|
||||
for i in range(0, len(history_view) - 1, 2)
|
||||
]
|
||||
return messages, state
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ derivative = "2.2.0"
|
|||
thiserror = "1.0.64"
|
||||
tiktoken-rs = "0.5.9"
|
||||
rand = "0.8.5"
|
||||
serde_json = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ pub mod open_ai {
|
|||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -381,6 +383,7 @@ mod test {
|
|||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![super::open_ai::ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
|
|
|
|||
|
|
@ -229,20 +229,6 @@ mod test {
|
|||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
assert_eq!(config.version, "v0.1");
|
||||
|
||||
let open_ai_provider = config
|
||||
.llm_providers
|
||||
.iter()
|
||||
.find(|p| p.name.to_lowercase() == "openai")
|
||||
.unwrap();
|
||||
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
|
||||
assert_eq!(
|
||||
open_ai_provider.access_key,
|
||||
Some("OPENAI_API_KEY".to_string())
|
||||
);
|
||||
assert_eq!(open_ai_provider.model, "gpt-4o");
|
||||
assert_eq!(open_ai_provider.default, Some(true));
|
||||
assert_eq!(open_ai_provider.stream, Some(true));
|
||||
|
||||
let prompt_guards = config.prompt_guards.as_ref().unwrap();
|
||||
let input_guards = &prompt_guards.input_guards;
|
||||
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
|
|||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const TOOL_ROLE: &str = "tool";
|
||||
pub const ASSISTANT_ROLE: &str = "assistant";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
|
|
@ -16,7 +18,7 @@ pub const GUARD_INTERNAL_HOST: &str = "guard";
|
|||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
|
|
|||
39
crates/common/src/errors.rs
Normal file
39
crates/common/src/errors.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use proxy_wasm::types::Status;
|
||||
|
||||
use crate::ratelimit;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
|
||||
DispatchError {
|
||||
upstream_name: String,
|
||||
path: String,
|
||||
internal_status: Status,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ServerError {
|
||||
#[error(transparent)]
|
||||
HttpDispatch(ClientError),
|
||||
#[error(transparent)]
|
||||
Deserialization(serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Serialization(serde_json::Error),
|
||||
#[error("{0}")]
|
||||
LogicError(String),
|
||||
#[error("upstream error response authority={authority}, path={path}, status={status}")]
|
||||
Upstream {
|
||||
authority: String,
|
||||
path: String,
|
||||
status: String,
|
||||
},
|
||||
#[error("jailbreak detected: {0}")]
|
||||
Jailbreak(String),
|
||||
#[error("{why}")]
|
||||
NoMessagesFound { why: String },
|
||||
#[error(transparent)]
|
||||
ExceededRatelimit(ratelimit::Error),
|
||||
#[error("{why}")]
|
||||
BadRequest { why: String },
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::stats::{Gauge, IncrementingMetric};
|
||||
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
|
||||
use derivative::Derivative;
|
||||
use log::debug;
|
||||
use proxy_wasm::{traits::Context, types::Status};
|
||||
|
|
@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
|
||||
DispatchError {
|
||||
upstream_name: String,
|
||||
path: String,
|
||||
internal_status: Status,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Client: Context {
|
||||
type CallContext: Debug;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,3 +10,4 @@ pub mod ratelimit;
|
|||
pub mod routing;
|
||||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod errors;
|
||||
|
|
|
|||
1
crates/llm_gateway/Cargo.lock
generated
1
crates/llm_gateway/Cargo.lock
generated
|
|
@ -228,6 +228,7 @@ dependencies = [
|
|||
"proxy-wasm",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"thiserror",
|
||||
"tiktoken-rs",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::llm_stream_context::LlmGatewayStreamContext;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::Configuration;
|
||||
use common::http::Client;
|
||||
use common::llm_providers::LlmProviders;
|
||||
|
|
@ -28,19 +28,19 @@ impl WasmMetrics {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {}
|
||||
pub struct CallContext {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmGatewayFilterContext {
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
callouts: RefCell<HashMap<u32, CallContext>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
}
|
||||
|
||||
impl LlmGatewayFilterContext {
|
||||
pub fn new() -> LlmGatewayFilterContext {
|
||||
LlmGatewayFilterContext {
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
llm_providers: None,
|
||||
|
|
@ -48,8 +48,8 @@ impl LlmGatewayFilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Client for LlmGatewayFilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
impl Client for FilterContext {
|
||||
type CallContext = CallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
|
|
@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for LlmGatewayFilterContext {}
|
||||
impl Context for FilterContext {}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for LlmGatewayFilterContext {
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
|
|
@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext {
|
|||
context_id
|
||||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
Some(Box::new(LlmGatewayStreamContext::new(
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
use llm_filter_context::LlmGatewayFilterContext;
|
||||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod llm_filter_context;
|
||||
mod llm_stream_context;
|
||||
mod filter_context;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(LlmGatewayFilterContext::new())
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::llm_filter_context::WasmMetrics;
|
||||
use crate::filter_context::WasmMetrics;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
Message, ToolCall, ToolCallState,
|
||||
|
|
@ -8,6 +8,7 @@ use common::consts::{
|
|||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::ratelimit::Header;
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
|
|
@ -22,25 +23,12 @@ use std::rc::Rc;
|
|||
|
||||
use common::stats::IncrementingMetric;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ServerError {
|
||||
#[error(transparent)]
|
||||
Deserialization(serde_json::Error),
|
||||
#[error("{0}")]
|
||||
LogicError(String),
|
||||
#[error(transparent)]
|
||||
ExceededRatelimit(ratelimit::Error),
|
||||
#[error("{why}")]
|
||||
BadRequest { why: String },
|
||||
}
|
||||
|
||||
pub struct LlmGatewayStreamContext {
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
tool_call_response: Option<String>,
|
||||
arch_state: Option<Vec<ArchState>>,
|
||||
request_body_size: usize,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
user_prompt: Option<Message>,
|
||||
|
|
@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext {
|
|||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl LlmGatewayStreamContext {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl StreamContext {
|
||||
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
|
||||
LlmGatewayStreamContext {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
chat_completions_request: None,
|
||||
tool_calls: None,
|
||||
tool_call_response: None,
|
||||
arch_state: None,
|
||||
request_body_size: 0,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
|
|
@ -160,7 +146,7 @@ impl LlmGatewayStreamContext {
|
|||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for LlmGatewayStreamContext {
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
|
|
@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.request_body_size = body_size;
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
|
|
@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext {
|
|||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
self.is_chat_completions_request = true;
|
||||
|
||||
// remove metadata from the request body
|
||||
deserialized_body.metadata = None;
|
||||
|
|
@ -333,10 +316,9 @@ impl HttpContext for LlmGatewayStreamContext {
|
|||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
Err(_e) => {
|
||||
debug!("invalid response: {}", String::from_utf8_lossy(&body));
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -418,4 +400,4 @@ impl HttpContext for LlmGatewayStreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for LlmGatewayStreamContext {}
|
||||
impl Context for StreamContext {}
|
||||
1
crates/prompt_gateway/Cargo.lock
generated
1
crates/prompt_gateway/Cargo.lock
generated
|
|
@ -228,6 +228,7 @@ dependencies = [
|
|||
"proxy-wasm",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"thiserror",
|
||||
"tiktoken-rs",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::prompt_stream_context::PromptStreamContext;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::common_types::EmbeddingType;
|
||||
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
|
||||
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
|
||||
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use common::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use common::embeddings::{
|
||||
|
|
@ -9,7 +9,6 @@ use common::embeddings::{
|
|||
};
|
||||
use common::http::CallArgs;
|
||||
use common::http::Client;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::stats::Gauge;
|
||||
use common::stats::IncrementingMetric;
|
||||
use log::debug;
|
||||
|
|
@ -44,31 +43,27 @@ pub struct FilterCallContext {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptGatewayFilterContext {
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
mode: GatewayMode,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
}
|
||||
|
||||
impl PromptGatewayFilterContext {
|
||||
pub fn new() -> PromptGatewayFilterContext {
|
||||
PromptGatewayFilterContext {
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
system_prompt: Rc::new(None),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
mode: GatewayMode::Prompt,
|
||||
llm_providers: None,
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
}
|
||||
|
|
@ -116,7 +111,7 @@ impl PromptGatewayFilterContext {
|
|||
Duration::from_secs(60),
|
||||
);
|
||||
|
||||
let call_context = crate::prompt_filter_context::FilterCallContext {
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
|
|
@ -193,7 +188,7 @@ impl PromptGatewayFilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Client for PromptGatewayFilterContext {
|
||||
impl Client for FilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
|
|
@ -205,7 +200,7 @@ impl Client for PromptGatewayFilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for PromptGatewayFilterContext {
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
|
|
@ -234,7 +229,7 @@ impl Context for PromptGatewayFilterContext {
|
|||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for PromptGatewayFilterContext {
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
|
|
@ -253,17 +248,11 @@ impl RootContext for PromptGatewayFilterContext {
|
|||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.mode = config.mode.unwrap_or_default();
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
|
|
@ -273,12 +262,11 @@ impl RootContext for PromptGatewayFilterContext {
|
|||
context_id
|
||||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
let embedding_store = match self.mode {
|
||||
GatewayMode::Llm => None,
|
||||
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
|
||||
let embedding_store = match self.embeddings_store.as_ref() {
|
||||
None => return None,
|
||||
Some(store) => Some(Rc::clone(store)),
|
||||
};
|
||||
Some(Box::new(PromptStreamContext::new(
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
|
|
@ -299,11 +287,8 @@ impl RootContext for PromptGatewayFilterContext {
|
|||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
debug!("starting up arch filter in mode: {:?}", self.mode);
|
||||
if self.mode == GatewayMode::Prompt {
|
||||
self.process_prompt_targets();
|
||||
}
|
||||
|
||||
debug!("starting up arch filter in mode: prompt gateway mode");
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
use prompt_filter_context::PromptGatewayFilterContext;
|
||||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod prompt_filter_context;
|
||||
mod prompt_stream_context;
|
||||
mod filter_context;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(PromptGatewayFilterContext::new())
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
|
||||
StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
StreamOptions, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -12,23 +12,25 @@ use common::common_types::{
|
|||
};
|
||||
use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use common::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
|
||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
|
||||
};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use common::http::{CallArgs, Client, ClientError};
|
||||
use common::errors::ClientError;
|
||||
use common::http::{CallArgs, Client};
|
||||
use common::stats::Gauge;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use derivative::Derivative;
|
||||
|
||||
use common::stats::IncrementingMetric;
|
||||
|
||||
|
|
@ -43,11 +45,13 @@ enum ResponseHandlerType {
|
|||
DefaultTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Derivative)]
|
||||
#[derivative(Debug)]
|
||||
pub struct StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target_name: Option<String>,
|
||||
#[derivative(Debug = "ignore")]
|
||||
request_body: ChatCompletionsRequest,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
|
|
@ -65,11 +69,12 @@ pub enum ServerError {
|
|||
Serialization(serde_json::Error),
|
||||
#[error("{0}")]
|
||||
LogicError(String),
|
||||
#[error("upstream error response authority={authority}, path={path}, status={status}")]
|
||||
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
|
||||
Upstream {
|
||||
authority: String,
|
||||
host: String,
|
||||
path: String,
|
||||
status: String,
|
||||
body: String,
|
||||
},
|
||||
#[error("jailbreak detected: {0}")]
|
||||
Jailbreak(String),
|
||||
|
|
@ -77,7 +82,7 @@ pub enum ServerError {
|
|||
NoMessagesFound { why: String },
|
||||
}
|
||||
|
||||
pub struct PromptStreamContext {
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
|
|
@ -98,8 +103,7 @@ pub struct PromptStreamContext {
|
|||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl PromptStreamContext {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl StreamContext {
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
|
|
@ -109,7 +113,7 @@ impl PromptStreamContext {
|
|||
overrides: Rc<Option<Overrides>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
) -> Self {
|
||||
PromptStreamContext {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
system_prompt,
|
||||
|
|
@ -145,7 +149,6 @@ impl PromptStreamContext {
|
|||
}
|
||||
|
||||
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
|
|
@ -160,6 +163,7 @@ impl PromptStreamContext {
|
|||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing embedding response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -230,6 +234,7 @@ impl PromptStreamContext {
|
|||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!("error serializing zero shot classification request: {}", error);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -259,6 +264,7 @@ impl PromptStreamContext {
|
|||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching zero shot classification request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -272,6 +278,7 @@ impl PromptStreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing hallucination response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -297,6 +304,7 @@ impl PromptStreamContext {
|
|||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
|
|
@ -335,6 +343,7 @@ impl PromptStreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing zero shot classification response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -446,6 +455,7 @@ impl PromptStreamContext {
|
|||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
|
|
@ -461,6 +471,7 @@ impl PromptStreamContext {
|
|||
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
|
||||
Some(prompt_target) => prompt_target.clone(),
|
||||
None => {
|
||||
debug!("prompt target not found: {}", prompt_target_name);
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Prompt target not found: {prompt_target_name}"
|
||||
|
|
@ -533,6 +544,7 @@ impl PromptStreamContext {
|
|||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("error serializing arch_fc request body: {}", e);
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -565,6 +577,7 @@ impl PromptStreamContext {
|
|||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching arch_fc request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
|
@ -576,6 +589,7 @@ impl PromptStreamContext {
|
|||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing arch_fc response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -689,6 +703,7 @@ impl PromptStreamContext {
|
|||
match serde_json::to_string(&hallucination_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!("error serializing hallucination classification request: {}", error);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -781,17 +796,19 @@ impl PromptStreamContext {
|
|||
fn function_call_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
debug!("upstream error response: {}", http_status);
|
||||
return self.send_server_error(
|
||||
ServerError::Upstream {
|
||||
authority: callout_context.upstream_cluster.unwrap(),
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status,
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
},
|
||||
None,
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -823,11 +840,18 @@ impl PromptStreamContext {
|
|||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
messages.append(callout_context.request_body.messages.as_mut());
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
if m.role == TOOL_ROLE || m.content.is_none() {
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -854,6 +878,7 @@ impl PromptStreamContext {
|
|||
content: Some(final_prompt),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -889,6 +914,7 @@ impl PromptStreamContext {
|
|||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("refrain from discussing jailbreaking.");
|
||||
debug!("jailbreak detected: {}", msg);
|
||||
return self.send_server_error(
|
||||
ServerError::Jailbreak(String::from(msg)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
|
|
@ -912,6 +938,7 @@ impl PromptStreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!("error serializing get embeddings request: {}", error);
|
||||
return self.send_server_error(ServerError::Deserialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -948,6 +975,7 @@ impl PromptStreamContext {
|
|||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
debug!("error dispatching get embeddings request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -981,6 +1009,7 @@ impl PromptStreamContext {
|
|||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
debug!("error deserializing default target response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -1000,6 +1029,7 @@ impl PromptStreamContext {
|
|||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
|
@ -1010,6 +1040,7 @@ impl PromptStreamContext {
|
|||
content: Some(api_resp.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
|
|
@ -1027,7 +1058,7 @@ impl PromptStreamContext {
|
|||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for PromptStreamContext {
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
|
|
@ -1090,7 +1121,6 @@ impl HttpContext for PromptStreamContext {
|
|||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
self.is_chat_completions_request = true;
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
|
|
@ -1256,9 +1286,8 @@ impl HttpContext for PromptStreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
debug!("invalid response: {}", String::from_utf8_lossy(&body));
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1276,55 +1305,42 @@ impl HttpContext for PromptStreamContext {
|
|||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
// compute sha hash from message history
|
||||
let mut hasher = Sha256::new();
|
||||
let prompts: Vec<String> = self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.map(|msg| msg.content.clone().unwrap())
|
||||
.collect();
|
||||
let prompts_merged = prompts.join("#.#");
|
||||
hasher.update(prompts_merged.clone());
|
||||
let hash_key = hasher.finalize();
|
||||
// conver hash to hex string
|
||||
let hash_key_str = format!("{:x}", hash_key);
|
||||
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
|
||||
|
||||
// create new tool call state
|
||||
let tool_call_state = ToolCallState {
|
||||
key: hash_key_str,
|
||||
message: self.user_prompt.clone(),
|
||||
tool_call: tool_calls[0].function.clone(),
|
||||
tool_response: self.tool_call_response.clone().unwrap(),
|
||||
};
|
||||
|
||||
// push tool call state to arch state
|
||||
self.arch_state
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(ArchState::ToolCall(vec![tool_call_state]));
|
||||
|
||||
let mut data: Value = serde_json::from_slice(&body).unwrap();
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
|
||||
debug!("arch_state: {}", arch_state_str);
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
|
||||
// get response, we will send these back to developer so they can see the api response
|
||||
// and tool call arch-fc generated
|
||||
let mut fc_messages = Vec::new();
|
||||
fc_messages.push(Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
});
|
||||
fc_messages.push(Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
});
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
debug!("arch => user: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
|
|
@ -1342,7 +1358,7 @@ impl HttpContext for PromptStreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for PromptStreamContext {
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
|
|
@ -1388,7 +1404,7 @@ impl Context for PromptStreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Client for PromptStreamContext {
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
|
|
@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
|
|
@ -547,6 +546,7 @@ fn request_to_llm_gateway() {
|
|||
},
|
||||
}]),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
|
|
@ -648,6 +648,7 @@ fn request_to_llm_gateway() {
|
|||
content: Some("hello from fake llm gateway".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
|
|
@ -666,8 +667,6 @@ fn request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
|
|
|
|||
3
model_server/.vscode/settings.json
vendored
Normal file
3
model_server/.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
|
||||
}
|
||||
|
|
@ -13,62 +13,38 @@ logger = get_model_server_logger()
|
|||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
tool_call_id: str = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
# TODO: make it default none
|
||||
metadata: Dict[str, str] = {}
|
||||
|
||||
|
||||
def process_state(arch_state, history: list[Message]):
|
||||
logger.info("state: {}".format(arch_state))
|
||||
state_json = json.loads(arch_state)
|
||||
|
||||
state_map = {}
|
||||
if state_json:
|
||||
for tools_state in state_json:
|
||||
for tool_state in tools_state:
|
||||
state_map[tool_state["key"]] = tool_state
|
||||
|
||||
logger.info(f"state_map: {json.dumps(state_map)}")
|
||||
|
||||
sha_history = []
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
if hist.role == "user":
|
||||
sha_history.append(hist.content)
|
||||
sha256_hash = hashlib.sha256()
|
||||
joined_key_str = ("#.#").join(sha_history)
|
||||
sha256_hash.update(joined_key_str.encode())
|
||||
sha_key = sha256_hash.hexdigest()
|
||||
logger.info(f"sha_key: {sha_key}")
|
||||
if sha_key in state_map:
|
||||
tool_call_state = state_map[sha_key]
|
||||
if "tool_call" in tool_call_state:
|
||||
tool_call_str = json.dumps(tool_call_state["tool_call"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
if "tool_response" in tool_call_state:
|
||||
tool_resp = tool_call_state["tool_response"]
|
||||
# TODO: try with role = user as well
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{tool_resp}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
# we dont want to match this state with any other messages
|
||||
del state_map[sha_key]
|
||||
|
||||
if hist.tool_calls:
|
||||
if len(hist.tool_calls) > 1:
|
||||
raise ValueError("Only one tool call is supported")
|
||||
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
elif hist.role == "tool":
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
else:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
return updated_history
|
||||
|
||||
|
||||
|
|
@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
||||
metadata = req.metadata
|
||||
arch_state = metadata.get("x-arch-state", "[]")
|
||||
|
||||
updated_history = process_state(arch_state, req.messages)
|
||||
updated_history = process_messages(req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
|
|
|
|||
66
model_server/app/tests/test_state.py
Normal file
66
model_server/app/tests/test_state.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
import json
|
||||
from app.function_calling.model_utils import Message, process_messages
|
||||
|
||||
test_input_history = """
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "--",
|
||||
"tool_call_id": "call_3394"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "--",
|
||||
"model": "gpt-3.5-turbo-0125"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
history = json.loads(test_input_history)
|
||||
message_history = []
|
||||
for h in history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = process_messages(message_history)
|
||||
assert len(updated_history) == 6
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
Loading…
Add table
Add a link
Reference in a new issue