Remove unnecessary clones (#26)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
José Ulises Niño Rivera 2024-07-31 11:48:34 -07:00 committed by GitHub
parent c13682a03b
commit b8ea65d858
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 23 deletions

View file

@ -80,6 +80,7 @@ impl FilterContext {
};
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
// Need to clone prompt target to leave config string intact.
prompt_target: prompt_target.clone(),
};
if self
@ -106,7 +107,7 @@ impl FilterContext {
) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: CreateEmbeddingResponse =
let mut embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
let mut payload: HashMap<String, String> = HashMap::new();
@ -127,7 +128,7 @@ impl FilterContext {
points: vec![VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
vector: embedding_response.data.remove(0).embedding,
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors

View file

@ -30,7 +30,7 @@ enum RequestType {
pub struct CallContext {
request_type: RequestType,
user_message: String,
user_message: Option<String>,
prompt_target: Option<PromptTarget>,
request_body: ChatCompletions,
}
@ -160,12 +160,11 @@ impl StreamContext {
// only extract entity names
let entity_names = get_entity_details(&prompt_target)
.iter()
.map(|entity| entity.name.clone())
.into_iter()
.map(|entity| entity.name)
.collect();
let user_message = callout_context.user_message.clone();
let ner_request = NERRequest {
input: user_message,
input: callout_context.user_message.take().unwrap(),
labels: entity_names,
model: DEFAULT_NER_MODEL.to_string(),
};
@ -216,7 +215,7 @@ impl StreamContext {
info!("ner_response: {:?}", ner_response);
let mut request_params: HashMap<String, String> = HashMap::new();
for entity in ner_response.data.iter() {
for entity in ner_response.data.into_iter() {
if entity.score < DEFAULT_NER_THRESHOLD {
warn!(
"score of entity was too low entity name: {}, score: {}",
@ -224,7 +223,7 @@ impl StreamContext {
);
continue;
}
request_params.insert(entity.label.clone(), entity.text.clone());
request_params.insert(entity.label, entity.text);
}
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
@ -240,8 +239,8 @@ impl StreamContext {
}
}
let req_param_str = match serde_json::to_string(&request_params) {
Ok(req_param_str) => req_param_str,
let req_param_bytes = match serde_json::to_string(&request_params) {
Ok(req_param_str) => req_param_str.as_bytes().to_owned(),
Err(e) => {
warn!("Error serializing request_params: {:?}", e);
self.resume_http_request();
@ -258,25 +257,25 @@ impl StreamContext {
.unwrap();
let http_path = match &endpoint.path {
Some(path) => path.clone(),
None => "/".to_string(),
Some(path) => path,
None => "/",
};
let http_method = match &endpoint.method {
Some(method) => method.clone(),
None => http::Method::POST.to_string(),
Some(method) => method,
None => http::Method::POST.as_str(),
};
let token_id = match self.dispatch_http_call(
&endpoint.cluster.clone(),
&endpoint.cluster,
vec![
(":method", http_method.as_str()),
(":path", http_path.as_str()),
(":method", http_method),
(":path", http_path),
(":authority", endpoint.cluster.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(req_param_str.as_bytes()),
Some(&req_param_bytes),
vec![],
Duration::from_secs(5),
) {
@ -361,7 +360,8 @@ impl HttpContext for StreamContext {
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
let mut deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size)
{
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
@ -389,8 +389,8 @@ impl HttpContext for StreamContext {
let user_message = match deserialized_body
.messages
.last()
.and_then(|last_message| last_message.content.as_ref())
.pop()
.and_then(|last_message| last_message.content)
{
Some(content) => content,
None => {
@ -400,6 +400,7 @@ impl HttpContext for StreamContext {
};
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
@ -437,7 +438,7 @@ impl HttpContext for StreamContext {
};
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
user_message: user_message.clone(),
user_message: Some(user_message),
prompt_target: None,
request_body: deserialized_body,
};