mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
update config (#93)
This commit is contained in:
parent
4182879717
commit
cc35eb0cd7
13 changed files with 575 additions and 329 deletions
10
arch/Cargo.lock
generated
10
arch/Cargo.lock
generated
|
|
@ -441,6 +441,15 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
|
|
@ -1075,6 +1084,7 @@ dependencies = [
|
|||
name = "public_types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"duration-string",
|
||||
"serde",
|
||||
"serde_yaml",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -176,7 +176,11 @@ static_resources:
|
|||
hostname: "arch_fc"
|
||||
{% for _, cluster in arch_clusters.items() %}
|
||||
- name: {{ cluster.name }}
|
||||
{% if cluster.connect_timeout -%}
|
||||
connect_timeout: {{ cluster.connect_timeout }}
|
||||
{% else -%}
|
||||
connect_timeout: 5s
|
||||
{% endif -%}
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
|
|
@ -186,7 +190,7 @@ static_resources:
|
|||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: {{ cluster.address }}
|
||||
address: {{ cluster.endpoint }}
|
||||
port_value: {{ cluster.port }}
|
||||
hostname: {{ cluster.address }}
|
||||
hostname: {{ cluster.name }}
|
||||
{% endfor %}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ use public_types::common_types::{
|
|||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
|
|
@ -358,103 +358,97 @@ impl StreamContext {
|
|||
|
||||
info!("prompt_target name: {:?}", prompt_target_name);
|
||||
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
//TODO: handle default function resolver type
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => {
|
||||
debug!("arch_fc request body content: {}", msg_body);
|
||||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!("Error serializing request_params: {:?}", e),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
ARC_FC_CLUSTER,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARC_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
ARC_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => {
|
||||
debug!("arch_fc request body content: {}", msg_body);
|
||||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
return self
|
||||
.send_server_error(format!("Error serializing request_params: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
ARC_FC_CLUSTER,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARC_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
ARC_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -530,17 +524,32 @@ impl StreamContext {
|
|||
debug!("tool_params: {}", tool_params_json_str);
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path = endpoint.path.unwrap_or(String::from("/"));
|
||||
let mut path = endpoint.path.unwrap_or(String::from("/"));
|
||||
let method = endpoint
|
||||
.method
|
||||
.unwrap_or(public_types::configuration::Method::Post);
|
||||
let mut body = Some(tool_params_json_str.as_bytes());
|
||||
if method == public_types::configuration::Method::Post {
|
||||
let mut query_params = vec![];
|
||||
for (key, value) in tool_params {
|
||||
query_params.push(format!("{}={}", key, format!("{:?}", value)));
|
||||
}
|
||||
let path_args = &query_params.join("&");
|
||||
path.push_str("?");
|
||||
path.push_str(path_args);
|
||||
} else {
|
||||
body = None;
|
||||
}
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster,
|
||||
&endpoint.name,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":method", method.to_string().as_str()),
|
||||
(":path", path.as_ref()),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(tool_params_json_str.as_bytes()),
|
||||
body,
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
|
|
@ -548,14 +557,14 @@ impl StreamContext {
|
|||
Err(e) => {
|
||||
let error_msg = format!(
|
||||
"Error dispatching call to cluster: {}, path: {}, err: {:?}",
|
||||
&endpoint.cluster, path, e
|
||||
&endpoint.name, path, e
|
||||
);
|
||||
debug!("{}", error_msg);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
callout_context.up_stream_cluster = Some(endpoint.cluster);
|
||||
callout_context.up_stream_cluster = Some(endpoint.name);
|
||||
callout_context.up_stream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -682,27 +691,18 @@ impl StreamContext {
|
|||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() {
|
||||
Some(jailbreak) => match jailbreak.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() {
|
||||
let default_err = "Toxicity detected. Please refrain from using toxic language.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() {
|
||||
Some(toxicity) => match toxicity.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
Some(prompt_guards) => match prompt_guards
|
||||
.input_guards
|
||||
.get(&public_types::configuration::GuardType::Jailbreak)
|
||||
{
|
||||
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
|
||||
Some(on_exception_details) => match on_exception_details.message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
|
|
@ -883,32 +883,27 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let prompt_guard_task = match (
|
||||
prompt_guards.input_guards.toxicity.is_some(),
|
||||
prompt_guards.input_guards.jailbreak.is_some(),
|
||||
) {
|
||||
(true, true) => PromptGuardTask::Both,
|
||||
(true, false) => PromptGuardTask::Toxicity,
|
||||
(false, true) => PromptGuardTask::Jailbreak,
|
||||
(false, false) => {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let prompt_guard_jailbreak_task = prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
if !prompt_guard_jailbreak_task {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: user_message.clone(),
|
||||
task: prompt_guard_task,
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
|
|
|
|||
|
|
@ -175,27 +175,36 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
|
||||
fn default_config() -> Configuration {
|
||||
let config: &str = r#"
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: "open-ai-gpt-4"
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
- name: open-ai-gpt-4
|
||||
access_key: $OPEN_AI_API_KEY
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_targets:
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
|
|
@ -204,16 +213,32 @@ prompt_targets:
|
|||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
name: api_server
|
||||
path: /weather
|
||||
entities:
|
||||
- name: city
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
- name: insurance_claim_details
|
||||
type: function_resolver
|
||||
description: This function resolver provides insurance claim details for a given policy number.
|
||||
parameters:
|
||||
- name: policy_number
|
||||
required: true
|
||||
description: The policy number for which the insurance claim details are requested.
|
||||
type: string
|
||||
- name: include_expired
|
||||
description: whether to include expired insurance claims in the response.
|
||||
type: bool
|
||||
required: true
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /insurance_claim_details
|
||||
system_prompt: |
|
||||
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use policy number to retrieve insurance claim details
|
||||
ratelimits:
|
||||
- provider: gpt-3.5-turbo
|
||||
selector:
|
||||
|
|
@ -222,7 +247,7 @@ ratelimits:
|
|||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
"#;
|
||||
"#;
|
||||
serde_yaml::from_str(config).unwrap()
|
||||
}
|
||||
|
||||
|
|
@ -442,7 +467,7 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.expect_http_call(Some("api_server"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -557,7 +582,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.expect_http_call(Some("api_server"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
|
|||
|
|
@ -17,25 +17,28 @@ config_yaml = yaml.safe_load(katanemo_config)
|
|||
inferred_clusters = {}
|
||||
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
cluster = prompt_target.get("endpoint", {}).get("cluster", "")
|
||||
if cluster not in inferred_clusters:
|
||||
inferred_clusters[cluster] = {
|
||||
"name": cluster,
|
||||
"address": cluster,
|
||||
name = prompt_target.get("endpoint", {}).get("name", "")
|
||||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
"port": 80, # default port
|
||||
}
|
||||
|
||||
print(inferred_clusters)
|
||||
|
||||
clusters = config_yaml.get("clusters", {})
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, cluster in clusters.items():
|
||||
for name, endpoint_details in endpoints.items():
|
||||
if name in inferred_clusters:
|
||||
print("updating cluster", cluster)
|
||||
inferred_clusters[name].update(cluster)
|
||||
print("updating cluster", endpoint_details)
|
||||
inferred_clusters[name].update(endpoint_details)
|
||||
endpoint = inferred_clusters[name]['endpoint']
|
||||
if len(endpoint.split(':')) > 1:
|
||||
inferred_clusters[name]['endpoint'] = endpoint.split(':')[0]
|
||||
inferred_clusters[name]['port'] = int(endpoint.split(':')[1])
|
||||
else:
|
||||
inferred_clusters[name] = cluster
|
||||
inferred_clusters[name] = endpoint_details
|
||||
|
||||
print("updated clusters", inferred_clusters)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi import FastAPI, Response
|
|||
from datetime import datetime, date, timedelta, timezone
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
import pytz
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -56,3 +57,19 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
|
|||
}
|
||||
|
||||
return claim_details
|
||||
|
||||
@app.get("/current_time")
|
||||
async def current_time(timezone: str):
|
||||
tz = None
|
||||
try:
|
||||
timezone.strip('"')
|
||||
tz = pytz.timezone(timezone)
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
return {
|
||||
"error": "Invalid timezone: {}".format(timezone)
|
||||
}
|
||||
current_time = datetime.now(tz)
|
||||
return {
|
||||
"timezone": timezone,
|
||||
"current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
pytz
|
||||
|
|
|
|||
|
|
@ -1,22 +1,32 @@
|
|||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
access_key: $OPEN_AI_API_KEY
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
llm_providers:
|
||||
|
||||
- name: open-ai-gpt-4
|
||||
api_key: $OPEN_AI_API_KEY
|
||||
model: gpt-4
|
||||
default: true
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_targets:
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
|
|
@ -27,14 +37,30 @@ prompt_targets:
|
|||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
cluster: api_server
|
||||
name: api_server
|
||||
path: /weather
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
- type: function_resolver
|
||||
name: insurance_claim_details
|
||||
|
||||
- name: system_time
|
||||
description: This function provides the current system time.
|
||||
parameters:
|
||||
- name: timezone
|
||||
description: The city for which the weather forecast is requested.
|
||||
default: US/Pacific
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /current_time
|
||||
method: Get
|
||||
system_prompt: |
|
||||
You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use 12 hour time format
|
||||
- Use AM/PM for time
|
||||
|
||||
- name: insurance_claim_details
|
||||
type: function_resolver
|
||||
description: This function resolver provides insurance claim details for a given policy number.
|
||||
parameters:
|
||||
- name: policy_number
|
||||
|
|
@ -46,8 +72,16 @@ prompt_targets:
|
|||
type: bool
|
||||
required: true
|
||||
endpoint:
|
||||
cluster: api_server
|
||||
name: api_server
|
||||
path: /insurance_claim_details
|
||||
system_prompt: |
|
||||
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use policy number to retrieve insurance claim details
|
||||
ratelimits:
|
||||
- provider: gpt-3.5-turbo
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
|
|
|
|||
|
|
@ -1,78 +1,109 @@
|
|||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0 # or 127.0.0.1
|
||||
port_value: 8080
|
||||
messages: "hugging-face-messages-json" # Defines how Arch should parse the content from application/json or text/pain Content-type in the http request
|
||||
address: 0.0.0.0 # or 127.0.0.1
|
||||
port: 10000
|
||||
# Defines how Arch should parse the content from application/json or text/pain Content-type in the http request
|
||||
message_format: huggingface
|
||||
common_tls_context: # If you configure port 443, you'll need to update the listener with your TLS certificates
|
||||
tls_certificates:
|
||||
- certificate_chain:
|
||||
filename: "/etc/arch/certs/cert.pem"
|
||||
filename: "/etc/certs/cert.pem"
|
||||
private_key:
|
||||
filename: "/etc/arch/certs/key.pem"
|
||||
filename: "/etc/certs/key.pem"
|
||||
|
||||
system_prompts:
|
||||
- name: "network_assistant"
|
||||
content: |
|
||||
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
endpoints:
|
||||
app_server:
|
||||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: "127.0.0.1:80"
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 500ms
|
||||
# max time to wait for a response
|
||||
timeout: 10000ms
|
||||
|
||||
llm_providers: #Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
mistral_local:
|
||||
endpoint: "127.0.0.1:8001"
|
||||
|
||||
error_target:
|
||||
endpoint: "error_target_1"
|
||||
|
||||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: "OpenAI"
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
rate_limit:
|
||||
rate_limits:
|
||||
selector: #optional headers, to add rate limiting based on http headers like JWT tokens or API keys
|
||||
http-header:
|
||||
http_header:
|
||||
name: "Authorization"
|
||||
value: "" # Empty value means each separate value has a separate limit
|
||||
limit:
|
||||
tokens: 100000 # Tokens per unit
|
||||
tokens: 100000 # Tokens per unit
|
||||
unit: "minute"
|
||||
- name: "Mistral"
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: "mistral-7B"
|
||||
|
||||
prompt_endpoints: #Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
- "http://127.0.0.2" #assumes port 8000, unless port is specified with :5000
|
||||
- "http://127.0.0.1:5000"
|
||||
- name: "Mistral8x7b"
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: "mistral-8x7b"
|
||||
|
||||
- name: "MistralLocal7b"
|
||||
model: "mistral-7b-instruct"
|
||||
endpoint: "mistral_local"
|
||||
|
||||
# provides a way to override default settings for the arch system
|
||||
overrides:
|
||||
# By default Arch uses an NLI + embedding approach to match an incomming prompt to a prompt target.
|
||||
# The intent matching threshold is kept at 0.80, you can overide this behavior if you would like
|
||||
prompt_target_intent_matching_threshold: 0.60
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
|
||||
|
||||
prompt_guards:
|
||||
input_guard:
|
||||
- name: "jailbreak"
|
||||
on_exception:
|
||||
forward_to_error_target: true
|
||||
- name: "toxicity"
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: "information_extraction"
|
||||
type: "default"
|
||||
description: "This prompt handles all scenarios that are question and answer in nature. Like summarization, information extraction, etc."
|
||||
path: "/agent/summary"
|
||||
auto-llm-dispatch-on-response: true #Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
|
||||
|
||||
- name: "reboot_network_device"
|
||||
path: "/agent/action"
|
||||
description: "Helps network operators perform device operations like rebooting a device."
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: "/agent/action"
|
||||
parameters:
|
||||
- name: "device_id"
|
||||
type: "string" # additional type options include: integer | float | list | dictionary | set
|
||||
# additional type options include: int | float | bool | string | list | dict
|
||||
type: "string"
|
||||
description: "Identifier of the network device to reboot."
|
||||
default_value: ""
|
||||
required: true
|
||||
- name: "confirmation"
|
||||
type: "integer" # additional type options include: integer | float | list | dictionary | set
|
||||
type: "string"
|
||||
description: "Confirmation flag to proceed with reboot."
|
||||
required: true
|
||||
default: "no"
|
||||
enum: [yes, no]
|
||||
|
||||
- name: "information_extraction"
|
||||
default: true
|
||||
description: "This prompt handles all scenarios that are question and answer in nature. Like summarization, information extraction, etc."
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: "/agent/summary"
|
||||
method: Post
|
||||
# Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
|
||||
auto_llm_dispatch_on_response: true
|
||||
# override system prompt for this prompt target
|
||||
system_prompt: |
|
||||
You are a helpful information extraction assistant. Use the information that is provided to you.
|
||||
|
||||
error_target:
|
||||
name: "error_handler"
|
||||
path: "/errors"
|
||||
endpoint:
|
||||
name: error_target_1
|
||||
path: /error
|
||||
|
||||
tracing: 100 #sampling rate. Note by default Arch works on OpenTelemetry compatible tracing.
|
||||
|
||||
intent-detection-threshold-override: 0.60 # By default Arch uses an NLI + embedding approach to match an incomming prompt to a prompt target.
|
||||
# The intent matching threshold is kept at 0.80, you can overide this behavior if you would like
|
||||
|
|
|
|||
10
public_types/Cargo.lock
generated
10
public_types/Cargo.lock
generated
|
|
@ -8,6 +8,15 @@ version = "0.1.13"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
|
|
@ -65,6 +74,7 @@ dependencies = [
|
|||
name = "public_types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"duration-string",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
duration-string = { version = "0.3.0", features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
|
|
|
|||
|
|
@ -151,11 +151,16 @@ pub mod open_ai {
|
|||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use duration_string::DurationString;
|
||||
use serde::{Deserialize, Serialize, Deserializer};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
|
|
@ -7,31 +10,88 @@ pub struct Overrides {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub default_prompt_endpoint: String,
|
||||
pub load_balancing: LoadBalancing,
|
||||
pub timeout_ms: u64,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: HashMap<String, Endpoint>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub tracing: Option<i16>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorTargetDetail {
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub address: String,
|
||||
pub port: u16,
|
||||
pub message_format: MessageFormat,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
impl Default for Listener {
|
||||
fn default() -> Self {
|
||||
Listener {
|
||||
address: "".to_string(),
|
||||
port: 0,
|
||||
message_format: MessageFormat::default(),
|
||||
// connect_timeout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub enum MessageFormat {
|
||||
#[serde(rename = "huggingface")]
|
||||
#[default]
|
||||
Huggingface,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PromptGuards {
|
||||
pub input_guards: InputGuards,
|
||||
pub input_guards: HashMap<GuardType, GuardOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct InputGuards {
|
||||
pub jailbreak: Option<GuardOptions>,
|
||||
pub toxicity: Option<GuardOptions>,
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GuardType {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GuardOptions {
|
||||
pub on_exception_message: Option<String>,
|
||||
pub on_exception: Option<OnExceptionDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OnExceptionDetails {
|
||||
pub forward_to_error_target: Option<bool>,
|
||||
pub error_handler: Option<String>,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimit {
|
||||
pub selector: LlmRatelimitSelector,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimitSelector {
|
||||
pub http_header: Option<RatelimitHeader>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -58,19 +118,11 @@ pub enum TimeUnit {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub struct RatelimitHeader {
|
||||
pub name: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LoadBalancing {
|
||||
#[serde(rename = "round_robin")]
|
||||
RoundRobin,
|
||||
#[serde(rename = "random")]
|
||||
Random,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct EmbeddingProviver {
|
||||
|
|
@ -82,23 +134,19 @@ pub struct EmbeddingProviver {
|
|||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
pub api_key: Option<String>,
|
||||
//TODO: handle env var replacement
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
pub endpoint: Option<EnpointType>,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EnpointType {
|
||||
String(String),
|
||||
Struct(Endpoint),
|
||||
pub stream: Option<bool>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub endpoint: Option<String>,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
// pub timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -114,82 +162,144 @@ pub struct Parameter {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptType {
|
||||
#[serde(rename = "function_resolver")]
|
||||
FunctionResolver,
|
||||
pub struct EndpointDetails {
|
||||
pub name: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<Method>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Method {
|
||||
Get,
|
||||
Post,
|
||||
Put,
|
||||
Delete,
|
||||
}
|
||||
|
||||
impl ToString for Method {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Method::Get => "GET".to_string(),
|
||||
Method::Post => "POST".to_string(),
|
||||
Method::Put => "PUT".to_string(),
|
||||
Method::Delete => "DELETE".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Method {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
match s.to_uppercase().as_str() {
|
||||
"GET" => Ok(Method::Get),
|
||||
"POST" => Ok(Method::Post),
|
||||
"PUT" => Ok(Method::Put),
|
||||
"DELETE" => Ok(Method::Delete),
|
||||
_ => Err(serde::de::Error::custom(format!("Invalid enum variant: {}", s))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub prompt_type: PromptType,
|
||||
pub name: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: String,
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub auto_llm_dispatch_on_response: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
pub const CONFIGURATION: &str = r#"
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
use std::fs;
|
||||
|
||||
llm_providers:
|
||||
- name: "open-ai-gpt-4"
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception_message: Looks like you are curious about my abilities…
|
||||
toxicity:
|
||||
on_exception_message: Looks like you are curious about my abilities…
|
||||
|
||||
prompt_targets:
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
description: Get the weather forecast for a location
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
parameters:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
description: Get the weather forecast for a location
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
parameters:
|
||||
- name: city
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
ratelimits:
|
||||
- provider: open-ai-gpt-4
|
||||
selector:
|
||||
key: x-katanemo-openai-limit-id
|
||||
limit:
|
||||
tokens: 100
|
||||
unit: minute
|
||||
"#;
|
||||
use crate::configuration::GuardType;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
|
||||
let ref_config =
|
||||
fs::read_to_string("../docs/source/_config/prompt-config-full-reference.yml")
|
||||
.expect("reference config file not found");
|
||||
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
assert_eq!(config.version, "0.1-beta");
|
||||
|
||||
let open_ai_provider = config
|
||||
.llm_providers
|
||||
.iter()
|
||||
.find(|p| p.name.to_lowercase() == "openai")
|
||||
.unwrap();
|
||||
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
|
||||
assert_eq!(
|
||||
open_ai_provider.access_key,
|
||||
Some("$OPENAI_API_KEY".to_string())
|
||||
);
|
||||
assert_eq!(open_ai_provider.model, "gpt-4o");
|
||||
assert_eq!(open_ai_provider.default, Some(true));
|
||||
assert_eq!(open_ai_provider.stream, Some(true));
|
||||
|
||||
let prompt_guards = config.prompt_guards.as_ref().unwrap();
|
||||
let input_guards = &prompt_guards.input_guards;
|
||||
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
|
||||
assert_eq!(
|
||||
jailbreak_guard
|
||||
.on_exception
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.forward_to_error_target,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
jailbreak_guard.on_exception.as_ref().unwrap().error_handler,
|
||||
None
|
||||
);
|
||||
|
||||
let prompt_targets = &config.prompt_targets;
|
||||
assert_eq!(prompt_targets.len(), 2);
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "reboot_network_device");
|
||||
assert_eq!(prompt_target.default, None);
|
||||
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "information_extraction")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "information_extraction");
|
||||
assert_eq!(prompt_target.default, Some(true));
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().name,
|
||||
"app_server".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/agent/summary".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().method.as_ref().unwrap().to_string(),
|
||||
"POST".to_string()
|
||||
);
|
||||
|
||||
let error_target = config.error_target.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().name,
|
||||
"error_target_1".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/error".to_string())
|
||||
);
|
||||
|
||||
let tracing = config.tracing.as_ref().unwrap();
|
||||
assert_eq!(*tracing, 100);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue