Integrate Arch-Function-Calling-1.5B model (#85)

* add arch support

* add missing file

* e2e tests

* delete old files and fix response

* fmt
This commit is contained in:
Adil Hafeez 2024-09-25 23:30:50 -07:00 committed by GitHub
parent 9ea6bb0d73
commit 3511798fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 203 additions and 427 deletions

View file

@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::{
open_ai::{ChatCompletionsResponse, Choice, Message, Usage},
BoltFCToolsCall, IntOrString, ToolCallDetail,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
@ -403,18 +402,6 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -424,7 +411,18 @@ fn request_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
@ -519,18 +517,6 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -540,7 +526,18 @@ fn request_not_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],