Merge branch 'shuguang/main' of https://github.com/katanemo/arch into shuguang/main

This commit is contained in:
cotran 2024-12-11 13:33:45 -08:00
commit 8cfef7bcd4
18 changed files with 98 additions and 518 deletions

View file

@ -99,6 +99,8 @@ properties:
type: string
in_path:
type: boolean
format:
type: string
additionalProperties: false
required:
- name

View file

@ -211,7 +211,7 @@ static_resources:
domains:
- "*"
routes:
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
{% for internal_clustrer in ["arch_fc", "model_server"] %}
- match:
prefix: "/"
headers:
@ -448,7 +448,7 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.mistral.ai
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
{% for internal_clustrer in ["arch_fc", "model_server"] %}
- name: {{ internal_clustrer }}
connect_timeout: 5s
type: STRICT_DNS

View file

@ -80,6 +80,8 @@ pub struct FunctionParameter {
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
}
impl Serialize for FunctionParameter {
@ -96,6 +98,9 @@ impl Serialize for FunctionParameter {
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
if let Some(format) = &self.format {
map.serialize_entry("format", format)?;
}
map.end()
}
}

View file

@ -196,6 +196,7 @@ pub struct Parameter {
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
pub format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
@ -250,6 +251,7 @@ impl From<&PromptTarget> for ChatCompletionTool {
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
format: entity.format.clone(),
};
properties.insert(entity.name.clone(), param);
}

View file

@ -1,6 +1,3 @@
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
@ -8,11 +5,6 @@ pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
pub const ARCH_FC_INTERNAL_HOST: &str = "arch_fc";
pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination";
pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings";
pub const GUARD_INTERNAL_HOST: &str = "guard";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
@ -24,7 +16,6 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const TRACE_PARENT_HEADER: &str = "traceparent";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str =
"It seems I'm missing some information. Could you provide the following details ";

View file

@ -1,59 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
#[serde(rename = "input")]
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
#[serde(rename = "model")]
pub model: String,
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
pub dimensions: Option<i32>,
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl CreateEmbeddingRequest {
pub fn new(
input: embeddings::CreateEmbeddingRequestInput,
model: String,
) -> CreateEmbeddingRequest {
CreateEmbeddingRequest {
input: Box::new(input),
model,
encoding_format: None,
dimensions: None,
user: None,
}
}
}
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum EncodingFormat {
#[serde(rename = "float")]
Float,
#[serde(rename = "base64")]
Base64,
}
impl Default for EncodingFormat {
fn default() -> EncodingFormat {
Self::Float
}
}

View file

@ -1,28 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CreateEmbeddingRequestInput {
/// The string that will be turned into an embedding.
String(String),
/// The array of integers that will be turned into an embedding.
Array(Vec<i32>),
}
impl Default for CreateEmbeddingRequestInput {
fn default() -> Self {
Self::String(Default::default())
}
}

View file

@ -1,55 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
/// The list of embeddings generated by the model.
#[serde(rename = "data")]
pub data: Vec<embeddings::Embedding>,
/// The name of the model used to generate the embedding.
#[serde(rename = "model")]
pub model: String,
/// The object type, which is always \"list\".
#[serde(rename = "object")]
pub object: Object,
#[serde(rename = "usage")]
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
}
impl CreateEmbeddingResponse {
pub fn new(
data: Vec<embeddings::Embedding>,
model: String,
object: Object,
usage: embeddings::CreateEmbeddingResponseUsage,
) -> CreateEmbeddingResponse {
CreateEmbeddingResponse {
data,
model,
object,
usage: Box::new(usage),
}
}
}
/// The object type, which is always \"list\".
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "list")]
List,
}
impl Default for Object {
fn default() -> Object {
Self::List
}
}

View file

@ -1,32 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// CreateEmbeddingResponseUsage : The usage information for the request.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponseUsage {
/// The number of tokens used by the prompt.
#[serde(rename = "prompt_tokens")]
pub prompt_tokens: i32,
/// The total number of tokens used by the request.
#[serde(rename = "total_tokens")]
pub total_tokens: i32,
}
impl CreateEmbeddingResponseUsage {
/// The usage information for the request.
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
CreateEmbeddingResponseUsage {
prompt_tokens,
total_tokens,
}
}
}

View file

@ -1,48 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// Embedding : Represents an embedding vector returned by embedding endpoint.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct Embedding {
/// The index of the embedding in the list of embeddings.
#[serde(rename = "index")]
pub index: i32,
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
#[serde(rename = "embedding")]
pub embedding: Vec<f64>,
/// The object type, which is always \"embedding\"
#[serde(rename = "object")]
pub object: Object,
}
impl Embedding {
/// Represents an embedding vector returned by embedding endpoint.
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
Embedding {
index,
embedding,
object,
}
}
}
/// The object type, which is always \"embedding\"
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "embedding")]
Embedding,
}
impl Default for Object {
fn default() -> Object {
Self::Embedding
}
}

View file

@ -1,10 +0,0 @@
pub mod create_embedding_request;
pub use self::create_embedding_request::CreateEmbeddingRequest;
pub mod create_embedding_request_input;
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
pub mod create_embedding_response;
pub use self::create_embedding_response::CreateEmbeddingResponse;
pub mod create_embedding_response_usage;
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
pub mod embedding;
pub use self::embedding::Embedding;

View file

@ -1,7 +1,6 @@
pub mod api;
pub mod configuration;
pub mod consts;
pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;

View file

@ -1,14 +1,7 @@
use common::api::hallucination::HallucinationClassificationResponse;
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::api::prompt_guard::PromptGuardResponse;
use common::api::zero_shot::ZeroShotClassificationResponse;
use common::configuration::Configuration;
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
@ -83,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "guard"),
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/guard"),
(":authority", "guard"),
(":path", "/function_calling"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
(":authority", "model_server"),
]),
None,
None,
@ -97,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
let prompt_guard_response = PromptGuardResponse {
toxic_prob: None,
toxic_verdict: None,
jailbreak_prob: None,
jailbreak_verdict: None,
};
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
1,
0,
prompt_guard_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "embeddings"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddings"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
index: 0,
embedding: vec![],
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
};
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
embeddings_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "zeroshot"),
(":method", "POST"),
(":path", "/zeroshot"),
(":authority", "zeroshot"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
3,
0,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("x-arch-upstream", "arch_fc"),
(":path", "/v1/chat/completions"),
(":authority", "arch_fc"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "120000"),
]),
None,
None,
None,
)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
@ -248,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "embeddings"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddings"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(5000))
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
embedding: vec![],
index: 0,
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage {
prompt_tokens: 0,
total_tokens: 0,
}),
};
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
101,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Trace),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
101
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
@ -435,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
@ -538,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
@ -564,55 +365,12 @@ fn prompt_gateway_request_to_llm_gateway() {
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "hallucination"),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "hallucination"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
// hallucination should return that parameters were not halliucinated
// prompt: str
// parameters: dict
// model: str
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
@ -628,14 +386,14 @@ fn prompt_gateway_request_to_llm_gateway() {
None,
None,
)
.returning(Some(6))
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
@ -652,8 +410,8 @@ fn prompt_gateway_request_to_llm_gateway() {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),

View file

@ -49,6 +49,7 @@ prompt_targets:
description: The location to get the weather for
required: true
type: string
format: city, state
- name: days
description: the number of days for the request
required: true

View file

@ -28,12 +28,6 @@ Content-Type: application/json
"description": "The location to get the weather for",
"format": "City, State"
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius"
},
"days": {
"type": "str",
"description": "the number of days for the request."
@ -236,7 +230,6 @@ Content-Type: application/json
}
### archgw to model_server 2
POST {{model_server_endpoint}}/function_calling HTTP/1.1
Content-Type: application/json
@ -292,3 +285,66 @@ Content-Type: application/json
],
"stream": false
}
### archgw to model_server 3
POST {{model_server_endpoint}}/function_calling HTTP/1.1
Content-Type: application/json
{
"model": "--",
"messages": [
{
"role": "user",
"content": "how is the weather in seattle"
},
{
"role": "assistant",
"content": "Of course, I can help with that. Could you please specify the days you want the weather forecast for?",
"model": "Arch-Function"
},
{
"role": "user",
"content": "for 2 days please"
}
],
"tools": [
{
"id": "weather-112",
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State"
},
"days": {
"type": "str",
"description": "the number of days for the request"
}
},
"required": [
"days", "location"
]
}
}
},
{
"type": "function",
"function": {
"name": "default_target",
"description": "This is the default target for all unmatched prompts.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
],
"stream": true
}

View file

@ -73,15 +73,6 @@ Content-Type: application/json
{
"role": "user",
"content": "for next 10 days"
},
{
"role": "assistant",
"content": "Could you tell me what units you want the weather in? (For example: Celsius or Fahrenheit)",
"model": "Arch-Function-1.5b"
},
{
"role": "user",
"content": "Fahrenheit"
}
]
}

View file

@ -15,7 +15,7 @@ from common import (
def test_prompt_gateway(stream):
expected_tool_call = {
"name": "get_current_weather",
"arguments": {"location": "seattle", "days": "10"},
"arguments": {"location": "seattle, wa", "days": "10"},
}
body = {
@ -169,7 +169,7 @@ def test_prompt_gateway_param_gathering(stream):
def test_prompt_gateway_param_tool_call(stream):
expected_tool_call = {
"name": "get_current_weather",
"arguments": {"location": "seattle", "days": "2"},
"arguments": {"location": "seattle, wa", "days": "2"},
}
body = {
@ -181,11 +181,11 @@ def test_prompt_gateway_param_tool_call(stream):
{
"role": "assistant",
"content": "Of course, I can help with that. Could you please specify the days you want the weather forecast for?",
"model": "Arch-Function-1.5B",
"model": "Arch-Function",
},
{
"role": "user",
"content": "2 days",
"content": "for 2 days please",
},
],
"stream": stream,

View file

@ -42,3 +42,10 @@ archgw_modelserver = "src.cli:run_server"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
python_files = ["test*.py"]
addopts = ["-v", "-s"]
retries = 2
retry_delay = 0.5
cumulative_timing = false