pre-commit.

This commit is contained in:
刘棒棒 2024-01-17 18:17:52 +08:00
parent ff10c9bdda
commit 20f31fa027

View file

@ -9,8 +9,8 @@
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
import re
import json
import re
from typing import AsyncIterator, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
@ -28,7 +28,7 @@ from tenacity import (
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.cost_manager import Costs
@ -38,7 +38,6 @@ from metagpt.utils.token_counter import (
count_string_tokens,
get_max_completion_tokens,
)
from metagpt.utils.common import CodeParser
def log_and_reraise(retry_state):
@ -166,7 +165,7 @@ class OpenAILLM(BaseLLM):
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(['role', 'content'])
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
@ -198,9 +197,9 @@ class OpenAILLM(BaseLLM):
def _parse_arguments(self, arguments: str) -> dict:
"""parse arguments in openai function call"""
if 'langugae' not in arguments and 'code' not in arguments:
if "langugae" not in arguments and "code" not in arguments:
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
return {'language': 'python', 'code': arguments}
return {"language": "python", "code": arguments}
# 匹配language
language_pattern = re.compile(r'[\"\']?language[\"\']?\s*:\s*["\']([^"\']+?)["\']', re.DOTALL)
@ -218,7 +217,7 @@ class OpenAILLM(BaseLLM):
if code_value is None:
raise ValueError(f"Parse code error for {arguments}")
# arguments只有code的情况
return {'language': language_value, 'code': code_value}
return {"language": language_value, "code": code_value}
@handle_exception
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
@ -230,20 +229,22 @@ class OpenAILLM(BaseLLM):
"""
message = rsp.choices[0].message
if (
message.tool_calls is not None and
message.tool_calls[0].function is not None and
message.tool_calls[0].function.arguments is not None
message.tool_calls is not None
and message.tool_calls[0].function is not None
and message.tool_calls[0].function.arguments is not None
):
# reponse is code
try:
return json.loads(message.tool_calls[0].function.arguments, strict=False)
except json.decoder.JSONDecodeError as e:
logger.debug(f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\
we will use RegExp to parse code, \n {e}")
return {'language': 'python', 'code': self._parse_arguments(message.tool_calls[0].function.arguments)}
logger.debug(
f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\
we will use RegExp to parse code, \n {e}"
)
return {"language": "python", "code": self._parse_arguments(message.tool_calls[0].function.arguments)}
elif message.tool_calls is None and message.content is not None:
# reponse is message
return {'language': 'markdown', 'code': self.get_choice_text(rsp)}
return {"language": "markdown", "code": self.get_choice_text(rsp)}
else:
logger.error(f"Failed to parse \n {rsp}\n")
raise Exception(f"Failed to parse \n {rsp}\n")