don't compute embeddings for names and other fixes see description (#126)

* serialize tools - 2

* fix int tests

* fix int test

* fix unit tests
This commit is contained in:
Adil Hafeez 2024-10-05 19:25:16 -07:00 committed by GitHub
parent 0e5ea3d6db
commit 2a747df7c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 125 additions and 86 deletions

View file

@ -69,14 +69,9 @@ static_resources:
clusters:
- name: openai
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
@ -98,14 +93,9 @@ static_resources:
tls_maximum_protocol_version: TLSv1_3
- name: mistral
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: mistral
endpoints:
@ -124,6 +114,7 @@ static_resources:
- name: model_server
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: model_server
@ -138,6 +129,7 @@ static_resources:
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: mistral_7b_instruct
@ -152,6 +144,7 @@ static_resources:
- name: arch_fc
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch_fc

View file

@ -12,3 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
// pub const ARCH_STATE_HEADER: &str = "x-arch-state";

View file

@ -72,11 +72,6 @@ impl FilterContext {
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.name,
EmbeddingType::Name,
);
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,

View file

@ -65,7 +65,7 @@ pub trait Client: Context {
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
internal_status: status.clone(),
internal_status: status,
}),
}
}

View file

@ -469,6 +469,7 @@ impl StreamContext {
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
metadata: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
@ -686,6 +687,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_string = match serde_json::to_string(&chat_completions_request) {
@ -875,6 +877,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);

View file

@ -254,7 +254,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
@ -276,22 +276,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(102))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.execute_and_expect(ReturnType::None)
.unwrap();
@ -335,31 +319,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
102,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
102
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
@ -599,6 +558,7 @@ fn request_ratelimited() {
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
@ -712,6 +672,7 @@ fn request_not_ratelimited() {
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();