This commit is contained in:
Adil Hafeez 2024-12-05 10:11:55 -08:00
parent af02807004
commit 5e182b6c09
8 changed files with 125 additions and 42 deletions

View file

@ -49,7 +49,9 @@ def validate_and_render_schema():
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", "")
name = prompt_target.get("endpoint", {}).get("name", None)
if not name:
continue
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,

View file

@ -494,6 +494,53 @@ impl StreamContext {
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
if default_prompt_target.endpoint.is_none() {
info!("default prompt target endpoint not found");
let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone()));
let messages = vec![
Message {
content: system_prompt,
role: SYSTEM_ROLE.to_string(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
Message {
content: self.user_prompt.as_ref().unwrap().content.clone(),
role: ASSISTANT_ROLE.to_string(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
@ -547,7 +594,7 @@ impl StreamContext {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
let messages = self.filter_out_arch_messages(&callout_context);
let messages = self.construct_llm_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
@ -988,7 +1035,7 @@ impl StreamContext {
self.tool_call_response.as_ref().unwrap()
);
let mut messages = self.filter_out_arch_messages(&callout_context);
let mut messages = self.construct_llm_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -1045,25 +1092,43 @@ impl StreamContext {
self.resume_http_request();
}
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
match prompt_target {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target) => prompt_target.system_prompt,
}
}
fn filter_out_arch_messages(&self, messages: &Vec<Message>) -> Vec<Message> {
messages
.into_iter()
.filter(|m| {
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
true
} else {
false
}
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
let prompt_system_prompt = self
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
}
};
info!("messages 1: {:?}", callout_context.request_body.messages);
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
@ -1075,18 +1140,12 @@ impl StreamContext {
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
info!("messages 2: {:?}", messages);
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
info!("messages 3: {:?}", messages);
messages
}

View file

@ -66,7 +66,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ManagedCluster"
components:
schemas:
ManagedCluster:

View file

@ -10,9 +10,7 @@ def main():
app.app.json_encoder = encoder.JSONEncoder
app.add_api(
"openapi.yaml",
arguments={
"title": "ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags"
},
arguments={"title": "ACM API for cluster management"},
pythonic_params=True,
)

View file

@ -4,7 +4,8 @@ info:
email: support@katanemo.com
name: Katanemo Labs Inc.
url: https://katanemo.com
title: ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags
description: This is the API for managing clusters using ACM - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags
title: ACM API for cluster management
version: 2.12.0
servers:
- url: /

View file

@ -16,19 +16,16 @@ REQUIRES = ["connexion>=2.0.2", "swagger-ui-bundle>=0.0.2", "python_dateutil>=2.
setup(
name=NAME,
version=VERSION,
description="ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags",
description="ACM API for cluster management",
author_email="support@katanemo.com",
url="",
keywords=[
"OpenAPI",
"ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags",
],
keywords=["OpenAPI", "ACM API for cluster management"],
install_requires=REQUIRES,
packages=find_packages(),
package_data={"": ["openapi/openapi.yaml"]},
include_package_data=True,
entry_points={"console_scripts": ["openapi_server=openapi_server.__main__:main"]},
long_description="""\
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
This is the API for managing clusters using ACM - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags
""",
)

View file

@ -42,14 +42,14 @@ prompt_guards:
message: Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting.
prompt_targets:
- name: listManagedClusters
- name: listAllClusterDetails
description: Query your clusters for more details.
http_method: GET
endpoint:
name: acm_service
path: /cluster.open-cluster-management.io/v1/managedclusters
- name: getCluster
- name: getClusterDetails
description: Query a single cluster for more details
http_method: GET
endpoint:
@ -62,6 +62,34 @@ prompt_targets:
required: true
type: str
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
system_prompt: |
You are a helpful assistant that can help answer ACM queries. Following is a list of available commands user can ask. Based on question asked, tell me which command you want to execute.
- name: listAllClusterDetails
description: Query your clusters for more details.
http_method: GET
endpoint:
name: acm_service
path: /cluster.open-cluster-management.io/v1/managedclusters
- name: getClusterDetails
description: Query a single cluster for more details
http_method: GET
endpoint:
name: acm_service
path: /cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}
parameters:
- name: cluster_name
in_path: true
description: The name of the cluster to retrieve
required: true
type: str
auto_llm_dispatch_on_response: true
tracing:
random_sampling: 100
trace_arch_internal: true

View file

@ -2,5 +2,4 @@ docker run --rm -v "${PWD}:/local" openapitools/openapi-generator-cli:latest gen
--skip-validate-spec \
-i /local/acm_api.yaml \
-g python-flask \
-o /local/acm_server \
# --additional-properties=defaultController=your.package.YourController \
-o /local/acm_service \