mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
Serialize tool calls for Arch FC (#131)
* Serialize tool calls * fix int tests
This commit is contained in:
parent
b43f687b85
commit
96686dc606
10 changed files with 166 additions and 57 deletions
1
arch/Cargo.lock
generated
1
arch/Cargo.lock
generated
|
|
@ -759,6 +759,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_yaml",
|
||||
"serial_test",
|
||||
"sha2",
|
||||
"thiserror",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ tiktoken-rs = "0.5.9"
|
|||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.64"
|
||||
sha2 = "0.10.8"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ COPY public_types /public_types
|
|||
RUN cargo build --release --target wasm32-wasi
|
||||
|
||||
# copy built filter into envoy image
|
||||
FROM envoyproxy/envoy:v1.30-latest as envoy
|
||||
FROM envoyproxy/envoy:v1.31-latest as envoy
|
||||
|
||||
#Build config generator, so that we have a single build image for both Rust and Python
|
||||
FROM python:3-slim as arch
|
||||
|
|
|
|||
|
|
@ -12,4 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
|||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
|
||||
// pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
|
|
@ -15,9 +15,9 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{
|
||||
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
||||
ParameterType, StreamOptions, ToolType,
|
||||
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
|
|
@ -28,6 +28,8 @@ use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
|||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -59,10 +61,16 @@ pub struct StreamContext {
|
|||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
tool_call_response: Option<String>,
|
||||
arch_state: Option<Vec<ArchState>>,
|
||||
request_body_size: usize,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
user_prompt: Option<Message>,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
is_chat_completions_request: bool,
|
||||
chat_completions_request: Option<ChatCompletionsRequest>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
|
|
@ -83,11 +91,17 @@ impl StreamContext {
|
|||
metrics,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
chat_completions_request: None,
|
||||
callouts: HashMap::new(),
|
||||
tool_calls: None,
|
||||
tool_call_response: None,
|
||||
arch_state: None,
|
||||
request_body_size: 0,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
is_chat_completions_request: false,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
|
|
@ -463,13 +477,20 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
// archfc handler needs state so it can expand tool calls
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::to_string(&self.arch_state).unwrap(),
|
||||
);
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: None,
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
|
|
@ -521,10 +542,8 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
debug!("response received for function resolver");
|
||||
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("function_resolver response str: {}", body_str);
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
|
||||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
|
|
@ -559,7 +578,6 @@ impl StreamContext {
|
|||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
|
||||
debug!("tool_call_details: {:?}", tool_calls);
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tool_calls
|
||||
.iter()
|
||||
|
|
@ -581,8 +599,10 @@ impl StreamContext {
|
|||
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
|
||||
debug!("prompt_target_name: {}", prompt_target.name);
|
||||
debug!("tool_name(s): {:?}", tool_names);
|
||||
debug!(
|
||||
"prompt_target_name: {}, tool_name(s): {:?}",
|
||||
prompt_target.name, tool_names
|
||||
);
|
||||
debug!("tool_params: {}", tool_params_json_str);
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
|
|
@ -611,6 +631,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.tool_calls = Some(tool_calls.clone());
|
||||
callout_context.upstream_cluster = Some(endpoint.name);
|
||||
callout_context.upstream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
|
@ -635,9 +656,9 @@ impl StreamContext {
|
|||
} else {
|
||||
warn!("http status code not found in api response");
|
||||
}
|
||||
debug!("response received for function call response");
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
debug!("function_call_response response str: {}", body_str);
|
||||
self.tool_call_response = Some(body_str.clone());
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
|
|
@ -697,10 +718,7 @@ impl StreamContext {
|
|||
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"function_calling sending request to openai: msg {}",
|
||||
json_string
|
||||
);
|
||||
debug!("arch => openai request body: {}", json_string);
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
|
|
@ -725,7 +743,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
|
|
@ -881,7 +899,7 @@ impl StreamContext {
|
|||
};
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending response back to default llm: {}", json_resp);
|
||||
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
|
||||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
}
|
||||
|
|
@ -899,7 +917,7 @@ impl HttpContext for StreamContext {
|
|||
self.delete_content_length_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
self.chat_completions_request =
|
||||
self.is_chat_completions_request =
|
||||
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
|
||||
|
||||
debug!(
|
||||
|
|
@ -922,6 +940,8 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.request_body_size = body_size;
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
|
|
@ -948,6 +968,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
if metadata.contains_key(ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
|
||||
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
|
||||
Some(arch_state)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
self.is_chat_completions_request = true;
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
|
|
@ -958,10 +992,11 @@ impl HttpContext for StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
let user_message = match deserialized_body
|
||||
let last_user_prompt = match deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.clone())
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
|
|
@ -970,17 +1005,24 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
|
||||
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
|
||||
self.chat_completions_request = Some(deserialized_body);
|
||||
|
||||
if !prompt_guard_jailbreak_task {
|
||||
debug!("Missing input guard. Making inline call to retrieve");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
user_message: user_message_str.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
|
|
@ -990,7 +1032,14 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: user_message.clone(),
|
||||
input: self
|
||||
.user_prompt
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
|
|
@ -1032,9 +1081,9 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let call_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
|
|
@ -1057,7 +1106,7 @@ impl HttpContext for StreamContext {
|
|||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !self.chat_completions_request {
|
||||
if !self.is_chat_completions_request {
|
||||
if let Some(body_str) = self
|
||||
.get_http_response_body(0, body_size)
|
||||
.and_then(|bytes| String::from_utf8(bytes).ok())
|
||||
|
|
@ -1067,7 +1116,7 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
if !end_of_stream && !self.streaming_response {
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
|
|
@ -1075,9 +1124,8 @@ impl HttpContext for StreamContext {
|
|||
.get_http_response_body(0, body_size)
|
||||
.expect("cant get response body");
|
||||
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
|
||||
if self.streaming_response {
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
debug!("streaming response");
|
||||
let chat_completions_data = match body_str.split_once("data: ") {
|
||||
Some((_, chat_completions_data)) => chat_completions_data,
|
||||
|
|
@ -1117,13 +1165,14 @@ impl HttpContext for StreamContext {
|
|||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_str(&body_str) {
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"error in non-streaming response: {}\n response was={}",
|
||||
e, body_str
|
||||
e,
|
||||
String::from_utf8(body).unwrap()
|
||||
),
|
||||
None,
|
||||
);
|
||||
|
|
@ -1132,6 +1181,65 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
self.response_tokens += chat_completions_response.usage.completion_tokens;
|
||||
|
||||
if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
// compute sha hash from message history
|
||||
let mut hasher = Sha256::new();
|
||||
let prompts: Vec<String> = self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.map(|msg| msg.content.clone().unwrap())
|
||||
.collect();
|
||||
let prompts_merged = prompts.join("#.#");
|
||||
hasher.update(prompts_merged.clone());
|
||||
let hash_key = hasher.finalize();
|
||||
// conver hash to hex string
|
||||
let hash_key_str = format!("{:x}", hash_key);
|
||||
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
|
||||
|
||||
// create new tool call state
|
||||
let tool_call_state = ToolCallState {
|
||||
key: hash_key_str,
|
||||
message: self.user_prompt.clone(),
|
||||
tool_call: tool_calls[0].function.clone(),
|
||||
tool_response: self.tool_call_response.clone().unwrap(),
|
||||
};
|
||||
|
||||
// push tool call state to arch state
|
||||
self.arch_state
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(ArchState::ToolCall(vec![tool_call_state]));
|
||||
|
||||
let mut data: Value = serde_json::from_slice(&body).unwrap();
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
|
||||
debug!("arch_state: {}", arch_state_str);
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
debug!("arch => user: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
|
|
|
|||
|
|
@ -571,9 +571,6 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -592,14 +589,15 @@ fn request_ratelimited() {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")];
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders))
|
||||
.returning(Some(response_headers_with_200))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
@ -612,10 +610,6 @@ fn request_ratelimited() {
|
|||
None,
|
||||
)
|
||||
.expect_metric_increment("ratelimited_rq", 1)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some("server error occurred: Exceeded Ratelimit: Not allowed"),
|
||||
)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -685,9 +679,6 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -706,15 +697,16 @@ fn request_not_ratelimited() {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")];
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders))
|
||||
.returning(Some(response_headers_with_200))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
|
|||
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
ARCH_STATE_HEADER = 'x-arch-state'
|
||||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ def predict(message, state):
|
|||
|
||||
metadata = None
|
||||
if 'arch_state' in state:
|
||||
metadata = {"x-arch-state": state['arch_state']}
|
||||
metadata = {ARCH_STATE_HEADER: state['arch_state']}
|
||||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
|
||||
|
|
@ -48,11 +49,12 @@ def predict(message, state):
|
|||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.debug("raw_response: ", raw_response.text)
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
|
||||
|
|
|
|||
3
model_server/.vscode/launch.json
vendored
3
model_server/.vscode/launch.json
vendored
|
|
@ -10,6 +10,9 @@
|
|||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["app.main:app","--reload", "--port", "8000"],
|
||||
"env": {
|
||||
"MODE": "local-cpu",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,8 @@ def process_state(arch_state, history: list[Message]):
|
|||
if hist.role == 'user':
|
||||
sha_history.append(hist.content)
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(json.dumps(sha_history).encode())
|
||||
joined_key_str = ('#.#').join(sha_history)
|
||||
sha256_hash.update(joined_key_str.encode())
|
||||
sha_key = sha256_hash.hexdigest()
|
||||
print(f"sha_key: {sha_key}")
|
||||
if sha_key in state_map:
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@ from app.arch_fc.arch_fc import process_state
|
|||
from app.arch_fc.common import ChatMessage, Message
|
||||
# test process_state
|
||||
|
||||
arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]'
|
||||
|
||||
arch_state = '[[{"key":"02ea8ec721b130dc30ec836b79ec675116cd5889bca7d63720bc64baed994fc1","message":{"role":"user","content":"how is the weather in new york?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"new york"}},"tool_response":"{\\"city\\":\\"new york\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":68,\\"max\\":79}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":70,\\"max\\":76}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":71,\\"max\\":84}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":61,\\"max\\":79}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":86,\\"max\\":91}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":85,\\"max\\":90}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":72,\\"max\\":89}}],\\"unit\\":\\"F\\"}"}],[{"key":"566b9a2197cba89f35c1e3fbeee55882772ae7627fcf4411dae90282f98a1067","message":{"role":"user","content":"how is the weather in chicago?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"chicago"}},"tool_response":"{\\"city\\":\\"chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":54,\\"max\\":64}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":84,\\"max\\":99}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":85,\\"max\\":100}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":50,\\"max\\":62}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":79,\\"max\\":85}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":88,\\"max\\":100}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":56,\\"max\\":61}}],\\"unit\\":\\"F\\"}"}]]'
|
||||
|
||||
def test_process_state():
|
||||
history = []
|
||||
history.append(Message(role="user", content="how is the weather in new york?"))
|
||||
history.append(Message(role="user", content="how is the weather in chicago?"))
|
||||
updated_history = process_state(arch_state, history)
|
||||
print(json.dumps(updated_history, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue