refine code: use handle_exception function instead of in-function duplicate code frags

This commit is contained in:
geekan 2023-12-19 16:16:52 +08:00
parent d3c135edff
commit f1c6a7ebfb
12 changed files with 159 additions and 130 deletions

View file

@ -21,7 +21,7 @@ import uuid
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Dict, List, Optional, Set, TypedDict
from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar
from pydantic import BaseModel, Field
@ -36,6 +36,7 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
class RawMessage(TypedDict):
@ -160,14 +161,11 @@ class Message(BaseModel):
return self.json(exclude_none=True)
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)
def load(val):
"""Convert the json string to object."""
try:
d = json.loads(val)
return Message(**d)
except JSONDecodeError as err:
logger.error(f"parse json failed: {val}, error:{err}")
return None
d = json.loads(val)
return Message(**d)
class UserMessage(Message):
@ -249,50 +247,46 @@ class MessageQueue:
return json.dumps(lst)
@staticmethod
def load(self, v) -> "MessageQueue":
def load(i) -> "MessageQueue":
"""Convert the json string to the `MessageQueue` object."""
q = MessageQueue()
queue = MessageQueue()
try:
lst = json.loads(v)
lst = json.loads(i)
for i in lst:
msg = Message(**i)
q.push(msg)
queue.push(msg)
except JSONDecodeError as e:
logger.warning(f"JSON load failed: {v}, error:{e}")
logger.warning(f"JSON load failed: {i}, error:{e}")
return q
return queue
class CodingContext(BaseModel):
# 定义一个泛型类型变量
T = TypeVar("T", bound="BaseModel")
class BaseContext(BaseModel):
@staticmethod
@handle_exception
def loads(val: str, cls: Type[T]) -> Optional[T]:
m = json.loads(val)
return cls(**m)
class CodingContext(BaseContext):
filename: str
design_doc: Optional[Document]
task_doc: Optional[Document]
code_doc: Optional[Document]
@staticmethod
def loads(val: str) -> CodingContext | None:
try:
m = json.loads(val)
return CodingContext(**m)
except Exception:
return None
class TestingContext(BaseModel):
class TestingContext(BaseContext):
filename: str
code_doc: Document
test_doc: Optional[Document]
@staticmethod
def loads(val: str) -> TestingContext | None:
try:
m = json.loads(val)
return TestingContext(**m)
except Exception:
return None
class RunCodeContext(BaseModel):
class RunCodeContext(BaseContext):
mode: str = "script"
code: Optional[str]
code_filename: str = ""
@ -304,28 +298,12 @@ class RunCodeContext(BaseModel):
output_filename: Optional[str]
output: Optional[str]
@staticmethod
def loads(val: str) -> RunCodeContext | None:
try:
m = json.loads(val)
return RunCodeContext(**m)
except Exception:
return None
class RunCodeResult(BaseModel):
class RunCodeResult(BaseContext):
summary: str
stdout: str
stderr: str
@staticmethod
def loads(val: str) -> RunCodeResult | None:
try:
m = json.loads(val)
return RunCodeResult(**m)
except Exception:
return None
class CodeSummarizeContext(BaseModel):
design_filename: str = ""
@ -349,5 +327,5 @@ class CodeSummarizeContext(BaseModel):
return hash((self.design_filename, self.task_filename))
class BugFixContext(BaseModel):
class BugFixContext(BaseContext):
filename: str = ""