diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 523f5781..e58bebde 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"]; pub const HEALTHZ_PATH: &str = "/healthz"; -pub const ARCH_STATE_HEADER: &str = "x-arch-state"; -pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; +pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; +pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; +pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message"; +pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response"; +pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 8ef176b2..e3120fc4 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -50,8 +50,7 @@ pub trait Client: Context { ) -> Result { debug!( "dispatching http call with args={:?} context={:?}", - call_args, - call_context + call_args, call_context ); match self.dispatch_http_call( diff --git a/crates/common/src/ratelimit.rs b/crates/common/src/ratelimit.rs index ef4e5d08..66c3facd 100644 --- a/crates/common/src/ratelimit.rs +++ b/crates/common/src/ratelimit.rs @@ -101,9 +101,7 @@ impl RatelimitMap { ) -> Result<(), Error> { debug!( "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", - provider, - selector, - tokens_used + provider, selector, tokens_used ); let provider_limits = match self.datastore.get(&provider) { diff --git a/crates/common/src/tokenizer.rs b/crates/common/src/tokenizer.rs index e9431521..11ce7295 100644 --- a/crates/common/src/tokenizer.rs +++ b/crates/common/src/tokenizer.rs @@ -1,4 +1,4 @@ -use log::{debug}; +use log::debug; #[allow(dead_code)] pub fn token_count(model_name: &str, text: &str) -> Result { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 55ee53ce..5b741a43 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -428,7 +428,7 @@ impl HttpContext for StreamContext { ); if self.request_body_sent_time.is_none() { - debug!("on_http_response_body: request body not sent, no doing any processing in llm filter"); + debug!("on_http_response_body: request body not sent, not doing any processing in llm filter"); return Action::Continue; } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 6fc23921..c0e8df94 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -4,10 +4,11 @@ use common::{ self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, }, consts::{ - ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE, - TRACE_PARENT_HEADER, USER_ROLE, + TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, + X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL, }, errors::ServerError, http::{CallArgs, Client}, @@ -125,8 +126,8 @@ impl HttpContext for StreamContext { self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + if metadata.contains_key(X_ARCH_STATE_HEADER) { + let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone(); let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); Some(arch_state) } else { @@ -336,10 +337,10 @@ impl HttpContext for StreamContext { if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_string()), - self.tool_calls.to_owned(), + None, ), ChatCompletionStreamResponse::new( self.tool_call_response.clone(), @@ -381,17 +382,39 @@ impl HttpContext for StreamContext { *metadata = Value::Object(serde_json::Map::new()); } - let fc_messages = vec![ - self.generate_toll_call_message(), - self.generate_api_response_message(), - ]; + let tool_call_message = self.generate_tool_call_message(); + let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_TOOL_CALL.to_string(), + serde_json::Value::String(tool_call_message_str), + ); + + let api_response_message = self.generate_api_response_message(); + let api_response_message_str = + serde_json::to_string(&api_response_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_API_RESPONSE.to_string(), + serde_json::Value::String(api_response_message_str), + ); + + let fc_messages = vec![tool_call_message, api_response_message]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), + X_ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); + + if let Some(arch_fc_response) = self.arch_fc_response.as_ref() { + metadata.as_object_mut().unwrap().insert( + X_ARCH_FC_MODEL_RESPONSE.to_string(), + serde_json::Value::String( + serde_json::to_string(arch_fc_response).unwrap(), + ), + ); + } let data_serialized = serde_json::to_string(&data).unwrap(); info!("archgw <= developer: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index ec68c104..1cd2fa86 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -9,6 +9,7 @@ use common::consts::{ API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + X_ARCH_FC_MODEL_RESPONSE, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; @@ -64,10 +65,10 @@ pub struct StreamContext { pub time_to_first_token: Option, pub traceparent: Option, pub _tracing: Rc>, + pub arch_fc_response: Option, } impl StreamContext { - #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, metrics: Rc, @@ -98,6 +99,7 @@ impl StreamContext { _tracing: tracing, start_upstream_llm_request_time: 0, time_to_first_token: None, + arch_fc_response: None, } } @@ -142,15 +144,16 @@ impl StreamContext { } }; - // intent was matched if we see function_latency in metadata - let intent_matched = model_server_response + let intent_matched = check_intent_matched(&model_server_response); + info!("intent matched: {}", intent_matched); + + self.arch_fc_response = model_server_response .metadata .as_ref() - .and_then(|metadata| metadata.get("function_latency")) - .is_some(); + .and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE)) + .cloned(); if !intent_matched { - info!("intent not matched"); // check if we have a default prompt target if let Some(default_prompt_target) = self .prompt_targets @@ -278,9 +281,9 @@ impl StreamContext { let direct_response_str = if self.streaming_response { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(ARCH_FC_MODEL_NAME.to_string()), None, ), ChatCompletionStreamResponse::new( @@ -293,7 +296,7 @@ impl StreamContext { .clone(), ), None, - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())), None, ), ]; @@ -623,13 +626,24 @@ impl StreamContext { messages } - pub fn generate_toll_call_message(&mut self) -> Message { - Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, + pub fn generate_tool_call_message(&mut self) -> Message { + if self.arch_fc_response.is_none() { + info!("arch_fc_response is none, generating tool call message"); + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + } + } else { + Message { + role: ASSISTANT_ROLE.to_string(), + content: self.arch_fc_response.as_ref().cloned(), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + } } } @@ -761,6 +775,23 @@ impl StreamContext { } } +fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { + let content = model_server_response + .choices.first() + .and_then(|choice| choice.message.content.as_ref()); + + let content_has_value = content.is_some() && !content.unwrap().is_empty(); + + let tool_calls = model_server_response + .choices.first() + .and_then(|choice| choice.message.tool_calls.as_ref()); + + // intent was matched if content has some value or tool_calls is empty + + + content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty()) +} + impl Client for StreamContext { type CallContext = StreamCallContext; @@ -772,3 +803,77 @@ impl Client for StreamContext { &self.metrics.active_http_calls } } + +#[cfg(test)] +mod test { + use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall}; + + use crate::stream_context::check_intent_matched; + + #[test] + fn test_intent_matched() { + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(!check_intent_matched(&model_server_response)); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("hello".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(check_intent_matched(&model_server_response)); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![ToolCall { + id: "1".to_string(), + function: common::api::open_ai::FunctionCallDetail { + name: "test".to_string(), + arguments: None, + }, + tool_type: common::api::open_ai::ToolType::Function, + }]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert!(check_intent_matched(&model_server_response)); + } +} diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index bbde10b0..91b36c01 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -380,6 +380,7 @@ fn prompt_gateway_request_to_llm_gateway() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) @@ -453,6 +454,7 @@ fn prompt_gateway_request_to_llm_gateway() { .expect_log(Some(LogLevel::Info), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -493,19 +495,9 @@ fn prompt_gateway_request_no_intent_match() { finish_reason: Some("test".to_string()), index: Some(0), message: Message { - role: "system".to_string(), + role: "assistant".to_string(), content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: Some(HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )])), - }, - }]), + tool_calls: None, model: None, tool_call_id: None, }, @@ -523,7 +515,7 @@ fn prompt_gateway_request_no_intent_match() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), Some("intent not matched")) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) .expect_log( Some(LogLevel::Info), Some("no default prompt target found, forwarding request to upstream llm"), @@ -651,17 +643,7 @@ fn prompt_gateway_request_no_intent_match_default_target() { message: Message { role: "system".to_string(), content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: Some(HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )])), - }, - }]), + tool_calls: None, model: None, tool_call_id: None, }, @@ -679,7 +661,7 @@ fn prompt_gateway_request_no_intent_match_default_target() { .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), Some("intent not matched")) + .expect_log(Some(LogLevel::Info), Some("intent matched: false")) .expect_log( Some(LogLevel::Info), Some("default prompt target found, forwarding request to default prompt target"), diff --git a/demos/samples_python/weather_forecast/arch_config.yaml b/demos/samples_python/weather_forecast/arch_config.yaml index db18eb85..b6463594 100644 --- a/demos/samples_python/weather_forecast/arch_config.yaml +++ b/demos/samples_python/weather_forecast/arch_config.yaml @@ -22,12 +22,12 @@ llm_providers: provider_interface: openai model: llama-3.2-3b-preview base_url: https://api.groq.com - default: true - name: gpt-4o access_key: $OPENAI_API_KEY provider_interface: openai model: gpt-4o + default: true system_prompt: | You are a helpful assistant. diff --git a/demos/samples_python/weather_forecast/main.py b/demos/samples_python/weather_forecast/main.py index 3be2f4da..84ee75e6 100644 --- a/demos/samples_python/weather_forecast/main.py +++ b/demos/samples_python/weather_forecast/main.py @@ -73,7 +73,7 @@ async def weather(req: WeatherRequest, res: Response): class DefaultTargetRequest(BaseModel): - messages: list + messages: list = [] @app.post("/default_target") @@ -86,12 +86,9 @@ async def default_target(req: DefaultTargetRequest, res: Response): "role": "assistant", "content": "I can help you with weather forecast", }, - "finish_reason": "completed", - "index": 0, } ], "model": "api_server", - "usage": {"completion_tokens": 0}, } logger.info(f"sending response: {json.dumps(resp)}") return resp diff --git a/demos/shared/chatbot_ui/.vscode/launch.json b/demos/shared/chatbot_ui/.vscode/launch.json index e7f91d36..6a470b7b 100644 --- a/demos/shared/chatbot_ui/.vscode/launch.json +++ b/demos/shared/chatbot_ui/.vscode/launch.json @@ -15,7 +15,7 @@ "LLM": "1", "CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1", "STREAMING": "True", - "ARCH_CONFIG": "../../weather_forecast/arch_config.yaml" + "ARCH_CONFIG": "../../samples_python/weather_forecast/arch_config.yaml" } }, { @@ -29,7 +29,7 @@ "LLM": "1", "CHAT_COMPLETION_ENDPOINT": "http://localhost:12000/v1", "STREAMING": "True", - "ARCH_CONFIG": "../../llm_routing/arch_config.yaml" + "ARCH_CONFIG": "../../samples_python/weather_forecast/arch_config.yaml" } }, ] diff --git a/demos/shared/chatbot_ui/common.py b/demos/shared/chatbot_ui/common.py index 42e50bd4..1de8f94c 100644 --- a/demos/shared/chatbot_ui/common.py +++ b/demos/shared/chatbot_ui/common.py @@ -120,8 +120,11 @@ def process_stream_chunk(chunk, history): if delta.content: # append content to the last history item - history[-1]["content"] = history[-1].get("content", "") + delta.content + if history[-1]["model"] != "Arch-Function-Chat": + history[-1]["content"] = history[-1].get("content", "") + delta.content # yield content if it is from assistant + if history[-1]["model"] == "Arch-Function": + return None if history[-1]["role"] == "assistant": return delta.content diff --git a/demos/shared/chatbot_ui/run_stream.py b/demos/shared/chatbot_ui/run_stream.py index b406e147..89448355 100644 --- a/demos/shared/chatbot_ui/run_stream.py +++ b/demos/shared/chatbot_ui/run_stream.py @@ -88,6 +88,22 @@ def chat( yield "", conversation, history, debug_output, model_selector + # update assistant response to have correct format + # arch-fc 1.1 expects following format: + # { + # "response": "", + # } + # and this entire block needs to be encoded in ```json\n{json_encoded_content}\n``` + + if not history[-1]["model"].startswith("Arch"): + assistant_response = { + "response": history[-1]["content"], + } + history[-1]["content"] = "```json\n{}\n```".format( + json.dumps(assistant_response) + ) + log.info("history: {}".format(json.dumps(history))) + def main(): with gr.Blocks( diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index 289d8bf2..11087d3a 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -30,5 +30,11 @@ llm_providers: model: deepseek-reasoner base_url: https://api.deepseek.com/ + - name: groq + access_key: $GROQ_API_KEY + provider_interface: openai + model: llama-3.1-8b-instant + base_url: https://api.groq.com + tracing: random_sampling: 100 diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py index 49dce5e7..a39be7c3 100644 --- a/model_server/src/commons/globals.py +++ b/model_server/src/commons/globals.py @@ -5,8 +5,6 @@ from src.core.guardrails import get_guardrail_handler from src.core.function_calling import ( ArchAgentConfig, ArchAgentHandler, - ArchIntentConfig, - ArchIntentHandler, ArchFunctionConfig, ArchFunctionHandler, ) @@ -17,7 +15,10 @@ logger = get_model_server_logger() # Define the client -ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1") +# ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1") +# use temporary endpoint until we deprecate archfc-v1.0 from archfc.katanemo.dev +# and officially release archfc-v1.1 on archfc.katanemo.dev +ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://34.72.123.163:8000/v1") ARCH_API_KEY = "EMPTY" ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY) ARCH_AGENT_CLIENT = ARCH_CLIENT @@ -30,9 +31,6 @@ ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard" # Define model handlers handler_map = { - "Arch-Intent": ArchIntentHandler( - ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig - ), "Arch-Function": ArchFunctionHandler( ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig ), diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 8c01493a..bec5dfdf 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -3,7 +3,6 @@ import copy import json import random import builtins -import textwrap import src.commons.utils as utils from openai import OpenAI @@ -22,179 +21,25 @@ from src.core.utils.model_utils import ( logger = utils.get_model_server_logger() -class ArchIntentConfig: - TASK_PROMPT = textwrap.dedent( - """ - You are a helpful assistant. - """ - ).strip() - - TOOL_PROMPT_TEMPLATE = textwrap.dedent( - """ - You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below. - - - {tool_text} - - """ - ).strip() - - FORMAT_PROMPT = textwrap.dedent( - """ - Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation: - - First line must read 'Yes' or 'No'. - - If yes, a second line must include a comma-separated list of tool indexes. - """ - ).strip() - - EXTRA_INSTRUCTION = "Are there any tools can help?" - - GENERATION_PARAMS = { - "temperature": 0.01, - "max_tokens": 1, - "stop_token_ids": [151645], - } - - -class ArchIntentHandler(ArchBaseHandler): - def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig): - """ - Initializes the intent handler. - - Args: - client (OpenAI): An OpenAI client instance. - model_name (str): Name of the model to use. - config (ArchIntentConfig): The configuration for Arch-Intent. - """ - - super().__init__( - client, - model_name, - config.TASK_PROMPT, - config.TOOL_PROMPT_TEMPLATE, - config.FORMAT_PROMPT, - config.GENERATION_PARAMS, - ) - - self.extra_instruction = config.EXTRA_INSTRUCTION - - @override - def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: - """ - Converts a list of tools into a JSON-like format with indexed keys. - - Args: - tools (List[Dict[str, Any]]): A list of tools represented as dictionaries. - - Returns: - str: A string representation of converted tools. - """ - - converted = [ - json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools) - ] - return "\n".join(converted) - - def detect_intent(self, content: str) -> bool: - """ - Detect if any intent match with prompts - - Args: - content: str: Model response that contains intent detection results - - Returns: - bool: A boolean value to indicate if any intent match with prompts or not - """ - if hasattr(content.choices[0].message, "content"): - return content.choices[0].message.content == "Yes" - else: - return False - - @override - async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: - """ - Generates a chat completion for a given request. - - Args: - req (ChatMessage): A chat message request object. - - Returns: - ChatCompletionResponse: The model's response to the chat request. - - Note: - Currently only support vllm inference - """ - logger.info("[Arch-Intent] - ChatCompletion") - - # In the case that no tools are available, simply return `No` to avoid making a call - if len(req.tools) == 0: - model_response = Message(content="No", tool_calls=[]) - logger.info("No tools found, return `No` as the model response.") - else: - messages = self._process_messages( - req.messages, req.tools, self.extra_instruction - ) - - logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}") - - model_response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - stream=False, - extra_body=self.generation_params, - ) - - logger.info(f"[response]: {json.dumps(model_response.model_dump())}") - - model_response = Message( - content=model_response.choices[0].message.content, tool_calls=[] - ) - - chat_completion_response = ChatCompletionResponse( - choices=[Choice(message=model_response)], model=self.model_name - ) - - return chat_completion_response - - -# ============================================================================================================= +# ============================================================================================================================================== class ArchFunctionConfig: - TASK_PROMPT = textwrap.dedent( - """ - You are a helpful assistant. + TASK_PROMPT = ( + "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed." + "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n" + "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary." + ) - Today's date: {} - """.format( - utils.get_today_date() - ) - ).strip() - - TOOL_PROMPT_TEMPLATE = textwrap.dedent( - """ - # Tools - - You may call one or more functions to assist with the user query. - - You are provided with function signatures within XML tags: - - {tool_text} - - """ - ).strip() - - FORMAT_PROMPT = textwrap.dedent( - """ - For each function call, return a json object with function name and arguments within XML tags: - - {"name": , "arguments": } - - """ - ).strip() + FORMAT_PROMPT = ( + "\n\nBased on your analysis, provide your response in one of the following JSON formats:" + '\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```' + '\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```' + '\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```' + ) GENERATION_PARAMS = { - "temperature": 0.6, + "temperature": 0.1, "top_p": 1.0, "top_k": 10, "max_tokens": 1024, @@ -203,34 +48,9 @@ class ArchFunctionConfig: "top_logprobs": 10, } - PREFILL_CONFIG = { - "prefill_params": { - "continue_final_message": True, - "add_generation_prompt": False, - }, - "prefill_prefix": [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ], - } - SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] -class ArchAgentConfig(ArchFunctionConfig): - GENERATION_PARAMS = { - "temperature": 0.01, - "stop_token_ids": [151645], - "logprobs": True, - "top_logprobs": 10, - } - - class ArchFunctionHandler(ArchBaseHandler): def __init__( self, @@ -251,13 +71,17 @@ class ArchFunctionHandler(ArchBaseHandler): client, model_name, config.TASK_PROMPT, - config.TOOL_PROMPT_TEMPLATE, config.FORMAT_PROMPT, config.GENERATION_PARAMS, ) - self.prefill_params = config.PREFILL_CONFIG["prefill_params"] - self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"] + self.generation_params = self.generation_params | { + "continue_final_message": True, + "add_generation_prompt": False, + } + + self.default_prefix = '```json\n{"' + self.clarify_prefix = '```json\n{"required_functions":' self.hallucination_state = None @@ -280,7 +104,7 @@ class ArchFunctionHandler(ArchBaseHandler): str: A string representation of converted tools. """ - converted = [json.dumps(tool) for tool in tools] + converted = [json.dumps(tool["function"], ensure_ascii=False) for tool in tools] return "\n".join(converted) def _fix_json_string(self, json_str: str) -> str: @@ -328,10 +152,14 @@ class ArchFunctionHandler(ArchBaseHandler): unmatched_opening = stack.pop() fixed_str += opening_bracket[unmatched_opening] - # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str.replace("'", '"') + try: + fixed_str = json.loads(fixed_str) + except Exception: + fixed_str = json.loads(fixed_str.replace("'", '"')) - def _extract_tool_calls(self, content: str) -> Dict[str, any]: + return json.dumps(fixed_str) + + def _parse_model_response(self, content: str) -> Dict[str, any]: """ Extracts tool call information from a given string. @@ -340,49 +168,55 @@ class ArchFunctionHandler(ArchBaseHandler): Returns: Dict: A dictionary of extraction, including: - - "result": A list of tool call dictionaries. - - "status": A boolean indicating if the extraction was valid. - - "message": An error message or exception if extraction failed. + - "required_functions": A list of detected intents. + - "clarification": Text to collect missing parameters + - "tool_calls": A list of tool call dictionaries. + - "is_valid": A boolean indicating if the extraction was valid. + - "error_message": An error message or exception if parsing failed. """ - tool_calls, is_valid, error_message = [], True, "" + response_dict = { + "raw_response": [], + "response": [], + "required_functions": [], + "clarification": "", + "tool_calls": [], + "is_valid": True, + "error_message": "", + } - flag = False - for line in content.split("\n"): - if not is_valid: - break + try: + if content.startswith("```") and content.endswith("```"): + content = content.strip("```").strip() + if content.startswith("json"): + content = content[4:].strip() - if "" == line: - flag = True - elif "" == line: - flag = False - else: - if flag: - try: - tool_content = json.loads(line) - except Exception as e: - fixed_content = self._fix_json_string(line) - try: - tool_content = json.loads(fixed_content) - except Exception: - is_valid, error_message = False, e - break + content = self._fix_json_string(content) + response_dict["raw_response"] = f"```json\n{content}\n```" - tool = { + model_response = json.loads(content) + response_dict["response"] = model_response.get("response", "") + response_dict["required_functions"] = model_response.get( + "required_functions", [] + ) + response_dict["clarification"] = model_response.get("clarification", "") + + for tool_call in model_response.get("tool_calls", []): + response_dict["tool_calls"].append( + { "id": f"call_{random.randint(1000, 10000)}", "type": "function", "function": { - "name": tool_content["name"], + "name": tool_call.get("name", ""), + "arguments": tool_call.get("arguments", {}), }, } - if "arguments" in tool_content: - tool["function"]["arguments"] = tool_content["arguments"] + ) + except Exception as e: + response_dict["is_valid"] = False + response_dict["error_message"] = f"Fail to parse model responses: {e}" - tool_calls.append(tool) - - flag = False - - return {"result": tool_calls, "status": is_valid, "message": error_message} + return response_dict def _convert_data_type(self, value: str, target_type: str): # TODO: Add more conversion rules as needed @@ -414,36 +248,37 @@ class ArchFunctionHandler(ArchBaseHandler): - "message": An error message. """ - is_valid, invalid_tool_call, error_message = True, None, "" + verification_dict = { + "is_valid": True, + "invalid_tool_call": {}, + "error_message": "", + } functions = {} for tool in tools: - if tool["type"] == "function": - functions[tool["function"]["name"]] = tool["function"]["parameters"] + functions[tool["function"]["name"]] = tool["function"]["parameters"] for tool_call in tool_calls: - if not is_valid: + if not verification_dict["is_valid"]: break func_name = tool_call["function"]["name"] - func_args = tool_call["function"].get("arguments") - if not func_args: - func_args = {} + func_args = tool_call["function"]["arguments"] # Check whether the function is available or not if func_name not in functions: - is_valid = False - invalid_tool_call = tool_call - error_message = f"{func_name} is not defined!" - break - + verification_dict["is_valid"] = False + verification_dict["invalid_tool_call"] = tool_call + verification_dict["error_message"] = f"{func_name} is not available!" else: # Check if all the requried parameters can be found in the tool calls for required_param in functions[func_name].get("required", []): if required_param not in func_args: - is_valid = False - invalid_tool_call = tool_call - error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" + verification_dict["is_valid"] = False + verification_dict["invalid_tool_call"] = tool_call + verification_dict[ + "error_message" + ] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" break # Verify the data type of each parameter in the tool calls @@ -453,9 +288,11 @@ class ArchFunctionHandler(ArchBaseHandler): logger.info(func_args) for param_name in func_args: if param_name not in function_properties: - is_valid = False - invalid_tool_call = tool_call - error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`." + verification_dict["is_valid"] = False + verification_dict["invalid_tool_call"] = tool_call + verification_dict[ + "error_message" + ] = f"Parameter `{param_name}` is not defined in the function `{func_name}`." break else: param_value = func_args[param_name] @@ -469,22 +306,22 @@ class ArchFunctionHandler(ArchBaseHandler): param_value, data_type ) if not isinstance(param_value, data_type): - is_valid = False - invalid_tool_call = tool_call - error_message = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." + verification_dict["is_valid"] = False + verification_dict["invalid_tool_call"] = tool_call + verification_dict[ + "error_message" + ] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." break else: - error_message = ( - f"Data type `{target_type}` is not supported." - ) + verification_dict["is_valid"] = False + verification_dict["invalid_tool_call"] = tool_call + verification_dict[ + "error_message" + ] = f"Data type `{target_type}` is not supported." - return { - "status": is_valid, - "invalid_tool_call": invalid_tool_call, - "message": error_message, - } + return verification_dict - def _add_prefill_message(self, messages: List[Dict[str, str]]): + def _prefill_message(self, messages: List[Dict[str, str]], prefill_message): """ Update messages and generation params for prompt prefilling @@ -494,29 +331,7 @@ class ArchFunctionHandler(ArchBaseHandler): Returns: prefill_messages (List[Dict[str, str]]): A list of messages. """ - - return messages + [ - { - "role": "assistant", - "content": random.choice(self.prefill_prefix), - } - ] - - def _engage_parameter_gathering(self, messages: List[Dict[str, str]]): - """ - Engage parameter gathering for tool calls - """ - - # TODO: log enaging parameter gathering - prefill_response = self.client.chat.completions.create( - messages=self._add_prefill_message(messages), - model=self.model_name, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) - return prefill_response + return messages + [{"role": "assistant", "content": prefill_message}] @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: @@ -544,7 +359,7 @@ class ArchFunctionHandler(ArchBaseHandler): # always enable `stream=True` to collect model responses response = self.client.chat.completions.create( - messages=messages, + messages=self._prefill_message(messages, self.default_prefix), model=self.model_name, stream=True, extra_body=self.generation_params, @@ -565,72 +380,114 @@ class ArchFunctionHandler(ArchBaseHandler): has_tool_calls, has_hallucination = None, False for _ in self.hallucination_state: - # check if the first token is - if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None: - if self.hallucination_state.tokens[0] == "": + # check if moodel response starts with tool calls, we do it after 5 tokens because we only check the first part of the response. + if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None: + content = "".join(self.hallucination_state.tokens) + if "tool_calls" in content: has_tool_calls = True else: has_tool_calls = False - break # if the model is hallucinating, start parameter gathering if self.hallucination_state.hallucination is True: has_hallucination = True break - if has_tool_calls: - if has_hallucination: - # start prompt prefilling if hallcuination is found in tool calls - logger.info( - f"[Hallucination]: {self.hallucination_state.error_message}" - ) - prefill_response = self._engage_parameter_gathering(messages) - model_response = prefill_response.choices[0].message.content - else: - model_response = "".join(self.hallucination_state.tokens) + if has_tool_calls and has_hallucination: + # start prompt prefilling if hallcuination is found in tool calls + logger.info( + f"[Hallucination]: {self.hallucination_state.error_message}" + ) + response = self.client.chat.completions.create( + messages=self._prefill_message(messages, self.clarify_prefix), + model=self.model_name, + stream=False, + extra_body=self.generation_params, + ) + model_response = response.choices[0].message.content else: - # start parameter gathering if the model is not generating tool calls - prefill_response = self._engage_parameter_gathering(messages) - model_response = prefill_response.choices[0].message.content + model_response = "".join(self.hallucination_state.tokens) # Extract tool calls from model response - extracted = self._extract_tool_calls(model_response) + response_dict = self._parse_model_response(model_response) + logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}") - if extracted["status"]: - # Response with tool calls - if len(extracted["result"]): - verified = {} - if use_agent_orchestrator: - # skip tool call verification if using agent orchestrator - verified = {"status": True, "message": ""} - else: - verified = self._verify_tool_calls( - tools=req.tools, tool_calls=extracted["result"] - ) - - if verified["status"]: - logger.info( - f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}" - ) - model_response = Message(content="", tool_calls=extracted["result"]) - else: - logger.error(f"Invalid tool call - {verified['message']}") - # Response without tool calls + # General model response + if response_dict.get("response", ""): + model_message = Message(content="", tool_calls=[]) + # Parameter gathering + elif response_dict.get("required_functions", []): + if not use_agent_orchestrator: + clarification = response_dict.get("clarification", "") + model_message = Message(content=clarification, tool_calls=[]) else: - model_response = Message(content=model_response, tool_calls=[]) - # Response with tool calls but contain errors + model_message = Message(content="", tool_calls=[]) + # Function Calling + elif response_dict.get("tool_calls", []): + if response_dict["is_valid"]: + if not use_agent_orchestrator: + verification_dict = self._verify_tool_calls( + tools=req.tools, tool_calls=response_dict["tool_calls"] + ) + + if verification_dict["is_valid"]: + logger.info( + f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}" + ) + model_message = Message( + content="", tool_calls=response_dict["tool_calls"] + ) + else: + logger.error( + f"Invalid tool call - {verification_dict['error_message']}" + ) + model_message = Message(content="", tool_calls=[]) + else: + # skip tool call verification if using agent orchestrator + logger.info( + f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}" + ) + model_message = Message( + content="", tool_calls=response_dict["tool_calls"] + ) + + else: + # Response with tool calls but invalid + model_message = Message(content="", tool_calls=[]) + # Response not in the desired format else: - logger.error(f"Tool call extraction error - {extracted['message']}") + logger.error(f"Invalid model response - {model_response}") + model_message = Message(content="", tool_calls=[]) chat_completion_response = ChatCompletionResponse( - choices=[Choice(message=model_response)], model=self.model_name + choices=[Choice(message=model_message)], + model=self.model_name, + metadata={"x-arch-fc-model-response": response_dict["raw_response"]}, + role="assistant", ) - logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}") + logger.info( + f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}" + ) return chat_completion_response +# ============================================================================================================================================== + + +class ArchAgentConfig(ArchFunctionConfig): + GENERATION_PARAMS = { + "temperature": 0.01, + "top_p": 1.0, + "top_k": 10, + "max_tokens": 1024, + "stop_token_ids": [151645], + "logprobs": True, + "top_logprobs": 10, + } + + class ArchAgentHandler(ArchFunctionHandler): def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig): super().__init__(client, model_name, config) @@ -657,7 +514,7 @@ class ArchAgentHandler(ArchFunctionHandler): ): tool_copy = copy.deepcopy(tool) del tool_copy["function"]["parameters"] - converted.append(json.dumps(tool_copy)) + converted.append(json.dumps(tool_copy["function"], ensure_ascii=False)) else: - converted.append(json.dumps(tool)) + converted.append(json.dumps(tool["function"], ensure_ascii=False)) return "\n".join(converted) diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 91effc92..992f8aa4 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger logger = get_model_server_logger() # constants -FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'") FUNC_NAME_END_TOKEN = ('",', "',") -TOOL_CALL_TOKEN = "" -END_TOOL_CALL_TOKEN = "" +END_TOOL_CALL_TOKEN = "}}" FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") -PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") -PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'") +PARAMETER_NAME_START_PATTERN = ('","', "','") PARAMETER_VALUE_START_PATTERN = ('":', "':") -PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") +PARAMETER_VALUE_END_TOKEN = ('",', '"}') BRACKETS = {"(": ")", "{": "}", "[": "]"} @@ -37,16 +36,9 @@ class MaskToken(Enum): HALLUCINATION_THRESHOLD_DICT = { - MaskToken.TOOL_CALL.value: { - "entropy": 0.35, - "varentropy": 1.7, - "probability": 0.8, - }, - MaskToken.PARAMETER_VALUE.value: { - "entropy": 0.28, - "varentropy": 1.4, - "probability": 0.8, - }, + "entropy": 0.0001, + "varentropy": 0.0001, + "probability": 0.8, } @@ -160,6 +152,7 @@ class HallucinationState: self._process_function(function) self.open_bracket = False self.bracket = None + self.function_name = "" self.check_parameter_name = {} self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT @@ -208,22 +201,20 @@ class HallucinationState: r = next(self.response_iterator) if hasattr(r.choices[0].delta, "content"): token_content = r.choices[0].delta.content - if token_content: + if token_content != "": try: logprobs = [ p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs ] - except Exception as e: - raise ValueError( - f"Error extracting logprobs from response: {e}" - ) - if token_content == END_TOOL_CALL_TOKEN: - self._reset_parameters() - else: self.append_and_check_token_hallucination( token_content, logprobs ) + except Exception as e: + self.append_and_check_token_hallucination( + token_content, [None] + ) + return token_content except StopIteration: raise StopIteration @@ -234,12 +225,12 @@ class HallucinationState: Detects hallucinations based on the token type and log probabilities. """ content = "".join(self.tokens).replace(" ", "") - if self.tokens[-1] == TOOL_CALL_TOKEN: - self.mask.append(MaskToken.TOOL_CALL) - self._check_logprob() # Function name extraction logic # If the state is function name and the token is not an end token, add to the mask + if content.endswith(END_TOOL_CALL_TOKEN): + self._reset_parameters() + if self.state == "function_name": if self.tokens[-1] not in FUNC_NAME_END_TOKEN: self.mask.append(MaskToken.FUNCTION_NAME) @@ -359,7 +350,7 @@ class HallucinationState: if check_threshold( entropy, varentropy, - self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value], + self.HALLUCINATION_THRESHOLD_DICT, ): self.hallucination = True self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}" diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 73dc6fef..9dfac528 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -1,4 +1,5 @@ import json +import src.commons.utils as utils from openai import OpenAI from pydantic import BaseModel @@ -56,7 +57,6 @@ class ArchBaseHandler: client: OpenAI, model_name: str, task_prompt: str, - tool_prompt_template: str, format_prompt: str, generation_params: Dict, ): @@ -67,7 +67,6 @@ class ArchBaseHandler: client (OpenAI): An OpenAI client instance. model_name (str): Name of the model to use. task_prompt (str): The main task prompt for the system. - tool_prompt (str): A prompt to describe tools. format_prompt (str): A prompt specifying the desired output format. generation_params (Dict): Generation parameters for the model. """ @@ -75,7 +74,6 @@ class ArchBaseHandler: self.model_name = model_name self.task_prompt = task_prompt - self.tool_prompt_template = tool_prompt_template self.format_prompt = format_prompt self.generation_params = generation_params @@ -105,13 +103,11 @@ class ArchBaseHandler: str: A formatted system prompt. """ - tool_text = self._convert_tools(tools) + today_date = utils.get_today_date() + tools = self._convert_tools(tools) system_prompt = ( - self.task_prompt - + "\n\n" - + self.tool_prompt_template.format(tool_text=tool_text) - + "\n\n" + self.task_prompt.format(today_date=today_date, tools=tools) + self.format_prompt ) @@ -146,7 +142,7 @@ class ArchBaseHandler: {"role": "system", "content": self._format_system_prompt(tools)} ) - for message in messages: + for idx, message in enumerate(messages): role, content, tool_calls = ( message.role, message.content, @@ -162,9 +158,24 @@ class ArchBaseHandler: if metadata.get("optimize_context_window", "false").lower() == "true": content = f"\n\n" else: - content = ( - f"\n{json.dumps(content)}\n" + # sample response below + # "content": "\n{'name': 'get_stock_price', 'result': '$196.66'}\n" + # msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}' + tool_call_msg = messages[idx - 1].content + if tool_call_msg.startswith("```") and tool_call_msg.endswith( + "```" + ): + tool_call_msg = tool_call_msg.strip("```").strip() + if tool_call_msg.startswith("json"): + tool_call_msg = tool_call_msg[4:].strip() + func_name = json.loads(tool_call_msg)["tool_calls"][0].get( + "name", "no_name" ) + tool_response = { + "name": func_name, + "result": content, + } + content = f"\n{json.dumps(tool_response)}\n" processed_messages.append({"role": role, "content": content}) diff --git a/model_server/src/main.py b/model_server/src/main.py index 1763b015..34856498 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -71,67 +71,58 @@ async def models(): @app.post("/function_calling") async def function_calling(req: ChatMessage, res: Response): logger.info("[Endpoint: /function_calling]") - logger.info(f"[request body]: {json.dumps(req.model_dump())}") + logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}") final_response: ChatCompletionResponse = None error_messages = None + use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False) + logger.info(f"Use agent orchestrator: {use_agent_orchestrator}") + try: - intent_detected = False - use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False) - logger.info(f"Use agent orchestrator: {use_agent_orchestrator}") - if not use_agent_orchestrator: - intent_start_time = time.perf_counter() - intent_response = await handler_map["Arch-Intent"].chat_completion(req) - intent_latency = time.perf_counter() - intent_start_time - intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response) + handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function" + model_handler: ArchFunctionHandler = handler_map[handler_name] - if use_agent_orchestrator or intent_detected: - # TODO: measure agreement between intent detection and function calling - try: - function_start_time = time.perf_counter() - handler_name = ( - "Arch-Agent" if use_agent_orchestrator else "Arch-Function" + start_time = time.perf_counter() + final_response = await model_handler.chat_completion(req) + latency = time.perf_counter() - start_time + + if not final_response.metadata: + final_response.metadata = {} + + # Parameter gathering for detected intents + if final_response.choices[0].message.content: + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) + # Function Calling + elif final_response.choices[0].message.tool_calls: + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) + + if not use_agent_orchestrator: + final_response.metadata["hallucination"] = str( + model_handler.hallucination_state.hallucination ) - function_calling_handler: ArchFunctionHandler = handler_map[ - handler_name - ] - final_response = await function_calling_handler.chat_completion(req) - function_latency = time.perf_counter() - function_start_time - - final_response.metadata = { - "function_latency": str(round(function_latency * 1000, 3)), - } - - if not use_agent_orchestrator: - final_response.metadata["intent_latency"] = str( - round(intent_latency * 1000, 3) - ) - final_response.metadata["hallucination"] = str( - function_calling_handler.hallucination_state.hallucination - ) - except ValueError as e: - res.statuscode = 503 - error_messages = ( - f"[{handler_name}] - Error in tool call extraction: {e}" - ) - except StopIteration as e: - res.statuscode = 500 - error_messages = f"[{handler_name}] - Error in hallucination check: {e}" - except Exception as e: - res.status_code = 500 - error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}" - raise + # No intent detected else: - # no intent matched - intent_response.metadata = { - "intent_latency": str(round(intent_latency * 1000, 3)), - } - final_response = intent_response + final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) + if not use_agent_orchestrator: + final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) + + final_response.metadata["hallucination"] = str( + model_handler.hallucination_state.hallucination + ) + + except ValueError as e: + res.statuscode = 503 + error_messages = f"[{handler_name}] - Error in tool call extraction: {e}" + raise + except StopIteration as e: + res.statuscode = 500 + error_messages = f"[{handler_name}] - Error in hallucination check: {e}" + raise except Exception as e: res.status_code = 500 - error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}" + error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}" raise if error_messages is not None: @@ -144,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response): @app.post("/guardrails") async def guardrails(req: GuardRequest, res: Response, max_num_words=300): logger.info("[Endpoint: /guardrails] - Gateway") - logger.info(f"[request body]: {json.dumps(req.model_dump())}") + logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}") final_response: GuardResponse = None error_messages = None diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index 0f2c9995..6005d5e6 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -1,5 +1,5 @@ import pytest - +import time from src.commons.globals import handler_map from src.core.utils.model_utils import ChatMessage, Message @@ -37,26 +37,9 @@ get_weather_api = { # get_data class return request, intent, hallucination, parameter_gathering -def get_hallucination_data_complex(): +def get_hallucination_data(): # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle?") - message2 = Message( - role="assistant", content="Can you specify the unit you want the weather in?" - ) - message3 = Message(role="user", content="In celcius please!") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1, message2, message3], tools=tools) - - return req, True, True, True - - -def get_hallucination_data_medium(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in?") + message1 = Message(role="user", content="How is the weather in Seattle in days?") # Create a list of tools tools = [get_weather_api] @@ -65,26 +48,10 @@ def get_hallucination_data_medium(): req = ChatMessage(messages=[message1], tools=tools) # first token will not be tool call - return req, True, True, True + return req, False, True -def get_complete_data_2(): - # Create instances of the Message class - message1 = Message( - role="user", - content="what is the weather forecast for seattle in the next 10 days?", - ) - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - return req, True, False, False - - -def get_complete_data(): +def get_success_tool_call_data(): # Create instances of the Message class message1 = Message(role="user", content="How is the weather in Seattle in 7 days?") @@ -94,7 +61,7 @@ def get_complete_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, True, False, False + return req, True, False def get_irrelevant_data(): @@ -107,7 +74,7 @@ def get_irrelevant_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, False, False, False + return req, False, False def get_greeting_data(): @@ -120,38 +87,29 @@ def get_greeting_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, False, False, False + return req, False, False @pytest.mark.asyncio @pytest.mark.parametrize( "get_data_func", [ - get_hallucination_data_complex, - get_complete_data, + get_hallucination_data, + get_greeting_data, get_irrelevant_data, - get_complete_data_2, + get_success_tool_call_data, ], ) async def test_function_calling(get_data_func): - req, intent, hallucination, parameter_gathering = get_data_func() + req, intent, hallucination = get_data_func() + handler_name = "Arch-Function" + use_agent_orchestrator = False + model_handler: ArchFunctionHandler = handler_map[handler_name] - intent_response = await handler_map["Arch-Intent"].chat_completion(req) + start_time = time.perf_counter() + final_response = await model_handler.chat_completion(req) + latency = time.perf_counter() - start_time - assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent + assert intent == (len(final_response.choices[0].message.tool_calls) >= 1) - if intent: - function_calling_response = await handler_map["Arch-Function"].chat_completion( - req - ) - assert ( - handler_map["Arch-Function"].hallucination_state.hallucination - == hallucination - ) - response_txt = function_calling_response.choices[0].message.content - - if parameter_gathering: - prefill_prefix = handler_map["Arch-Function"].prefill_prefix - assert any( - response_txt.startswith(prefix) for prefix in prefill_prefix - ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" + assert hallucination == model_handler.hallucination_state.hallucination diff --git a/model_server/tests/core/test_state.py b/model_server/tests/core/test_state.py index b5a0332a..331b8d3f 100644 --- a/model_server/tests/core/test_state.py +++ b/model_server/tests/core/test_state.py @@ -1,5 +1,5 @@ from src.commons.globals import handler_map -from src.core.function_calling import Message +from src.core.function_calling import ArchFunctionHandler, Message test_input_history = [ @@ -7,34 +7,19 @@ test_input_history = [ { "role": "assistant", "model": "Arch-Function", - "tool_calls": [ - { - "id": "call_3394", - "type": "function", - "function": { - "name": "weather_forecast", - "arguments": {"city": "Chicago", "days": 5}, - }, - } - ], + "content": '```json\n{"tool_calls": [{"name": "get_current_weather", "arguments": {"days": 5, "location": "Chicago, Illinois"}}]}\n```', + }, + { + "role": "tool", + "model": "Arch-Function", + "content": '{"location":"Chicago%2C%20Illinois","temperature":[{"date":"2025-04-14","temperature":{"min":53,"max":65},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432817+00:00"},{"date":"2025-04-15","temperature":{"min":85,"max":97},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432830+00:00"},{"date":"2025-04-16","temperature":{"min":62,"max":78},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432835+00:00"},{"date":"2025-04-17","temperature":{"min":89,"max":101},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432839+00:00"},{"date":"2025-04-18","temperature":{"min":86,"max":104},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432843+00:00"}],"units":"Farenheit"}', }, - {"role": "tool", "content": "--", "tool_call_id": "call_3394"}, - {"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"}, - {"role": "user", "content": "how is the weather in chicago for next 5 days?"}, { "role": "assistant", - "tool_calls": [ - { - "id": "call_5306", - "type": "function", - "function": { - "name": "weather_forecast", - "arguments": {"city": "Chicago", "days": 5}, - }, - } - ], + "model": "gpt-4o-2024-08-06", + "content": '{"response": "Based on the forecast data you provided, here is the weather for the next 5 days in Chicago:\\n\\n- **April 14, 2025**: The temperature will range between 53\\u00b0F and 65\\u00b0F. \\n- **April 15, 2025**: The temperature will range between 85\\u00b0F and 97\\u00b0F.\\n- **April 16, 2025**: The temperature will range between 62\\u00b0F and 78\\u00b0F.\\n- **April 17, 2025**: The temperature will range between 89\\u00b0F and 101\\u00b0F.\\n- **April 18, 2025**: The temperature will range between 86\\u00b0F and 104\\u00b0F.\\n\\nPlease note that the temperatures are given in Fahrenheit."}', }, - {"role": "tool", "content": "--", "tool_call_id": "call_5306"}, + {"role": "user", "content": "what about seattle?"}, ] @@ -44,7 +29,8 @@ def test_update_fc_history(): for h in test_input_history: message_history.append(Message(**h)) - updated_history = handler_map["Arch-Function"]._process_messages(message_history) - assert len(updated_history) == 7 + handler: ArchFunctionHandler = handler_map["Arch-Function"] + updated_history = handler._process_messages(message_history) + assert len(updated_history) == 5 # ensure that tool role does not exist anymore assert all([h["role"] != "tool" for h in updated_history]) diff --git a/tests/archgw/common.py b/tests/archgw/common.py index 404d8ef9..f87f6403 100644 --- a/tests/archgw/common.py +++ b/tests/archgw/common.py @@ -47,14 +47,11 @@ TEST_CASE_FIXTURES = { "tool_call_id": "", "tool_calls": [ { - "id": "call_6009", + "id": "call_2925", "type": "function", "function": { "name": "get_current_weather", - "arguments": { - "location": "Seattle, WA", - "days": "2", - }, + "arguments": {"location": "Seattle", "days": "2"}, }, } ], @@ -63,7 +60,11 @@ TEST_CASE_FIXTURES = { } ], "model": "Arch-Function", - "metadata": {"intent_latency": "455.092", "function_latency": "312.744"}, + "metadata": { + "x-arch-fc-model-response": '{"tool_calls": [{"name": "get_current_weather", "arguments": {"location": "Seattle", "days": "2"}}]}', + "function_latency": "361.841", + "intent_latency": "361.841", + }, }, "api_server_response": [ { diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index 9b804bae..e6a10f3a 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -2,6 +2,7 @@ import json import pytest import requests from deepdiff import DeepDiff +import re from common import ( PROMPT_GATEWAY_ENDPOINT, @@ -11,6 +12,15 @@ from common import ( ) +def cleanup_tool_call(tool_call): + pattern = r"```json\n(.*?)\n```" + match = re.search(pattern, tool_call, re.DOTALL) + if match: + tool_call = match.group(1) + + return tool_call.strip() + + @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway(stream): expected_tool_call = { @@ -42,9 +52,14 @@ def test_prompt_gateway(stream): assert "role" in choices[0]["delta"] role = choices[0]["delta"]["role"] assert role == "assistant" - tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + print(f"choices: {choices}") + tool_call_str = choices[0].get("delta", {}).get("content", "") + print("tool_call_str: ", tool_call_str) + cleaned_tool_call_str = cleanup_tool_call(tool_call_str) + print("cleaned_tool_call_str: ", cleaned_tool_call_str) + tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = tool_calls[0] location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"] @@ -62,7 +77,7 @@ def test_prompt_gateway(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -71,18 +86,24 @@ def test_prompt_gateway(stream): else: response_json = response.json() - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"] assert choices[0]["message"]["role"] == "assistant" # now verify arch_messages (tool call and api response) that are sent as response metadata arch_messages = get_arch_messages(response_json) + print("arch_messages: ", json.dumps(arch_messages)) assert len(arch_messages) == 2 tool_calls_message = arch_messages[0] - tool_calls = tool_calls_message.get("tool_calls", []) - assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + print("tool_calls_message: ", tool_calls_message) + tool_calls = tool_calls_message.get("content", []) + cleaned_tool_call_str = cleanup_tool_call(tool_calls) + cleaned_tool_call_json = json.loads(cleaned_tool_call_str) + print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json)) + tool_calls_list = cleaned_tool_call_json.get("tool_calls", []) + assert len(tool_calls_list) > 0 + tool_call = tool_calls_list[0] location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"] @@ -231,7 +252,7 @@ def test_prompt_gateway_param_tool_call(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -240,7 +261,7 @@ def test_prompt_gateway_param_tool_call(stream): else: response_json = response.json() - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"] @@ -262,7 +283,7 @@ def test_prompt_gateway_default_target(stream): "messages": [ { "role": "user", - "content": "hello, what can you do for me?", + "content": "hello", }, ], "stream": stream, @@ -273,17 +294,20 @@ def test_prompt_gateway_default_target(stream): chunks = get_data_chunks(response, n=3) assert len(chunks) > 0 response_json = json.loads(chunks[0]) + print("response_json chunks[0]: ", response_json) assert response_json.get("model").startswith("api_server") assert len(response_json.get("choices", [])) > 0 assert response_json.get("choices")[0]["delta"]["role"] == "assistant" response_json = json.loads(chunks[1]) + print("response_json chunks[1]: ", response_json) choices = response_json.get("choices", []) assert len(choices) > 0 content = choices[0]["delta"]["content"] assert content == "I can help you with weather forecast" else: response_json = response.json() + print("response_json: ", response_json) assert response_json.get("model").startswith("api_server") assert len(response_json.get("choices")) > 0 assert response_json.get("choices")[0]["message"]["role"] == "assistant" diff --git a/tests/modelserver/test_hallucination.py b/tests/modelserver/test_hallucination.py index f1a3d9b4..323db3fc 100644 --- a/tests/modelserver/test_hallucination.py +++ b/tests/modelserver/test_hallucination.py @@ -4,6 +4,9 @@ import requests import logging import yaml +pytestmark = pytest.mark.skip( + reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet" +) MODEL_SERVER_ENDPOINT = os.getenv( "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" diff --git a/tests/modelserver/test_modelserver.py b/tests/modelserver/test_modelserver.py index 75e6d27e..4596606f 100644 --- a/tests/modelserver/test_modelserver.py +++ b/tests/modelserver/test_modelserver.py @@ -5,6 +5,9 @@ import yaml from deepdiff import DeepDiff +pytestmark = pytest.mark.skip( + reason="Skipping entire test file as this these tests are heavily dependent on model output" +) MODEL_SERVER_ENDPOINT = os.getenv( "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" diff --git a/tests/rest/api_prompt_gateway.rest b/tests/rest/api_prompt_gateway.rest index f28690c9..4a1f77e0 100644 --- a/tests/rest/api_prompt_gateway.rest +++ b/tests/rest/api_prompt_gateway.rest @@ -68,7 +68,7 @@ Content-Type: application/json { "role": "assistant", "content": "It seems I'm missing some information. Could you provide the following details days ?", - "model": "Arch-Function-1.5b" + "model": "Arch-Function" }, { "role": "user", @@ -91,7 +91,7 @@ Content-Type: application/json { "role": "assistant", "content": "It seems I'm missing some information. Could you provide the following details days ?", - "model": "Arch-Function-1.5b" + "model": "Arch-Function" }, { "role": "user",