From 3859e8eb43d73277ec7df51fabc82ed1a6efd49d Mon Sep 17 00:00:00 2001
From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com>
Date: Mon, 9 Dec 2024 15:40:57 -0800
Subject: [PATCH 1/4] Fix bugs
---
model_server/src/core/function_calling.py | 30 +++++++++++++++--------
1 file changed, 20 insertions(+), 10 deletions(-)
diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py
index 57428235..d3fee240 100644
--- a/model_server/src/core/function_calling.py
+++ b/model_server/src/core/function_calling.py
@@ -43,7 +43,11 @@ class ArchIntentConfig:
EXTRA_INSTRUCTION = "Are there any tools can help?"
- GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
+ GENERATION_PARAMS = {
+ "temperature": 0.01,
+ "max_tokens": 1,
+ "stop_token_ids": [151645],
+ }
class ArchIntentHandler(ArchBaseHandler):
@@ -318,6 +322,9 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
for line in content.split("\n"):
+ if not is_valid:
+ break
+
if "" == line:
flag = True
elif "" == line:
@@ -332,7 +339,7 @@ class ArchFunctionHandler(ArchBaseHandler):
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
- return tool_calls, is_valid, error_message
+ break
tool_calls.append(
{
@@ -347,7 +354,7 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
- return {"result": tool_calls, "status": is_valid, "message": "error_message"}
+ return {"result": tool_calls, "status": is_valid, "message": error_message}
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
@@ -374,16 +381,19 @@ class ArchFunctionHandler(ArchBaseHandler):
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
- func_name, func_args = (
- tool_call["function"]["name"],
- tool_call["function"]["arguments"],
- )
+ if not is_valid:
+ break
+
+ func_name = tool_call["function"]["name"]
+ func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
+ invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
- return is_valid, error_message
+ break
+
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
@@ -391,7 +401,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
- return is_valid, invalid_tool_call, error_message
+ break
# Verify the data type of each parameter in the tool calls
for param_name in func_args:
@@ -405,7 +415,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
- return is_valid, invalid_tool_call, error_message
+ break
return {
"status": is_valid,
From 44872107a86c4125c0069c0ecb03fd963e5d935a Mon Sep 17 00:00:00 2001
From: Adil Hafeez
Date: Tue, 10 Dec 2024 15:12:31 -0800
Subject: [PATCH 2/4] integrate arch with model_server
---
arch/envoy.template.yaml | 5 +-
arch/tools/cli/config_generator.py | 1 +
crates/common/src/api/open_ai.rs | 10 +-
crates/common/src/configuration.rs | 113 ++-
crates/common/src/lib.rs | 2 +-
crates/common/src/path.rs | 5 +-
crates/prompt_gateway/src/context.rs | 59 --
crates/prompt_gateway/src/embeddings.rs | 5 -
crates/prompt_gateway/src/filter_context.rs | 210 +----
crates/prompt_gateway/src/http_context.rs | 79 +-
crates/prompt_gateway/src/lib.rs | 1 -
crates/prompt_gateway/src/stream_context.rs | 890 +-------------------
e2e_tests/api_model_server.rest | 82 +-
13 files changed, 240 insertions(+), 1222 deletions(-)
delete mode 100644 crates/prompt_gateway/src/embeddings.rs
diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml
index d5c7b95e..b4d4b999 100644
--- a/arch/envoy.template.yaml
+++ b/arch/envoy.template.yaml
@@ -211,8 +211,7 @@ static_resources:
domains:
- "*"
routes:
-
- {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
+ {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
- match:
prefix: "/"
headers:
@@ -449,7 +448,7 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.mistral.ai
- {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
+ {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
- name: {{ internal_clustrer }}
connect_timeout: 5s
type: STRICT_DNS
diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py
index c32dae16..2f28f537 100644
--- a/arch/tools/cli/config_generator.py
+++ b/arch/tools/cli/config_generator.py
@@ -90,6 +90,7 @@ def validate_and_render_schema():
rendered = template.render(data)
print(ENVOY_CONFIG_FILE_RENDERED)
+ print(rendered)
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered)
diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs
index 20b550ae..e96906fa 100644
--- a/crates/common/src/api/open_ai.rs
+++ b/crates/common/src/api/open_ai.rs
@@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest {
pub metadata: Option>,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
@@ -165,8 +165,8 @@ pub struct Message {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
- pub finish_reason: String,
- pub index: usize,
+ pub finish_reason: Option,
+ pub index: Option,
pub message: Message,
}
@@ -217,8 +217,8 @@ impl ChatCompletionsResponse {
tool_calls: None,
tool_call_id: None,
},
- index: 0,
- finish_reason: "done".to_string(),
+ index: Some(0),
+ finish_reason: Some("done".to_string()),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs
index 0d9bea80..9a56e863 100644
--- a/crates/common/src/configuration.rs
+++ b/crates/common/src/configuration.rs
@@ -2,6 +2,10 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
+use crate::api::open_ai::{
+ ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
+};
+
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@@ -231,11 +235,46 @@ pub struct PromptTarget {
pub auto_llm_dispatch_on_response: Option,
}
+// convert PromptTarget to ChatCompletionTool
+impl Into for &PromptTarget {
+ fn into(self) -> ChatCompletionTool {
+ let properties: HashMap = match self.parameters {
+ Some(ref entities) => {
+ let mut properties: HashMap = HashMap::new();
+ for entity in entities.iter() {
+ let param = FunctionParameter {
+ parameter_type: ParameterType::from(
+ entity.parameter_type.clone().unwrap_or("str".to_string()),
+ ),
+ description: entity.description.clone(),
+ required: entity.required,
+ enum_values: entity.enum_values.clone(),
+ default: entity.default.clone(),
+ };
+ properties.insert(entity.name.clone(), param);
+ }
+ properties
+ }
+ None => HashMap::new(),
+ };
+
+ ChatCompletionTool {
+ tool_type: crate::api::open_ai::ToolType::Function,
+ function: FunctionDefinition {
+ name: self.name.clone(),
+ description: self.description.clone(),
+ parameters: FunctionParameters { properties },
+ },
+ }
+ }
+}
+
#[cfg(test)]
mod test {
+ use pretty_assertions::assert_eq;
use std::fs;
- use crate::configuration::GuardType;
+ use crate::{api::open_ai::ToolType, configuration::GuardType};
#[test]
fn test_deserialize_configuration() {
@@ -307,4 +346,76 @@ mod test {
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
+
+ #[test]
+ fn test_tool_conversion() {
+ let ref_config = fs::read_to_string(
+ "../../docs/source/resources/includes/arch_config_full_reference.yaml",
+ )
+ .expect("reference config file not found");
+ let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
+ let prompt_targets = &config.prompt_targets;
+ let prompt_target = prompt_targets
+ .as_ref()
+ .unwrap()
+ .iter()
+ .find(|p| p.name == "reboot_network_device")
+ .unwrap();
+ let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
+ assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
+ assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
+ assert_eq!(
+ chat_completion_tool.function.description,
+ "Reboot a specific network device"
+ );
+ assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
+ assert_eq!(
+ chat_completion_tool
+ .function
+ .parameters
+ .properties
+ .contains_key("device_id"),
+ true
+ );
+ assert_eq!(
+ chat_completion_tool
+ .function
+ .parameters
+ .properties
+ .get("device_id")
+ .unwrap()
+ .parameter_type,
+ crate::api::open_ai::ParameterType::String
+ );
+ assert_eq!(
+ chat_completion_tool
+ .function
+ .parameters
+ .properties
+ .get("device_id")
+ .unwrap()
+ .description,
+ "Identifier of the network device to reboot.".to_string()
+ );
+ assert_eq!(
+ chat_completion_tool
+ .function
+ .parameters
+ .properties
+ .get("device_id")
+ .unwrap()
+ .required,
+ Some(true)
+ );
+ assert_eq!(
+ chat_completion_tool
+ .function
+ .parameters
+ .properties
+ .get("confirmation")
+ .unwrap()
+ .parameter_type,
+ crate::api::open_ai::ParameterType::Bool
+ );
+ }
}
diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs
index cd5238a3..a7c881c6 100644
--- a/crates/common/src/lib.rs
+++ b/crates/common/src/lib.rs
@@ -5,10 +5,10 @@ pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;
+pub mod path;
pub mod pii;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
-pub mod path;
diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs
index 2b289c9d..17dfe2ce 100644
--- a/crates/common/src/path.rs
+++ b/crates/common/src/path.rs
@@ -1,6 +1,9 @@
use std::collections::HashMap;
-pub fn replace_params_in_path(path: &str, params: &HashMap) -> Result {
+pub fn replace_params_in_path(
+ path: &str,
+ params: &HashMap,
+) -> Result {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs
index 567f41eb..4dba8588 100644
--- a/crates/prompt_gateway/src/context.rs
+++ b/crates/prompt_gateway/src/context.rs
@@ -19,70 +19,11 @@ impl Context for StreamContext {
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
- /*
- state transition
-
- graph LR
-
- on_http_request_body --> prompt received
- prompt received --> get embeddings & arch guard
- arch guard --> get embeddings
- get embeddings --> zeroshot intent
-
- ┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
- │ │ │ │ │ │ │ │
- │ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
- │ │ │ │ │ │ │ │
- └──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
- │ ▲
- │ │
- │ │
- │ ┌────────┴───────┐
- │ │ │
- └───────────►│ arch guard │
- │ │
- └────────────────┘
-
-
- continue from zeroshot intent
-
- graph LR
-
- zeroshot intent --> arch_fc
- zeroshot intent --> default prompt target
- arch_fc --> developer api call & hallucination check
- hallucination check --> parameter gathering & developer api call
- developer api call --> resume request to llm
-
-
- ┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
- │ │ │ │ │ │ │ │
- │ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
- │ │ │ │ │ │ │ │
- └────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
- │ │ ▲
- │ └─────────────┐ │
- │ │ │
- │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
- │ │ │ │ │ │ │ │
- └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
- │ │ │ │ │ │
- └───────────────────────┘ └─────────────────────┘ └───────────────────────┘
-
-
- using https://mermaid-ascii.art/
- */
-
if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
- ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
- ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
- ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
- ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
- ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
} else {
self.send_server_error(
diff --git a/crates/prompt_gateway/src/embeddings.rs b/crates/prompt_gateway/src/embeddings.rs
deleted file mode 100644
index f2883682..00000000
--- a/crates/prompt_gateway/src/embeddings.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
-pub enum EmbeddingType {
- Name,
- Description,
-}
diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs
index 449be126..9780ed7d 100644
--- a/crates/prompt_gateway/src/filter_context.rs
+++ b/crates/prompt_gateway/src/filter_context.rs
@@ -1,35 +1,17 @@
-use crate::embeddings::EmbeddingType;
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
-use common::consts::ARCH_UPSTREAM_HOST_HEADER;
-use common::consts::DEFAULT_EMBEDDING_MODEL;
-use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
-use common::embeddings::{
- CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
-};
-use common::http::CallArgs;
use common::http::Client;
use common::stats::Gauge;
-use common::stats::IncrementingMetric;
-use http::StatusCode;
-use log::{debug, info, trace, warn};
+use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
-use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::rc::Rc;
-use std::time::Duration;
-
-pub type EmbeddingTypeMap = HashMap>;
-pub type EmbeddingsStore = HashMap;
#[derive(Debug)]
-pub struct FilterCallContext {
- pub prompt_target_name: String,
- pub embedding_type: EmbeddingType,
-}
+pub struct FilterCallContext {}
#[derive(Debug)]
pub struct FilterContext {
@@ -40,9 +22,6 @@ pub struct FilterContext {
system_prompt: Rc