diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index 963e7719..52108680 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -80,6 +80,7 @@ impl FilterContext { }; let embedding_request = EmbeddingRequest { create_embedding_request: embeddings_input, + // Need to clone prompt target to leave config string intact. prompt_target: prompt_target.clone(), }; if self @@ -106,7 +107,7 @@ impl FilterContext { ) { if let Some(body) = self.get_http_call_response_body(0, body_size) { if !body.is_empty() { - let embedding_response: CreateEmbeddingResponse = + let mut embedding_response: CreateEmbeddingResponse = serde_json::from_slice(&body).unwrap(); let mut payload: HashMap = HashMap::new(); @@ -127,7 +128,7 @@ impl FilterContext { points: vec![VectorPoint { id: format!("{:x}", id.unwrap()), payload, - vector: embedding_response.data[0].embedding.clone(), + vector: embedding_response.data.remove(0).embedding, }], }; let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 9dfb7076..c119d234 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -30,7 +30,7 @@ enum RequestType { pub struct CallContext { request_type: RequestType, - user_message: String, + user_message: Option, prompt_target: Option, request_body: ChatCompletions, } @@ -160,12 +160,11 @@ impl StreamContext { // only extract entity names let entity_names = get_entity_details(&prompt_target) - .iter() - .map(|entity| entity.name.clone()) + .into_iter() + .map(|entity| entity.name) .collect(); - let user_message = callout_context.user_message.clone(); let ner_request = NERRequest { - input: user_message, + input: callout_context.user_message.take().unwrap(), labels: entity_names, model: DEFAULT_NER_MODEL.to_string(), }; @@ -216,7 +215,7 @@ impl StreamContext { info!("ner_response: {:?}", ner_response); let mut request_params: HashMap = HashMap::new(); - for entity in ner_response.data.iter() { + for entity in ner_response.data.into_iter() { if entity.score < DEFAULT_NER_THRESHOLD { warn!( "score of entity was too low entity name: {}, score: {}", @@ -224,7 +223,7 @@ impl StreamContext { ); continue; } - request_params.insert(entity.label.clone(), entity.text.clone()); + request_params.insert(entity.label, entity.text); } let prompt_target = callout_context.prompt_target.as_ref().unwrap(); @@ -240,8 +239,8 @@ impl StreamContext { } } - let req_param_str = match serde_json::to_string(&request_params) { - Ok(req_param_str) => req_param_str, + let req_param_bytes = match serde_json::to_string(&request_params) { + Ok(req_param_str) => req_param_str.as_bytes().to_owned(), Err(e) => { warn!("Error serializing request_params: {:?}", e); self.resume_http_request(); @@ -258,25 +257,25 @@ impl StreamContext { .unwrap(); let http_path = match &endpoint.path { - Some(path) => path.clone(), - None => "/".to_string(), + Some(path) => path, + None => "/", }; let http_method = match &endpoint.method { - Some(method) => method.clone(), - None => http::Method::POST.to_string(), + Some(method) => method, + None => http::Method::POST.as_str(), }; let token_id = match self.dispatch_http_call( - &endpoint.cluster.clone(), + &endpoint.cluster, vec![ - (":method", http_method.as_str()), - (":path", http_path.as_str()), + (":method", http_method), + (":path", http_path), (":authority", endpoint.cluster.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), ], - Some(req_param_str.as_bytes()), + Some(&req_param_bytes), vec![], Duration::from_secs(5), ) { @@ -361,7 +360,8 @@ impl HttpContext for StreamContext { // Deserialize body into spec. // Currently OpenAI API. - let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) { + let mut deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) + { Some(body_bytes) => match serde_json::from_slice(&body_bytes) { Ok(deserialized) => deserialized, Err(msg) => { @@ -389,8 +389,8 @@ impl HttpContext for StreamContext { let user_message = match deserialized_body .messages - .last() - .and_then(|last_message| last_message.content.as_ref()) + .pop() + .and_then(|last_message| last_message.content) { Some(content) => content, None => { @@ -400,6 +400,7 @@ impl HttpContext for StreamContext { }; let get_embeddings_input = CreateEmbeddingRequest { + // Need to clone into input because user_message is used below. input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), model: String::from(DEFAULT_EMBEDDING_MODEL), encoding_format: None, @@ -437,7 +438,7 @@ impl HttpContext for StreamContext { }; let call_context = CallContext { request_type: RequestType::GetEmbedding, - user_message: user_message.clone(), + user_message: Some(user_message), prompt_target: None, request_body: deserialized_body, };