Use large github action machine to run e2e tests (#230)

This commit is contained in:
Adil Hafeez 2024-10-30 17:54:51 -07:00 committed by GitHub
parent bb882fb59b
commit e462e393b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 4725 additions and 441 deletions

View file

@ -446,61 +446,89 @@ impl StreamContext {
// it may be that arch fc is handling the conversation for parameter collection
if arch_assistant {
info!("arch fc is engaged in parameter collection");
} else {
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!(
"default prompt target found, forwarding request to default prompt target"
);
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
} else if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
} else {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
let messages = self.filter_out_arch_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!(
"archgw (low similarity score) => llm request: {}",
llm_request_str
);
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
@ -873,42 +901,8 @@ impl StreamContext {
"archgw <= api call response: {}",
self.tool_call_response.as_ref().unwrap()
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap()
.clone();
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt.clone()),
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
let mut messages = self.filter_out_arch_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -960,6 +954,51 @@ impl StreamContext {
self.resume_http_request();
}
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
let prompt_system_prompt = self
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
}
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!(