diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 6daa29e8..0fe980dd 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -99,6 +99,8 @@ properties: type: string in_path: type: boolean + format: + type: string additionalProperties: false required: - name diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index b4d4b999..96c3a955 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -211,7 +211,7 @@ static_resources: domains: - "*" routes: - {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %} + {% for internal_clustrer in ["arch_fc", "model_server"] %} - match: prefix: "/" headers: @@ -448,7 +448,7 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext sni: api.mistral.ai - {% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %} + {% for internal_clustrer in ["arch_fc", "model_server"] %} - name: {{ internal_clustrer }} connect_timeout: 5s type: STRICT_DNS diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index b72185e0..7b42b139 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -80,6 +80,8 @@ pub struct FunctionParameter { pub enum_values: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, } impl Serialize for FunctionParameter { @@ -96,6 +98,9 @@ impl Serialize for FunctionParameter { if let Some(default) = &self.default { map.serialize_entry("default", default)?; } + if let Some(format) = &self.format { + map.serialize_entry("format", format)?; + } map.end() } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 91982846..e83c1117 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -196,6 +196,7 @@ pub struct Parameter { pub enum_values: Option>, pub default: Option, pub in_path: Option, + pub format: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] @@ -250,6 +251,7 @@ impl From<&PromptTarget> for ChatCompletionTool { required: entity.required, enum_values: entity.enum_values.clone(), default: entity.default.clone(), + format: entity.format.clone(), }; properties.insert(entity.name.clone(), param); } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 87b661ca..561dbae3 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -1,6 +1,3 @@ -pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli"; -pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; -pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; @@ -8,11 +5,6 @@ pub const TOOL_ROLE: &str = "tool"; pub const ASSISTANT_ROLE: &str = "assistant"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; -pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot"; -pub const ARCH_FC_INTERNAL_HOST: &str = "arch_fc"; -pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination"; -pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings"; -pub const GUARD_INTERNAL_HOST: &str = "guard"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; @@ -24,7 +16,6 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; -pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; pub const HALLUCINATION_TEMPLATE: &str = "It seems I'm missing some information. Could you provide the following details "; diff --git a/crates/common/src/embeddings/create_embedding_request.rs b/crates/common/src/embeddings/create_embedding_request.rs deleted file mode 100644 index 21e52f8a..00000000 --- a/crates/common/src/embeddings/create_embedding_request.rs +++ /dev/null @@ -1,59 +0,0 @@ -/* - * OMF Embeddings - * - * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) - * - * The version of the OpenAPI document: 1.0.0 - * - * Generated by: https://openapi-generator.tech - */ - -use crate::embeddings; -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] -pub struct CreateEmbeddingRequest { - #[serde(rename = "input")] - pub input: Box, - /// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them. - #[serde(rename = "model")] - pub model: String, - /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). - #[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")] - pub encoding_format: Option, - /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. - #[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")] - pub dimensions: Option, - /// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids). - #[serde(rename = "user", skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -impl CreateEmbeddingRequest { - pub fn new( - input: embeddings::CreateEmbeddingRequestInput, - model: String, - ) -> CreateEmbeddingRequest { - CreateEmbeddingRequest { - input: Box::new(input), - model, - encoding_format: None, - dimensions: None, - user: None, - } - } -} -/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). -#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] -pub enum EncodingFormat { - #[serde(rename = "float")] - Float, - #[serde(rename = "base64")] - Base64, -} - -impl Default for EncodingFormat { - fn default() -> EncodingFormat { - Self::Float - } -} diff --git a/crates/common/src/embeddings/create_embedding_request_input.rs b/crates/common/src/embeddings/create_embedding_request_input.rs deleted file mode 100644 index 83195ced..00000000 --- a/crates/common/src/embeddings/create_embedding_request_input.rs +++ /dev/null @@ -1,28 +0,0 @@ -/* - * OMF Embeddings - * - * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) - * - * The version of the OpenAPI document: 1.0.0 - * - * Generated by: https://openapi-generator.tech - */ - -use serde::{Deserialize, Serialize}; - -/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens. -/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum CreateEmbeddingRequestInput { - /// The string that will be turned into an embedding. - String(String), - /// The array of integers that will be turned into an embedding. - Array(Vec), -} - -impl Default for CreateEmbeddingRequestInput { - fn default() -> Self { - Self::String(Default::default()) - } -} diff --git a/crates/common/src/embeddings/create_embedding_response.rs b/crates/common/src/embeddings/create_embedding_response.rs deleted file mode 100644 index 278929e0..00000000 --- a/crates/common/src/embeddings/create_embedding_response.rs +++ /dev/null @@ -1,55 +0,0 @@ -/* - * OMF Embeddings - * - * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) - * - * The version of the OpenAPI document: 1.0.0 - * - * Generated by: https://openapi-generator.tech - */ - -use crate::embeddings; -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] -pub struct CreateEmbeddingResponse { - /// The list of embeddings generated by the model. - #[serde(rename = "data")] - pub data: Vec, - /// The name of the model used to generate the embedding. - #[serde(rename = "model")] - pub model: String, - /// The object type, which is always \"list\". - #[serde(rename = "object")] - pub object: Object, - #[serde(rename = "usage")] - pub usage: Box, -} - -impl CreateEmbeddingResponse { - pub fn new( - data: Vec, - model: String, - object: Object, - usage: embeddings::CreateEmbeddingResponseUsage, - ) -> CreateEmbeddingResponse { - CreateEmbeddingResponse { - data, - model, - object, - usage: Box::new(usage), - } - } -} -/// The object type, which is always \"list\". -#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] -pub enum Object { - #[serde(rename = "list")] - List, -} - -impl Default for Object { - fn default() -> Object { - Self::List - } -} diff --git a/crates/common/src/embeddings/create_embedding_response_usage.rs b/crates/common/src/embeddings/create_embedding_response_usage.rs deleted file mode 100644 index 2a4730ca..00000000 --- a/crates/common/src/embeddings/create_embedding_response_usage.rs +++ /dev/null @@ -1,32 +0,0 @@ -/* - * OMF Embeddings - * - * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) - * - * The version of the OpenAPI document: 1.0.0 - * - * Generated by: https://openapi-generator.tech - */ - -use serde::{Deserialize, Serialize}; - -/// CreateEmbeddingResponseUsage : The usage information for the request. -#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] -pub struct CreateEmbeddingResponseUsage { - /// The number of tokens used by the prompt. - #[serde(rename = "prompt_tokens")] - pub prompt_tokens: i32, - /// The total number of tokens used by the request. - #[serde(rename = "total_tokens")] - pub total_tokens: i32, -} - -impl CreateEmbeddingResponseUsage { - /// The usage information for the request. - pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage { - CreateEmbeddingResponseUsage { - prompt_tokens, - total_tokens, - } - } -} diff --git a/crates/common/src/embeddings/embedding.rs b/crates/common/src/embeddings/embedding.rs deleted file mode 100644 index e36db376..00000000 --- a/crates/common/src/embeddings/embedding.rs +++ /dev/null @@ -1,48 +0,0 @@ -/* - * OMF Embeddings - * - * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) - * - * The version of the OpenAPI document: 1.0.0 - * - * Generated by: https://openapi-generator.tech - */ - -use serde::{Deserialize, Serialize}; - -/// Embedding : Represents an embedding vector returned by embedding endpoint. -#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] -pub struct Embedding { - /// The index of the embedding in the list of embeddings. - #[serde(rename = "index")] - pub index: i32, - /// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings). - #[serde(rename = "embedding")] - pub embedding: Vec, - /// The object type, which is always \"embedding\" - #[serde(rename = "object")] - pub object: Object, -} - -impl Embedding { - /// Represents an embedding vector returned by embedding endpoint. - pub fn new(index: i32, embedding: Vec, object: Object) -> Embedding { - Embedding { - index, - embedding, - object, - } - } -} -/// The object type, which is always \"embedding\" -#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] -pub enum Object { - #[serde(rename = "embedding")] - Embedding, -} - -impl Default for Object { - fn default() -> Object { - Self::Embedding - } -} diff --git a/crates/common/src/embeddings/mod.rs b/crates/common/src/embeddings/mod.rs deleted file mode 100644 index d7ef176b..00000000 --- a/crates/common/src/embeddings/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub mod create_embedding_request; -pub use self::create_embedding_request::CreateEmbeddingRequest; -pub mod create_embedding_request_input; -pub use self::create_embedding_request_input::CreateEmbeddingRequestInput; -pub mod create_embedding_response; -pub use self::create_embedding_response::CreateEmbeddingResponse; -pub mod create_embedding_response_usage; -pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage; -pub mod embedding; -pub use self::embedding::Embedding; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index a7c881c6..32549893 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,7 +1,6 @@ pub mod api; pub mod configuration; pub mod consts; -pub mod embeddings; pub mod errors; pub mod http; pub mod llm_providers; diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index ac6009f8..0c3ccf60 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,14 +1,7 @@ -use common::api::hallucination::HallucinationClassificationResponse; use common::api::open_ai::{ ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, }; -use common::api::prompt_guard::PromptGuardResponse; -use common::api::zero_shot::ZeroShotClassificationResponse; use common::configuration::Configuration; -use common::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ @@ -83,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_http_call( Some("arch_internal"), Some(vec![ - ("x-arch-upstream", "guard"), + ("x-arch-upstream", "model_server"), (":method", "POST"), - (":path", "/guard"), - (":authority", "guard"), + (":path", "/function_calling"), ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), + (":authority", "model_server"), ]), None, None, @@ -97,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); - - let prompt_guard_response = PromptGuardResponse { - toxic_prob: None, - toxic_verdict: None, - jailbreak_prob: None, - jailbreak_verdict: None, - }; - let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 1, - 0, - prompt_guard_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&prompt_guard_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "embeddings"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddings"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(2)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - index: 0, - embedding: vec![], - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)), - }; - let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 2, - 0, - embeddings_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embeddings_response_buffer)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "zeroshot"), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", "zeroshot"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(3)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let zero_shot_response = ZeroShotClassificationResponse { - predicted_class: "weather_forecast".to_string(), - predicted_class_score: 0.1, - scores: HashMap::new(), - model: "test-model".to_string(), - }; - let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 3, - 0, - zeroshot_intent_detection_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&zeroshot_intent_detection_buffer)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - (":method", "POST"), - ("x-arch-upstream", "arch_fc"), - (":path", "/v1/chat/completions"), - (":authority", "arch_fc"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "120000"), - ]), - None, - None, - None, - ) - .returning(Some(4)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); } fn setup_filter(module: &mut Tester, config: &str) -> i32 { @@ -248,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { .execute_and_expect(ReturnType::Bool(true)) .unwrap(); - module - .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Info), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "embeddings"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddings"), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(101)) - .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(5000)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - embedding: vec![], - index: 0, - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage { - prompt_tokens: 0, - total_tokens: 0, - }), - }; - let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - filter_context, - 101, - 0, - embedding_response_str.len() as i32, - 0, - ) - .expect_log( - Some(LogLevel::Trace), - Some( - format!( - "filter_context: on_http_call_response called with token_id: {:?}", - 101 - ) - .as_str(), - ), - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embedding_response_str)) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - filter_context } @@ -435,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() { .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call(Some("arch_internal"), None, None, None, None) .returning(Some(4)) @@ -538,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() { completion_tokens: 0, }), choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, + finish_reason: Some("test".to_string()), + index: Some(0), message: Message { role: "system".to_string(), content: None, @@ -564,55 +365,12 @@ fn prompt_gateway_request_to_llm_gateway() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "hallucination"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "hallucination"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - // hallucination should return that parameters were not halliucinated - // prompt: str - // parameters: dict - // model: str - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), @@ -628,14 +386,14 @@ fn prompt_gateway_request_to_llm_gateway() { None, None, ) - .returning(Some(6)) + .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) @@ -652,8 +410,8 @@ fn prompt_gateway_request_to_llm_gateway() { completion_tokens: 0, }), choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, + finish_reason: Some("test".to_string()), + index: Some(0), message: Message { role: "assistant".to_string(), content: Some("hello from fake llm gateway".to_string()), diff --git a/demos/weather_forecast/arch_config.yaml b/demos/weather_forecast/arch_config.yaml index 935be68d..4238352b 100644 --- a/demos/weather_forecast/arch_config.yaml +++ b/demos/weather_forecast/arch_config.yaml @@ -49,6 +49,7 @@ prompt_targets: description: The location to get the weather for required: true type: string + format: city, state - name: days description: the number of days for the request required: true diff --git a/e2e_tests/api_model_server.rest b/e2e_tests/api_model_server.rest index 37f31c4f..d01a9a8c 100644 --- a/e2e_tests/api_model_server.rest +++ b/e2e_tests/api_model_server.rest @@ -28,12 +28,6 @@ Content-Type: application/json "description": "The location to get the weather for", "format": "City, State" }, - "unit": { - "type": "str", - "description": "The unit to return the weather in.", - "enum": ["celsius", "fahrenheit"], - "default": "celsius" - }, "days": { "type": "str", "description": "the number of days for the request." @@ -236,7 +230,6 @@ Content-Type: application/json } - ### archgw to model_server 2 POST {{model_server_endpoint}}/function_calling HTTP/1.1 Content-Type: application/json @@ -292,3 +285,66 @@ Content-Type: application/json ], "stream": false } + + +### archgw to model_server 3 +POST {{model_server_endpoint}}/function_calling HTTP/1.1 +Content-Type: application/json + +{ + "model": "--", + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + }, + { + "role": "assistant", + "content": "Of course, I can help with that. Could you please specify the days you want the weather forecast for?", + "model": "Arch-Function" + }, + { + "role": "user", + "content": "for 2 days please" + } + ], + "tools": [ + { + "id": "weather-112", + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State" + }, + "days": { + "type": "str", + "description": "the number of days for the request" + } + }, + "required": [ + "days", "location" + ] + } + } + }, + { + "type": "function", + "function": { + "name": "default_target", + "description": "This is the default target for all unmatched prompts.", + "parameters": { + "type": "object", + "properties": {} + } + } + } + ], + "stream": true +} diff --git a/e2e_tests/api_prompt_gateway.rest b/e2e_tests/api_prompt_gateway.rest index c4ef844d..537b06ac 100644 --- a/e2e_tests/api_prompt_gateway.rest +++ b/e2e_tests/api_prompt_gateway.rest @@ -73,15 +73,6 @@ Content-Type: application/json { "role": "user", "content": "for next 10 days" - }, - { - "role": "assistant", - "content": "Could you tell me what units you want the weather in? (For example: Celsius or Fahrenheit)", - "model": "Arch-Function-1.5b" - }, - { - "role": "user", - "content": "Fahrenheit" } ] } diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 01b2b80c..7fa4d8b6 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -15,7 +15,7 @@ from common import ( def test_prompt_gateway(stream): expected_tool_call = { "name": "get_current_weather", - "arguments": {"location": "seattle", "days": "10"}, + "arguments": {"location": "seattle, wa", "days": "10"}, } body = { @@ -169,7 +169,7 @@ def test_prompt_gateway_param_gathering(stream): def test_prompt_gateway_param_tool_call(stream): expected_tool_call = { "name": "get_current_weather", - "arguments": {"location": "seattle", "days": "2"}, + "arguments": {"location": "seattle, wa", "days": "2"}, } body = { @@ -181,11 +181,11 @@ def test_prompt_gateway_param_tool_call(stream): { "role": "assistant", "content": "Of course, I can help with that. Could you please specify the days you want the weather forecast for?", - "model": "Arch-Function-1.5B", + "model": "Arch-Function", }, { "role": "user", - "content": "2 days", + "content": "for 2 days please", }, ], "stream": stream, diff --git a/model_server/pyproject.toml b/model_server/pyproject.toml index 9fa447f0..23f8db6a 100644 --- a/model_server/pyproject.toml +++ b/model_server/pyproject.toml @@ -42,3 +42,10 @@ archgw_modelserver = "src.cli:run_server" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +python_files = ["test*.py"] +addopts = ["-v", "-s"] +retries = 2 +retry_delay = 0.5 +cumulative_timing = false