2025-06-10 12:53:27 -07:00
|
|
|
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
|
2024-11-25 19:19:06 -06:00
|
|
|
use serde::{Deserialize, Serialize};
|
2025-07-02 14:08:19 -07:00
|
|
|
use serde_with::skip_serializing_none;
|
2024-11-25 19:19:06 -06:00
|
|
|
use std::collections::HashMap;
|
2024-10-03 10:57:01 -07:00
|
|
|
use std::fmt::Display;
|
2024-07-16 14:50:32 -07:00
|
|
|
|
2024-12-20 13:25:01 -08:00
|
|
|
use crate::api::open_ai::{
|
|
|
|
|
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
|
|
|
|
};
|
|
|
|
|
|
2025-05-19 09:59:22 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct Routing {
|
2025-07-08 00:33:40 -07:00
|
|
|
pub llm_provider: Option<String>,
|
|
|
|
|
pub model: Option<String>,
|
2025-05-19 09:59:22 -07:00
|
|
|
}
|
|
|
|
|
|
2024-12-06 17:25:42 -08:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct Configuration {
|
|
|
|
|
pub version: String,
|
|
|
|
|
pub endpoints: Option<HashMap<String, Endpoint>>,
|
|
|
|
|
pub llm_providers: Vec<LlmProvider>,
|
|
|
|
|
pub overrides: Option<Overrides>,
|
|
|
|
|
pub system_prompt: Option<String>,
|
|
|
|
|
pub prompt_guards: Option<PromptGuards>,
|
|
|
|
|
pub prompt_targets: Option<Vec<PromptTarget>>,
|
|
|
|
|
pub error_target: Option<ErrorTargetDetail>,
|
|
|
|
|
pub ratelimits: Option<Vec<Ratelimit>>,
|
|
|
|
|
pub tracing: Option<Tracing>,
|
|
|
|
|
pub mode: Option<GatewayMode>,
|
2025-05-19 09:59:22 -07:00
|
|
|
pub routing: Option<Routing>,
|
2024-12-06 17:25:42 -08:00
|
|
|
}
|
|
|
|
|
|
2024-09-17 22:37:58 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
|
|
|
pub struct Overrides {
|
2024-09-23 12:07:31 -07:00
|
|
|
pub prompt_target_intent_matching_threshold: Option<f64>,
|
2025-02-07 19:14:15 -08:00
|
|
|
pub optimize_context_window: Option<bool>,
|
2025-03-19 15:21:34 -07:00
|
|
|
pub use_agent_orchestrator: Option<bool>,
|
2024-09-17 22:37:58 -07:00
|
|
|
}
|
|
|
|
|
|
2024-10-08 16:24:08 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
|
|
|
pub struct Tracing {
|
|
|
|
|
pub sampling_rate: Option<f64>,
|
2024-11-07 22:11:00 -06:00
|
|
|
pub trace_arch_internal: Option<bool>,
|
2024-10-08 16:24:08 -07:00
|
|
|
}
|
|
|
|
|
|
2024-10-16 14:20:26 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
2024-10-09 15:47:32 -07:00
|
|
|
pub enum GatewayMode {
|
|
|
|
|
#[serde(rename = "llm")]
|
|
|
|
|
Llm,
|
2024-10-16 14:20:26 -07:00
|
|
|
#[default]
|
2024-10-09 15:47:32 -07:00
|
|
|
#[serde(rename = "prompt")]
|
|
|
|
|
Prompt,
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-30 17:49:05 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct ErrorTargetDetail {
|
|
|
|
|
pub endpoint: Option<EndpointDetails>,
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-23 12:07:31 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
2024-09-30 17:49:05 -07:00
|
|
|
pub struct PromptGuards {
|
|
|
|
|
pub input_guards: HashMap<GuardType, GuardOptions>,
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-03 10:57:01 -07:00
|
|
|
impl PromptGuards {
|
|
|
|
|
pub fn jailbreak_on_exception_message(&self) -> Option<&str> {
|
|
|
|
|
self.input_guards
|
|
|
|
|
.get(&GuardType::Jailbreak)?
|
|
|
|
|
.on_exception
|
|
|
|
|
.as_ref()?
|
|
|
|
|
.message
|
|
|
|
|
.as_ref()?
|
|
|
|
|
.as_str()
|
|
|
|
|
.into()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-30 17:49:05 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
|
|
|
|
pub enum GuardType {
|
|
|
|
|
#[serde(rename = "jailbreak")]
|
|
|
|
|
Jailbreak,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
2024-09-23 12:07:31 -07:00
|
|
|
pub struct GuardOptions {
|
2024-09-30 17:49:05 -07:00
|
|
|
pub on_exception: Option<OnExceptionDetails>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct OnExceptionDetails {
|
|
|
|
|
pub forward_to_error_target: Option<bool>,
|
|
|
|
|
pub error_handler: Option<String>,
|
|
|
|
|
pub message: Option<String>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct LlmRatelimit {
|
|
|
|
|
pub selector: LlmRatelimitSelector,
|
|
|
|
|
pub limit: Limit,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct LlmRatelimitSelector {
|
|
|
|
|
pub http_header: Option<RatelimitHeader>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
|
|
|
|
pub struct Header {
|
|
|
|
|
pub key: String,
|
|
|
|
|
pub value: Option<String>,
|
2024-09-17 10:59:50 -07:00
|
|
|
}
|
|
|
|
|
|
2024-08-07 14:15:26 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct Ratelimit {
|
2024-10-03 10:57:01 -07:00
|
|
|
pub model: String,
|
2024-08-28 11:11:05 -07:00
|
|
|
pub selector: Header,
|
|
|
|
|
pub limit: Limit,
|
2024-08-07 14:15:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct Limit {
|
2024-08-28 11:11:05 -07:00
|
|
|
pub tokens: u32,
|
|
|
|
|
pub unit: TimeUnit,
|
2024-08-07 14:15:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub enum TimeUnit {
|
2024-08-28 11:11:05 -07:00
|
|
|
#[serde(rename = "second")]
|
|
|
|
|
Second,
|
2024-08-07 14:15:26 -07:00
|
|
|
#[serde(rename = "minute")]
|
|
|
|
|
Minute,
|
2024-08-28 11:11:05 -07:00
|
|
|
#[serde(rename = "hour")]
|
|
|
|
|
Hour,
|
2024-08-07 14:15:26 -07:00
|
|
|
}
|
|
|
|
|
|
2024-08-28 11:11:05 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
2024-09-30 17:49:05 -07:00
|
|
|
pub struct RatelimitHeader {
|
|
|
|
|
pub name: String,
|
2024-08-28 11:11:05 -07:00
|
|
|
pub value: Option<String>,
|
2024-07-16 14:50:32 -07:00
|
|
|
}
|
|
|
|
|
|
2024-07-31 14:05:52 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
2024-07-16 14:50:32 -07:00
|
|
|
//TODO: use enum for model, but if there is a new model, we need to update the code
|
|
|
|
|
pub struct EmbeddingProviver {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub model: String,
|
|
|
|
|
}
|
|
|
|
|
|
2025-05-22 22:55:46 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
2025-01-17 18:25:55 -08:00
|
|
|
pub enum LlmProviderType {
|
2025-05-30 17:12:52 -07:00
|
|
|
#[serde(rename = "arch")]
|
|
|
|
|
Arch,
|
2025-05-22 22:55:46 -07:00
|
|
|
#[serde(rename = "claude")]
|
|
|
|
|
Claude,
|
|
|
|
|
#[serde(rename = "deepseek")]
|
|
|
|
|
Deepseek,
|
|
|
|
|
#[serde(rename = "groq")]
|
|
|
|
|
Groq,
|
2025-01-17 18:25:55 -08:00
|
|
|
#[serde(rename = "mistral")]
|
|
|
|
|
Mistral,
|
2025-05-22 22:55:46 -07:00
|
|
|
#[serde(rename = "openai")]
|
|
|
|
|
OpenAI,
|
2025-06-11 15:15:00 -07:00
|
|
|
#[serde(rename = "gemini")]
|
|
|
|
|
Gemini,
|
2025-01-17 18:25:55 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Display for LlmProviderType {
|
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
match self {
|
2025-05-30 17:12:52 -07:00
|
|
|
LlmProviderType::Arch => write!(f, "arch"),
|
2025-05-22 22:55:46 -07:00
|
|
|
LlmProviderType::Claude => write!(f, "claude"),
|
|
|
|
|
LlmProviderType::Deepseek => write!(f, "deepseek"),
|
|
|
|
|
LlmProviderType::Groq => write!(f, "groq"),
|
2025-06-11 15:15:00 -07:00
|
|
|
LlmProviderType::Gemini => write!(f, "gemini"),
|
2025-01-17 18:25:55 -08:00
|
|
|
LlmProviderType::Mistral => write!(f, "mistral"),
|
2025-05-22 22:55:46 -07:00
|
|
|
LlmProviderType::OpenAI => write!(f, "openai"),
|
2025-01-17 18:25:55 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-07-02 14:08:19 -07:00
|
|
|
#[skip_serializing_none]
|
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
|
|
|
pub struct ModelUsagePreference {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub model: String,
|
|
|
|
|
pub usage: Option<String>,
|
|
|
|
|
}
|
|
|
|
|
|
2025-05-30 17:40:46 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct LlmRoute {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub description: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl From<&LlmProvider> for LlmRoute {
|
|
|
|
|
fn from(provider: &LlmProvider) -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
name: provider.name.to_string(),
|
|
|
|
|
description: provider
|
|
|
|
|
.usage
|
|
|
|
|
.as_ref()
|
|
|
|
|
.cloned()
|
|
|
|
|
.unwrap_or_else(|| "No description available".to_string()),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-16 14:50:32 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
//TODO: use enum for model, but if there is a new model, we need to update the code
|
|
|
|
|
pub struct LlmProvider {
|
|
|
|
|
pub name: String,
|
2025-01-17 18:25:55 -08:00
|
|
|
pub provider_interface: LlmProviderType,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub access_key: Option<String>,
|
2025-03-19 15:21:34 -07:00
|
|
|
pub model: Option<String>,
|
2024-08-06 23:40:06 -07:00
|
|
|
pub default: Option<bool>,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub stream: Option<bool>,
|
2025-01-17 18:25:55 -08:00
|
|
|
pub endpoint: Option<String>,
|
|
|
|
|
pub port: Option<u16>,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub rate_limits: Option<LlmRatelimit>,
|
2025-05-19 09:59:22 -07:00
|
|
|
pub usage: Option<String>,
|
2024-07-16 14:50:32 -07:00
|
|
|
}
|
|
|
|
|
|
2025-06-10 12:53:27 -07:00
|
|
|
pub trait IntoModels {
|
|
|
|
|
fn into_models(self) -> Models;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl IntoModels for Vec<LlmProvider> {
|
|
|
|
|
fn into_models(self) -> Models {
|
|
|
|
|
let data = self
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|provider| ModelDetail {
|
|
|
|
|
id: provider.name.clone(),
|
|
|
|
|
object: "model".to_string(),
|
|
|
|
|
created: 0,
|
|
|
|
|
owned_by: "system".to_string(),
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
|
|
Models {
|
|
|
|
|
object: ModelObject::List,
|
|
|
|
|
data,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-05-22 22:55:46 -07:00
|
|
|
impl Default for LlmProvider {
|
|
|
|
|
fn default() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
name: "openai".to_string(),
|
|
|
|
|
provider_interface: LlmProviderType::OpenAI,
|
|
|
|
|
access_key: None,
|
|
|
|
|
model: None,
|
|
|
|
|
default: Some(true),
|
|
|
|
|
stream: Some(false),
|
|
|
|
|
endpoint: None,
|
|
|
|
|
port: None,
|
|
|
|
|
rate_limits: None,
|
|
|
|
|
usage: None,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-03 10:57:01 -07:00
|
|
|
impl Display for LlmProvider {
|
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
write!(f, "{}", self.name)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-30 16:23:23 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct Endpoint {
|
2024-09-30 17:49:05 -07:00
|
|
|
pub endpoint: Option<String>,
|
2024-07-30 16:23:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
2024-09-10 14:24:46 -07:00
|
|
|
pub struct Parameter {
|
2024-07-30 16:23:23 -07:00
|
|
|
pub name: String,
|
2024-09-10 14:24:46 -07:00
|
|
|
#[serde(rename = "type")]
|
|
|
|
|
pub parameter_type: Option<String>,
|
|
|
|
|
pub description: String,
|
2024-07-30 16:23:23 -07:00
|
|
|
pub required: Option<bool>,
|
2024-09-17 12:03:21 -07:00
|
|
|
#[serde(rename = "enum")]
|
|
|
|
|
pub enum_values: Option<Vec<String>>,
|
2024-09-20 09:02:24 -07:00
|
|
|
pub default: Option<String>,
|
2024-12-06 14:37:33 -08:00
|
|
|
pub in_path: Option<bool>,
|
2024-12-20 13:25:01 -08:00
|
|
|
pub format: Option<String>,
|
2024-12-06 14:37:33 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
|
|
|
|
pub enum HttpMethod {
|
|
|
|
|
#[default]
|
|
|
|
|
#[serde(rename = "GET")]
|
|
|
|
|
Get,
|
|
|
|
|
#[serde(rename = "POST")]
|
|
|
|
|
Post,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Display for HttpMethod {
|
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
match self {
|
|
|
|
|
HttpMethod::Get => write!(f, "GET"),
|
|
|
|
|
HttpMethod::Post => write!(f, "POST"),
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-09-10 14:24:46 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
2024-09-30 17:49:05 -07:00
|
|
|
pub struct EndpointDetails {
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub path: Option<String>,
|
2024-12-06 14:37:33 -08:00
|
|
|
#[serde(rename = "http_method")]
|
|
|
|
|
pub method: Option<HttpMethod>,
|
2025-02-06 11:48:09 -08:00
|
|
|
pub http_headers: Option<HashMap<String, String>>,
|
2024-07-30 16:23:23 -07:00
|
|
|
}
|
|
|
|
|
|
2024-07-16 14:50:32 -07:00
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
pub struct PromptTarget {
|
|
|
|
|
pub name: String,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub default: Option<bool>,
|
2024-09-16 19:20:07 -07:00
|
|
|
pub description: String,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub endpoint: Option<EndpointDetails>,
|
2024-09-10 14:24:46 -07:00
|
|
|
pub parameters: Option<Vec<Parameter>>,
|
2024-07-30 16:23:23 -07:00
|
|
|
pub system_prompt: Option<String>,
|
2024-09-30 17:49:05 -07:00
|
|
|
pub auto_llm_dispatch_on_response: Option<bool>,
|
2024-07-16 14:50:32 -07:00
|
|
|
}
|
|
|
|
|
|
2024-12-20 13:25:01 -08:00
|
|
|
// convert PromptTarget to ChatCompletionTool
|
|
|
|
|
impl From<&PromptTarget> for ChatCompletionTool {
|
|
|
|
|
fn from(val: &PromptTarget) -> Self {
|
|
|
|
|
let properties: HashMap<String, FunctionParameter> = match val.parameters {
|
|
|
|
|
Some(ref entities) => {
|
|
|
|
|
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
|
|
|
|
for entity in entities.iter() {
|
|
|
|
|
let param = FunctionParameter {
|
|
|
|
|
parameter_type: ParameterType::from(
|
|
|
|
|
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
|
|
|
|
),
|
|
|
|
|
description: entity.description.clone(),
|
|
|
|
|
required: entity.required,
|
|
|
|
|
enum_values: entity.enum_values.clone(),
|
|
|
|
|
default: entity.default.clone(),
|
|
|
|
|
format: entity.format.clone(),
|
|
|
|
|
};
|
|
|
|
|
properties.insert(entity.name.clone(), param);
|
|
|
|
|
}
|
|
|
|
|
properties
|
|
|
|
|
}
|
|
|
|
|
None => HashMap::new(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
ChatCompletionTool {
|
|
|
|
|
tool_type: crate::api::open_ai::ToolType::Function,
|
|
|
|
|
function: FunctionDefinition {
|
|
|
|
|
name: val.name.clone(),
|
|
|
|
|
description: val.description.clone(),
|
|
|
|
|
parameters: FunctionParameters { properties },
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-16 14:50:32 -07:00
|
|
|
#[cfg(test)]
|
|
|
|
|
mod test {
|
2024-12-20 13:25:01 -08:00
|
|
|
use pretty_assertions::assert_eq;
|
2024-09-30 17:49:05 -07:00
|
|
|
use std::fs;
|
|
|
|
|
|
2024-12-20 13:25:01 -08:00
|
|
|
use crate::{api::open_ai::ToolType, configuration::GuardType};
|
2024-07-16 14:50:32 -07:00
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn test_deserialize_configuration() {
|
2024-10-16 14:20:26 -07:00
|
|
|
let ref_config = fs::read_to_string(
|
|
|
|
|
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
|
|
|
|
)
|
|
|
|
|
.expect("reference config file not found");
|
2024-09-30 17:49:05 -07:00
|
|
|
|
|
|
|
|
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
2024-10-08 13:18:34 -07:00
|
|
|
assert_eq!(config.version, "v0.1");
|
2024-09-30 17:49:05 -07:00
|
|
|
|
|
|
|
|
let prompt_guards = config.prompt_guards.as_ref().unwrap();
|
|
|
|
|
let input_guards = &prompt_guards.input_guards;
|
|
|
|
|
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
|
|
|
|
|
assert_eq!(
|
|
|
|
|
jailbreak_guard
|
|
|
|
|
.on_exception
|
|
|
|
|
.as_ref()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.forward_to_error_target,
|
|
|
|
|
None
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
jailbreak_guard.on_exception.as_ref().unwrap().error_handler,
|
|
|
|
|
None
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let prompt_targets = &config.prompt_targets;
|
2024-10-28 20:05:06 -04:00
|
|
|
assert_eq!(prompt_targets.as_ref().unwrap().len(), 2);
|
2024-09-30 17:49:05 -07:00
|
|
|
let prompt_target = prompt_targets
|
2024-10-28 20:05:06 -04:00
|
|
|
.as_ref()
|
|
|
|
|
.unwrap()
|
2024-09-30 17:49:05 -07:00
|
|
|
.iter()
|
|
|
|
|
.find(|p| p.name == "reboot_network_device")
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert_eq!(prompt_target.name, "reboot_network_device");
|
|
|
|
|
assert_eq!(prompt_target.default, None);
|
|
|
|
|
|
|
|
|
|
let prompt_target = prompt_targets
|
2024-10-28 20:05:06 -04:00
|
|
|
.as_ref()
|
|
|
|
|
.unwrap()
|
2024-09-30 17:49:05 -07:00
|
|
|
.iter()
|
|
|
|
|
.find(|p| p.name == "information_extraction")
|
|
|
|
|
.unwrap();
|
|
|
|
|
assert_eq!(prompt_target.name, "information_extraction");
|
|
|
|
|
assert_eq!(prompt_target.default, Some(true));
|
|
|
|
|
assert_eq!(
|
|
|
|
|
prompt_target.endpoint.as_ref().unwrap().name,
|
|
|
|
|
"app_server".to_string()
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
prompt_target.endpoint.as_ref().unwrap().path,
|
|
|
|
|
Some("/agent/summary".to_string())
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let tracing = config.tracing.as_ref().unwrap();
|
2024-10-08 16:24:08 -07:00
|
|
|
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
|
2024-10-09 15:47:32 -07:00
|
|
|
|
2024-10-16 14:20:26 -07:00
|
|
|
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
|
2024-10-09 15:47:32 -07:00
|
|
|
assert_eq!(*mode, super::GatewayMode::Prompt);
|
2024-07-16 14:50:32 -07:00
|
|
|
}
|
2024-12-20 13:25:01 -08:00
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn test_tool_conversion() {
|
|
|
|
|
let ref_config = fs::read_to_string(
|
|
|
|
|
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
|
|
|
|
)
|
|
|
|
|
.expect("reference config file not found");
|
|
|
|
|
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
|
|
|
|
let prompt_targets = &config.prompt_targets;
|
|
|
|
|
let prompt_target = prompt_targets
|
|
|
|
|
.as_ref()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.iter()
|
|
|
|
|
.find(|p| p.name == "reboot_network_device")
|
|
|
|
|
.unwrap();
|
|
|
|
|
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
|
|
|
|
|
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
|
|
|
|
|
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool.function.description,
|
|
|
|
|
"Reboot a specific network device"
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool
|
|
|
|
|
.function
|
|
|
|
|
.parameters
|
|
|
|
|
.properties
|
|
|
|
|
.contains_key("device_id"),
|
|
|
|
|
true
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool
|
|
|
|
|
.function
|
|
|
|
|
.parameters
|
|
|
|
|
.properties
|
|
|
|
|
.get("device_id")
|
|
|
|
|
.unwrap()
|
|
|
|
|
.parameter_type,
|
|
|
|
|
crate::api::open_ai::ParameterType::String
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool
|
|
|
|
|
.function
|
|
|
|
|
.parameters
|
|
|
|
|
.properties
|
|
|
|
|
.get("device_id")
|
|
|
|
|
.unwrap()
|
|
|
|
|
.description,
|
|
|
|
|
"Identifier of the network device to reboot.".to_string()
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool
|
|
|
|
|
.function
|
|
|
|
|
.parameters
|
|
|
|
|
.properties
|
|
|
|
|
.get("device_id")
|
|
|
|
|
.unwrap()
|
|
|
|
|
.required,
|
|
|
|
|
Some(true)
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
chat_completion_tool
|
|
|
|
|
.function
|
|
|
|
|
.parameters
|
|
|
|
|
.properties
|
|
|
|
|
.get("confirmation")
|
|
|
|
|
.unwrap()
|
|
|
|
|
.parameter_type,
|
|
|
|
|
crate::api::open_ai::ParameterType::Bool
|
|
|
|
|
);
|
|
|
|
|
}
|
2024-09-25 13:29:20 -07:00
|
|
|
}
|