mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
even more fixes
This commit is contained in:
parent
6e2beaa968
commit
09e0f80204
3 changed files with 94 additions and 94 deletions
|
|
@ -446,95 +446,91 @@ impl StreamContext {
|
|||
// it may be that arch fc is handling the conversation for parameter collection
|
||||
if arch_assistant {
|
||||
info!("arch fc is engaged in parameter collection");
|
||||
} else {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!(
|
||||
"default prompt target found, forwarding request to default prompt target"
|
||||
);
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
} else if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"archgw (low similarity score) => llm request: {}",
|
||||
llm_request_str
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"archgw (low similarity score) => llm request: {}",
|
||||
llm_request_str
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -959,17 +955,22 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
// add system prompt
|
||||
|
||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(prompt_target_name) => {
|
||||
self.prompt_targets
|
||||
let prompt_system_prompt = self
|
||||
.prompt_targets
|
||||
.get(prompt_target_name)
|
||||
.unwrap()
|
||||
.clone()
|
||||
.system_prompt
|
||||
.system_prompt;
|
||||
match prompt_system_prompt {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt),
|
||||
}
|
||||
}
|
||||
};
|
||||
if system_prompt.is_some() {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ llm_providers:
|
|||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
|
|
@ -26,8 +26,7 @@ endpoints:
|
|||
system_prompt: |
|
||||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding,
|
||||
- when you get data in json format, offer some summary but don't be too verbose
|
||||
- be concise and to the point
|
||||
- if you don't have data, say so and offer to help with something else
|
||||
- be concise, to the point and do not over analyze the data
|
||||
|
||||
prompt_targets:
|
||||
- name: workforce
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ with open("workforce_data.json") as file:
|
|||
|
||||
|
||||
# Define the request model
|
||||
class WorkforceRequset(BaseModel):
|
||||
class WorkforceRequest(BaseModel):
|
||||
region: str
|
||||
staffing_type: str
|
||||
data_snapshot_days_ago: Optional[int] = None
|
||||
|
|
@ -74,7 +74,7 @@ def send_slack_message(request: SlackRequest):
|
|||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequset):
|
||||
def get_workforce(request: WorkforceRequest):
|
||||
"""
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
|
|
@ -90,7 +90,7 @@ def get_workforce(request: WorkforceRequset):
|
|||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
"satisfaction": f"Satisfaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue