mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
Clean up Embeddings Store (#121)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
10b5c5b42c
commit
2a9b9486f3
8 changed files with 366 additions and 238 deletions
|
|
@ -1,7 +1,8 @@
|
|||
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
|
||||
use crate::http::{CallArgs, Client};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit;
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stats::{Counter, Gauge, IncrementingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
|
|
@ -11,10 +12,10 @@ use public_types::configuration::{Configuration, Overrides, PromptGuards, Prompt
|
|||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
|
|
@ -32,102 +33,74 @@ impl WasmMetrics {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CallContext {
|
||||
prompt_target: String,
|
||||
embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {
|
||||
pub prompt_target_name: String,
|
||||
pub embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
|
||||
EMBEDDINGS.get_or_init(|| {
|
||||
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
|
||||
RwLock::new(embeddings)
|
||||
})
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
llm_providers: None,
|
||||
embeddings_store: None,
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
panic!("Error reading prompt targets: {:?}", e);
|
||||
}
|
||||
};
|
||||
for values in prompt_targets.iter() {
|
||||
let prompt_target = &values.1;
|
||||
|
||||
// schedule embeddings call for prompt target name
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Name,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
// schedule embeddings call for prompt target description
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Description,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
fn process_prompt_targets(&self) {
|
||||
for values in self.prompt_targets.iter() {
|
||||
let prompt_target = values.1;
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.name,
|
||||
EmbeddingType::Name,
|
||||
);
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.description,
|
||||
EmbeddingType::Description,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(&self, input: String) -> u32 {
|
||||
fn schedule_embeddings_call(
|
||||
&self,
|
||||
prompt_target_name: &str,
|
||||
input: &str,
|
||||
embedding_type: EmbeddingType,
|
||||
) {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(input)),
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
let json_data = serde_json::to_string(&embeddings_input).unwrap();
|
||||
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let call_args = CallArgs::new(
|
||||
MODEL_SERVER_NAME,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
|
|
@ -139,16 +112,16 @@ impl FilterContext {
|
|||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching HTTP call: {}, error: {:?}",
|
||||
MODEL_SERVER_NAME, e
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
token_id
|
||||
|
||||
if let Err(error) = self.http_call(call_args, call_context) {
|
||||
panic!("{error}")
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
|
|
@ -157,40 +130,79 @@ impl FilterContext {
|
|||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_targets = self.prompt_targets.read().unwrap();
|
||||
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Received embeddings response for unknown prompt target name={}",
|
||||
prompt_target_name
|
||||
)
|
||||
});
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
log::info!(
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.expect("No body in response");
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
debug!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
embeddings_store().write().unwrap().insert(
|
||||
prompt_target.name.clone(),
|
||||
HashMap::from([(embedding_type, embeddings)]),
|
||||
);
|
||||
let entry = self.temp_embeddings_store.entry(prompt_target_name);
|
||||
match entry {
|
||||
Entry::Occupied(_) => {
|
||||
entry.and_modify(|e| {
|
||||
if let Entry::Vacant(e) = e.entry(embedding_type) {
|
||||
e.insert(embeddings);
|
||||
} else {
|
||||
panic!(
|
||||
"Duplicate {:?} for prompt target with name=\"{}\"",
|
||||
&embedding_type, prompt_target.name
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
|
||||
self.embeddings_store =
|
||||
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
|
||||
}
|
||||
} else {
|
||||
panic!("No body in response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
|
|
@ -203,16 +215,18 @@ impl Context for FilterContext {
|
|||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target,
|
||||
callout_data.prompt_target_name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -231,12 +245,11 @@ impl RootContext for FilterContext {
|
|||
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
let mut prompt_targets = HashMap::new();
|
||||
for pt in config.prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
prompt_targets.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
|
||||
ratelimit::ratelimits(config.ratelimits);
|
||||
|
||||
|
|
@ -257,6 +270,10 @@ impl RootContext for FilterContext {
|
|||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
self.embeddings_store.as_ref()?;
|
||||
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
|
|
@ -268,6 +285,11 @@ impl RootContext for FilterContext {
|
|||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
Rc::clone(
|
||||
self.embeddings_store
|
||||
.as_ref()
|
||||
.expect("Embeddings Store must exist when StreamContext is being constructed"),
|
||||
),
|
||||
)))
|
||||
}
|
||||
|
||||
|
|
|
|||
84
arch/src/http.rs
Normal file
84
arch/src/http.rs
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
use crate::stats::{Gauge, IncrementingMetric};
|
||||
use log::debug;
|
||||
use proxy_wasm::{traits::Context, types::Status};
|
||||
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CallArgs<'a> {
|
||||
upstream: &'a str,
|
||||
headers: Vec<(&'a str, &'a str)>,
|
||||
body: Option<&'a [u8]>,
|
||||
trailers: Vec<(&'a str, &'a str)>,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl<'a> CallArgs<'a> {
|
||||
pub fn new(
|
||||
upstream: &'a str,
|
||||
headers: Vec<(&'a str, &'a str)>,
|
||||
body: Option<&'a [u8]>,
|
||||
trailers: Vec<(&'a str, &'a str)>,
|
||||
timeout: Duration,
|
||||
) -> Self {
|
||||
CallArgs {
|
||||
upstream,
|
||||
headers,
|
||||
body,
|
||||
trailers,
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("Error dispatching HTTP call to `{upstream_name}`, error: {internal_status:?}")]
|
||||
DispatchError {
|
||||
upstream_name: String,
|
||||
internal_status: Status,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Client: Context {
|
||||
type CallContext: Debug;
|
||||
|
||||
fn http_call(
|
||||
&self,
|
||||
call_args: CallArgs,
|
||||
call_context: Self::CallContext,
|
||||
) -> Result<(), ClientError> {
|
||||
debug!(
|
||||
"dispatching http call with args={:?} context={:?}",
|
||||
call_args, call_context
|
||||
);
|
||||
|
||||
match self.dispatch_http_call(
|
||||
call_args.upstream,
|
||||
call_args.headers,
|
||||
call_args.body,
|
||||
call_args.trailers,
|
||||
call_args.timeout,
|
||||
) {
|
||||
Ok(id) => {
|
||||
self.add_call_context(id, call_context);
|
||||
Ok(())
|
||||
}
|
||||
Err(status) => Err(ClientError::DispatchError {
|
||||
upstream_name: String::from(call_args.upstream),
|
||||
internal_status: status.clone(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_call_context(&self, id: u32, call_context: Self::CallContext) {
|
||||
let callouts = self.callouts();
|
||||
if callouts.borrow_mut().insert(id, call_context).is_some() {
|
||||
panic!("Duplicate http call with id={}", id);
|
||||
}
|
||||
self.active_http_calls().increment(1);
|
||||
}
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>>;
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge;
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ use proxy_wasm::types::*;
|
|||
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod http;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ impl LlmProviders {
|
|||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).map(|rc| rc.clone())
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ pub trait IncrementingMetric: Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::consts::{
|
|||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
|
|
@ -31,7 +31,6 @@ use public_types::embeddings::{
|
|||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
enum ResponseHandlerType {
|
||||
|
|
@ -56,7 +55,8 @@ pub struct CallContext {
|
|||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
|
|
@ -72,15 +72,17 @@ impl StreamContext {
|
|||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
callouts: HashMap::new(),
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
|
|
@ -174,35 +176,21 @@ impl StreamContext {
|
|||
prompt_embeddings_vector.len()
|
||||
);
|
||||
|
||||
let prompt_target_embeddings = match embeddings_store().read() {
|
||||
Ok(embeddings) => embeddings,
|
||||
Err(e) => {
|
||||
return self
|
||||
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_target_names = prompt_targets
|
||||
let prompt_target_names = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
// exclude default target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
||||
let similarity_scores: Vec<(String, f64)> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let pte = match prompt_target_embeddings.get(prompt_name) {
|
||||
let pte = match self.embeddings_store.get(prompt_name) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
|
|
@ -373,8 +361,6 @@ impl StreamContext {
|
|||
debug!("checking for default prompt target");
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
|
|
@ -429,7 +415,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
|
||||
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
|
||||
Some(prompt_target) => prompt_target.clone(),
|
||||
None => {
|
||||
return self.send_server_error(
|
||||
|
|
@ -441,7 +427,7 @@ impl StreamContext {
|
|||
|
||||
info!("prompt_target name: {:?}", prompt_target_name);
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
for pt in self.prompt_targets.values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
|
|
@ -592,13 +578,7 @@ impl StreamContext {
|
|||
let tools_call_name = tool_calls[0].function.name.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&tools_call_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
|
||||
debug!("prompt_target_name: {}", prompt_target.name);
|
||||
debug!("tool_name(s): {:?}", tool_names);
|
||||
|
|
@ -660,8 +640,6 @@ impl StreamContext {
|
|||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
|
@ -832,8 +810,6 @@ impl StreamContext {
|
|||
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue