diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 293ffc8a..c32dae16 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -49,7 +49,9 @@ def validate_and_render_schema(): if "prompt_targets" in config_yaml: for prompt_target in config_yaml["prompt_targets"]: - name = prompt_target.get("endpoint", {}).get("name", "") + name = prompt_target.get("endpoint", {}).get("name", None) + if not name: + continue if name not in inferred_clusters: inferred_clusters[name] = { "name": name, diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index ebcb0636..501e25a9 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -494,6 +494,53 @@ impl StreamContext { .find(|pt| pt.default.unwrap_or(false)) { debug!("default prompt target found, forwarding request to default prompt target"); + if default_prompt_target.endpoint.is_none() { + info!("default prompt target endpoint not found"); + + let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone())); + + let messages = vec![ + Message { + content: system_prompt, + role: SYSTEM_ROLE.to_string(), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + }, + Message { + content: self.user_prompt.as_ref().unwrap().content.clone(), + role: ASSISTANT_ROLE.to_string(), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + }, + ]; + + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { + model: callout_context.request_body.model, + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + metadata: None, + }; + + let llm_request_str = match serde_json::to_string(&chat_completions_request) { + Ok(json_string) => json_string, + Err(e) => { + return self.send_server_error(ServerError::Serialization(e), None); + } + }; + + self.set_http_request_body( + 0, + self.request_body_size, + &llm_request_str.into_bytes(), + ); + + self.resume_http_request(); + return; + } let endpoint = default_prompt_target.endpoint.clone().unwrap(); let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); @@ -547,7 +594,7 @@ impl StreamContext { // if no default prompt target is found and similarity score is low send response to upstream llm // removing tool calls and tool response - let messages = self.filter_out_arch_messages(&callout_context); + let messages = self.construct_llm_messages(&callout_context); let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { model: callout_context.request_body.model, @@ -988,7 +1035,7 @@ impl StreamContext { self.tool_call_response.as_ref().unwrap() ); - let mut messages = self.filter_out_arch_messages(&callout_context); + let mut messages = self.construct_llm_messages(&callout_context); let user_message = match messages.pop() { Some(user_message) => user_message, @@ -1045,25 +1092,43 @@ impl StreamContext { self.resume_http_request(); } - fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec { - let mut messages: Vec = Vec::new(); - // add system prompt + fn get_system_prompt(&self, prompt_target: Option) -> Option { + match prompt_target { + None => self.system_prompt.as_ref().clone(), + Some(prompt_target) => prompt_target.system_prompt, + } + } + fn filter_out_arch_messages(&self, messages: &Vec) -> Vec { + messages + .into_iter() + .filter(|m| { + if m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) + { + true + } else { + false + } + }) + .cloned() + .collect() + } + + fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec { + let mut messages: Vec = Vec::new(); + + // add system prompt let system_prompt = match callout_context.prompt_target_name.as_ref() { None => self.system_prompt.as_ref().clone(), Some(prompt_target_name) => { - let prompt_system_prompt = self - .prompt_targets - .get(prompt_target_name) - .unwrap() - .clone() - .system_prompt; - match prompt_system_prompt { - None => self.system_prompt.as_ref().clone(), - Some(system_prompt) => Some(system_prompt), - } + self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned()) } }; + + info!("messages 1: {:?}", callout_context.request_body.messages); + if system_prompt.is_some() { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), @@ -1075,18 +1140,12 @@ impl StreamContext { messages.push(system_prompt_message); } - // don't send tools message and api response to chat gpt - for m in callout_context.request_body.messages.iter() { - // don't send api response and tool calls to upstream LLMs - if m.role == TOOL_ROLE - || m.content.is_none() - || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) - { - continue; - } - messages.push(m.clone()); - } + info!("messages 2: {:?}", messages); + messages.append( + &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()), + ); + info!("messages 3: {:?}", messages); messages } diff --git a/demos/acm_k8s/acm_api.yaml b/demos/acm_k8s/acm_api.yaml index 4b4ec9fd..e7638523 100644 --- a/demos/acm_k8s/acm_api.yaml +++ b/demos/acm_k8s/acm_api.yaml @@ -66,7 +66,6 @@ paths: application/json: schema: $ref: "#/components/schemas/ManagedCluster" - components: schemas: ManagedCluster: diff --git a/demos/acm_k8s/acm_service/openapi_server/__main__.py b/demos/acm_k8s/acm_service/openapi_server/__main__.py index 417a9020..31ada589 100644 --- a/demos/acm_k8s/acm_service/openapi_server/__main__.py +++ b/demos/acm_k8s/acm_service/openapi_server/__main__.py @@ -10,9 +10,7 @@ def main(): app.app.json_encoder = encoder.JSONEncoder app.add_api( "openapi.yaml", - arguments={ - "title": "ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags" - }, + arguments={"title": "ACM API for cluster management"}, pythonic_params=True, ) diff --git a/demos/acm_k8s/acm_service/openapi_server/openapi/openapi.yaml b/demos/acm_k8s/acm_service/openapi_server/openapi/openapi.yaml index 61202a34..46503e9c 100644 --- a/demos/acm_k8s/acm_service/openapi_server/openapi/openapi.yaml +++ b/demos/acm_k8s/acm_service/openapi_server/openapi/openapi.yaml @@ -4,7 +4,8 @@ info: email: support@katanemo.com name: Katanemo Labs Inc. url: https://katanemo.com - title: ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags + description: This is the API for managing clusters using ACM - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags + title: ACM API for cluster management version: 2.12.0 servers: - url: / diff --git a/demos/acm_k8s/acm_service/setup.py b/demos/acm_k8s/acm_service/setup.py index 6eb19f15..36741ffd 100644 --- a/demos/acm_k8s/acm_service/setup.py +++ b/demos/acm_k8s/acm_service/setup.py @@ -16,19 +16,16 @@ REQUIRES = ["connexion>=2.0.2", "swagger-ui-bundle>=0.0.2", "python_dateutil>=2. setup( name=NAME, version=VERSION, - description="ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags", + description="ACM API for cluster management", author_email="support@katanemo.com", url="", - keywords=[ - "OpenAPI", - "ACM API for cluster management - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags", - ], + keywords=["OpenAPI", "ACM API for cluster management"], install_requires=REQUIRES, packages=find_packages(), package_data={"": ["openapi/openapi.yaml"]}, include_package_data=True, entry_points={"console_scripts": ["openapi_server=openapi_server.__main__:main"]}, long_description="""\ - No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + This is the API for managing clusters using ACM - https://docs.redhat.com/en/documentation/red_hat_advanced_cluster_management_for_kubernetes/2.12/html/apis/apis#tags """, ) diff --git a/demos/acm_k8s/arch_config.yaml b/demos/acm_k8s/arch_config.yaml index c7d315f8..cdfd6bee 100644 --- a/demos/acm_k8s/arch_config.yaml +++ b/demos/acm_k8s/arch_config.yaml @@ -42,14 +42,14 @@ prompt_guards: message: Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting. prompt_targets: - - name: listManagedClusters + - name: listAllClusterDetails description: Query your clusters for more details. http_method: GET endpoint: name: acm_service path: /cluster.open-cluster-management.io/v1/managedclusters - - name: getCluster + - name: getClusterDetails description: Query a single cluster for more details http_method: GET endpoint: @@ -62,6 +62,34 @@ prompt_targets: required: true type: str + - name: default_target + default: true + description: This is the default target for all unmatched prompts. + system_prompt: | + You are a helpful assistant that can help answer ACM queries. Following is a list of available commands user can ask. Based on question asked, tell me which command you want to execute. + + - name: listAllClusterDetails + description: Query your clusters for more details. + http_method: GET + endpoint: + name: acm_service + path: /cluster.open-cluster-management.io/v1/managedclusters + + - name: getClusterDetails + description: Query a single cluster for more details + http_method: GET + endpoint: + name: acm_service + path: /cluster.open-cluster-management.io/v1/managedclusters/{cluster_name} + parameters: + - name: cluster_name + in_path: true + description: The name of the cluster to retrieve + required: true + type: str + + auto_llm_dispatch_on_response: true + tracing: random_sampling: 100 trace_arch_internal: true diff --git a/demos/acm_k8s/generate_acm_service_stub.sh b/demos/acm_k8s/generate_acm_service_stub.sh index a3ead749..2dedf44c 100644 --- a/demos/acm_k8s/generate_acm_service_stub.sh +++ b/demos/acm_k8s/generate_acm_service_stub.sh @@ -2,5 +2,4 @@ docker run --rm -v "${PWD}:/local" openapitools/openapi-generator-cli:latest gen --skip-validate-spec \ -i /local/acm_api.yaml \ -g python-flask \ - -o /local/acm_server \ - # --additional-properties=defaultController=your.package.YourController \ + -o /local/acm_service \