update get_choice_function_arguments.

This commit is contained in:
刘棒棒 2024-01-15 21:17:01 +08:00
parent b430e2c88f
commit d1666c3307

View file

@ -9,6 +9,7 @@
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
import re
import json
from typing import AsyncIterator, Union
@ -37,6 +38,7 @@ from metagpt.utils.token_counter import (
count_string_tokens,
get_max_completion_tokens,
)
from metagpt.utils.common import CodeParser
def log_and_reraise(retry_state):
@ -147,10 +149,7 @@ class OpenAILLM(BaseLLM):
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
if "tools" not in kwargs:
configs = {
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
"tool_choice": GENERAL_TOOL_CHOICE,
}
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
kwargs.update(configs)
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
@ -161,23 +160,7 @@ class OpenAILLM(BaseLLM):
self._update_costs(rsp.usage)
return rsp
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
if isinstance(messages, list):
messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages]
return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages]
if isinstance(messages, Message):
messages = [messages.to_dict()]
elif isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
else:
raise ValueError(
f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!"
)
return messages
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
async def aask_code(self, messages: list[dict], **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
@ -187,18 +170,60 @@ class OpenAILLM(BaseLLM):
>>> rsp = await llm.aask_code(msg)
# -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
def _parse_arguments(self, arguments: str) -> dict:
"""parse arguments in openai function call"""
if 'langugae' not in arguments and 'code' not in arguments:
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
return {'language': 'python', 'code': arguments}
# 匹配language
language_pattern = re.compile(r'[\"\']?language[\"\']?\s*:\s*["\']([^"\']+?)["\']', re.DOTALL)
language_match = language_pattern.search(arguments)
language_value = language_match.group(1) if language_match else None
# 匹配code
code_pattern = r'(["\'`]{3}|["\'`])([\s\S]*?)\1'
try:
code_value = re.findall(code_pattern, arguments)[-1][-1]
except Exception as e:
logger.error(f"{e}, when re.findall({code_pattern}, {arguments})")
code_value = None
if code_value is None:
raise ValueError(f"Parse code error for {arguments}")
# arguments只有code的情况
return {'language': language_value, 'code': code_value}
@handle_exception
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
:param dict rsp: same as in self.get_choice_function(rsp)
:return dict: return the first function arguments of choice, for example,
{'language': 'python', 'code': "print('Hello, World!')"}
"""
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
message = rsp.choices[0].message
if (
message.tool_calls is not None and
message.tool_calls[0].function is not None and
message.tool_calls[0].function.arguments is not None
):
# reponse is code
try:
return json.loads(message.tool_calls[0].function.arguments, strict=False)
except json.decoder.JSONDecodeError as e:
logger.debug(f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\
we will use RegExp to parse code, \n {e}")
return {'language': 'python', 'code': self._parse_arguments(message.tool_calls[0].function.arguments)}
elif message.tool_calls is None and message.content is not None:
# reponse is message
return {'language': 'markdown', 'code': self.get_choice_text(rsp)}
else:
logger.error(f"Failed to parse \n {rsp}\n")
raise Exception(f"Failed to parse \n {rsp}\n")
def get_choice_text(self, rsp: ChatCompletion) -> str:
"""Required to provide the first text of choice"""