Clean up Embeddings Store (#121)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-04 19:33:52 -07:00 committed by GitHub
parent 10b5c5b42c
commit 2a9b9486f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 366 additions and 238 deletions

View file

@ -4,7 +4,7 @@ use crate::consts::{
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
@ -31,7 +31,6 @@ use public_types::embeddings::{
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::RwLock;
use std::time::Duration;
enum ResponseHandlerType {
@ -56,7 +55,8 @@ pub struct CallContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
ratelimit_selector: Option<Header>,
@ -72,15 +72,17 @@ impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
embeddings_store: Rc<EmbeddingsStore>,
) -> Self {
StreamContext {
context_id,
metrics,
prompt_targets,
embeddings_store,
callouts: HashMap::new(),
ratelimit_selector: None,
streaming_response: false,
@ -174,35 +176,21 @@ impl StreamContext {
prompt_embeddings_vector.len()
);
let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings,
Err(e) => {
return self
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
}
};
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
return;
}
};
let prompt_target_names = prompt_targets
let prompt_target_names = self
.prompt_targets
.iter()
// exclude default target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(name, _)| name.clone())
.collect();
let similarity_scores: Vec<(String, f64)> = prompt_targets
let similarity_scores: Vec<(String, f64)> = self
.prompt_targets
.iter()
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let pte = match prompt_target_embeddings.get(prompt_name) {
let pte = match self.embeddings_store.get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
@ -373,8 +361,6 @@ impl StreamContext {
debug!("checking for default prompt target");
if let Some(default_prompt_target) = self
.prompt_targets
.read()
.unwrap()
.values()
.find(|pt| pt.default.unwrap_or(false))
{
@ -429,7 +415,7 @@ impl StreamContext {
}
}
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
return self.send_server_error(
@ -441,7 +427,7 @@ impl StreamContext {
info!("prompt_target name: {:?}", prompt_target_name);
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
for pt in self.prompt_targets.values() {
// only extract entity names
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
@ -592,13 +578,7 @@ impl StreamContext {
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&tools_call_name)
.unwrap()
.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
debug!("prompt_target_name: {}", prompt_target.name);
debug!("tool_name(s): {:?}", tool_names);
@ -660,8 +640,6 @@ impl StreamContext {
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
@ -832,8 +810,6 @@ impl StreamContext {
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();