Merge branch 'main' into adil/dashboard_update

This commit is contained in:
Adil Hafeez 2024-10-18 14:28:50 -07:00
commit e2d9b6a008
26 changed files with 320 additions and 257 deletions

View file

@ -19,12 +19,12 @@ jobs:
- name: Setup | Install wasm toolchain
run: rustup target add wasm32-wasi
- name: Build wasm module for prompt_gateway
run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi
- name: Run Tests on common crate
run: cd crates/common && cargo test
- name: Build wasm module for prompt_gateway
run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi
- name: Run Tests on prompt_gateway crate
run: cd crates/prompt_gateway && cargo test

View file

@ -12,8 +12,8 @@ FROM envoyproxy/envoy:v1.31-latest as envoy
#Build config generator, so that we have a single build image for both Rust and Python
FROM python:3-slim as arch
COPY --from=builder /arch/prompt_gateway/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
COPY --from=builder /arch/llm_gateway/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy
WORKDIR /config
COPY arch/requirements.txt .

3
arch/tools/.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
}

View file

@ -5,19 +5,11 @@
"path": "."
},
{
"name": "common",
"path": "crates/common"
"name": "crates",
"path": "crates"
},
{
"name": "prompt_gateway",
"path": "crates/prompt_gateway"
},
{
"name": "llm_gateway",
"path": "crates/llm_gateway"
},
{
"name": "arch/tools",
"name": "archgw_cli",
"path": "arch/tools"
},
{
@ -36,10 +28,15 @@
"name": "demos/insurance_agent",
"path": "./demos/insurance_agent",
},
{
"name": "demos/function_calling/api_server",
"path": "./demos/function_calling/api_server",
},
],
"settings": {}
"settings": {
},
"extensions": {
"recommendations": [
"ms-python.python",
"ms-python.debugpy",
"rust-lang.rust-analyzer",
"humao.rest-client"
]
}
}

View file

@ -5,6 +5,7 @@
"version": "0.2.0",
"configurations": [
{
"python": "${workspaceFolder}/venv/bin/python",
"name": "chatbot-ui",
"cwd": "${workspaceFolder}/app",
"type": "debugpy",

3
chatbot_ui/.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
}

View file

@ -2,14 +2,21 @@ import json
import os
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging as log
import logging
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
ARCH_STATE_HEADER = "x-arch-state"
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
client = OpenAI(
api_key="--",
@ -23,23 +30,19 @@ def predict(message, state):
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": message})
log.info("history: ", history)
log.info(f"history: {history}")
# Custom headers
custom_headers = {
"x-arch-deterministic-provider": "openai",
}
metadata = None
if "arch_state" in state:
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
try:
raw_response = client.chat.completions.with_raw_response.create(
model="--",
messages=history,
temperature=1.0,
metadata=metadata,
# metadata=metadata,
extra_headers=custom_headers,
)
except Exception as e:
@ -49,26 +52,35 @@ def predict(message, state):
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.info("raw_response: ", raw_response.text)
log.error(f"raw_response: {raw_response.text}")
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
response_json = json.loads(raw_response.text)
arch_state = None
if response_json:
metadata = response_json.get("metadata", {})
if metadata:
arch_state = metadata.get(ARCH_STATE_HEADER, None)
if arch_state:
state["arch_state"] = arch_state
# load arch_state from metadata
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
# parse arch_state into json object
arch_state = json.loads(arch_state_str)
# load messages from arch_state
arch_messages_str = arch_state.get("messages", "[]")
# parse messages into json object
arch_messages = json.loads(arch_messages_str)
# append messages from arch gateway to history
for message in arch_messages:
history.append(message)
content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model})
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
(history[i]["content"], history[i + 1]["content"])
for i in range(0, len(history) - 1, 2)
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
return messages, state

View file

@ -14,6 +14,7 @@ derivative = "2.2.0"
thiserror = "1.0.64"
tiktoken-rs = "0.5.9"
rand = "0.8.5"
serde_json = "1.0"
[dev-dependencies]
pretty_assertions = "1.4.1"

View file

@ -188,6 +188,8 @@ pub mod open_ai {
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -381,6 +383,7 @@ mod test {
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,

View file

@ -229,20 +229,6 @@ mod test {
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
assert_eq!(config.version, "v0.1");
let open_ai_provider = config
.llm_providers
.iter()
.find(|p| p.name.to_lowercase() == "openai")
.unwrap();
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
assert_eq!(
open_ai_provider.access_key,
Some("OPENAI_API_KEY".to_string())
);
assert_eq!(open_ai_provider.model, "gpt-4o");
assert_eq!(open_ai_provider.default, Some(true));
assert_eq!(open_ai_provider.stream, Some(true));
let prompt_guards = config.prompt_guards.as_ref().unwrap();
let input_guards = &prompt_guards.input_guards;
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();

View file

@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
@ -16,7 +18,7 @@ pub const GUARD_INTERNAL_HOST: &str = "guard";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const REQUEST_ID_HEADER: &str = "x-request-id";

View file

@ -0,0 +1,39 @@
use proxy_wasm::types::Status;
use crate::ratelimit;
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
HttpDispatch(ClientError),
#[error(transparent)]
Deserialization(serde_json::Error),
#[error(transparent)]
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
Upstream {
authority: String,
path: String,
status: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),
#[error("{why}")]
NoMessagesFound { why: String },
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
}

View file

@ -1,4 +1,4 @@
use crate::stats::{Gauge, IncrementingMetric};
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
use derivative::Derivative;
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> {
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
pub trait Client: Context {
type CallContext: Debug;

View file

@ -10,3 +10,4 @@ pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod errors;

View file

@ -228,6 +228,7 @@ dependencies = [
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",

View file

@ -1,4 +1,4 @@
use crate::llm_stream_context::LlmGatewayStreamContext;
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::http::Client;
use common::llm_providers::LlmProviders;
@ -28,19 +28,19 @@ impl WasmMetrics {
}
#[derive(Debug)]
pub struct FilterCallContext {}
pub struct CallContext {}
#[derive(Debug)]
pub struct LlmGatewayFilterContext {
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
}
impl LlmGatewayFilterContext {
pub fn new() -> LlmGatewayFilterContext {
LlmGatewayFilterContext {
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
llm_providers: None,
@ -48,8 +48,8 @@ impl LlmGatewayFilterContext {
}
}
impl Client for LlmGatewayFilterContext {
type CallContext = FilterCallContext;
impl Client for FilterContext {
type CallContext = CallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext {
}
}
impl Context for LlmGatewayFilterContext {}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for LlmGatewayFilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext {
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
Some(Box::new(LlmGatewayStreamContext::new(
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(

View file

@ -1,13 +1,13 @@
use llm_filter_context::LlmGatewayFilterContext;
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod llm_filter_context;
mod llm_stream_context;
mod filter_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(LlmGatewayFilterContext::new())
Box::new(FilterContext::new())
});
}}

View file

@ -1,4 +1,4 @@
use crate::llm_filter_context::WasmMetrics;
use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse,
Message, ToolCall, ToolCallState,
@ -8,6 +8,7 @@ use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::{ratelimit, routing, tokenizer};
@ -22,25 +23,12 @@ use std::rc::Rc;
use common::stats::IncrementingMetric;
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
Deserialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
}
pub struct LlmGatewayStreamContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext {
request_id: Option<String>,
}
impl LlmGatewayStreamContext {
#[allow(clippy::too_many_arguments)]
impl StreamContext {
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
LlmGatewayStreamContext {
StreamContext {
context_id,
metrics,
chat_completions_request: None,
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
@ -160,7 +146,7 @@ impl LlmGatewayStreamContext {
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for LlmGatewayStreamContext {
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext {
return Action::Continue;
}
self.request_body_size = body_size;
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
// remove metadata from the request body
deserialized_body.metadata = None;
@ -333,10 +316,9 @@ impl HttpContext for LlmGatewayStreamContext {
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
Err(_e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
return Action::Continue;
}
};
@ -418,4 +400,4 @@ impl HttpContext for LlmGatewayStreamContext {
}
}
impl Context for LlmGatewayStreamContext {}
impl Context for StreamContext {}

View file

@ -228,6 +228,7 @@ dependencies = [
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",

View file

@ -1,7 +1,7 @@
use crate::prompt_stream_context::PromptStreamContext;
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::embeddings::{
@ -9,7 +9,6 @@ use common::embeddings::{
};
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
@ -44,31 +43,27 @@ pub struct FilterCallContext {
}
#[derive(Debug)]
pub struct PromptGatewayFilterContext {
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
mode: GatewayMode,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
}
impl PromptGatewayFilterContext {
pub fn new() -> PromptGatewayFilterContext {
PromptGatewayFilterContext {
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
mode: GatewayMode::Prompt,
llm_providers: None,
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
}
@ -116,7 +111,7 @@ impl PromptGatewayFilterContext {
Duration::from_secs(60),
);
let call_context = crate::prompt_filter_context::FilterCallContext {
let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
@ -193,7 +188,7 @@ impl PromptGatewayFilterContext {
}
}
impl Client for PromptGatewayFilterContext {
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
@ -205,7 +200,7 @@ impl Client for PromptGatewayFilterContext {
}
}
impl Context for PromptGatewayFilterContext {
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -234,7 +229,7 @@ impl Context for PromptGatewayFilterContext {
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for PromptGatewayFilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
@ -253,17 +248,11 @@ impl RootContext for PromptGatewayFilterContext {
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true
}
@ -273,12 +262,11 @@ impl RootContext for PromptGatewayFilterContext {
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
let embedding_store = match self.mode {
GatewayMode::Llm => None,
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
let embedding_store = match self.embeddings_store.as_ref() {
None => return None,
Some(store) => Some(Rc::clone(store)),
};
Some(Box::new(PromptStreamContext::new(
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
@ -299,11 +287,8 @@ impl RootContext for PromptGatewayFilterContext {
}
fn on_tick(&mut self) {
debug!("starting up arch filter in mode: {:?}", self.mode);
if self.mode == GatewayMode::Prompt {
self.process_prompt_targets();
}
debug!("starting up arch filter in mode: prompt gateway mode");
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -1,13 +1,13 @@
use prompt_filter_context::PromptGatewayFilterContext;
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod prompt_filter_context;
mod prompt_stream_context;
mod filter_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(PromptGatewayFilterContext::new())
Box::new(FilterContext::new())
});
}}

View file

@ -1,9 +1,9 @@
use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
StreamOptions, ToolCall, ToolCallState, ToolType,
StreamOptions, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -12,23 +12,25 @@ use common::common_types::{
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::{CallArgs, Client, ClientError};
use common::errors::ClientError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use derivative::Derivative;
use common::stats::IncrementingMetric;
@ -43,11 +45,13 @@ enum ResponseHandlerType {
DefaultTarget,
}
#[derive(Debug, Clone)]
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct StreamCallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
request_body: ChatCompletionsRequest,
tool_calls: Option<Vec<ToolCall>>,
similarity_scores: Option<Vec<(String, f64)>>,
@ -65,11 +69,12 @@ pub enum ServerError {
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
Upstream {
authority: String,
host: String,
path: String,
status: String,
body: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),
@ -77,7 +82,7 @@ pub enum ServerError {
NoMessagesFound { why: String },
}
pub struct PromptStreamContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
@ -98,8 +103,7 @@ pub struct PromptStreamContext {
request_id: Option<String>,
}
impl PromptStreamContext {
#[allow(clippy::too_many_arguments)]
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
@ -109,7 +113,7 @@ impl PromptStreamContext {
overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
) -> Self {
PromptStreamContext {
StreamContext {
context_id,
metrics,
system_prompt,
@ -145,7 +149,6 @@ impl PromptStreamContext {
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
@ -160,6 +163,7 @@ impl PromptStreamContext {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
debug!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -230,6 +234,7 @@ impl PromptStreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing zero shot classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -259,6 +264,7 @@ impl PromptStreamContext {
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -272,6 +278,7 @@ impl PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
debug!("error deserializing hallucination response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -297,6 +304,7 @@ impl PromptStreamContext {
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
};
let chat_completion_response = ChatCompletionsResponse {
@ -335,6 +343,7 @@ impl PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
debug!("error deserializing zero shot classification response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -446,6 +455,7 @@ impl PromptStreamContext {
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
@ -461,6 +471,7 @@ impl PromptStreamContext {
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
debug!("prompt target not found: {}", prompt_target_name);
return self.send_server_error(
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
@ -533,6 +544,7 @@ impl PromptStreamContext {
msg_body
}
Err(e) => {
debug!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
@ -565,6 +577,7 @@ impl PromptStreamContext {
callout_context.prompt_target_name = Some(prompt_target.name);
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
@ -576,6 +589,7 @@ impl PromptStreamContext {
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
debug!("error deserializing arch_fc response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -689,6 +703,7 @@ impl PromptStreamContext {
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing hallucination classification request: {}", error);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -781,17 +796,19 @@ impl PromptStreamContext {
fn function_call_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
callout_context: StreamCallContext,
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
debug!("upstream error response: {}", http_status);
return self.send_server_error(
ServerError::Upstream {
authority: callout_context.upstream_cluster.unwrap(),
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status,
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
},
None,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
} else {
@ -823,11 +840,18 @@ impl PromptStreamContext {
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
messages.append(callout_context.request_body.messages.as_mut());
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
if m.role == TOOL_ROLE || m.content.is_none() {
continue;
}
messages.push(m.clone());
}
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -854,6 +878,7 @@ impl PromptStreamContext {
content: Some(final_prompt),
model: None,
tool_calls: None,
tool_call_id: None,
}
});
@ -889,6 +914,7 @@ impl PromptStreamContext {
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
debug!("jailbreak detected: {}", msg);
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
@ -912,6 +938,7 @@ impl PromptStreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
@ -948,6 +975,7 @@ impl PromptStreamContext {
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -981,6 +1009,7 @@ impl PromptStreamContext {
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
debug!("error deserializing default target response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -1000,6 +1029,7 @@ impl PromptStreamContext {
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
@ -1010,6 +1040,7 @@ impl PromptStreamContext {
content: Some(api_resp.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
@ -1027,7 +1058,7 @@ impl PromptStreamContext {
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for PromptStreamContext {
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -1090,7 +1121,6 @@ impl HttpContext for PromptStreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
@ -1256,9 +1286,8 @@ impl HttpContext for PromptStreamContext {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
return Action::Continue;
}
};
@ -1276,55 +1305,42 @@ impl HttpContext for PromptStreamContext {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
// get response, we will send these back to developer so they can see the api response
// and tool call arch-fc generated
let mut fc_messages = Vec::new();
fc_messages.push(Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
});
fc_messages.push(Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
});
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
@ -1342,7 +1358,7 @@ impl HttpContext for PromptStreamContext {
}
}
impl Context for PromptStreamContext {
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -1388,7 +1404,7 @@ impl Context for PromptStreamContext {
}
}
impl Client for PromptStreamContext {
impl Client for StreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {

View file

@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
@ -547,6 +546,7 @@ fn request_to_llm_gateway() {
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
@ -648,6 +648,7 @@ fn request_to_llm_gateway() {
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
@ -666,8 +667,6 @@ fn request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))

3
model_server/.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python",
}

View file

@ -13,62 +13,38 @@ logger = get_model_server_logger()
class Message(BaseModel):
role: str
content: str
content: str = ""
tool_calls: List[Dict[str, Any]] = []
tool_call_id: str = ""
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
# TODO: make it default none
metadata: Dict[str, str] = {}
def process_state(arch_state, history: list[Message]):
logger.info("state: {}".format(arch_state))
state_json = json.loads(arch_state)
state_map = {}
if state_json:
for tools_state in state_json:
for tool_state in tools_state:
state_map[tool_state["key"]] = tool_state
logger.info(f"state_map: {json.dumps(state_map)}")
sha_history = []
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
updated_history.append({"role": hist.role, "content": hist.content})
if hist.role == "user":
sha_history.append(hist.content)
sha256_hash = hashlib.sha256()
joined_key_str = ("#.#").join(sha_history)
sha256_hash.update(joined_key_str.encode())
sha_key = sha256_hash.hexdigest()
logger.info(f"sha_key: {sha_key}")
if sha_key in state_map:
tool_call_state = state_map[sha_key]
if "tool_call" in tool_call_state:
tool_call_str = json.dumps(tool_call_state["tool_call"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
if "tool_response" in tool_call_state:
tool_resp = tool_call_state["tool_response"]
# TODO: try with role = user as well
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{tool_resp}\n</tool_response>",
}
)
# we dont want to match this state with any other messages
del state_map[sha_key]
if hist.tool_calls:
if len(hist.tool_calls) > 1:
raise ValueError("Only one tool call is supported")
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
elif hist.role == "tool":
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
}
)
else:
updated_history.append({"role": hist.role, "content": hist.content})
return updated_history
@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response):
messages = [{"role": "system", "content": tools_encoded}]
metadata = req.metadata
arch_state = metadata.get("x-arch-state", "[]")
updated_history = process_state(arch_state, req.messages)
updated_history = process_messages(req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})

View file

@ -0,0 +1,66 @@
from typing import List
import pytest
import json
from app.function_calling.model_utils import Message, process_messages
test_input_history = """
[
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
},
{
"role": "tool",
"content": "--",
"tool_call_id": "call_3394"
},
{
"role": "assistant",
"content": "--",
"model": "gpt-3.5-turbo-0125"
},
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
}
]
"""
def test_update_fc_history():
history = json.loads(test_input_history)
message_history = []
for h in history:
message_history.append(Message(**h))
updated_history = process_messages(message_history)
assert len(updated_history) == 6
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])