mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Remove unnecessary clones (#26)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com> Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
parent
c13682a03b
commit
b8ea65d858
2 changed files with 25 additions and 23 deletions
|
|
@ -80,6 +80,7 @@ impl FilterContext {
|
|||
};
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
// Need to clone prompt target to leave config string intact.
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
if self
|
||||
|
|
@ -106,7 +107,7 @@ impl FilterContext {
|
|||
) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
|
|
@ -127,7 +128,7 @@ impl FilterContext {
|
|||
points: vec![VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
vector: embedding_response.data.remove(0).embedding,
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ enum RequestType {
|
|||
|
||||
pub struct CallContext {
|
||||
request_type: RequestType,
|
||||
user_message: String,
|
||||
user_message: Option<String>,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: ChatCompletions,
|
||||
}
|
||||
|
|
@ -160,12 +160,11 @@ impl StreamContext {
|
|||
|
||||
// only extract entity names
|
||||
let entity_names = get_entity_details(&prompt_target)
|
||||
.iter()
|
||||
.map(|entity| entity.name.clone())
|
||||
.into_iter()
|
||||
.map(|entity| entity.name)
|
||||
.collect();
|
||||
let user_message = callout_context.user_message.clone();
|
||||
let ner_request = NERRequest {
|
||||
input: user_message,
|
||||
input: callout_context.user_message.take().unwrap(),
|
||||
labels: entity_names,
|
||||
model: DEFAULT_NER_MODEL.to_string(),
|
||||
};
|
||||
|
|
@ -216,7 +215,7 @@ impl StreamContext {
|
|||
info!("ner_response: {:?}", ner_response);
|
||||
|
||||
let mut request_params: HashMap<String, String> = HashMap::new();
|
||||
for entity in ner_response.data.iter() {
|
||||
for entity in ner_response.data.into_iter() {
|
||||
if entity.score < DEFAULT_NER_THRESHOLD {
|
||||
warn!(
|
||||
"score of entity was too low entity name: {}, score: {}",
|
||||
|
|
@ -224,7 +223,7 @@ impl StreamContext {
|
|||
);
|
||||
continue;
|
||||
}
|
||||
request_params.insert(entity.label.clone(), entity.text.clone());
|
||||
request_params.insert(entity.label, entity.text);
|
||||
}
|
||||
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
|
|
@ -240,8 +239,8 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
let req_param_str = match serde_json::to_string(&request_params) {
|
||||
Ok(req_param_str) => req_param_str,
|
||||
let req_param_bytes = match serde_json::to_string(&request_params) {
|
||||
Ok(req_param_str) => req_param_str.as_bytes().to_owned(),
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_params: {:?}", e);
|
||||
self.resume_http_request();
|
||||
|
|
@ -258,25 +257,25 @@ impl StreamContext {
|
|||
.unwrap();
|
||||
|
||||
let http_path = match &endpoint.path {
|
||||
Some(path) => path.clone(),
|
||||
None => "/".to_string(),
|
||||
Some(path) => path,
|
||||
None => "/",
|
||||
};
|
||||
|
||||
let http_method = match &endpoint.method {
|
||||
Some(method) => method.clone(),
|
||||
None => http::Method::POST.to_string(),
|
||||
Some(method) => method,
|
||||
None => http::Method::POST.as_str(),
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster.clone(),
|
||||
&endpoint.cluster,
|
||||
vec![
|
||||
(":method", http_method.as_str()),
|
||||
(":path", http_path.as_str()),
|
||||
(":method", http_method),
|
||||
(":path", http_path),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(req_param_str.as_bytes()),
|
||||
Some(&req_param_bytes),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
|
|
@ -361,7 +360,8 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
|
||||
let mut deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size)
|
||||
{
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
|
|
@ -389,8 +389,8 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.as_ref())
|
||||
.pop()
|
||||
.and_then(|last_message| last_message.content)
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
|
|
@ -400,6 +400,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
|
|
@ -437,7 +438,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
user_message: user_message.clone(),
|
||||
user_message: Some(user_message),
|
||||
prompt_target: None,
|
||||
request_body: deserialized_body,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue