integrate arch with model_server

This commit is contained in:
Adil Hafeez 2024-12-10 15:12:31 -08:00
parent 791ce0a7ed
commit 44872107a8
13 changed files with 240 additions and 1222 deletions

View file

@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest {
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
@ -165,8 +165,8 @@ pub struct Message {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub finish_reason: Option<String>,
pub index: Option<usize>,
pub message: Message,
}
@ -217,8 +217,8 @@ impl ChatCompletionsResponse {
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
index: Some(0),
finish_reason: Some("done".to_string()),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),

View file

@ -2,6 +2,10 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -231,11 +235,46 @@ pub struct PromptTarget {
pub auto_llm_dispatch_on_response: Option<bool>,
}
// convert PromptTarget to ChatCompletionTool
impl Into<ChatCompletionTool> for &PromptTarget {
fn into(self) -> ChatCompletionTool {
let properties: HashMap<String, FunctionParameter> = match self.parameters {
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
ChatCompletionTool {
tool_type: crate::api::open_ai::ToolType::Function,
function: FunctionDefinition {
name: self.name.clone(),
description: self.description.clone(),
parameters: FunctionParameters { properties },
},
}
}
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use std::fs;
use crate::configuration::GuardType;
use crate::{api::open_ai::ToolType, configuration::GuardType};
#[test]
fn test_deserialize_configuration() {
@ -307,4 +346,76 @@ mod test {
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
#[test]
fn test_tool_conversion() {
let ref_config = fs::read_to_string(
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
)
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
let prompt_targets = &config.prompt_targets;
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
assert_eq!(
chat_completion_tool.function.description,
"Reboot a specific network device"
);
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.contains_key("device_id"),
true
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::String
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.description,
"Identifier of the network device to reboot.".to_string()
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.required,
Some(true)
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("confirmation")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::Bool
);
}
}

View file

@ -5,10 +5,10 @@ pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;
pub mod path;
pub mod pii;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;

View file

@ -1,6 +1,9 @@
use std::collections::HashMap;
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
pub fn replace_params_in_path(
path: &str,
params: &HashMap<String, String>,
) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();