mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
Hallucination integration with rust (#122)
This commit is contained in:
parent
43dc2a0a73
commit
b1fa127704
9 changed files with 278 additions and 56 deletions
|
|
@ -5,28 +5,9 @@ services:
|
|||
- "10000:10000"
|
||||
- "19901:9901"
|
||||
volumes:
|
||||
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_config.yaml}:/config/arch_config.yaml
|
||||
- ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml
|
||||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
- ./envoy.template.dev.yaml:/config/envoy.template.yaml
|
||||
- ./target/wasm32-wasi/release/intelligent_prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
|
||||
depends_on:
|
||||
model_server:
|
||||
condition: service_healthy
|
||||
env_file:
|
||||
- stage.env
|
||||
|
||||
model_server:
|
||||
image: model_server:latest
|
||||
ports:
|
||||
- "18081:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
volumes:
|
||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
||||
environment:
|
||||
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
|
||||
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
|
||||
- MODE=${MODE:-cloud}
|
||||
- FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: arch_ingress_http
|
||||
codec_type: HTTP1
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
access_log:
|
||||
|
|
@ -72,11 +72,6 @@ static_resources:
|
|||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
# typed_extension_protocol_options:
|
||||
# envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
# "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
# explicit_http_config:
|
||||
# http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: openai
|
||||
endpoints:
|
||||
|
|
@ -129,7 +124,7 @@ static_resources:
|
|||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 8000
|
||||
port_value: 51000
|
||||
hostname: "model_server"
|
||||
- name: mistral_7b_instruct
|
||||
connect_timeout: 5s
|
||||
|
|
@ -159,7 +154,7 @@ static_resources:
|
|||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 8000
|
||||
port_value: 51000
|
||||
hostname: "arch_fc"
|
||||
{% for _, cluster in arch_clusters.items() %}
|
||||
- name: {{ cluster.name }}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
@ -13,3 +14,4 @@ pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
|||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER,
|
||||
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
|
|
@ -17,12 +18,13 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{
|
||||
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
||||
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
|
||||
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
|
|
@ -37,22 +39,24 @@ use std::num::NonZero;
|
|||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
enum ResponseHandlerType {
|
||||
GetEmbeddings,
|
||||
FunctionResolver,
|
||||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
HallucinationDetect,
|
||||
ArchGuard,
|
||||
DefaultTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target_name: Option<String>,
|
||||
request_body: ChatCompletionsRequest,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
upstream_cluster: Option<String>,
|
||||
upstream_cluster_path: Option<String>,
|
||||
|
|
@ -310,6 +314,69 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn hallucination_classification_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
let hallucination_response: HallucinationClassificationResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
let mut keys_with_low_score: Vec<String> = Vec::new();
|
||||
for (key, value) in &hallucination_response.params_scores {
|
||||
if *value < DEFAULT_HALLUCINATED_THRESHOLD {
|
||||
debug!(
|
||||
"hallucination detected: score for {} : {} is less than threshold {}",
|
||||
key, value, DEFAULT_HALLUCINATED_THRESHOLD
|
||||
);
|
||||
keys_with_low_score.push(key.clone().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if !keys_with_low_score.is_empty() {
|
||||
let response =
|
||||
"It seems I’m missing some information. Could you provide the following details: "
|
||||
.to_string()
|
||||
+ &keys_with_low_score.join(", ")
|
||||
+ " ?";
|
||||
let message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
};
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message,
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
debug!("hallucination response: {:?}", chat_completion_response);
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(
|
||||
serde_json::to_string(&chat_completion_response)
|
||||
.unwrap()
|
||||
.as_bytes(),
|
||||
),
|
||||
);
|
||||
} else {
|
||||
// not a hallucination, resume the flow
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
|
|
@ -565,6 +632,9 @@ impl StreamContext {
|
|||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
|
||||
// TODO CO: pass nli check
|
||||
// If hallucination, pass chat template to check parameters
|
||||
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tool_calls
|
||||
.iter()
|
||||
|
|
@ -581,10 +651,11 @@ impl StreamContext {
|
|||
String::from(ARCH_MESSAGES_KEY),
|
||||
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
|
||||
);
|
||||
|
||||
let tools_call_name = tool_calls[0].function.name.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
callout_context.tool_calls = Some(tool_calls.clone());
|
||||
|
||||
debug!(
|
||||
"prompt_target_name: {}, tool_name(s): {:?}",
|
||||
|
|
@ -592,6 +663,81 @@ impl StreamContext {
|
|||
);
|
||||
debug!("tool_params: {}", tool_params_json_str);
|
||||
|
||||
if model_resp.message.tool_calls.is_some()
|
||||
&& !model_resp.message.tool_calls.as_ref().unwrap().is_empty()
|
||||
{
|
||||
use serde_json::Value;
|
||||
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
|
||||
let tool_params_dict: HashMap<String, String> = match v.as_object() {
|
||||
Some(obj) => obj
|
||||
.iter()
|
||||
.filter_map(|(key, value)| {
|
||||
value
|
||||
.as_str()
|
||||
.map(|str_value| (key.clone(), str_value.to_string()))
|
||||
})
|
||||
.collect(),
|
||||
None => HashMap::new(), // Return an empty HashMap if v is not an object
|
||||
};
|
||||
|
||||
let hallucination_classification_request = HallucinationClassificationRequest {
|
||||
prompt: callout_context.user_message.as_ref().unwrap().clone(),
|
||||
model: String::from(DEFAULT_INTENT_MODEL),
|
||||
parameters: tool_params_dict,
|
||||
};
|
||||
|
||||
let json_data: String =
|
||||
match serde_json::to_string(&hallucination_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
let call_args = CallArgs::new(
|
||||
MODEL_SERVER_NAME,
|
||||
"/hallucination",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect;
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
} else {
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
}
|
||||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.name
|
||||
.clone();
|
||||
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
|
||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||
let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.arguments
|
||||
.clone();
|
||||
tool_params.insert(
|
||||
String::from(ARCH_MESSAGES_KEY),
|
||||
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
|
||||
);
|
||||
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
let call_args = CallArgs::new(
|
||||
|
|
@ -612,8 +758,6 @@ impl StreamContext {
|
|||
callout_context.upstream_cluster_path = Some(path.clone());
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
||||
self.tool_calls = Some(tool_calls.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
|
@ -806,6 +950,7 @@ impl StreamContext {
|
|||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
|
|
@ -1009,6 +1154,7 @@ impl HttpContext for StreamContext {
|
|||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
|
|
@ -1057,6 +1203,7 @@ impl HttpContext for StreamContext {
|
|||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
|
|
@ -1144,7 +1291,13 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.response_tokens += chat_completions_response.usage.completion_tokens;
|
||||
if chat_completions_response.usage.is_some() {
|
||||
self.response_tokens += chat_completions_response
|
||||
.usage
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.completion_tokens;
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
|
|
@ -1239,6 +1392,9 @@ impl Context for StreamContext {
|
|||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::HallucinationDetect => {
|
||||
self.hallucination_classification_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionResolver => {
|
||||
self.function_resolver_handler(body, callout_context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use proxy_wasm_test_framework::types::{
|
|||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::common_types::PromptGuardResponse;
|
||||
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
|
|
@ -534,9 +534,9 @@ fn request_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
},
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
|
|
@ -572,6 +572,38 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -585,14 +617,14 @@ fn request_ratelimited() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
@ -642,9 +674,9 @@ fn request_not_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
},
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
|
|
@ -680,6 +712,43 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -693,14 +762,14 @@ fn request_not_ratelimited() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
|
|||
|
|
@ -49,12 +49,17 @@ def predict(message, state):
|
|||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.debug("raw_response: ", raw_response.text)
|
||||
log.info("raw_response: ", raw_response.text)
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None)
|
||||
response_json = json.loads(raw_response.text)
|
||||
arch_state = None
|
||||
if response_json:
|
||||
metadata = response_json.get('metadata', {})
|
||||
if metadata:
|
||||
arch_state = metadata.get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ llm_providers:
|
|||
- name: open-ai-gpt-4
|
||||
access_key: OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
system_prompt: |
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ async def models():
|
|||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
print(f"Embedding Call Start Time: {time.time()}")
|
||||
if req.model not in transformers:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
start = time.time()
|
||||
|
|
@ -77,7 +78,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
print(f"Embedding Call Complete Time: {time.time()}")
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
|
|
@ -217,10 +218,10 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
)
|
||||
result_score = result["scores"]
|
||||
result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)}
|
||||
logger.info(f"hallucination result: {result_params}")
|
||||
|
||||
return {
|
||||
"params_scores": result_params,
|
||||
"raw_result": result,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ pub mod open_ai {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Usage,
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
|
|
@ -272,6 +272,19 @@ pub struct ZeroShotClassificationResponse {
|
|||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue