Clean up Embeddings Store (#121)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-04 19:33:52 -07:00 committed by GitHub
parent 10b5c5b42c
commit 2a9b9486f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 366 additions and 238 deletions

View file

@ -1,7 +1,8 @@
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
use crate::http::{CallArgs, Client};
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stats::{Counter, Gauge, IncrementingMetric};
use crate::stream_context::StreamContext;
use log::debug;
use proxy_wasm::traits::*;
@ -11,10 +12,10 @@ use public_types::configuration::{Configuration, Overrides, PromptGuards, Prompt
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{OnceLock, RwLock};
use std::time::Duration;
#[derive(Copy, Clone, Debug)]
@ -32,102 +33,74 @@ impl WasmMetrics {
}
}
#[derive(Debug)]
struct CallContext {
prompt_target: String,
embedding_type: EmbeddingType,
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
#[derive(Debug)]
pub struct FilterCallContext {
pub prompt_target_name: String,
pub embedding_type: EmbeddingType,
}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
EMBEDDINGS.get_or_init(|| {
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
RwLock::new(embeddings)
})
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
llm_providers: None,
embeddings_store: None,
temp_embeddings_store: HashMap::new(),
}
}
fn process_prompt_targets(&mut self) {
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
panic!("Error reading prompt targets: {:?}", e);
}
};
for values in prompt_targets.iter() {
let prompt_target = &values.1;
// schedule embeddings call for prompt target name
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Name,
}
})
.is_some()
{
panic!("duplicate token_id")
}
// schedule embeddings call for prompt target description
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Description,
}
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.name,
EmbeddingType::Name,
);
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,
EmbeddingType::Description,
);
}
}
fn schedule_embeddings_call(&self, input: String) -> u32 {
fn schedule_embeddings_call(
&self,
prompt_target_name: &str,
input: &str,
embedding_type: EmbeddingType,
) {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(input)),
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = serde_json::to_string(&embeddings_input).unwrap();
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
let call_args = CallArgs::new(
MODEL_SERVER_NAME,
vec![
(":method", "POST"),
@ -139,16 +112,16 @@ impl FilterContext {
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching HTTP call: {}, error: {:?}",
MODEL_SERVER_NAME, e
);
}
);
let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
token_id
if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}")
}
}
fn embedding_response_handler(
@ -157,40 +130,79 @@ impl FilterContext {
embedding_type: EmbeddingType,
prompt_target_name: String,
) {
let prompt_targets = self.prompt_targets.read().unwrap();
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap_or_else(|| {
panic!(
"Received embeddings response for unknown prompt target name={}",
prompt_target_name
)
});
let embeddings = embedding_response.data.remove(0).embedding;
log::info!(
let body = self
.get_http_call_response_body(0, body_size)
.expect("No body in response");
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let embeddings = embedding_response.data.remove(0).embedding;
debug!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
embeddings_store().write().unwrap().insert(
prompt_target.name.clone(),
HashMap::from([(embedding_type, embeddings)]),
);
let entry = self.temp_embeddings_store.entry(prompt_target_name);
match entry {
Entry::Occupied(_) => {
entry.and_modify(|e| {
if let Entry::Vacant(e) = e.entry(embedding_type) {
e.insert(embeddings);
} else {
panic!(
"Duplicate {:?} for prompt target with name=\"{}\"",
&embedding_type, prompt_target.name
)
}
});
}
Entry::Vacant(_) => {
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
}
}
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
self.embeddings_store =
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
}
} else {
panic!("No body in response");
}
}
}
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &Gauge {
&self.metrics.active_http_calls
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
@ -203,16 +215,18 @@ impl Context for FilterContext {
"filter_context: on_http_call_response called with token_id: {:?}",
token_id
);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
let callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
self.metrics.active_http_calls.increment(-1);
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target,
callout_data.prompt_target_name,
)
}
}
@ -231,12 +245,11 @@ impl RootContext for FilterContext {
self.overrides = Rc::new(config.overrides);
let mut prompt_targets = HashMap::new();
for pt in config.prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.prompt_targets = Rc::new(prompt_targets);
ratelimit::ratelimits(config.ratelimits);
@ -257,6 +270,10 @@ impl RootContext for FilterContext {
"||| create_http_context called with context_id: {:?} |||",
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
self.embeddings_store.as_ref()?;
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
@ -268,6 +285,11 @@ impl RootContext for FilterContext {
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
Rc::clone(
self.embeddings_store
.as_ref()
.expect("Embeddings Store must exist when StreamContext is being constructed"),
),
)))
}

84
arch/src/http.rs Normal file
View file

@ -0,0 +1,84 @@
use crate::stats::{Gauge, IncrementingMetric};
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
#[derive(Debug)]
pub struct CallArgs<'a> {
upstream: &'a str,
headers: Vec<(&'a str, &'a str)>,
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
}
impl<'a> CallArgs<'a> {
pub fn new(
upstream: &'a str,
headers: Vec<(&'a str, &'a str)>,
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
) -> Self {
CallArgs {
upstream,
headers,
body,
trailers,
timeout,
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
internal_status: Status,
},
}
pub trait Client: Context {
type CallContext: Debug;
fn http_call(
&self,
call_args: CallArgs,
call_context: Self::CallContext,
) -> Result<(), ClientError> {
debug!(
"dispatching http call with args={:?} context={:?}",
call_args, call_context
);
match self.dispatch_http_call(
call_args.upstream,
call_args.headers,
call_args.body,
call_args.trailers,
call_args.timeout,
) {
Ok(id) => {
self.add_call_context(id, call_context);
Ok(())
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
internal_status: status.clone(),
}),
}
}
fn add_call_context(&self, id: u32, call_context: Self::CallContext) {
let callouts = self.callouts();
if callouts.borrow_mut().insert(id, call_context).is_some() {
panic!("Duplicate http call with id={}", id);
}
self.active_http_calls().increment(1);
}
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>>;
fn active_http_calls(&self) -> &Gauge;
}

View file

@ -4,6 +4,7 @@ use proxy_wasm::types::*;
mod consts;
mod filter_context;
mod http;
mod llm_providers;
mod ratelimit;
mod routing;

View file

@ -18,7 +18,7 @@ impl LlmProviders {
}
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).map(|rc| rc.clone())
self.providers.get(name).cloned()
}
}

View file

@ -24,6 +24,7 @@ pub trait IncrementingMetric: Metric {
}
}
#[allow(unused)]
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {

View file

@ -4,7 +4,7 @@ use crate::consts::{
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
@ -31,7 +31,6 @@ use public_types::embeddings::{
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::RwLock;
use std::time::Duration;
enum ResponseHandlerType {
@ -56,7 +55,8 @@ pub struct CallContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
ratelimit_selector: Option<Header>,
@ -72,15 +72,17 @@ impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
embeddings_store: Rc<EmbeddingsStore>,
) -> Self {
StreamContext {
context_id,
metrics,
prompt_targets,
embeddings_store,
callouts: HashMap::new(),
ratelimit_selector: None,
streaming_response: false,
@ -174,35 +176,21 @@ impl StreamContext {
prompt_embeddings_vector.len()
);
let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings,
Err(e) => {
return self
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
}
};
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
return;
}
};
let prompt_target_names = prompt_targets
let prompt_target_names = self
.prompt_targets
.iter()
// exclude default target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(name, _)| name.clone())
.collect();
let similarity_scores: Vec<(String, f64)> = prompt_targets
let similarity_scores: Vec<(String, f64)> = self
.prompt_targets
.iter()
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let pte = match prompt_target_embeddings.get(prompt_name) {
let pte = match self.embeddings_store.get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
@ -373,8 +361,6 @@ impl StreamContext {
debug!("checking for default prompt target");
if let Some(default_prompt_target) = self
.prompt_targets
.read()
.unwrap()
.values()
.find(|pt| pt.default.unwrap_or(false))
{
@ -429,7 +415,7 @@ impl StreamContext {
}
}
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
return self.send_server_error(
@ -441,7 +427,7 @@ impl StreamContext {
info!("prompt_target name: {:?}", prompt_target_name);
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
for pt in self.prompt_targets.values() {
// only extract entity names
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
@ -592,13 +578,7 @@ impl StreamContext {
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&tools_call_name)
.unwrap()
.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
debug!("prompt_target_name: {}", prompt_target.name);
debug!("tool_name(s): {:?}", tool_names);
@ -660,8 +640,6 @@ impl StreamContext {
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
@ -832,8 +810,6 @@ impl StreamContext {
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();