Serialize tool calls for Arch FC (#131)

* Serialize tool calls

* fix int tests
This commit is contained in:
Adil Hafeez 2024-10-07 00:03:25 -07:00 committed by GitHub
parent b43f687b85
commit 96686dc606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 166 additions and 57 deletions

1
arch/Cargo.lock generated
View file

@ -759,6 +759,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"serial_test",
"sha2",
"thiserror",
"tiktoken-rs",
]

View file

@ -21,6 +21,7 @@ tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
sha2 = "0.10.8"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }

View file

@ -10,7 +10,7 @@ COPY public_types /public_types
RUN cargo build --release --target wasm32-wasi
# copy built filter into envoy image
FROM envoyproxy/envoy:v1.30-latest as envoy
FROM envoyproxy/envoy:v1.31-latest as envoy
#Build config generator, so that we have a single build image for both Rust and Python
FROM python:3-slim as arch

View file

@ -12,4 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
// pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";

View file

@ -1,7 +1,7 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
@ -15,9 +15,9 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
ParameterType, StreamOptions, ToolType,
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
@ -28,6 +28,8 @@ use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
@ -59,10 +61,16 @@ pub struct StreamContext {
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
response_tokens: usize,
chat_completions_request: bool,
is_chat_completions_request: bool,
chat_completions_request: Option<ChatCompletionsRequest>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
@ -83,11 +91,17 @@ impl StreamContext {
metrics,
prompt_targets,
embeddings_store,
chat_completions_request: None,
callouts: HashMap::new(),
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
response_tokens: 0,
chat_completions_request: false,
is_chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
@ -463,13 +477,20 @@ impl StreamContext {
});
}
// archfc handler needs state so it can expand tool calls
let mut metadata = HashMap::new();
metadata.insert(
ARCH_STATE_HEADER.to_string(),
serde_json::to_string(&self.arch_state).unwrap(),
);
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
metadata: None,
metadata: Some(metadata),
};
let msg_body = match serde_json::to_string(&chat_completions) {
@ -521,10 +542,8 @@ impl StreamContext {
}
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
debug!("response received for function resolver");
let body_str = String::from_utf8(body).unwrap();
debug!("function_resolver response str: {}", body_str);
debug!("arch <= app response body: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
@ -559,7 +578,6 @@ impl StreamContext {
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
let tool_names: Vec<String> = tool_calls
.iter()
@ -581,8 +599,10 @@ impl StreamContext {
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
debug!("prompt_target_name: {}", prompt_target.name);
debug!("tool_name(s): {:?}", tool_names);
debug!(
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name, tool_names
);
debug!("tool_params: {}", tool_params_json_str);
let endpoint = prompt_target.endpoint.unwrap();
@ -611,6 +631,7 @@ impl StreamContext {
}
};
self.tool_calls = Some(tool_calls.clone());
callout_context.upstream_cluster = Some(endpoint.name);
callout_context.upstream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
@ -635,9 +656,9 @@ impl StreamContext {
} else {
warn!("http status code not found in api response");
}
debug!("response received for function call response");
let body_str: String = String::from_utf8(body).unwrap();
debug!("function_call_response response str: {}", body_str);
self.tool_call_response = Some(body_str.clone());
debug!("arch <= app response body: {}", body_str);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
@ -697,10 +718,7 @@ impl StreamContext {
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
}
};
debug!(
"function_calling sending request to openai: msg {}",
json_string
);
debug!("arch => openai request body: {}", json_string);
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
@ -725,7 +743,7 @@ impl StreamContext {
}
}
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.resume_http_request();
}
@ -881,7 +899,7 @@ impl StreamContext {
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
}
@ -899,7 +917,7 @@ impl HttpContext for StreamContext {
self.delete_content_length_header();
self.save_ratelimit_header();
self.chat_completions_request =
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!(
@ -922,6 +940,8 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.request_body_size = body_size;
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -948,6 +968,20 @@ impl HttpContext for StreamContext {
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
None
}
}
None => None,
};
self.is_chat_completions_request = true;
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(&self.llm_provider().model);
@ -958,10 +992,11 @@ impl HttpContext for StreamContext {
});
}
let user_message = match deserialized_body
let last_user_prompt = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
.and_then(|last_message| last_message.content.clone())
{
Some(content) => content,
None => {
@ -970,17 +1005,24 @@ impl HttpContext for StreamContext {
}
};
self.user_prompt = Some(last_user_prompt.clone());
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
user_message: user_message_str.clone(),
prompt_target_name: None,
request_body: deserialized_body,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
@ -990,7 +1032,14 @@ impl HttpContext for StreamContext {
}
let get_prompt_guards_request = PromptGuardRequest {
input: user_message.clone(),
input: self
.user_prompt
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.clone(),
task: PromptGuardTask::Jailbreak,
};
@ -1032,9 +1081,9 @@ impl HttpContext for StreamContext {
let call_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: deserialized_body,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
@ -1057,7 +1106,7 @@ impl HttpContext for StreamContext {
self.context_id, body_size, end_of_stream
);
if !self.chat_completions_request {
if !self.is_chat_completions_request {
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
@ -1067,7 +1116,7 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
if !end_of_stream && !self.streaming_response {
if !end_of_stream {
return Action::Pause;
}
@ -1075,9 +1124,8 @@ impl HttpContext for StreamContext {
.get_http_response_body(0, body_size)
.expect("cant get response body");
let body_str = String::from_utf8(body).expect("body is not utf-8");
if self.streaming_response {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
@ -1117,13 +1165,14 @@ impl HttpContext for StreamContext {
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(&body_str) {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
self.send_server_error(
format!(
"error in non-streaming response: {}\n response was={}",
e, body_str
e,
String::from_utf8(body).unwrap()
),
None,
);
@ -1132,6 +1181,65 @@ impl HttpContext for StreamContext {
};
self.response_tokens += chat_completions_response.usage.completion_tokens;
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
debug!(

View file

@ -571,9 +571,6 @@ fn request_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("api_server"),
Some(vec![
@ -592,14 +589,15 @@ fn request_ratelimited() {
.execute_and_expect(ReturnType::None)
.unwrap();
let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")];
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders))
.returning(Some(response_headers_with_200))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
@ -612,10 +610,6 @@ fn request_ratelimited() {
None,
)
.expect_metric_increment("ratelimited_rq", 1)
.expect_log(
Some(LogLevel::Debug),
Some("server error occurred: Exceeded Ratelimit: Not allowed"),
)
.execute_and_expect(ReturnType::None)
.unwrap();
}
@ -685,9 +679,6 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("api_server"),
Some(vec![
@ -706,15 +697,16 @@ fn request_not_ratelimited() {
.execute_and_expect(ReturnType::None)
.unwrap();
let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")];
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders))
.returning(Some(response_headers_with_200))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)

View file

@ -11,6 +11,7 @@ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
ARCH_STATE_HEADER = 'x-arch-state'
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
@ -32,7 +33,7 @@ def predict(message, state):
metadata = None
if 'arch_state' in state:
metadata = {"x-arch-state": state['arch_state']}
metadata = {ARCH_STATE_HEADER: state['arch_state']}
try:
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
@ -48,11 +49,12 @@ def predict(message, state):
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.debug("raw_response: ", raw_response.text)
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None)
if arch_state:
state['arch_state'] = arch_state

View file

@ -10,6 +10,9 @@
"request": "launch",
"module": "uvicorn",
"args": ["app.main:app","--reload", "--port", "8000"],
"env": {
"MODE": "local-cpu",
}
}
]
}

View file

@ -69,7 +69,8 @@ def process_state(arch_state, history: list[Message]):
if hist.role == 'user':
sha_history.append(hist.content)
sha256_hash = hashlib.sha256()
sha256_hash.update(json.dumps(sha_history).encode())
joined_key_str = ('#.#').join(sha_history)
sha256_hash.update(joined_key_str.encode())
sha_key = sha256_hash.hexdigest()
print(f"sha_key: {sha_key}")
if sha_key in state_map:

View file

@ -4,14 +4,15 @@ from app.arch_fc.arch_fc import process_state
from app.arch_fc.common import ChatMessage, Message
# test process_state
arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]'
arch_state = '[[{"key":"02ea8ec721b130dc30ec836b79ec675116cd5889bca7d63720bc64baed994fc1","message":{"role":"user","content":"how is the weather in new york?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"new york"}},"tool_response":"{\\"city\\":\\"new york\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":68,\\"max\\":79}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":70,\\"max\\":76}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":71,\\"max\\":84}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":61,\\"max\\":79}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":86,\\"max\\":91}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":85,\\"max\\":90}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":72,\\"max\\":89}}],\\"unit\\":\\"F\\"}"}],[{"key":"566b9a2197cba89f35c1e3fbeee55882772ae7627fcf4411dae90282f98a1067","message":{"role":"user","content":"how is the weather in chicago?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"chicago"}},"tool_response":"{\\"city\\":\\"chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":54,\\"max\\":64}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":84,\\"max\\":99}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":85,\\"max\\":100}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":50,\\"max\\":62}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":79,\\"max\\":85}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":88,\\"max\\":100}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":56,\\"max\\":61}}],\\"unit\\":\\"F\\"}"}]]'
def test_process_state():
history = []
history.append(Message(role="user", content="how is the weather in new york?"))
history.append(Message(role="user", content="how is the weather in chicago?"))
updated_history = process_state(arch_state, history)
print(json.dumps(updated_history, indent=2))
if __name__ == "__main__":
pytest.main()