refine code: use handle_exception function instead of in-function duplicate code frags

This commit is contained in:
geekan 2023-12-19 16:16:52 +08:00
parent d3c135edff
commit f1c6a7ebfb
12 changed files with 159 additions and 130 deletions

View file

@ -43,7 +43,7 @@ Fill in the above nodes based on the format example.
"""
def dict_to_markdown(d, prefix="-", postfix="\n"):
def dict_to_markdown(d, prefix="###", postfix="\n"):
markdown_str = ""
for key, value in d.items():
markdown_str += f"{prefix} {key}: {value}{postfix}"

View file

@ -16,13 +16,13 @@
class.
"""
import subprocess
import traceback
from typing import Tuple
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import RunCodeResult
from metagpt.utils.exceptions import handle_exception
PROMPT_TEMPLATE = """
Role: You are a senior development and qa engineer, your role is summarize the code running result.
@ -78,15 +78,12 @@ class RunCode(Action):
super().__init__(name, context, llm)
@classmethod
@handle_exception
async def run_text(cls, code) -> Tuple[str, str]:
try:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
return namespace.get("result", ""), ""
except Exception:
# If there is an error in the code, return the error message
return "", traceback.format_exc()
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
return namespace.get("result", ""), ""
@classmethod
async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
@ -145,18 +142,17 @@ class RunCode(Action):
rsp = await self._aask(prompt)
return RunCodeResult(summary=rsp, stdout=outs, stderr=errs)
@staticmethod
@handle_exception(exception_type=subprocess.CalledProcessError)
def _install_via_subprocess(cmd, check, cwd, env):
return subprocess.run(cmd, check=check, cwd=cwd, env=env)
@staticmethod
def _install_dependencies(working_directory, env):
install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"]
logger.info(" ".join(install_command))
try:
subprocess.run(install_command, check=True, cwd=working_directory, env=env)
except subprocess.CalledProcessError as e:
logger.warning(f"{e}")
RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env)
install_pytest_command = ["python", "-m", "pip", "install", "pytest"]
logger.info(" ".join(install_pytest_command))
try:
subprocess.run(install_pytest_command, check=True, cwd=working_directory, env=env)
except subprocess.CalledProcessError as e:
logger.warning(f"{e}")
RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env)

View file

@ -137,6 +137,7 @@ class Config(metaclass=Singleton):
continue
configs.update(yaml_data)
OPTIONS.set(configs)
logger.info(f"Default OpenAI API Model: {self.openai_api_model}")
@staticmethod
def _get(*args, **kwargs):

View file

@ -15,17 +15,17 @@ from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
def parse_file(self, file_path):
@classmethod
@handle_exception(exception_type=Exception, default_return=[])
def _parse_file(cls, file_path: Path) -> list:
"""Parse a Python file in the repository."""
try:
return ast.parse(file_path.read_text()).body
except:
return []
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path):
"""Extract class, function, and global variable information from the AST."""
@ -52,7 +52,7 @@ class RepoParser(BaseModel):
files_classes = []
directory = self.base_directory
for path in directory.rglob("*.py"):
tree = self.parse_file(path)
tree = self._parse_file(path)
file_info = self.extract_class_and_function_info(tree, path)
files_classes.append(file_info)
@ -90,5 +90,10 @@ def main():
logger.info(pformat(symbols))
def error():
"""raise Exception and logs it"""
RepoParser._parse_file(Path("test.py"))
if __name__ == "__main__":
main()
error()

View file

@ -21,7 +21,7 @@ import uuid
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Dict, List, Optional, Set, TypedDict
from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar
from pydantic import BaseModel, Field
@ -36,6 +36,7 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
class RawMessage(TypedDict):
@ -160,14 +161,11 @@ class Message(BaseModel):
return self.json(exclude_none=True)
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)
def load(val):
"""Convert the json string to object."""
try:
d = json.loads(val)
return Message(**d)
except JSONDecodeError as err:
logger.error(f"parse json failed: {val}, error:{err}")
return None
d = json.loads(val)
return Message(**d)
class UserMessage(Message):
@ -249,50 +247,46 @@ class MessageQueue:
return json.dumps(lst)
@staticmethod
def load(self, v) -> "MessageQueue":
def load(i) -> "MessageQueue":
"""Convert the json string to the `MessageQueue` object."""
q = MessageQueue()
queue = MessageQueue()
try:
lst = json.loads(v)
lst = json.loads(i)
for i in lst:
msg = Message(**i)
q.push(msg)
queue.push(msg)
except JSONDecodeError as e:
logger.warning(f"JSON load failed: {v}, error:{e}")
logger.warning(f"JSON load failed: {i}, error:{e}")
return q
return queue
class CodingContext(BaseModel):
# 定义一个泛型类型变量
T = TypeVar("T", bound="BaseModel")
class BaseContext(BaseModel):
@staticmethod
@handle_exception
def loads(val: str, cls: Type[T]) -> Optional[T]:
m = json.loads(val)
return cls(**m)
class CodingContext(BaseContext):
filename: str
design_doc: Optional[Document]
task_doc: Optional[Document]
code_doc: Optional[Document]
@staticmethod
def loads(val: str) -> CodingContext | None:
try:
m = json.loads(val)
return CodingContext(**m)
except Exception:
return None
class TestingContext(BaseModel):
class TestingContext(BaseContext):
filename: str
code_doc: Document
test_doc: Optional[Document]
@staticmethod
def loads(val: str) -> TestingContext | None:
try:
m = json.loads(val)
return TestingContext(**m)
except Exception:
return None
class RunCodeContext(BaseModel):
class RunCodeContext(BaseContext):
mode: str = "script"
code: Optional[str]
code_filename: str = ""
@ -304,28 +298,12 @@ class RunCodeContext(BaseModel):
output_filename: Optional[str]
output: Optional[str]
@staticmethod
def loads(val: str) -> RunCodeContext | None:
try:
m = json.loads(val)
return RunCodeContext(**m)
except Exception:
return None
class RunCodeResult(BaseModel):
class RunCodeResult(BaseContext):
summary: str
stdout: str
stderr: str
@staticmethod
def loads(val: str) -> RunCodeResult | None:
try:
m = json.loads(val)
return RunCodeResult(**m)
except Exception:
return None
class CodeSummarizeContext(BaseModel):
design_filename: str = ""
@ -349,5 +327,5 @@ class CodeSummarizeContext(BaseModel):
return hash((self.design_filename, self.task_filename))
class BugFixContext(BaseModel):
class BugFixContext(BaseContext):
filename: str = ""

View file

@ -11,6 +11,8 @@ from typing import List
import meilisearch
from meilisearch.index import Index
from metagpt.utils.exceptions import handle_exception
class DataSource:
def __init__(self, name: str, url: str):
@ -34,11 +36,7 @@ class MeilisearchEngine:
index.add_documents(documents)
self.set_index(index)
@handle_exception(exception_type=Exception, default_return=[])
def search(self, query):
try:
search_results = self._index.search(query)
return search_results["hits"]
except Exception as e:
# Handle MeiliSearch API errors
print(f"MeiliSearch API error: {e}")
return []
search_results = self._index.search(query)
return search_results["hits"]

View file

@ -20,11 +20,13 @@ import re
import typing
from typing import List, Tuple, Union
import aiofiles
import loguru
from tenacity import RetryCallState, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
def check_cmd_exists(command) -> int:
@ -399,3 +401,11 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C
)
return log_it
@handle_exception
async def aread(file_path: str) -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(file_path), mode="r") as reader:
content = await reader.read()
return content

View file

@ -25,7 +25,7 @@ def py_make_scanner(context):
except IndexError:
raise StopIteration(idx) from None
if nextchar == '"' or nextchar == "'":
if nextchar in ("'", '"'):
if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar:
# Handle the case where the next two characters are the same as nextchar
return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote

View file

@ -15,7 +15,8 @@ from typing import Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.common import aread
from metagpt.utils.exceptions import handle_exception
class DependencyFile:
@ -36,21 +37,14 @@ class DependencyFile:
"""Load dependencies from the file asynchronously."""
if not self._filename.exists():
return
try:
async with aiofiles.open(str(self._filename), mode="r") as reader:
data = await reader.read()
self._dependencies = json.loads(data)
except Exception as e:
logger.error(f"Failed to load {str(self._filename)}, error:{e}")
self._dependencies = await aread(self._filename)
@handle_exception
async def save(self):
"""Save dependencies to the file asynchronously."""
try:
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
except Exception as e:
logger.error(f"Failed to save {str(self._filename)}, error:{e}")
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19 14:46
@Author : alexanderwu
@File : exceptions.py
"""
import asyncio
import functools
import traceback
from typing import Any, Callable, Tuple, Type, TypeVar, Union
from metagpt.logs import logger
ReturnType = TypeVar("ReturnType")
def handle_exception(
_func: Callable[..., ReturnType] = None,
*,
exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
default_return: Any = None,
) -> Callable[..., ReturnType]:
"""handle exception, return default value"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return await func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
f"stack: {traceback.format_exc()}"
)
return default_return
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
f"stack: {traceback.format_exc()}"
)
return default_return
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
if _func is None:
return decorator
else:
return decorator(_func)

View file

@ -11,6 +11,7 @@ from pathlib import Path
import aiofiles
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
class File:
@ -19,6 +20,7 @@ class File:
CHUNK_SIZE = 64 * 1024
@classmethod
@handle_exception
async def write(cls, root_path: Path, filename: str, content: bytes) -> Path:
"""Write the file content to the local specified path.
@ -33,18 +35,15 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file writing process.
"""
try:
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
except Exception as e:
logger.error(f"Error writing file: {e}")
raise e
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
@classmethod
@handle_exception
async def read(cls, file_path: Path, chunk_size: int = None) -> bytes:
"""Partitioning read the file content from the local specified path.
@ -58,18 +57,14 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file reading process.
"""
try:
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
except Exception as e:
logger.error(f"Error reading file: {e}")
raise e
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content

View file

@ -19,6 +19,7 @@ import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
from metagpt.utils.json_to_markdown import json_to_markdown
@ -97,15 +98,7 @@ class FileRepository:
path_name = self.workdir / filename
if not path_name.exists():
return None
try:
async with aiofiles.open(str(path_name), mode="r") as reader:
doc.content = await reader.read()
except FileNotFoundError as e:
logger.info(f"open {str(path_name)} failed:{e}")
return None
except Exception as e:
logger.info(f"open {str(path_name)} failed:{e}")
return None
doc.content = await aread(path_name)
return doc
async def get_all(self) -> List[Document]: