From d1666c3307289edd0fcb53c8ba881574ee0dca19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 15 Jan 2024 21:17:01 +0800 Subject: [PATCH] update get_choice_function_arguments. --- metagpt/provider/openai_api.py | 71 +++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 747e36480..66d215eda 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -9,6 +9,7 @@ @Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. """ +import re import json from typing import AsyncIterator, Union @@ -37,6 +38,7 @@ from metagpt.utils.token_counter import ( count_string_tokens, get_max_completion_tokens, ) +from metagpt.utils.common import CodeParser def log_and_reraise(retry_state): @@ -147,10 +149,7 @@ class OpenAILLM(BaseLLM): def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict: """Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create""" if "tools" not in kwargs: - configs = { - "tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}], - "tool_choice": GENERAL_TOOL_CHOICE, - } + configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]} kwargs.update(configs) return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs) @@ -161,23 +160,7 @@ class OpenAILLM(BaseLLM): self._update_costs(rsp.usage) return rsp - def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: - """convert messages to list[dict].""" - if isinstance(messages, list): - messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] - return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] - - if isinstance(messages, Message): - messages = [messages.to_dict()] - elif isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - else: - raise ValueError( - f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!" - ) - return messages - - async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: + async def aask_code(self, messages: list[dict], **kwargs) -> dict: """Use function of tools to ask a code. Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create @@ -187,18 +170,60 @@ class OpenAILLM(BaseLLM): >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} """ - messages = self._process_message(messages) rsp = await self._achat_completion_function(messages, **kwargs) return self.get_choice_function_arguments(rsp) + def _parse_arguments(self, arguments: str) -> dict: + """parse arguments in openai function call""" + if 'langugae' not in arguments and 'code' not in arguments: + logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ") + return {'language': 'python', 'code': arguments} + + # 匹配language + language_pattern = re.compile(r'[\"\']?language[\"\']?\s*:\s*["\']([^"\']+?)["\']', re.DOTALL) + language_match = language_pattern.search(arguments) + language_value = language_match.group(1) if language_match else None + + # 匹配code + code_pattern = r'(["\'`]{3}|["\'`])([\s\S]*?)\1' + try: + code_value = re.findall(code_pattern, arguments)[-1][-1] + except Exception as e: + logger.error(f"{e}, when re.findall({code_pattern}, {arguments})") + code_value = None + + if code_value is None: + raise ValueError(f"Parse code error for {arguments}") + # arguments只有code的情况 + return {'language': language_value, 'code': code_value} + @handle_exception def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: """Required to provide the first function arguments of choice. + :param dict rsp: same as in self.get_choice_function(rsp) :return dict: return the first function arguments of choice, for example, {'language': 'python', 'code': "print('Hello, World!')"} """ - return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) + message = rsp.choices[0].message + if ( + message.tool_calls is not None and + message.tool_calls[0].function is not None and + message.tool_calls[0].function.arguments is not None + ): + # reponse is code + try: + return json.loads(message.tool_calls[0].function.arguments, strict=False) + except json.decoder.JSONDecodeError as e: + logger.debug(f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\ + we will use RegExp to parse code, \n {e}") + return {'language': 'python', 'code': self._parse_arguments(message.tool_calls[0].function.arguments)} + elif message.tool_calls is None and message.content is not None: + # reponse is message + return {'language': 'markdown', 'code': self.get_choice_text(rsp)} + else: + logger.error(f"Failed to parse \n {rsp}\n") + raise Exception(f"Failed to parse \n {rsp}\n") def get_choice_text(self, rsp: ChatCompletion) -> str: """Required to provide the first text of choice"""