diff --git a/arch/Cargo.lock b/arch/Cargo.lock index 4886134a..388a329d 100644 --- a/arch/Cargo.lock +++ b/arch/Cargo.lock @@ -759,6 +759,7 @@ dependencies = [ "serde_json", "serde_yaml", "serial_test", + "sha2", "thiserror", "tiktoken-rs", ] diff --git a/arch/Cargo.toml b/arch/Cargo.toml index e7a5d721..15fe482b 100644 --- a/arch/Cargo.toml +++ b/arch/Cargo.toml @@ -21,6 +21,7 @@ tiktoken-rs = "0.5.9" acap = "0.3.0" rand = "0.8.5" thiserror = "1.0.64" +sha2 = "0.10.8" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } diff --git a/arch/Dockerfile b/arch/Dockerfile index 60526054..af557211 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -10,7 +10,7 @@ COPY public_types /public_types RUN cargo build --release --target wasm32-wasi # copy built filter into envoy image -FROM envoyproxy/envoy:v1.30-latest as envoy +FROM envoyproxy/envoy:v1.31-latest as envoy #Build config generator, so that we have a single build image for both Rust and Python FROM python:3-slim as arch diff --git a/arch/src/consts.rs b/arch/src/consts.rs index 9b14c532..5cf0478e 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -12,4 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; -// pub const ARCH_STATE_HEADER: &str = "x-arch-state"; +pub const ARCH_STATE_HEADER: &str = "x-arch-state"; diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index fdfe5be0..26ba4858 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -1,7 +1,7 @@ use crate::consts::{ ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, - ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, - DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; @@ -15,9 +15,9 @@ use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use public_types::common_types::open_ai::{ - ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, + ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, - ParameterType, StreamOptions, ToolType, + ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, }; use public_types::common_types::{ EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, @@ -28,6 +28,8 @@ use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; +use serde_json::Value; +use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::num::NonZero; use std::rc::Rc; @@ -59,10 +61,16 @@ pub struct StreamContext { embeddings_store: Rc, overrides: Rc>, callouts: HashMap, + tool_calls: Option>, + tool_call_response: Option, + arch_state: Option>, + request_body_size: usize, ratelimit_selector: Option
, streaming_response: bool, + user_prompt: Option, response_tokens: usize, - chat_completions_request: bool, + is_chat_completions_request: bool, + chat_completions_request: Option, prompt_guards: Rc, llm_providers: Rc, llm_provider: Option>, @@ -83,11 +91,17 @@ impl StreamContext { metrics, prompt_targets, embeddings_store, + chat_completions_request: None, callouts: HashMap::new(), + tool_calls: None, + tool_call_response: None, + arch_state: None, + request_body_size: 0, ratelimit_selector: None, streaming_response: false, + user_prompt: None, response_tokens: 0, - chat_completions_request: false, + is_chat_completions_request: false, llm_providers, llm_provider: None, prompt_guards, @@ -463,13 +477,20 @@ impl StreamContext { }); } + // archfc handler needs state so it can expand tool calls + let mut metadata = HashMap::new(); + metadata.insert( + ARCH_STATE_HEADER.to_string(), + serde_json::to_string(&self.arch_state).unwrap(), + ); + let chat_completions = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), messages: callout_context.request_body.messages.clone(), tools: Some(chat_completion_tools), stream: false, stream_options: None, - metadata: None, + metadata: Some(metadata), }; let msg_body = match serde_json::to_string(&chat_completions) { @@ -521,10 +542,8 @@ impl StreamContext { } fn function_resolver_handler(&mut self, body: Vec, mut callout_context: CallContext) { - debug!("response received for function resolver"); - let body_str = String::from_utf8(body).unwrap(); - debug!("function_resolver response str: {}", body_str); + debug!("arch <= app response body: {}", body_str); let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, @@ -559,7 +578,6 @@ impl StreamContext { let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); - debug!("tool_call_details: {:?}", tool_calls); // extract all tool names let tool_names: Vec = tool_calls .iter() @@ -581,8 +599,10 @@ impl StreamContext { let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - debug!("prompt_target_name: {}", prompt_target.name); - debug!("tool_name(s): {:?}", tool_names); + debug!( + "prompt_target_name: {}, tool_name(s): {:?}", + prompt_target.name, tool_names + ); debug!("tool_params: {}", tool_params_json_str); let endpoint = prompt_target.endpoint.unwrap(); @@ -611,6 +631,7 @@ impl StreamContext { } }; + self.tool_calls = Some(tool_calls.clone()); callout_context.upstream_cluster = Some(endpoint.name); callout_context.upstream_cluster_path = Some(path); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; @@ -635,9 +656,9 @@ impl StreamContext { } else { warn!("http status code not found in api response"); } - debug!("response received for function call response"); let body_str: String = String::from_utf8(body).unwrap(); - debug!("function_call_response response str: {}", body_str); + self.tool_call_response = Some(body_str.clone()); + debug!("arch <= app response body: {}", body_str); let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target = self .prompt_targets @@ -697,10 +718,7 @@ impl StreamContext { .send_server_error(format!("Error serializing request_body: {:?}", e), None); } }; - debug!( - "function_calling sending request to openai: msg {}", - json_string - ); + debug!("arch => openai request body: {}", json_string); // Tokenize and Ratelimit. if let Some(selector) = self.ratelimit_selector.take() { @@ -725,7 +743,7 @@ impl StreamContext { } } - self.set_http_request_body(0, json_string.len(), &json_string.into_bytes()); + self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); self.resume_http_request(); } @@ -881,7 +899,7 @@ impl StreamContext { }; let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); debug!("sending response back to default llm: {}", json_resp); - self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes()); + self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.resume_http_request(); } } @@ -899,7 +917,7 @@ impl HttpContext for StreamContext { self.delete_content_length_header(); self.save_ratelimit_header(); - self.chat_completions_request = + self.is_chat_completions_request = self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; debug!( @@ -922,6 +940,8 @@ impl HttpContext for StreamContext { return Action::Continue; } + self.request_body_size = body_size; + // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -948,6 +968,20 @@ impl HttpContext for StreamContext { } }; + self.arch_state = match deserialized_body.metadata { + Some(ref metadata) => { + if metadata.contains_key(ARCH_STATE_HEADER) { + let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); + Some(arch_state) + } else { + None + } + } + None => None, + }; + + self.is_chat_completions_request = true; // Set the model based on the chosen LLM Provider deserialized_body.model = String::from(&self.llm_provider().model); @@ -958,10 +992,11 @@ impl HttpContext for StreamContext { }); } - let user_message = match deserialized_body + let last_user_prompt = match deserialized_body .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) .last() - .and_then(|last_message| last_message.content.clone()) { Some(content) => content, None => { @@ -970,17 +1005,24 @@ impl HttpContext for StreamContext { } }; + self.user_prompt = Some(last_user_prompt.clone()); + + let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); + let prompt_guard_jailbreak_task = self .prompt_guards .input_guards .contains_key(&public_types::configuration::GuardType::Jailbreak); + + self.chat_completions_request = Some(deserialized_body); + if !prompt_guard_jailbreak_task { debug!("Missing input guard. Making inline call to retrieve"); let callout_context = CallContext { response_handler_type: ResponseHandlerType::ArchGuard, - user_message: Some(user_message), + user_message: user_message_str.clone(), prompt_target_name: None, - request_body: deserialized_body, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, @@ -990,7 +1032,14 @@ impl HttpContext for StreamContext { } let get_prompt_guards_request = PromptGuardRequest { - input: user_message.clone(), + input: self + .user_prompt + .as_ref() + .unwrap() + .content + .as_ref() + .unwrap() + .clone(), task: PromptGuardTask::Jailbreak, }; @@ -1032,9 +1081,9 @@ impl HttpContext for StreamContext { let call_context = CallContext { response_handler_type: ResponseHandlerType::ArchGuard, - user_message: Some(user_message), + user_message: self.user_prompt.as_ref().unwrap().content.clone(), prompt_target_name: None, - request_body: deserialized_body, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, @@ -1057,7 +1106,7 @@ impl HttpContext for StreamContext { self.context_id, body_size, end_of_stream ); - if !self.chat_completions_request { + if !self.is_chat_completions_request { if let Some(body_str) = self .get_http_response_body(0, body_size) .and_then(|bytes| String::from_utf8(bytes).ok()) @@ -1067,7 +1116,7 @@ impl HttpContext for StreamContext { return Action::Continue; } - if !end_of_stream && !self.streaming_response { + if !end_of_stream { return Action::Pause; } @@ -1075,9 +1124,8 @@ impl HttpContext for StreamContext { .get_http_response_body(0, body_size) .expect("cant get response body"); - let body_str = String::from_utf8(body).expect("body is not utf-8"); - if self.streaming_response { + let body_str = String::from_utf8(body).expect("body is not utf-8"); debug!("streaming response"); let chat_completions_data = match body_str.split_once("data: ") { Some((_, chat_completions_data)) => chat_completions_data, @@ -1117,13 +1165,14 @@ impl HttpContext for StreamContext { } else { debug!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_str(&body_str) { + match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { self.send_server_error( format!( "error in non-streaming response: {}\n response was={}", - e, body_str + e, + String::from_utf8(body).unwrap() ), None, ); @@ -1132,6 +1181,65 @@ impl HttpContext for StreamContext { }; self.response_tokens += chat_completions_response.usage.completion_tokens; + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + // compute sha hash from message history + let mut hasher = Sha256::new(); + let prompts: Vec = self + .chat_completions_request + .as_ref() + .unwrap() + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .map(|msg| msg.content.clone().unwrap()) + .collect(); + let prompts_merged = prompts.join("#.#"); + hasher.update(prompts_merged.clone()); + let hash_key = hasher.finalize(); + // conver hash to hex string + let hash_key_str = format!("{:x}", hash_key); + debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); + + // create new tool call state + let tool_call_state = ToolCallState { + key: hash_key_str, + message: self.user_prompt.clone(), + tool_call: tool_calls[0].function.clone(), + tool_response: self.tool_call_response.clone().unwrap(), + }; + + // push tool call state to arch state + self.arch_state + .as_mut() + .unwrap() + .push(ArchState::ToolCall(vec![tool_call_state])); + + let mut data: Value = serde_json::from_slice(&body).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); + debug!("arch_state: {}", arch_state_str); + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("arch => user: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } } debug!( diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 53fbf215..6b36347f 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -571,9 +571,6 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -592,14 +589,15 @@ fn request_ratelimited() { .execute_and_expect(ReturnType::None) .unwrap(); + let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")]; let body_text = String::from("test body"); module .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders)) + .returning(Some(response_headers_with_200)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) @@ -612,10 +610,6 @@ fn request_ratelimited() { None, ) .expect_metric_increment("ratelimited_rq", 1) - .expect_log( - Some(LogLevel::Debug), - Some("server error occurred: Exceeded Ratelimit: Not allowed"), - ) .execute_and_expect(ReturnType::None) .unwrap(); } @@ -685,9 +679,6 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -706,15 +697,16 @@ fn request_not_ratelimited() { .execute_and_expect(ReturnType::None) .unwrap(); + let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")]; + let body_text = String::from("test body"); module .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders)) + .returning(Some(response_headers_with_200)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 75f5f295..eba24016 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -11,6 +11,7 @@ OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") +ARCH_STATE_HEADER = 'x-arch-state' log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) @@ -32,7 +33,7 @@ def predict(message, state): metadata = None if 'arch_state' in state: - metadata = {"x-arch-state": state['arch_state']} + metadata = {ARCH_STATE_HEADER: state['arch_state']} try: raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME, @@ -48,11 +49,12 @@ def predict(message, state): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) + log.debug("raw_response: ", raw_response.text) response = raw_response.parse() # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request - arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None) + arch_state = json.loads(raw_response.text).get('metadata', {}).get(ARCH_STATE_HEADER, None) if arch_state: state['arch_state'] = arch_state diff --git a/model_server/.vscode/launch.json b/model_server/.vscode/launch.json index 23828ad7..3e8a10a7 100644 --- a/model_server/.vscode/launch.json +++ b/model_server/.vscode/launch.json @@ -10,6 +10,9 @@ "request": "launch", "module": "uvicorn", "args": ["app.main:app","--reload", "--port", "8000"], + "env": { + "MODE": "local-cpu", + } } ] } diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py index 8b1a3335..e90b6531 100644 --- a/model_server/app/arch_fc/arch_fc.py +++ b/model_server/app/arch_fc/arch_fc.py @@ -69,7 +69,8 @@ def process_state(arch_state, history: list[Message]): if hist.role == 'user': sha_history.append(hist.content) sha256_hash = hashlib.sha256() - sha256_hash.update(json.dumps(sha_history).encode()) + joined_key_str = ('#.#').join(sha_history) + sha256_hash.update(joined_key_str.encode()) sha_key = sha256_hash.hexdigest() print(f"sha_key: {sha_key}") if sha_key in state_map: diff --git a/model_server/app/arch_fc/test_arch_fc.py b/model_server/app/arch_fc/test_arch_fc.py index 0eb409d0..caf5550c 100644 --- a/model_server/app/arch_fc/test_arch_fc.py +++ b/model_server/app/arch_fc/test_arch_fc.py @@ -4,14 +4,15 @@ from app.arch_fc.arch_fc import process_state from app.arch_fc.common import ChatMessage, Message # test process_state -arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]' - +arch_state = '[[{"key":"02ea8ec721b130dc30ec836b79ec675116cd5889bca7d63720bc64baed994fc1","message":{"role":"user","content":"how is the weather in new york?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"new york"}},"tool_response":"{\\"city\\":\\"new york\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":68,\\"max\\":79}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":70,\\"max\\":76}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":71,\\"max\\":84}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":61,\\"max\\":79}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":86,\\"max\\":91}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":85,\\"max\\":90}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":72,\\"max\\":89}}],\\"unit\\":\\"F\\"}"}],[{"key":"566b9a2197cba89f35c1e3fbeee55882772ae7627fcf4411dae90282f98a1067","message":{"role":"user","content":"how is the weather in chicago?"},"tool_call":{"name":"weather_forecast","arguments":{"city":"chicago"}},"tool_response":"{\\"city\\":\\"chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":54,\\"max\\":64}},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":84,\\"max\\":99}},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":85,\\"max\\":100}},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":50,\\"max\\":62}},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":79,\\"max\\":85}},{\\"date\\":\\"2024-10-12\\",\\"temperature\\":{\\"min\\":88,\\"max\\":100}},{\\"date\\":\\"2024-10-13\\",\\"temperature\\":{\\"min\\":56,\\"max\\":61}}],\\"unit\\":\\"F\\"}"}]]' def test_process_state(): history = [] + history.append(Message(role="user", content="how is the weather in new york?")) history.append(Message(role="user", content="how is the weather in chicago?")) updated_history = process_state(arch_state, history) print(json.dumps(updated_history, indent=2)) + if __name__ == "__main__": pytest.main()