update ser&deser after env_refactor

This commit is contained in:
better629 2023-12-19 14:22:52 +08:00
parent 35ac28c30e
commit ebc4fe4b17
15 changed files with 152 additions and 200 deletions

View file

@ -13,6 +13,8 @@
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
"""
from __future__ import annotations
import asyncio
import json
import os.path
@ -20,14 +22,9 @@ import uuid
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Dict, List, Optional, Set, TypedDict
from pydantic import BaseModel, Field
from dataclasses import dataclass, field
from typing import Type, TypedDict, Union, Optional
from typing import Dict, List, Set, TypedDict, Optional, Any
from pydantic import BaseModel, Field
from pydantic.main import ModelMetaclass
from metagpt.config import CONFIG
from metagpt.const import (
@ -39,15 +36,7 @@ from metagpt.const import (
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \
actionoutput_str_to_mapping
from metagpt.utils.utils import import_class
from metagpt.utils.common import any_to_str, any_to_str_set
# from metagpt.utils.serialize import actionoutout_schema_to_mapping
# from metagpt.actions.action_output import ActionOutput
# from metagpt.actions.action import Action
from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \
actionoutput_str_to_mapping
from metagpt.utils.utils import import_class
@ -58,7 +47,6 @@ class RawMessage(TypedDict):
role: str
class Document(BaseModel):
"""
Represents a document.
@ -68,7 +56,7 @@ class Document(BaseModel):
filename: str = ""
content: str = ""
def get_meta(self) -> "Document":
def get_meta(self) -> Document:
"""Get metadata of the document.
:return: A new Document instance with the same root path and filename.
@ -120,7 +108,6 @@ class Message(BaseModel):
def __init__(self, **kwargs):
instruct_content = kwargs.get("instruct_content", None)
cause_by = kwargs.get("cause_by", None)
if instruct_content and not isinstance(instruct_content, BaseModel):
ic = instruct_content
mapping = actionoutput_str_to_mapping(ic["mapping"])
@ -129,9 +116,11 @@ class Message(BaseModel):
ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
kwargs["instruct_content"] = ic_new
if cause_by and not isinstance(cause_by, ModelMetaclass):
action_class = import_class("Action", "metagpt.actions.action")
kwargs["cause_by"] = action_class.deser_class(cause_by)
kwargs["id"] = uuid.uuid4().hex
kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", ""))
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
super(Message, self).__init__(**kwargs)
def __setattr__(self, key, val):
@ -156,9 +145,6 @@ class Message(BaseModel):
mapping = actionoutput_mapping_to_str(mapping)
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
cb = self.cause_by
if cb:
obj_dict["cause_by"] = cb.ser_class()
return obj_dict
def __str__(self):
@ -214,11 +200,24 @@ class AIMessage(Message):
super().__init__(content=content, role="assistant")
class MessageQueue:
class MessageQueue(BaseModel):
"""Message queue which supports asynchronous updates."""
def __init__(self):
self._queue = Queue()
_queue: Queue = Field(default_factory=Queue)
_private_attributes = {
"_queue": Queue()
}
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
for key in self._private_attributes.keys():
if key in kwargs:
object.__setattr__(self, key, kwargs[key])
else:
object.__setattr__(self, key, self._private_attributes[key])
def pop(self) -> Message | None:
"""Pop one message from the queue."""
@ -266,7 +265,7 @@ class MessageQueue:
return json.dumps(lst)
@staticmethod
def load(self, v) -> "MessageQueue":
def load(self, v) -> MessageQueue:
"""Convert the json string to the `MessageQueue` object."""
q = MessageQueue()
try:
@ -287,7 +286,7 @@ class CodingContext(BaseModel):
code_doc: Optional[Document]
@staticmethod
def loads(val: str) -> "CodingContext" | None:
def loads(val: str) -> CodingContext | None:
try:
m = json.loads(val)
return CodingContext(**m)
@ -301,7 +300,7 @@ class TestingContext(BaseModel):
test_doc: Optional[Document]
@staticmethod
def loads(val: str) -> "TestingContext" | None:
def loads(val: str) -> TestingContext | None:
try:
m = json.loads(val)
return TestingContext(**m)
@ -322,7 +321,7 @@ class RunCodeContext(BaseModel):
output: Optional[str]
@staticmethod
def loads(val: str) -> "RunCodeContext" | None:
def loads(val: str) -> RunCodeContext | None:
try:
m = json.loads(val)
return RunCodeContext(**m)
@ -336,7 +335,7 @@ class RunCodeResult(BaseModel):
stderr: str
@staticmethod
def loads(val: str) -> "RunCodeResult" | None:
def loads(val: str) -> RunCodeResult | None:
try:
m = json.loads(val)
return RunCodeResult(**m)
@ -351,7 +350,7 @@ class CodeSummarizeContext(BaseModel):
reason: str = ""
@staticmethod
def loads(filenames: List) -> "CodeSummarizeContext":
def loads(filenames: List) -> CodeSummarizeContext:
ctx = CodeSummarizeContext()
for filename in filenames:
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):