2023-06-30 17:10:48 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
@Time : 2023/5/8 22:12
|
|
|
|
|
@Author : alexanderwu
|
|
|
|
|
@File : schema.py
|
2023-11-03 11:53:47 +08:00
|
|
|
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
|
|
|
|
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
2023-11-27 16:15:55 +08:00
|
|
|
@Modified By: mashenquan, 2023/11/22.
|
|
|
|
|
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
|
|
|
|
|
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
|
|
|
|
|
between actions.
|
2023-11-29 10:14:04 +08:00
|
|
|
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
2023-06-30 17:10:48 +08:00
|
|
|
"""
|
2023-07-22 11:28:22 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
from __future__ import annotations
|
2023-07-22 11:28:22 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
import asyncio
|
2023-10-31 15:23:37 +08:00
|
|
|
import json
|
2023-11-22 17:08:00 +08:00
|
|
|
import os.path
|
2023-11-29 10:14:04 +08:00
|
|
|
import uuid
|
2023-12-19 23:53:04 +08:00
|
|
|
from abc import ABC
|
2023-11-01 20:08:58 +08:00
|
|
|
from asyncio import Queue, QueueEmpty, wait_for
|
2023-10-31 15:23:37 +08:00
|
|
|
from json import JSONDecodeError
|
2023-11-28 18:16:50 +08:00
|
|
|
from pathlib import Path
|
2024-01-08 22:15:56 +08:00
|
|
|
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
2023-12-27 14:00:54 +08:00
|
|
|
|
|
|
|
|
from pydantic import (
|
|
|
|
|
BaseModel,
|
|
|
|
|
ConfigDict,
|
|
|
|
|
Field,
|
|
|
|
|
PrivateAttr,
|
|
|
|
|
field_serializer,
|
|
|
|
|
field_validator,
|
2024-01-08 22:15:56 +08:00
|
|
|
model_serializer,
|
|
|
|
|
model_validator,
|
2023-12-27 14:00:54 +08:00
|
|
|
)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
from metagpt.const import (
|
|
|
|
|
MESSAGE_ROUTE_CAUSE_BY,
|
|
|
|
|
MESSAGE_ROUTE_FROM,
|
|
|
|
|
MESSAGE_ROUTE_TO,
|
2023-11-06 22:38:43 +08:00
|
|
|
MESSAGE_ROUTE_TO_ALL,
|
2023-11-28 18:16:50 +08:00
|
|
|
SYSTEM_DESIGN_FILE_REPO,
|
|
|
|
|
TASK_FILE_REPO,
|
2023-11-01 20:08:58 +08:00
|
|
|
)
|
2023-07-22 11:28:22 +08:00
|
|
|
from metagpt.logs import logger
|
2023-12-20 10:54:49 +08:00
|
|
|
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
2023-12-19 16:16:52 +08:00
|
|
|
from metagpt.utils.exceptions import handle_exception
|
2023-12-21 10:48:46 +08:00
|
|
|
from metagpt.utils.serialize import (
|
|
|
|
|
actionoutout_schema_to_mapping,
|
|
|
|
|
actionoutput_mapping_to_str,
|
|
|
|
|
actionoutput_str_to_mapping,
|
|
|
|
|
)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
class SerializationMixin(BaseModel, extra="forbid"):
|
2024-01-02 15:26:23 +08:00
|
|
|
"""
|
|
|
|
|
PolyMorphic subclasses Serialization / Deserialization Mixin
|
|
|
|
|
- First of all, we need to know that pydantic is not designed for polymorphism.
|
|
|
|
|
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
|
|
|
|
|
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
|
|
|
|
|
|
|
|
|
|
More details:
|
|
|
|
|
- https://docs.pydantic.dev/latest/concepts/serialization/
|
|
|
|
|
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
|
|
|
|
|
"""
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
__is_polymorphic_base = False
|
|
|
|
|
__subclasses_map__ = {}
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
@model_serializer(mode="wrap")
|
|
|
|
|
def __serialize_with_class_type__(self, default_serializer) -> Any:
|
|
|
|
|
# default serializer, then append the `__module_class_name` field and return
|
|
|
|
|
ret = default_serializer(self)
|
|
|
|
|
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
2023-12-28 16:07:39 +08:00
|
|
|
return ret
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
@model_validator(mode="wrap")
|
2023-12-28 16:07:39 +08:00
|
|
|
@classmethod
|
2024-01-08 22:15:56 +08:00
|
|
|
def __convert_to_real_type__(cls, value: Any, handler):
|
|
|
|
|
if isinstance(value, dict) is False:
|
|
|
|
|
return handler(value)
|
|
|
|
|
|
|
|
|
|
# it is a dict so make sure to remove the __module_class_name
|
|
|
|
|
# because we don't allow extra keywords but want to ensure
|
|
|
|
|
# e.g Cat.model_validate(cat.model_dump()) works
|
|
|
|
|
class_full_name = value.pop("__module_class_name", None)
|
|
|
|
|
|
|
|
|
|
# if it's not the polymorphic base we construct via default handler
|
|
|
|
|
if not cls.__is_polymorphic_base:
|
|
|
|
|
if class_full_name is None:
|
|
|
|
|
return handler(value)
|
|
|
|
|
elif str(cls) == f"<class '{class_full_name}'>":
|
|
|
|
|
return handler(value)
|
|
|
|
|
else:
|
|
|
|
|
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# otherwise we lookup the correct polymorphic type and construct that
|
|
|
|
|
# instead
|
|
|
|
|
if class_full_name is None:
|
|
|
|
|
raise ValueError("Missing __module_class_name field")
|
|
|
|
|
|
|
|
|
|
class_type = cls.__subclasses_map__.get(class_full_name, None)
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
if class_type is None:
|
2024-01-08 22:15:56 +08:00
|
|
|
# TODO could try dynamic import
|
|
|
|
|
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
return class_type(**value)
|
|
|
|
|
|
|
|
|
|
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
|
|
|
|
cls.__is_polymorphic_base = is_polymorphic_base
|
|
|
|
|
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
|
|
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
2023-12-25 22:39:03 +08:00
|
|
|
class SimpleMessage(BaseModel):
|
2023-06-30 17:10:48 +08:00
|
|
|
content: str
|
|
|
|
|
role: str
|
|
|
|
|
|
|
|
|
|
|
2023-11-22 17:08:00 +08:00
|
|
|
class Document(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Represents a document.
|
|
|
|
|
"""
|
|
|
|
|
|
2023-12-04 23:04:07 +08:00
|
|
|
root_path: str = ""
|
|
|
|
|
filename: str = ""
|
|
|
|
|
content: str = ""
|
2023-11-22 17:08:00 +08:00
|
|
|
|
|
|
|
|
def get_meta(self) -> Document:
|
|
|
|
|
"""Get metadata of the document.
|
|
|
|
|
|
|
|
|
|
:return: A new Document instance with the same root path and filename.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return Document(root_path=self.root_path, filename=self.filename)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def root_relative_path(self):
|
|
|
|
|
"""Get relative path from root of git repository.
|
|
|
|
|
|
|
|
|
|
:return: relative path from root of git repository.
|
|
|
|
|
"""
|
|
|
|
|
return os.path.join(self.root_path, self.filename)
|
|
|
|
|
|
2023-12-15 00:37:10 +08:00
|
|
|
def __str__(self):
|
|
|
|
|
return self.content
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.content
|
|
|
|
|
|
2023-11-22 17:08:00 +08:00
|
|
|
|
|
|
|
|
class Documents(BaseModel):
|
|
|
|
|
"""A class representing a collection of documents.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
docs: Dict[str, Document] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
class Message(BaseModel):
|
2023-06-30 17:10:48 +08:00
|
|
|
"""list[<role>: <content>]"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-12-27 14:00:54 +08:00
|
|
|
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
2023-06-30 17:10:48 +08:00
|
|
|
content: str
|
2023-12-27 14:00:54 +08:00
|
|
|
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
2023-11-08 20:27:18 +08:00
|
|
|
role: str = "user" # system / user / assistant
|
2023-12-27 14:00:54 +08:00
|
|
|
cause_by: str = Field(default="", validate_default=True)
|
|
|
|
|
sent_from: str = Field(default="", validate_default=True)
|
2023-12-29 04:27:44 +08:00
|
|
|
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
2023-12-27 14:00:54 +08:00
|
|
|
|
|
|
|
|
@field_validator("id", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_id(cls, id: str) -> str:
|
|
|
|
|
return id if id else uuid.uuid4().hex
|
|
|
|
|
|
|
|
|
|
@field_validator("instruct_content", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
2024-01-09 15:40:42 +08:00
|
|
|
if ic and isinstance(ic, dict) and "class" in ic:
|
|
|
|
|
if "mapping" in ic:
|
|
|
|
|
# compatible with custom-defined ActionOutput
|
|
|
|
|
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
|
|
|
|
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
|
|
|
|
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
|
|
|
|
elif "module" in ic:
|
2024-01-09 16:07:33 +08:00
|
|
|
# subclasses of BaseModel
|
2024-01-09 15:40:42 +08:00
|
|
|
ic_obj = import_class(ic["class"], ic["module"])
|
|
|
|
|
else:
|
|
|
|
|
raise KeyError("missing required key to init Message.instruct_content from dict")
|
2023-12-27 14:00:54 +08:00
|
|
|
ic = ic_obj(**ic["value"])
|
|
|
|
|
return ic
|
|
|
|
|
|
|
|
|
|
@field_validator("cause_by", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_cause_by(cls, cause_by: Any) -> str:
|
|
|
|
|
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
|
|
|
|
|
|
|
|
|
@field_validator("sent_from", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_sent_from(cls, sent_from: Any) -> str:
|
|
|
|
|
return any_to_str(sent_from if sent_from else "")
|
|
|
|
|
|
|
|
|
|
@field_validator("send_to", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_send_to(cls, send_to: Any) -> set:
|
|
|
|
|
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
|
|
|
|
|
|
|
|
|
@field_serializer("instruct_content", mode="plain")
|
|
|
|
|
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
|
|
|
|
|
ic_dict = None
|
|
|
|
|
if ic:
|
|
|
|
|
# compatible with custom-defined ActionOutput
|
|
|
|
|
schema = ic.model_json_schema()
|
2024-01-09 15:40:42 +08:00
|
|
|
ic_type = str(type(ic))
|
|
|
|
|
if "<class 'metagpt.actions.action_node" in ic_type:
|
|
|
|
|
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
|
2023-12-27 14:00:54 +08:00
|
|
|
mapping = actionoutout_schema_to_mapping(schema)
|
|
|
|
|
mapping = actionoutput_mapping_to_str(mapping)
|
|
|
|
|
|
|
|
|
|
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
2024-01-09 15:40:42 +08:00
|
|
|
else:
|
|
|
|
|
# due to instruct_content can be assigned by subclasses of BaseModel
|
|
|
|
|
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
|
2023-12-27 14:00:54 +08:00
|
|
|
return ic_dict
|
|
|
|
|
|
|
|
|
|
def __init__(self, content: str = "", **data: Any):
|
|
|
|
|
data["content"] = data.get("content", content)
|
|
|
|
|
super().__init__(**data)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-11-04 16:20:47 +08:00
|
|
|
def __setattr__(self, key, val):
|
2023-11-08 20:27:18 +08:00
|
|
|
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
2023-11-04 16:20:47 +08:00
|
|
|
if key == MESSAGE_ROUTE_CAUSE_BY:
|
2023-11-08 20:27:18 +08:00
|
|
|
new_val = any_to_str(val)
|
|
|
|
|
elif key == MESSAGE_ROUTE_FROM:
|
|
|
|
|
new_val = any_to_str(val)
|
|
|
|
|
elif key == MESSAGE_ROUTE_TO:
|
|
|
|
|
new_val = any_to_str_set(val)
|
|
|
|
|
else:
|
|
|
|
|
new_val = val
|
|
|
|
|
super().__setattr__(key, new_val)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
# prefix = '-'.join([self.role, str(self.cause_by)])
|
2023-12-22 16:40:04 +08:00
|
|
|
if self.instruct_content:
|
2023-12-26 14:44:09 +08:00
|
|
|
return f"{self.role}: {self.instruct_content.model_dump()}"
|
2023-06-30 17:10:48 +08:00
|
|
|
return f"{self.role}: {self.content}"
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.__str__()
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> dict:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
2023-10-31 15:23:37 +08:00
|
|
|
return {"role": self.role, "content": self.content}
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
def dump(self) -> str:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the object to json string"""
|
2023-12-27 16:34:43 +08:00
|
|
|
return self.model_dump_json(exclude_none=True, warnings=False)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-19 16:16:52 +08:00
|
|
|
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
2023-11-08 20:27:18 +08:00
|
|
|
def load(val):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the json string to object."""
|
2023-12-22 16:40:04 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
try:
|
2023-12-19 10:44:06 +08:00
|
|
|
m = json.loads(val)
|
|
|
|
|
id = m.get("id")
|
|
|
|
|
if "id" in m:
|
|
|
|
|
del m["id"]
|
|
|
|
|
msg = Message(**m)
|
|
|
|
|
if id:
|
|
|
|
|
msg.id = id
|
|
|
|
|
return msg
|
2023-10-31 15:23:37 +08:00
|
|
|
except JSONDecodeError as err:
|
2023-11-08 20:27:18 +08:00
|
|
|
logger.error(f"parse json failed: {val}, error:{err}")
|
2023-10-31 15:23:37 +08:00
|
|
|
return None
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="user")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SystemMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="system")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="assistant")
|
|
|
|
|
|
|
|
|
|
|
2023-12-19 14:22:52 +08:00
|
|
|
class MessageQueue(BaseModel):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Message queue which supports asynchronous updates."""
|
|
|
|
|
|
2023-12-26 14:44:09 +08:00
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
2023-12-19 14:22:52 +08:00
|
|
|
|
2023-12-26 14:44:09 +08:00
|
|
|
_queue: Queue = PrivateAttr(default_factory=Queue)
|
2023-12-19 14:22:52 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
def pop(self) -> Message | None:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop one message from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
|
|
|
|
item = self._queue.get_nowait()
|
|
|
|
|
if item:
|
|
|
|
|
self._queue.task_done()
|
|
|
|
|
return item
|
|
|
|
|
except QueueEmpty:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def pop_all(self) -> List[Message]:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop all messages from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
ret = []
|
|
|
|
|
while True:
|
|
|
|
|
msg = self.pop()
|
|
|
|
|
if not msg:
|
|
|
|
|
break
|
|
|
|
|
ret.append(msg)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def push(self, msg: Message):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Push a message into the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
self._queue.put_nowait(msg)
|
|
|
|
|
|
|
|
|
|
def empty(self):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Return true if the queue is empty."""
|
2023-11-01 20:08:58 +08:00
|
|
|
return self._queue.empty()
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
async def dump(self) -> str:
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the `MessageQueue` object to a json string."""
|
2023-11-01 20:08:58 +08:00
|
|
|
if self.empty():
|
|
|
|
|
return "[]"
|
|
|
|
|
|
|
|
|
|
lst = []
|
2023-12-29 14:52:21 +08:00
|
|
|
msgs = []
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
item = await wait_for(self._queue.get(), timeout=1.0)
|
|
|
|
|
if item is None:
|
|
|
|
|
break
|
2023-12-29 14:52:21 +08:00
|
|
|
msgs.append(item)
|
|
|
|
|
lst.append(item.dump())
|
2023-11-01 20:08:58 +08:00
|
|
|
self._queue.task_done()
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.debug("Queue is empty, exiting...")
|
2023-12-29 14:52:21 +08:00
|
|
|
finally:
|
|
|
|
|
for m in msgs:
|
|
|
|
|
self._queue.put_nowait(m)
|
|
|
|
|
return json.dumps(lst, ensure_ascii=False)
|
2023-11-01 20:08:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-19 17:55:34 +08:00
|
|
|
def load(data) -> "MessageQueue":
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the json string to the `MessageQueue` object."""
|
2023-12-19 16:16:52 +08:00
|
|
|
queue = MessageQueue()
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
2023-12-19 17:55:34 +08:00
|
|
|
lst = json.loads(data)
|
2023-11-01 20:08:58 +08:00
|
|
|
for i in lst:
|
2023-12-29 14:52:21 +08:00
|
|
|
msg = Message.load(i)
|
2023-12-19 16:16:52 +08:00
|
|
|
queue.push(msg)
|
2023-11-01 20:08:58 +08:00
|
|
|
except JSONDecodeError as e:
|
2023-12-19 17:55:34 +08:00
|
|
|
logger.warning(f"JSON load failed: {data}, error:{e}")
|
2023-11-01 20:08:58 +08:00
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
return queue
|
2023-11-23 17:49:38 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
# 定义一个泛型类型变量
|
|
|
|
|
T = TypeVar("T", bound="BaseModel")
|
|
|
|
|
|
|
|
|
|
|
2023-12-19 23:53:04 +08:00
|
|
|
class BaseContext(BaseModel, ABC):
|
2023-12-19 16:31:38 +08:00
|
|
|
@classmethod
|
2023-12-19 16:16:52 +08:00
|
|
|
@handle_exception
|
2023-12-19 16:31:38 +08:00
|
|
|
def loads(cls: Type[T], val: str) -> Optional[T]:
|
|
|
|
|
i = json.loads(val)
|
|
|
|
|
return cls(**i)
|
2023-11-23 17:49:38 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class CodingContext(BaseContext):
|
2023-11-23 17:49:38 +08:00
|
|
|
filename: str
|
2023-12-26 14:44:09 +08:00
|
|
|
design_doc: Optional[Document] = None
|
|
|
|
|
task_doc: Optional[Document] = None
|
|
|
|
|
code_doc: Optional[Document] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class TestingContext(BaseContext):
|
2023-11-23 22:41:44 +08:00
|
|
|
filename: str
|
|
|
|
|
code_doc: Document
|
2023-12-26 14:44:09 +08:00
|
|
|
test_doc: Optional[Document] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class RunCodeContext(BaseContext):
|
2023-11-23 22:41:44 +08:00
|
|
|
mode: str = "script"
|
2023-12-26 14:44:09 +08:00
|
|
|
code: Optional[str] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
code_filename: str = ""
|
2023-12-26 14:44:09 +08:00
|
|
|
test_code: Optional[str] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
test_filename: str = ""
|
|
|
|
|
command: List[str] = Field(default_factory=list)
|
|
|
|
|
working_directory: str = ""
|
|
|
|
|
additional_python_paths: List[str] = Field(default_factory=list)
|
2023-12-26 14:44:09 +08:00
|
|
|
output_filename: Optional[str] = None
|
|
|
|
|
output: Optional[str] = None
|
2023-11-24 13:30:00 +08:00
|
|
|
|
2023-11-24 19:56:27 +08:00
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class RunCodeResult(BaseContext):
|
2023-11-24 19:56:27 +08:00
|
|
|
summary: str
|
|
|
|
|
stdout: str
|
|
|
|
|
stderr: str
|
|
|
|
|
|
2023-11-28 18:16:50 +08:00
|
|
|
|
|
|
|
|
class CodeSummarizeContext(BaseModel):
|
|
|
|
|
design_filename: str = ""
|
|
|
|
|
task_filename: str = ""
|
2023-12-04 23:04:07 +08:00
|
|
|
codes_filenames: List[str] = Field(default_factory=list)
|
|
|
|
|
reason: str = ""
|
2023-11-28 18:16:50 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-04 23:04:07 +08:00
|
|
|
def loads(filenames: List) -> CodeSummarizeContext:
|
2023-11-28 18:16:50 +08:00
|
|
|
ctx = CodeSummarizeContext()
|
|
|
|
|
for filename in filenames:
|
|
|
|
|
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
|
|
|
|
ctx.design_filename = str(filename)
|
|
|
|
|
continue
|
|
|
|
|
if Path(filename).is_relative_to(TASK_FILE_REPO):
|
|
|
|
|
ctx.task_filename = str(filename)
|
|
|
|
|
continue
|
|
|
|
|
return ctx
|
2023-12-04 23:04:07 +08:00
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash((self.design_filename, self.task_filename))
|
2023-12-12 21:32:03 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class BugFixContext(BaseContext):
|
2023-12-12 21:32:03 +08:00
|
|
|
filename: str = ""
|
2024-01-02 23:09:09 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# mermaid class view
|
|
|
|
|
class ClassMeta(BaseModel):
|
|
|
|
|
name: str = ""
|
|
|
|
|
abstraction: bool = False
|
|
|
|
|
static: bool = False
|
|
|
|
|
visibility: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassAttribute(ClassMeta):
|
|
|
|
|
value_type: str = ""
|
|
|
|
|
default_value: str = ""
|
|
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + self.visibility
|
|
|
|
|
if self.value_type:
|
|
|
|
|
content += self.value_type + " "
|
|
|
|
|
content += self.name
|
|
|
|
|
if self.default_value:
|
|
|
|
|
content += "="
|
|
|
|
|
if self.value_type not in ["str", "string", "String"]:
|
|
|
|
|
content += self.default_value
|
|
|
|
|
else:
|
|
|
|
|
content += '"' + self.default_value.replace('"', "") + '"'
|
|
|
|
|
if self.abstraction:
|
|
|
|
|
content += "*"
|
|
|
|
|
if self.static:
|
|
|
|
|
content += "$"
|
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassMethod(ClassMeta):
|
|
|
|
|
args: List[ClassAttribute] = Field(default_factory=list)
|
|
|
|
|
return_type: str = ""
|
|
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + self.visibility
|
|
|
|
|
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
|
|
|
|
if self.return_type:
|
|
|
|
|
content += ":" + self.return_type
|
|
|
|
|
if self.abstraction:
|
|
|
|
|
content += "*"
|
|
|
|
|
if self.static:
|
|
|
|
|
content += "$"
|
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassView(ClassMeta):
|
|
|
|
|
attributes: List[ClassAttribute] = Field(default_factory=list)
|
|
|
|
|
methods: List[ClassMethod] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
|
|
|
|
for v in self.attributes:
|
|
|
|
|
content += v.get_mermaid(align=align + 1) + "\n"
|
|
|
|
|
for v in self.methods:
|
|
|
|
|
content += v.get_mermaid(align=align + 1) + "\n"
|
|
|
|
|
content += "".join(["\t" for i in range(align)]) + "}\n"
|
|
|
|
|
return content
|