diff --git a/arch/Cargo.lock b/arch/Cargo.lock index b72c678c..201f2e68 100644 --- a/arch/Cargo.lock +++ b/arch/Cargo.lock @@ -441,6 +441,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "duration-string" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13" +dependencies = [ + "serde", +] + [[package]] name = "either" version = "1.13.0" @@ -1075,6 +1084,7 @@ dependencies = [ name = "public_types" version = "0.1.0" dependencies = [ + "duration-string", "serde", "serde_yaml", ] diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 99965503..fb6a4f3b 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -176,7 +176,11 @@ static_resources: hostname: "arch_fc" {% for _, cluster in arch_clusters.items() %} - name: {{ cluster.name }} + {% if cluster.connect_timeout -%} + connect_timeout: {{ cluster.connect_timeout }} + {% else -%} connect_timeout: 5s + {% endif -%} type: STRICT_DNS lb_policy: ROUND_ROBIN load_assignment: @@ -186,7 +190,7 @@ static_resources: - endpoint: address: socket_address: - address: {{ cluster.address }} + address: {{ cluster.endpoint }} port_value: {{ cluster.port }} - hostname: {{ cluster.address }} + hostname: {{ cluster.name }} {% endfor %} diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 69c65092..7a099afb 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -23,7 +23,7 @@ use public_types::common_types::{ EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType}; +use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; @@ -358,103 +358,97 @@ impl StreamContext { info!("prompt_target name: {:?}", prompt_target_name); - match prompt_target.prompt_type { - PromptType::FunctionResolver => { - let mut chat_completion_tools: Vec = Vec::new(); - for pt in self.prompt_targets.read().unwrap().values() { - // only extract entity names - let properties: HashMap = match pt.parameters { - // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. - Some(ref entities) => { - let mut properties: HashMap = HashMap::new(); - for entity in entities.iter() { - let param = FunctionParameter { - parameter_type: ParameterType::from( - entity.parameter_type.clone().unwrap_or("str".to_string()), - ), - description: entity.description.clone(), - required: entity.required, - enum_values: entity.enum_values.clone(), - default: entity.default.clone(), - }; - properties.insert(entity.name.clone(), param); - } - properties - } - None => HashMap::new(), - }; - let tools_parameters = FunctionParameters { properties }; - - chat_completion_tools.push({ - ChatCompletionTool { - tool_type: ToolType::Function, - function: FunctionDefinition { - name: pt.name.clone(), - description: pt.description.clone(), - parameters: tools_parameters, - }, - } - }); + //TODO: handle default function resolver type + let mut chat_completion_tools: Vec = Vec::new(); + for pt in self.prompt_targets.read().unwrap().values() { + // only extract entity names + let properties: HashMap = match pt.parameters { + // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. + Some(ref entities) => { + let mut properties: HashMap = HashMap::new(); + for entity in entities.iter() { + let param = FunctionParameter { + parameter_type: ParameterType::from( + entity.parameter_type.clone().unwrap_or("str".to_string()), + ), + description: entity.description.clone(), + required: entity.required, + enum_values: entity.enum_values.clone(), + default: entity.default.clone(), + }; + properties.insert(entity.name.clone(), param); + } + properties } + None => HashMap::new(), + }; + let tools_parameters = FunctionParameters { properties }; - let chat_completions = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), - messages: callout_context.request_body.messages.clone(), - tools: Some(chat_completion_tools), - stream: false, - stream_options: None, - }; - - let msg_body = match serde_json::to_string(&chat_completions) { - Ok(msg_body) => { - debug!("arch_fc request body content: {}", msg_body); - msg_body - } - Err(e) => { - return self.send_server_error( - format!("Error serializing request_params: {:?}", e), - None, - ); - } - }; - - let token_id = match self.dispatch_http_call( - ARC_FC_CLUSTER, - vec![ - (":method", "POST"), - (":path", "/v1/chat/completions"), - (":authority", ARC_FC_CLUSTER), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ( - "x-envoy-upstream-rq-timeout-ms", - ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(), - ), - ], - Some(msg_body.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = - format!("Error dispatching HTTP call for function-call: {:?}", e); - return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); - } - }; - - debug!( - "dispatched call to function {} token_id={}", - ARC_FC_CLUSTER, token_id - ); - - self.metrics.active_http_calls.increment(1); - callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; - callout_context.prompt_target_name = Some(prompt_target.name); - if self.callouts.insert(token_id, callout_context).is_some() { - panic!("duplicate token_id") + chat_completion_tools.push({ + ChatCompletionTool { + tool_type: ToolType::Function, + function: FunctionDefinition { + name: pt.name.clone(), + description: pt.description.clone(), + parameters: tools_parameters, + }, } + }); + } + + let chat_completions = ChatCompletionsRequest { + model: GPT_35_TURBO.to_string(), + messages: callout_context.request_body.messages.clone(), + tools: Some(chat_completion_tools), + stream: false, + stream_options: None, + }; + + let msg_body = match serde_json::to_string(&chat_completions) { + Ok(msg_body) => { + debug!("arch_fc request body content: {}", msg_body); + msg_body } + Err(e) => { + return self + .send_server_error(format!("Error serializing request_params: {:?}", e), None); + } + }; + + let token_id = match self.dispatch_http_call( + ARC_FC_CLUSTER, + vec![ + (":method", "POST"), + (":path", "/v1/chat/completions"), + (":authority", ARC_FC_CLUSTER), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ( + "x-envoy-upstream-rq-timeout-ms", + ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(), + ), + ], + Some(msg_body.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e); + return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); + } + }; + + debug!( + "dispatched call to function {} token_id={}", + ARC_FC_CLUSTER, token_id + ); + + self.metrics.active_http_calls.increment(1); + callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; + callout_context.prompt_target_name = Some(prompt_target.name); + if self.callouts.insert(token_id, callout_context).is_some() { + panic!("duplicate token_id") } } @@ -530,17 +524,32 @@ impl StreamContext { debug!("tool_params: {}", tool_params_json_str); let endpoint = prompt_target.endpoint.unwrap(); - let path = endpoint.path.unwrap_or(String::from("/")); + let mut path = endpoint.path.unwrap_or(String::from("/")); + let method = endpoint + .method + .unwrap_or(public_types::configuration::Method::Post); + let mut body = Some(tool_params_json_str.as_bytes()); + if method == public_types::configuration::Method::Post { + let mut query_params = vec![]; + for (key, value) in tool_params { + query_params.push(format!("{}={}", key, format!("{:?}", value))); + } + let path_args = &query_params.join("&"); + path.push_str("?"); + path.push_str(path_args); + } else { + body = None; + } let token_id = match self.dispatch_http_call( - &endpoint.cluster, + &endpoint.name, vec![ - (":method", "POST"), + (":method", method.to_string().as_str()), (":path", path.as_ref()), - (":authority", endpoint.cluster.as_str()), + (":authority", endpoint.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), ], - Some(tool_params_json_str.as_bytes()), + body, vec![], Duration::from_secs(5), ) { @@ -548,14 +557,14 @@ impl StreamContext { Err(e) => { let error_msg = format!( "Error dispatching call to cluster: {}, path: {}, err: {:?}", - &endpoint.cluster, path, e + &endpoint.name, path, e ); debug!("{}", error_msg); return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); } }; - callout_context.up_stream_cluster = Some(endpoint.cluster); + callout_context.up_stream_cluster = Some(endpoint.name); callout_context.up_stream_cluster_path = Some(path); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if self.callouts.insert(token_id, callout_context).is_some() { @@ -682,27 +691,18 @@ impl StreamContext { if prompt_guard_resp.jailbreak_verdict.is_some() && prompt_guard_resp.jailbreak_verdict.unwrap() { + //TODO: handle other scenarios like forward to error target let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking."; let error_msg = match self.prompt_guards.as_ref() { - Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() { - Some(jailbreak) => match jailbreak.on_exception_message.as_ref() { - Some(error_msg) => error_msg, - None => default_err, - }, - None => default_err, - }, - None => default_err, - }; - - return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST)); - } - - if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() { - let default_err = "Toxicity detected. Please refrain from using toxic language."; - let error_msg = match self.prompt_guards.as_ref() { - Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() { - Some(toxicity) => match toxicity.on_exception_message.as_ref() { - Some(error_msg) => error_msg, + Some(prompt_guards) => match prompt_guards + .input_guards + .get(&public_types::configuration::GuardType::Jailbreak) + { + Some(jailbreak) => match jailbreak.on_exception.as_ref() { + Some(on_exception_details) => match on_exception_details.message.as_ref() { + Some(error_msg) => error_msg, + None => default_err, + }, None => default_err, }, None => default_err, @@ -883,32 +883,27 @@ impl HttpContext for StreamContext { } }; - let prompt_guard_task = match ( - prompt_guards.input_guards.toxicity.is_some(), - prompt_guards.input_guards.jailbreak.is_some(), - ) { - (true, true) => PromptGuardTask::Both, - (true, false) => PromptGuardTask::Toxicity, - (false, true) => PromptGuardTask::Jailbreak, - (false, false) => { - info!("Input guards set but no prompt guards were found"); - let callout_context = CallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: Some(user_message), - prompt_target_name: None, - request_body: deserialized_body, - similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - }; + let prompt_guard_jailbreak_task = prompt_guards + .input_guards + .contains_key(&public_types::configuration::GuardType::Jailbreak); + if !prompt_guard_jailbreak_task { + info!("Input guards set but no prompt guards were found"); + let callout_context = CallContext { + response_handler_type: ResponseHandlerType::ArchGuard, + user_message: Some(user_message), + prompt_target_name: None, + request_body: deserialized_body, + similarity_scores: None, + up_stream_cluster: None, + up_stream_cluster_path: None, + }; + self.get_embeddings(callout_context); + return Action::Pause; + } let get_prompt_guards_request = PromptGuardRequest { input: user_message.clone(), - task: prompt_guard_task, + task: PromptGuardTask::Jailbreak, }; let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 21ce8979..336454e0 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -175,27 +175,36 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { fn default_config() -> Configuration { let config: &str = r#" -default_prompt_endpoint: "127.0.0.1" -load_balancing: "round_robin" -timeout_ms: 5000 +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +endpoints: + api_server: + endpoint: api_server:80 + connect_timeout: 0.005s llm_providers: - - name: "open-ai-gpt-4" - api_key: "$OPEN_AI_API_KEY" + - name: open-ai-gpt-4 + access_key: $OPEN_AI_API_KEY model: gpt-4 + default: true + +overrides: + # confidence threshold for prompt target intent matching + prompt_target_intent_matching_threshold: 0.6 system_prompt: | - You are a helpful weather forecaster. Please following following guidelines when responding to user queries: - - Use farenheight for temperature - - Use miles per hour for wind speed + You are a helpful assistant. prompt_targets: - - type: function_resolver - name: weather_forecast - description: This resolver provides weather forecast information. - endpoint: - cluster: weatherhost - path: /weather + + - name: weather_forecast + description: This function provides realtime weather forecast information for a given city. parameters: - name: city required: true @@ -204,16 +213,32 @@ prompt_targets: description: The number of days for which the weather forecast is requested. - name: units description: The units in which the weather forecast is requested. - - - type: function_resolver - name: weather_forecast_2 - description: This resolver provides weather forecast information. endpoint: - cluster: weatherhost + name: api_server path: /weather - entities: - - name: city + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + - name: insurance_claim_details + type: function_resolver + description: This function resolver provides insurance claim details for a given policy number. + parameters: + - name: policy_number + required: true + description: The policy number for which the insurance claim details are requested. + type: string + - name: include_expired + description: whether to include expired insurance claims in the response. + type: bool + required: true + endpoint: + name: api_server + path: /insurance_claim_details + system_prompt: | + You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: + - Use policy number to retrieve insurance claim details ratelimits: - provider: gpt-3.5-turbo selector: @@ -222,7 +247,7 @@ ratelimits: limit: tokens: 1 unit: minute - "#; +"#; serde_yaml::from_str(config).unwrap() } @@ -442,7 +467,7 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("weatherhost"), None, None, None, None) + .expect_http_call(Some("api_server"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) @@ -557,7 +582,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("weatherhost"), None, None, None, None) + .expect_http_call(Some("api_server"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) diff --git a/config_generator/config_generator.py b/config_generator/config_generator.py index 386806f1..3b0b048a 100644 --- a/config_generator/config_generator.py +++ b/config_generator/config_generator.py @@ -17,25 +17,28 @@ config_yaml = yaml.safe_load(katanemo_config) inferred_clusters = {} for prompt_target in config_yaml["prompt_targets"]: - cluster = prompt_target.get("endpoint", {}).get("cluster", "") - if cluster not in inferred_clusters: - inferred_clusters[cluster] = { - "name": cluster, - "address": cluster, + name = prompt_target.get("endpoint", {}).get("name", "") + if name not in inferred_clusters: + inferred_clusters[name] = { + "name": name, "port": 80, # default port } print(inferred_clusters) -clusters = config_yaml.get("clusters", {}) +endpoints = config_yaml.get("endpoints", {}) # override the inferred clusters with the ones defined in the config -for name, cluster in clusters.items(): +for name, endpoint_details in endpoints.items(): if name in inferred_clusters: - print("updating cluster", cluster) - inferred_clusters[name].update(cluster) + print("updating cluster", endpoint_details) + inferred_clusters[name].update(endpoint_details) + endpoint = inferred_clusters[name]['endpoint'] + if len(endpoint.split(':')) > 1: + inferred_clusters[name]['endpoint'] = endpoint.split(':')[0] + inferred_clusters[name]['port'] = int(endpoint.split(':')[1]) else: - inferred_clusters[name] = cluster + inferred_clusters[name] = endpoint_details print("updated clusters", inferred_clusters) diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index f4ca6fa4..a2d03853 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -3,6 +3,7 @@ from fastapi import FastAPI, Response from datetime import datetime, date, timedelta, timezone import logging from pydantic import BaseModel +import pytz logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -56,3 +57,19 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon } return claim_details + +@app.get("/current_time") +async def current_time(timezone: str): + tz = None + try: + timezone.strip('"') + tz = pytz.timezone(timezone) + except pytz.exceptions.UnknownTimeZoneError: + return { + "error": "Invalid timezone: {}".format(timezone) + } + current_time = datetime.now(tz) + return { + "timezone": timezone, + "current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z") + } diff --git a/demos/function_calling/api_server/requirements.txt b/demos/function_calling/api_server/requirements.txt index 97dc7cd8..566bcbcb 100644 --- a/demos/function_calling/api_server/requirements.txt +++ b/demos/function_calling/api_server/requirements.txt @@ -1,2 +1,3 @@ fastapi uvicorn +pytz diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 065beb03..fc7fffb8 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -1,22 +1,32 @@ -default_prompt_endpoint: "127.0.0.1" -load_balancing: "round_robin" -timeout_ms: 5000 +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +endpoints: + api_server: + endpoint: api_server:80 + connect_timeout: 0.005s + +llm_providers: + - name: open-ai-gpt-4 + access_key: $OPEN_AI_API_KEY + model: gpt-4 + default: true overrides: # confidence threshold for prompt target intent matching prompt_target_intent_matching_threshold: 0.6 -llm_providers: - - - name: open-ai-gpt-4 - api_key: $OPEN_AI_API_KEY - model: gpt-4 - default: true +system_prompt: | + You are a helpful assistant. prompt_targets: - - type: function_resolver - name: weather_forecast + - name: weather_forecast description: This function provides realtime weather forecast information for a given city. parameters: - name: city @@ -27,14 +37,30 @@ prompt_targets: - name: units description: The units in which the weather forecast is requested. endpoint: - cluster: api_server + name: api_server path: /weather system_prompt: | You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: - Use farenheight for temperature - Use miles per hour for wind speed - - type: function_resolver - name: insurance_claim_details + + - name: system_time + description: This function provides the current system time. + parameters: + - name: timezone + description: The city for which the weather forecast is requested. + default: US/Pacific + endpoint: + name: api_server + path: /current_time + method: Get + system_prompt: | + You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries: + - Use 12 hour time format + - Use AM/PM for time + + - name: insurance_claim_details + type: function_resolver description: This function resolver provides insurance claim details for a given policy number. parameters: - name: policy_number @@ -46,8 +72,16 @@ prompt_targets: type: bool required: true endpoint: - cluster: api_server + name: api_server path: /insurance_claim_details system_prompt: | You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - Use policy number to retrieve insurance claim details +ratelimits: + - provider: gpt-3.5-turbo + selector: + key: selector-key + value: selector-value + limit: + tokens: 1 + unit: minute diff --git a/docs/source/_config/prompt-config-full-reference.yml b/docs/source/_config/prompt-config-full-reference.yml index 4a697b7a..58bfb1ed 100644 --- a/docs/source/_config/prompt-config-full-reference.yml +++ b/docs/source/_config/prompt-config-full-reference.yml @@ -1,78 +1,109 @@ version: "0.1-beta" listener: - address: 0.0.0.0 # or 127.0.0.1 - port_value: 8080 - messages: "hugging-face-messages-json" # Defines how Arch should parse the content from application/json or text/pain Content-type in the http request + address: 0.0.0.0 # or 127.0.0.1 + port: 10000 + # Defines how Arch should parse the content from application/json or text/pain Content-type in the http request + message_format: huggingface common_tls_context: # If you configure port 443, you'll need to update the listener with your TLS certificates tls_certificates: - certificate_chain: - filename: "/etc/arch/certs/cert.pem" + filename: "/etc/certs/cert.pem" private_key: - filename: "/etc/arch/certs/key.pem" + filename: "/etc/certs/key.pem" -system_prompts: - - name: "network_assistant" - content: | - You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions. +# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. +endpoints: + app_server: + # value could be ip address or a hostname with port + # this could also be a list of endpoints for load balancing + # for example endpoint: [ ip1:port, ip2:port ] + endpoint: "127.0.0.1:80" + # max time to wait for a connection to be established + connect_timeout: 500ms + # max time to wait for a response + timeout: 10000ms -llm_providers: #Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way + mistral_local: + endpoint: "127.0.0.1:8001" + + error_target: + endpoint: "error_target_1" + +# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way +llm_providers: - name: "OpenAI" access_key: $OPENAI_API_KEY model: gpt-4o default: true stream: true - rate_limit: + rate_limits: selector: #optional headers, to add rate limiting based on http headers like JWT tokens or API keys - http-header: + http_header: name: "Authorization" value: "" # Empty value means each separate value has a separate limit limit: - tokens: 100000 # Tokens per unit + tokens: 100000 # Tokens per unit unit: "minute" - - name: "Mistral" - access_key: $MISTRAL_API_KEY - model: "mistral-7B" -prompt_endpoints: #Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. - - "http://127.0.0.2" #assumes port 8000, unless port is specified with :5000 - - "http://127.0.0.1:5000" + - name: "Mistral8x7b" + access_key: $MISTRAL_API_KEY + model: "mistral-8x7b" + + - name: "MistralLocal7b" + model: "mistral-7b-instruct" + endpoint: "mistral_local" + +# provides a way to override default settings for the arch system +overrides: + # By default Arch uses an NLI + embedding approach to match an incomming prompt to a prompt target. + # The intent matching threshold is kept at 0.80, you can overide this behavior if you would like + prompt_target_intent_matching_threshold: 0.60 + +# default system prompt used by all prompt targets +system_prompt: | + You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions. prompt_guards: - input_guard: - - name: "jailbreak" - on_exception: - forward_to_error_target: true - - name: "toxicity" + input_guards: + jailbreak: on_exception: message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." prompt_targets: - - name: "information_extraction" - type: "default" - description: "This prompt handles all scenarios that are question and answer in nature. Like summarization, information extraction, etc." - path: "/agent/summary" - auto-llm-dispatch-on-response: true #Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM - - name: "reboot_network_device" - path: "/agent/action" description: "Helps network operators perform device operations like rebooting a device." + endpoint: + name: app_server + path: "/agent/action" parameters: - name: "device_id" - type: "string" # additional type options include: integer | float | list | dictionary | set + # additional type options include: int | float | bool | string | list | dict + type: "string" description: "Identifier of the network device to reboot." - default_value: "" required: true - name: "confirmation" - type: "integer" # additional type options include: integer | float | list | dictionary | set + type: "string" description: "Confirmation flag to proceed with reboot." - required: true + default: "no" + enum: [yes, no] + + - name: "information_extraction" + default: true + description: "This prompt handles all scenarios that are question and answer in nature. Like summarization, information extraction, etc." + endpoint: + name: app_server + path: "/agent/summary" + method: Post + # Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM + auto_llm_dispatch_on_response: true + # override system prompt for this prompt target + system_prompt: | + You are a helpful information extraction assistant. Use the information that is provided to you. error_target: - name: "error_handler" - path: "/errors" + endpoint: + name: error_target_1 + path: /error tracing: 100 #sampling rate. Note by default Arch works on OpenTelemetry compatible tracing. - -intent-detection-threshold-override: 0.60 # By default Arch uses an NLI + embedding approach to match an incomming prompt to a prompt target. - # The intent matching threshold is kept at 0.80, you can overide this behavior if you would like diff --git a/public_types/Cargo.lock b/public_types/Cargo.lock index 5a176a11..b253445b 100644 --- a/public_types/Cargo.lock +++ b/public_types/Cargo.lock @@ -8,6 +8,15 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "duration-string" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13" +dependencies = [ + "serde", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -65,6 +74,7 @@ dependencies = [ name = "public_types" version = "0.1.0" dependencies = [ + "duration-string", "pretty_assertions", "serde", "serde_json", diff --git a/public_types/Cargo.toml b/public_types/Cargo.toml index 3c94335a..94a1725d 100644 --- a/public_types/Cargo.toml +++ b/public_types/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9.34" +duration-string = { version = "0.3.0", features = ["serde"] } [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index df89a66f..9b3e3968 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -151,11 +151,16 @@ pub mod open_ai { fn from(s: String) -> Self { match s.as_str() { "int" => ParameterType::Int, + "integer" => ParameterType::Int, "float" => ParameterType::Float, "bool" => ParameterType::Bool, + "boolean" => ParameterType::Bool, + "str" => ParameterType::String, "string" => ParameterType::String, "list" => ParameterType::List, + "array" => ParameterType::List, "dict" => ParameterType::Dict, + "dictionary" => ParameterType::Dict, _ => ParameterType::String, } } diff --git a/public_types/src/configuration.rs b/public_types/src/configuration.rs index 7e781a79..eaf48245 100644 --- a/public_types/src/configuration.rs +++ b/public_types/src/configuration.rs @@ -1,4 +1,7 @@ -use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, time::Duration}; + +use duration_string::DurationString; +use serde::{Deserialize, Serialize, Deserializer}; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Overrides { @@ -7,31 +10,88 @@ pub struct Overrides { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { - pub default_prompt_endpoint: String, - pub load_balancing: LoadBalancing, - pub timeout_ms: u64, - pub overrides: Option, + pub version: String, + pub listener: Listener, + pub endpoints: HashMap, pub llm_providers: Vec, - pub prompt_guards: Option, + pub overrides: Option, pub system_prompt: Option, + pub prompt_guards: Option, pub prompt_targets: Vec, + pub error_target: Option, + pub tracing: Option, pub ratelimits: Option>, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorTargetDetail { + pub endpoint: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Listener { + pub address: String, + pub port: u16, + pub message_format: MessageFormat, + // pub connect_timeout: Option, +} + +impl Default for Listener { + fn default() -> Self { + Listener { + address: "".to_string(), + port: 0, + message_format: MessageFormat::default(), + // connect_timeout: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub enum MessageFormat { + #[serde(rename = "huggingface")] + #[default] + Huggingface, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct PromptGuards { - pub input_guards: InputGuards, + pub input_guards: HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct InputGuards { - pub jailbreak: Option, - pub toxicity: Option, +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum GuardType { + #[serde(rename = "jailbreak")] + Jailbreak, } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct GuardOptions { - pub on_exception_message: Option, + pub on_exception: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OnExceptionDetails { + pub forward_to_error_target: Option, + pub error_handler: Option, + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmRatelimit { + pub selector: LlmRatelimitSelector, + pub limit: Limit, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmRatelimitSelector { + pub http_header: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Header { + pub key: String, + pub value: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -58,19 +118,11 @@ pub enum TimeUnit { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct Header { - pub key: String, +pub struct RatelimitHeader { + pub name: String, pub value: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum LoadBalancing { - #[serde(rename = "round_robin")] - RoundRobin, - #[serde(rename = "random")] - Random, -} - #[derive(Debug, Clone, Serialize, Deserialize)] //TODO: use enum for model, but if there is a new model, we need to update the code pub struct EmbeddingProviver { @@ -82,23 +134,19 @@ pub struct EmbeddingProviver { //TODO: use enum for model, but if there is a new model, we need to update the code pub struct LlmProvider { pub name: String, - pub api_key: Option, + //TODO: handle env var replacement + pub access_key: Option, pub model: String, pub default: Option, - pub endpoint: Option, -} -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum EnpointType { - String(String), - Struct(Endpoint), + pub stream: Option, + pub rate_limits: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { - pub cluster: String, - pub path: Option, - pub method: Option, + pub endpoint: Option, + // pub connect_timeout: Option, + // pub timeout: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -114,82 +162,144 @@ pub struct Parameter { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub enum PromptType { - #[serde(rename = "function_resolver")] - FunctionResolver, +pub struct EndpointDetails { + pub name: String, + pub path: Option, + pub method: Option, +} + + +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "UPPERCASE")] +pub enum Method { + Get, + Post, + Put, + Delete, +} + +impl ToString for Method { + fn to_string(&self) -> String { + match self { + Method::Get => "GET".to_string(), + Method::Post => "POST".to_string(), + Method::Put => "PUT".to_string(), + Method::Delete => "DELETE".to_string(), + } + } +} + +impl<'de> Deserialize<'de> for Method { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + match s.to_uppercase().as_str() { + "GET" => Ok(Method::Get), + "POST" => Ok(Method::Post), + "PUT" => Ok(Method::Put), + "DELETE" => Ok(Method::Delete), + _ => Err(serde::de::Error::custom(format!("Invalid enum variant: {}", s))), + } + } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PromptTarget { - #[serde(rename = "type")] - pub prompt_type: PromptType, pub name: String, + pub default: Option, pub description: String, + pub endpoint: Option, pub parameters: Option>, - pub endpoint: Option, pub system_prompt: Option, + pub auto_llm_dispatch_on_response: Option, } #[cfg(test)] mod test { - pub const CONFIGURATION: &str = r#" -default_prompt_endpoint: "127.0.0.1" -load_balancing: "round_robin" -timeout_ms: 5000 + use std::fs; -llm_providers: - - name: "open-ai-gpt-4" - api_key: "$OPEN_AI_API_KEY" - model: gpt-4 - -system_prompt: | - You are a helpful weather forecaster. Please following following guidelines when responding to user queries: - - Use farenheight for temperature - - Use miles per hour for wind speed - -prompt_guards: - input_guards: - jailbreak: - on_exception_message: Looks like you are curious about my abilities… - toxicity: - on_exception_message: Looks like you are curious about my abilities… - -prompt_targets: - - - type: function_resolver - name: weather_forecast - description: Get the weather forecast for a location - endpoint: - cluster: weatherhost - path: /weather - parameters: - - name: location - required: true - description: "The location for which the weather is requested" - - - type: function_resolver - name: weather_forecast_2 - description: Get the weather forecast for a location - few_shot_examples: - - what is the weather in New York? - endpoint: - cluster: weatherhost - path: /weather - parameters: - - name: city - description: "The location for which the weather is requested" - -ratelimits: - - provider: open-ai-gpt-4 - selector: - key: x-katanemo-openai-limit-id - limit: - tokens: 100 - unit: minute - "#; + use crate::configuration::GuardType; #[test] fn test_deserialize_configuration() { - let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap(); + let ref_config = + fs::read_to_string("../docs/source/_config/prompt-config-full-reference.yml") + .expect("reference config file not found"); + + let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); + assert_eq!(config.version, "0.1-beta"); + + let open_ai_provider = config + .llm_providers + .iter() + .find(|p| p.name.to_lowercase() == "openai") + .unwrap(); + assert_eq!(open_ai_provider.name.to_lowercase(), "openai"); + assert_eq!( + open_ai_provider.access_key, + Some("$OPENAI_API_KEY".to_string()) + ); + assert_eq!(open_ai_provider.model, "gpt-4o"); + assert_eq!(open_ai_provider.default, Some(true)); + assert_eq!(open_ai_provider.stream, Some(true)); + + let prompt_guards = config.prompt_guards.as_ref().unwrap(); + let input_guards = &prompt_guards.input_guards; + let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); + assert_eq!( + jailbreak_guard + .on_exception + .as_ref() + .unwrap() + .forward_to_error_target, + None + ); + assert_eq!( + jailbreak_guard.on_exception.as_ref().unwrap().error_handler, + None + ); + + let prompt_targets = &config.prompt_targets; + assert_eq!(prompt_targets.len(), 2); + let prompt_target = prompt_targets + .iter() + .find(|p| p.name == "reboot_network_device") + .unwrap(); + assert_eq!(prompt_target.name, "reboot_network_device"); + assert_eq!(prompt_target.default, None); + + let prompt_target = prompt_targets + .iter() + .find(|p| p.name == "information_extraction") + .unwrap(); + assert_eq!(prompt_target.name, "information_extraction"); + assert_eq!(prompt_target.default, Some(true)); + assert_eq!( + prompt_target.endpoint.as_ref().unwrap().name, + "app_server".to_string() + ); + assert_eq!( + prompt_target.endpoint.as_ref().unwrap().path, + Some("/agent/summary".to_string()) + ); + assert_eq!( + prompt_target.endpoint.as_ref().unwrap().method.as_ref().unwrap().to_string(), + "POST".to_string() + ); + + let error_target = config.error_target.as_ref().unwrap(); + assert_eq!( + error_target.endpoint.as_ref().unwrap().name, + "error_target_1".to_string() + ); + assert_eq!( + error_target.endpoint.as_ref().unwrap().path, + Some("/error".to_string()) + ); + + let tracing = config.tracing.as_ref().unwrap(); + assert_eq!(*tracing, 100); } }