From b31a7a569ac7bb977a6ee6957cd8a6a1dc8ecf18 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 28 Mar 2025 03:04:12 -0700 Subject: [PATCH 1/8] update rest and other parts of the code to work with arch fc 1.1 --- crates/common/src/consts.rs | 7 +- crates/llm_gateway/src/stream_context.rs | 2 +- crates/prompt_gateway/src/http_context.rs | 45 +++++-- crates/prompt_gateway/src/stream_context.rs | 139 +++++++++++++++++--- demos/shared/chatbot_ui/common.py | 5 +- model_server/src/core/function_calling.py | 14 +- model_server/src/core/utils/model_utils.py | 16 ++- model_server/src/main.py | 15 +-- 8 files changed, 196 insertions(+), 47 deletions(-) diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index cd52220e..8a12f1d1 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const HEALTHZ_PATH: &str = "/healthz"; -pub const ARCH_STATE_HEADER: &str = "x-arch-state"; -pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; +pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; +pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; +pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message"; +pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response"; +pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 78d7e21e..9a6e3985 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -411,7 +411,7 @@ impl HttpContext for StreamContext { ); if self.request_body_sent_time.is_none() { - debug!("on_http_response_body: request body not sent, no doing any processing in llm filter"); + debug!("on_http_response_body: request body not sent, not doing any processing in llm filter"); return Action::Continue; } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 3a7dc7d9..ac747a6b 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -4,10 +4,11 @@ use common::{ self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, }, consts::{ - ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE, - TRACE_PARENT_HEADER, USER_ROLE, + TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, + X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL, }, errors::ServerError, http::{CallArgs, Client}, @@ -125,8 +126,8 @@ impl HttpContext for StreamContext { self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + if metadata.contains_key(X_ARCH_STATE_HEADER) { + let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone(); let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); Some(arch_state) } else { @@ -336,10 +337,10 @@ impl HttpContext for StreamContext { if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_string()), - self.tool_calls.to_owned(), + None, ), ChatCompletionStreamResponse::new( self.tool_call_response.clone(), @@ -381,17 +382,39 @@ impl HttpContext for StreamContext { *metadata = Value::Object(serde_json::Map::new()); } - let fc_messages = vec![ - self.generate_toll_call_message(), - self.generate_api_response_message(), - ]; + let tool_call_message = self.generate_toll_call_message(); + let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_TOOL_CALL.to_string(), + serde_json::Value::String(tool_call_message_str), + ); + + let api_response_message = self.generate_api_response_message(); + let api_response_message_str = + serde_json::to_string(&api_response_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_API_RESPONSE.to_string(), + serde_json::Value::String(api_response_message_str), + ); + + let fc_messages = vec![tool_call_message, api_response_message]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), + X_ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); + + if let Some(arch_fc_response) = self.arch_fc_response.as_ref() { + metadata.as_object_mut().unwrap().insert( + X_ARCH_FC_MODEL_RESPONSE.to_string(), + serde_json::Value::String( + serde_json::to_string(arch_fc_response).unwrap(), + ), + ); + } let data_serialized = serde_json::to_string(&data).unwrap(); info!("archgw <= developer: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index ec68c104..b2950d2f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -9,6 +9,7 @@ use common::consts::{ API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + X_ARCH_FC_MODEL_RESPONSE, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; @@ -64,10 +65,10 @@ pub struct StreamContext { pub time_to_first_token: Option, pub traceparent: Option, pub _tracing: Rc>, + pub arch_fc_response: Option, } impl StreamContext { - #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, metrics: Rc, @@ -98,6 +99,7 @@ impl StreamContext { _tracing: tracing, start_upstream_llm_request_time: 0, time_to_first_token: None, + arch_fc_response: None, } } @@ -142,15 +144,17 @@ impl StreamContext { } }; - // intent was matched if we see function_latency in metadata - let intent_matched = model_server_response + let intent_matched = check_intent_matched(&model_server_response); + info!("intent matched: {}", intent_matched); + + self.arch_fc_response = model_server_response .metadata .as_ref() - .and_then(|metadata| metadata.get("function_latency")) - .is_some(); + .and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE)) + .cloned(); + if !intent_matched { - info!("intent not matched"); // check if we have a default prompt target if let Some(default_prompt_target) = self .prompt_targets @@ -278,9 +282,9 @@ impl StreamContext { let direct_response_str = if self.streaming_response { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(ARCH_FC_MODEL_NAME.to_string()), None, ), ChatCompletionStreamResponse::new( @@ -293,7 +297,7 @@ impl StreamContext { .clone(), ), None, - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())), None, ), ]; @@ -624,12 +628,23 @@ impl StreamContext { } pub fn generate_toll_call_message(&mut self) -> Message { - Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, + if self.arch_fc_response.is_none() { + info!("arch_fc_response is none, generating tool call message"); + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + } + } else { + Message { + role: ASSISTANT_ROLE.to_string(), + content: self.arch_fc_response.as_ref().cloned(), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + } } } @@ -761,6 +776,26 @@ impl StreamContext { } } +fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { + let content = model_server_response + .choices + .get(0) + .and_then(|choice| choice.message.content.as_ref()); + + let content_has_value = content.is_some() && !content.unwrap().is_empty(); + + let tool_calls = model_server_response + .choices + .get(0) + .and_then(|choice| choice.message.tool_calls.as_ref()); + + // intent was matched if content has some value or tool_calls is empty + let intent_matched = + content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty()); + + return intent_matched; +} + impl Client for StreamContext { type CallContext = StreamCallContext; @@ -772,3 +807,77 @@ impl Client for StreamContext { &self.metrics.active_http_calls } } + +#[cfg(test)] +mod test { + use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall}; + + use crate::stream_context::check_intent_matched; + + #[test] + fn test_intent_matched() { + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), false); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("hello".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), true); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![ToolCall { + id: "1".to_string(), + function: common::api::open_ai::FunctionCallDetail { + name: "test".to_string(), + arguments: None, + }, + tool_type: common::api::open_ai::ToolType::Function, + }]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), true); + } +} diff --git a/demos/shared/chatbot_ui/common.py b/demos/shared/chatbot_ui/common.py index 42e50bd4..1de8f94c 100644 --- a/demos/shared/chatbot_ui/common.py +++ b/demos/shared/chatbot_ui/common.py @@ -120,8 +120,11 @@ def process_stream_chunk(chunk, history): if delta.content: # append content to the last history item - history[-1]["content"] = history[-1].get("content", "") + delta.content + if history[-1]["model"] != "Arch-Function-Chat": + history[-1]["content"] = history[-1].get("content", "") + delta.content # yield content if it is from assistant + if history[-1]["model"] == "Arch-Function": + return None if history[-1]["role"] == "assistant": return delta.content diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index f5b2cc44..75673f66 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -197,12 +197,12 @@ class ArchFunctionHandler(ArchBaseHandler): response_dict["response"] = model_response.get("response", "") response_dict["required_functions"] = model_response.get( - "required_functions", "" + "required_functions", [] ) response_dict["clarification"] = model_response.get("clarification", "") for tool_call in model_response.get("tool_calls", []): - response_dict["tool_call"].append( + response_dict["tool_calls"].append( { "id": f"call_{random.randint(1000, 10000)}", "type": "function", @@ -448,6 +448,7 @@ class ArchFunctionHandler(ArchBaseHandler): if len(chunk.choices) > 0 and chunk.choices[0].delta.content: model_response += chunk.choices[0].delta.content + logger.info(f"[arch-fc]: raw model response: {model_response}") # Extract tool calls from model response response_dict = self._parse_model_resonse(model_response) @@ -499,10 +500,15 @@ class ArchFunctionHandler(ArchBaseHandler): model_message = Message(content="", tool_calls=[]) chat_completion_response = ChatCompletionResponse( - choices=[Choice(message=model_message)], model=self.model_name + choices=[Choice(message=model_message)], + model=self.model_name, + metadata={"x-arch-fc-model-response": model_response}, + role="assistant", ) - logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}") + logger.info( + f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}" + ) return chat_completion_response diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 0c42c796..2dfd2237 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -142,7 +142,7 @@ class ArchBaseHandler: {"role": "system", "content": self._format_system_prompt(tools)} ) - for message in messages: + for idx, message in enumerate(messages): role, content, tool_calls = ( message.role, message.content, @@ -158,9 +158,17 @@ class ArchBaseHandler: if metadata.get("optimize_context_window", "false").lower() == "true": content = f"\n\n" else: - content = ( - f"\n{json.dumps(content)}\n" - ) + # sample response below + # "content": "\n{'name': 'get_stock_price', 'result': '$196.66'}\n" + # msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}' + func_name = json.loads(messages[idx - 1].content)["tool_calls"][ + 0 + ].get("name", "no_name") + tool_response = { + "name": func_name, + "result": content, + } + content = f"\n{json.dumps(tool_response)}\n" processed_messages.append({"role": role, "content": content}) diff --git a/model_server/src/main.py b/model_server/src/main.py index b0434222..c3070392 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -87,16 +87,15 @@ async def function_calling(req: ChatMessage, res: Response): final_response = await model_handler.chat_completion(req) latency = time.perf_counter() - start_time + if not final_response.metadata: + final_response.metadata = {} + # Parameter gathering for detected intents if final_response.choices[0].message.content: - final_response.metadata = { - "function_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) # Function Calling elif final_response.choices[0].message.tool_calls: - final_response.metadata = { - "function_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) # ********************************************************************************************* # TODO: Put the following code back when hallucination check is ready @@ -107,9 +106,7 @@ async def function_calling(req: ChatMessage, res: Response): # ) # No intent detected else: - final_response.metadata = { - "intent_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) if not use_agent_orchestrator: final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) From 066182391f4971a9f65f5a73f40424ed70c49a5a Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 28 Mar 2025 03:34:34 -0700 Subject: [PATCH 2/8] fix tests --- tests/archgw/common.py | 13 +++++++------ tests/modelserver/test_hallucination.py | 3 +++ tests/modelserver/test_modelserver.py | 3 +++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/archgw/common.py b/tests/archgw/common.py index 404d8ef9..f87f6403 100644 --- a/tests/archgw/common.py +++ b/tests/archgw/common.py @@ -47,14 +47,11 @@ TEST_CASE_FIXTURES = { "tool_call_id": "", "tool_calls": [ { - "id": "call_6009", + "id": "call_2925", "type": "function", "function": { "name": "get_current_weather", - "arguments": { - "location": "Seattle, WA", - "days": "2", - }, + "arguments": {"location": "Seattle", "days": "2"}, }, } ], @@ -63,7 +60,11 @@ TEST_CASE_FIXTURES = { } ], "model": "Arch-Function", - "metadata": {"intent_latency": "455.092", "function_latency": "312.744"}, + "metadata": { + "x-arch-fc-model-response": '{"tool_calls": [{"name": "get_current_weather", "arguments": {"location": "Seattle", "days": "2"}}]}', + "function_latency": "361.841", + "intent_latency": "361.841", + }, }, "api_server_response": [ { diff --git a/tests/modelserver/test_hallucination.py b/tests/modelserver/test_hallucination.py index f1a3d9b4..323db3fc 100644 --- a/tests/modelserver/test_hallucination.py +++ b/tests/modelserver/test_hallucination.py @@ -4,6 +4,9 @@ import requests import logging import yaml +pytestmark = pytest.mark.skip( + reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet" +) MODEL_SERVER_ENDPOINT = os.getenv( "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" diff --git a/tests/modelserver/test_modelserver.py b/tests/modelserver/test_modelserver.py index 75e6d27e..4596606f 100644 --- a/tests/modelserver/test_modelserver.py +++ b/tests/modelserver/test_modelserver.py @@ -5,6 +5,9 @@ import yaml from deepdiff import DeepDiff +pytestmark = pytest.mark.skip( + reason="Skipping entire test file as this these tests are heavily dependent on model output" +) MODEL_SERVER_ENDPOINT = os.getenv( "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" From e2edeedf555348ea4c05917c24d0228a4f73aa19 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 28 Mar 2025 03:47:09 -0700 Subject: [PATCH 3/8] fix rust tests --- crates/prompt_gateway/tests/integration.rs | 32 +++++----------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index bbde10b0..91b36c01 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -380,6 +380,7 @@ fn prompt_gateway_request_to_llm_gateway() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) @@ -453,6 +454,7 @@ fn prompt_gateway_request_to_llm_gateway() { .expect_log(Some(LogLevel::Info), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -493,19 +495,9 @@ fn prompt_gateway_request_no_intent_match() { finish_reason: Some("test".to_string()), index: Some(0), message: Message { - role: "system".to_string(), + role: "assistant".to_string(), content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: Some(HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )])), - }, - }]), + tool_calls: None, model: None, tool_call_id: None, }, @@ -523,7 +515,7 @@ fn prompt_gateway_request_no_intent_match() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), Some("intent not matched")) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) .expect_log( Some(LogLevel::Info), Some("no default prompt target found, forwarding request to upstream llm"), @@ -651,17 +643,7 @@ fn prompt_gateway_request_no_intent_match_default_target() { message: Message { role: "system".to_string(), content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: Some(HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )])), - }, - }]), + tool_calls: None, model: None, tool_call_id: None, }, @@ -679,7 +661,7 @@ fn prompt_gateway_request_no_intent_match_default_target() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), Some("intent not matched")) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) .expect_log( Some(LogLevel::Info), Some("default prompt target found, forwarding request to default prompt target"), From e48918259eececa26f9f2527b80cd8c082c23ebb Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 28 Mar 2025 03:58:57 -0700 Subject: [PATCH 4/8] fix tests --- tests/e2e/test_prompt_gateway.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index f122ad30..d65aae08 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -42,9 +42,11 @@ def test_prompt_gateway(stream): assert "role" in choices[0]["delta"] role = choices[0]["delta"]["role"] assert role == "assistant" - tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + print(f"choices: {choices}") + tool_call_str = choices[0].get("delta", {}).get("content", "") + tool_calls = json.loads(tool_call_str).get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = tool_calls[0] location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"] From a3f2b3cef928605865c1e2a7efe3a2ee4e5580e8 Mon Sep 17 00:00:00 2001 From: CTran Date: Fri, 28 Mar 2025 09:49:20 -0700 Subject: [PATCH 5/8] add hallucination modification (#455) * add hallucination modification * disable test --- model_server/src/core/function_calling.py | 66 +++++++++---------- .../src/core/utils/hallucination_utils.py | 45 +++++-------- model_server/src/main.py | 14 ++-- .../tests/core/test_function_calling.py | 56 ++++++++-------- 4 files changed, 86 insertions(+), 95 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 75673f66..847861e9 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -406,47 +406,47 @@ class ArchFunctionHandler(ArchBaseHandler): # ********************************************************************************************* # initialize the hallucination handler, which is an iterator + self.hallucination_state = HallucinationState( + response_iterator=response, function=req.tools + ) - # self.hallucination_state = HallucinationState( - # response_iterator=response, function=req.tools - # ) + has_tool_calls, has_hallucination = None, False + for _ in self.hallucination_state: + # check if the first token is + if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None: + content = ''.join(self.hallucination_state.tokens) + if "tool_calls" in content: + has_tool_calls = True + else: + has_tool_calls = False + break - # has_tool_calls, has_hallucination = None, False - # for _ in self.hallucination_state: - # # check if the first token is - # if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None: - # if self.hallucination_state.tokens[0] == "": - # has_tool_calls = True - # else: - # has_tool_calls = False - # break + # if the model is hallucinating, start parameter gathering + if self.hallucination_state.hallucination is True: + has_hallucination = True + break - # # if the model is hallucinating, start parameter gathering - # if self.hallucination_state.hallucination is True: - # has_hallucination = True - # break - - # if has_tool_calls: - # if has_hallucination: - # # start prompt prefilling if hallcuination is found in tool calls - # logger.info( - # f"[Hallucination]: {self.hallucination_state.error_message}" - # ) - # prefill_response = self._engage_parameter_gathering(messages) - # model_response = prefill_response.choices[0].message.content - # else: - # model_response = "".join(self.hallucination_state.tokens) + if has_tool_calls: + if has_hallucination: + # start prompt prefilling if hallcuination is found in tool calls + logger.info( + f"[Hallucination]: {self.hallucination_state.error_message}" + ) + prefill_response = self._engage_parameter_gathering(messages) + model_response = prefill_response.choices[0].message.content + else: + model_response = "".join(self.hallucination_state.tokens) # else: # # start parameter gathering if the model is not generating tool calls # prefill_response = self._engage_parameter_gathering(messages) # model_response = prefill_response.choices[0].message.content - # *********************************************************************************************\ - # TODO: Remove the following for loop after updating hallucination check - # ********************************************************************************************* - for chunk in response: - if len(chunk.choices) > 0 and chunk.choices[0].delta.content: - model_response += chunk.choices[0].delta.content + # # *********************************************************************************************\ + # # TODO: Remove the following for loop after updating hallucination check + # # ********************************************************************************************* + # for chunk in response: + # if len(chunk.choices) > 0 and chunk.choices[0].delta.content: + # model_response += chunk.choices[0].delta.content logger.info(f"[arch-fc]: raw model response: {model_response}") # Extract tool calls from model response diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 91effc92..05432710 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger logger = get_model_server_logger() # constants -FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'") FUNC_NAME_END_TOKEN = ('",', "',") -TOOL_CALL_TOKEN = "" -END_TOOL_CALL_TOKEN = "" +END_TOOL_CALL_TOKEN = "}}" FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") -PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") -PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'") +PARAMETER_NAME_START_PATTERN = ('","', "','") PARAMETER_VALUE_START_PATTERN = ('":', "':") -PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") +PARAMETER_VALUE_END_TOKEN = ('",', '"}') BRACKETS = {"(": ")", "{": "}", "[": "]"} @@ -37,16 +36,9 @@ class MaskToken(Enum): HALLUCINATION_THRESHOLD_DICT = { - MaskToken.TOOL_CALL.value: { - "entropy": 0.35, - "varentropy": 1.7, - "probability": 0.8, - }, - MaskToken.PARAMETER_VALUE.value: { - "entropy": 0.28, - "varentropy": 1.4, - "probability": 0.8, - }, + "entropy": 0.28, + "varentropy": 1.4, + "probability": 0.8, } @@ -160,6 +152,7 @@ class HallucinationState: self._process_function(function) self.open_bracket = False self.bracket = None + self.function_name = "" self.check_parameter_name = {} self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT @@ -218,12 +211,10 @@ class HallucinationState: raise ValueError( f"Error extracting logprobs from response: {e}" ) - if token_content == END_TOOL_CALL_TOKEN: - self._reset_parameters() - else: - self.append_and_check_token_hallucination( - token_content, logprobs - ) + + self.append_and_check_token_hallucination( + token_content, logprobs + ) return token_content except StopIteration: raise StopIteration @@ -233,13 +224,13 @@ class HallucinationState: Processes the current token and updates the state and mask accordingly. Detects hallucinations based on the token type and log probabilities. """ - content = "".join(self.tokens).replace(" ", "") - if self.tokens[-1] == TOOL_CALL_TOKEN: - self.mask.append(MaskToken.TOOL_CALL) - self._check_logprob() + content = "".join(self.tokens).replace(" ", "").replace("Ġ",'') # Function name extraction logic # If the state is function name and the token is not an end token, add to the mask + if content.endswith(END_TOOL_CALL_TOKEN): + self._reset_parameters() + if self.state == "function_name": if self.tokens[-1] not in FUNC_NAME_END_TOKEN: self.mask.append(MaskToken.FUNCTION_NAME) @@ -359,7 +350,7 @@ class HallucinationState: if check_threshold( entropy, varentropy, - self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value], + self.HALLUCINATION_THRESHOLD_DICT, ): self.hallucination = True self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}" diff --git a/model_server/src/main.py b/model_server/src/main.py index c3070392..ac29a743 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -100,10 +100,10 @@ async def function_calling(req: ChatMessage, res: Response): # ********************************************************************************************* # TODO: Put the following code back when hallucination check is ready # ********************************************************************************************* - # if not use_agent_orchestrator: - # final_response.metadata["hallucination"] = str( - # model_handler.hallucination_state.hallucination - # ) + if not use_agent_orchestrator: + final_response.metadata["hallucination"] = str( + model_handler.hallucination_state.hallucination + ) # No intent detected else: final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) @@ -114,9 +114,9 @@ async def function_calling(req: ChatMessage, res: Response): # ********************************************************************************************* # TODO: Put the following code back when hallucination check is ready # ********************************************************************************************* - # final_response.metadata["hallucination"] = str( - # model_handler.hallucination_state.hallucination - # ) + final_response.metadata["hallucination"] = str( + model_handler.hallucination_state.hallucination + ) except ValueError as e: res.statuscode = 503 diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index 0f2c9995..01b9ad95 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -123,35 +123,35 @@ def get_greeting_data(): return req, False, False, False -@pytest.mark.asyncio -@pytest.mark.parametrize( - "get_data_func", - [ - get_hallucination_data_complex, - get_complete_data, - get_irrelevant_data, - get_complete_data_2, - ], -) -async def test_function_calling(get_data_func): - req, intent, hallucination, parameter_gathering = get_data_func() +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "get_data_func", +# [ +# get_hallucination_data_complex, +# get_complete_data, +# get_irrelevant_data, +# get_complete_data_2, +# ], +# ) +# async def test_function_calling(get_data_func): +# req, intent, hallucination, parameter_gathering = get_data_func() - intent_response = await handler_map["Arch-Intent"].chat_completion(req) +# intent_response = await handler_map["Arch-Intent"].chat_completion(req) - assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent +# assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent - if intent: - function_calling_response = await handler_map["Arch-Function"].chat_completion( - req - ) - assert ( - handler_map["Arch-Function"].hallucination_state.hallucination - == hallucination - ) - response_txt = function_calling_response.choices[0].message.content +# if intent: +# function_calling_response = await handler_map["Arch-Function"].chat_completion( +# req +# ) +# assert ( +# handler_map["Arch-Function"].hallucination_state.hallucination +# == hallucination +# ) +# response_txt = function_calling_response.choices[0].message.content - if parameter_gathering: - prefill_prefix = handler_map["Arch-Function"].prefill_prefix - assert any( - response_txt.startswith(prefix) for prefix in prefill_prefix - ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" +# if parameter_gathering: +# prefill_prefix = handler_map["Arch-Function"].prefill_prefix +# assert any( +# response_txt.startswith(prefix) for prefix in prefill_prefix +# ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" From 8290d1969f68b7ab2f18dcd61e4c43cc96bfba69 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 28 Mar 2025 12:38:44 -0700 Subject: [PATCH 6/8] use public endpoint for arch v1.1 --- model_server/src/commons/globals.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py index 97da3e20..3712178e 100644 --- a/model_server/src/commons/globals.py +++ b/model_server/src/commons/globals.py @@ -16,6 +16,7 @@ logger = get_model_server_logger() # Define the client ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1") +ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://35.225.55.128:8000/v1") ARCH_API_KEY = "EMPTY" ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY) ARCH_AGENT_CLIENT = ARCH_CLIENT From 425f9b0dd54370c12034829ca4a28d2a06d96f92 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 28 Mar 2025 15:10:51 -0700 Subject: [PATCH 7/8] Update model usage --- model_server/src/core/function_calling.py | 20 ++++++++++++-------- model_server/src/core/utils/model_utils.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 847861e9..ac6c2605 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -27,16 +27,15 @@ logger = utils.get_model_server_logger() class ArchFunctionConfig: TASK_PROMPT = ( "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed." - "\nToday's date: {today_date}" - "\n\nYou are provided with function signatures within XML tags:\n{tool_text}\n" - "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n" + "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n" + "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary." ) FORMAT_PROMPT = ( - "Based on your analysis, provide your response in one of the following JSON formats:" - '\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```' - '\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```' - '\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```' + "\n\nBased on your analysis, provide your response in one of the following JSON formats:" + '\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```' + '\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```' + '\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```' ) GENERATION_PARAMS = { @@ -193,6 +192,11 @@ class ArchFunctionHandler(ArchBaseHandler): } try: + if content.startswith("```") and content.endswith("```"): + content = content.strip("```").strip() + if content.startswith("json"): + content = content[4:].strip() + model_response = json.loads(self._fix_json_string(content)) response_dict["response"] = model_response.get("response", "") @@ -414,7 +418,7 @@ class ArchFunctionHandler(ArchBaseHandler): for _ in self.hallucination_state: # check if the first token is if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None: - content = ''.join(self.hallucination_state.tokens) + content = "".join(self.hallucination_state.tokens) if "tool_calls" in content: has_tool_calls = True else: diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 2dfd2237..0ce75333 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -104,10 +104,10 @@ class ArchBaseHandler: """ today_date = utils.get_today_date() - tool_text = self._convert_tools(tools) + tools = self._convert_tools(tools) system_prompt = ( - self.task_prompt.format(today_date=today_date, tool_text=tool_text) + self.task_prompt.format(today_date=today_date, tools=tools) + self.format_prompt ) From f035d166c88f126c7c7ab88fe7ad1e884493f6f1 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 28 Mar 2025 16:30:03 -0700 Subject: [PATCH 8/8] Fix hallucination check --- model_server/src/commons/globals.py | 2 +- model_server/src/core/function_calling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py index 3712178e..a477a939 100644 --- a/model_server/src/commons/globals.py +++ b/model_server/src/commons/globals.py @@ -15,7 +15,7 @@ logger = get_model_server_logger() # Define the client -ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1") +# ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1") ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://35.225.55.128:8000/v1") ARCH_API_KEY = "EMPTY" ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index ac6c2605..1aa9ad46 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -417,7 +417,7 @@ class ArchFunctionHandler(ArchBaseHandler): has_tool_calls, has_hallucination = None, False for _ in self.hallucination_state: # check if the first token is - if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None: + if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None: content = "".join(self.hallucination_state.tokens) if "tool_calls" in content: has_tool_calls = True