From 2fd8a5a06d21af46c4bf160e7353f090660735d1 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:41:18 -0800 Subject: [PATCH 1/3] Update Arch-Guard and corresponding e2e test --- e2e_tests/api_model_server.rest | 14 ++++++++++++-- model_server/app/model_handler/guardrails.py | 6 +++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/e2e_tests/api_model_server.rest b/e2e_tests/api_model_server.rest index c5fa0850..74bda508 100644 --- a/e2e_tests/api_model_server.rest +++ b/e2e_tests/api_model_server.rest @@ -1,7 +1,8 @@ @model_server_endpoint = http://localhost:51000 @archfc_endpoint = https://api.fc.archgw.com -### talk to model_server for completion + +# talk to function calling endpoint POST {{model_server_endpoint}}/function_calling HTTP/1.1 Content-Type: application/json @@ -41,7 +42,6 @@ Content-Type: application/json } - # talk to Arch-Function directly for completion POST {{archfc_endpoint}}/v1/chat/completions HTTP/1.1 Content-Type: application/json @@ -59,3 +59,13 @@ Content-Type: application/json "continue_final_message": true, "add_generation_prompt": false } + + +# talk to guardrails endpoint +POST {{model_server_endpoint}}/guardrails HTTP/1.1 +Content-Type: application/json + +{ + "input": "how is the weather in seattle for next 10 days", + "task": "jailbreak" +} diff --git a/model_server/app/model_handler/guardrails.py b/model_server/app/model_handler/guardrails.py index f733552e..4f6eaf0e 100644 --- a/model_server/app/model_handler/guardrails.py +++ b/model_server/app/model_handler/guardrails.py @@ -105,7 +105,7 @@ class ArchGuardHanlder: sentence = None return GuardResponse( - prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency + prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency ) def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse: @@ -138,9 +138,9 @@ class ArchGuardHanlder: chunk_result = self._predict_text(req.task, chunk) if chunk_result.verdict: - prob.append(chunk_result.prob) + prob.append(chunk_result.prob[0]) verdict = True - sentence.append(chunk_result.sentence) + sentence.append(chunk_result.sentence[0]) latency += chunk_result.latency return GuardResponse( From 79eafc02413cc61ec3801e08bab3daae49fe0550 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:07:01 -0800 Subject: [PATCH 2/3] Update `ArchBaseHandler` --- model_server/app/model_handler/base_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py index f6b811da..c6b88749 100644 --- a/model_server/app/model_handler/base_handler.py +++ b/model_server/app/model_handler/base_handler.py @@ -106,7 +106,7 @@ class ArchBaseHandler: self, messages: List[Message], tools: List[Dict[str, Any]] = None, - extra_instructions: str = None, + extra_instruction: str = None, ): """ Processes a list of messages and formats them appropriately. @@ -114,7 +114,7 @@ class ArchBaseHandler: Args: messages (List[Message]): A list of message objects. tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt. - extra_instructions (str, optional): Additional instructions to append to the last user message. + extra_instruction (str, optional): Additional instructions to append to the last user message. Returns: List[Dict[str, Any]]: A list of processed message dictionaries. @@ -148,8 +148,8 @@ class ArchBaseHandler: assert processed_messages[-1]["role"] == "user" - if extra_instructions: - processed_messages[-1]["content"] += extra_instructions + if extra_instruction: + processed_messages[-1]["content"] += extra_instruction return processed_messages From afec644789080fdd8e6e54f37f2ee9915470c29d Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:14:44 -0800 Subject: [PATCH 3/3] Update the logic of intent detection --- model_server/app/main.py | 2 ++ .../app/model_handler/function_calling.py | 28 +++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/model_server/app/main.py b/model_server/app/main.py index d615c086..c798bd8d 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -69,6 +69,8 @@ async def function_calling(req: ChatMessage, res: Response): # logger.error(f"Error in chat_completion from `Arch-Function`: {e}") res.status_code = 500 return {"error": f"[Arch-Function] - {e}"} + # [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent + # else: except Exception as e: # [TODO] Review: update how to collect debugging outputs diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index 897c785a..78b058a3 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -97,20 +97,24 @@ class ArchIntentHandler(ArchBaseHandler): Currently only support vllm inference """ - messages = self._process_messages( - req.messages, req.tools, self.extra_instruction - ) + # In the case that no tools are available, simply return `No` to avoid making a call + if len(req.tools) == 0: + model_response = Message(content="No", tool_calls=[]) + else: + messages = self._process_messages( + req.messages, req.tools, self.extra_instruction + ) - model_response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - stream=False, - extra_body=self.generation_params, - ) + model_response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=False, + extra_body=self.generation_params, + ) - model_response = Message( - content=model_response.choices[0].message.content, tool_calls=[] - ) + model_response = Message( + content=model_response.choices[0].message.content, tool_calls=[] + ) chat_completion_response = ChatCompletionResponse( choices=[Choice(message=model_response)], model=self.model_name