mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
update get_choice_function_arguments.
This commit is contained in:
parent
b430e2c88f
commit
d1666c3307
1 changed files with 48 additions and 23 deletions
|
|
@ -9,6 +9,7 @@
|
|||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ from metagpt.utils.token_counter import (
|
|||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
)
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
|
|
@ -147,10 +149,7 @@ class OpenAILLM(BaseLLM):
|
|||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {
|
||||
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
|
||||
"tool_choice": GENERAL_TOOL_CHOICE,
|
||||
}
|
||||
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
|
@ -161,23 +160,7 @@ class OpenAILLM(BaseLLM):
|
|||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
if isinstance(messages, list):
|
||||
messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages]
|
||||
return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages]
|
||||
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages.to_dict()]
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
async def aask_code(self, messages: list[dict], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
|
|
@ -187,18 +170,60 @@ class OpenAILLM(BaseLLM):
|
|||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
def _parse_arguments(self, arguments: str) -> dict:
|
||||
"""parse arguments in openai function call"""
|
||||
if 'langugae' not in arguments and 'code' not in arguments:
|
||||
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
|
||||
return {'language': 'python', 'code': arguments}
|
||||
|
||||
# 匹配language
|
||||
language_pattern = re.compile(r'[\"\']?language[\"\']?\s*:\s*["\']([^"\']+?)["\']', re.DOTALL)
|
||||
language_match = language_pattern.search(arguments)
|
||||
language_value = language_match.group(1) if language_match else None
|
||||
|
||||
# 匹配code
|
||||
code_pattern = r'(["\'`]{3}|["\'`])([\s\S]*?)\1'
|
||||
try:
|
||||
code_value = re.findall(code_pattern, arguments)[-1][-1]
|
||||
except Exception as e:
|
||||
logger.error(f"{e}, when re.findall({code_pattern}, {arguments})")
|
||||
code_value = None
|
||||
|
||||
if code_value is None:
|
||||
raise ValueError(f"Parse code error for {arguments}")
|
||||
# arguments只有code的情况
|
||||
return {'language': language_value, 'code': code_value}
|
||||
|
||||
@handle_exception
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:param dict rsp: same as in self.get_choice_function(rsp)
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
message = rsp.choices[0].message
|
||||
if (
|
||||
message.tool_calls is not None and
|
||||
message.tool_calls[0].function is not None and
|
||||
message.tool_calls[0].function.arguments is not None
|
||||
):
|
||||
# reponse is code
|
||||
try:
|
||||
return json.loads(message.tool_calls[0].function.arguments, strict=False)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
logger.debug(f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\
|
||||
we will use RegExp to parse code, \n {e}")
|
||||
return {'language': 'python', 'code': self._parse_arguments(message.tool_calls[0].function.arguments)}
|
||||
elif message.tool_calls is None and message.content is not None:
|
||||
# reponse is message
|
||||
return {'language': 'markdown', 'code': self.get_choice_text(rsp)}
|
||||
else:
|
||||
logger.error(f"Failed to parse \n {rsp}\n")
|
||||
raise Exception(f"Failed to parse \n {rsp}\n")
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue