2023-06-30 17:10:48 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
@Time : 2023/5/8 22:12
|
|
|
|
|
@Author : alexanderwu
|
|
|
|
|
@File : schema.py
|
2023-11-03 11:53:47 +08:00
|
|
|
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
|
|
|
|
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
2023-11-27 16:15:55 +08:00
|
|
|
@Modified By: mashenquan, 2023/11/22.
|
|
|
|
|
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
|
|
|
|
|
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
|
|
|
|
|
between actions.
|
2023-11-29 10:14:04 +08:00
|
|
|
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
2023-06-30 17:10:48 +08:00
|
|
|
"""
|
2023-07-22 11:28:22 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
from __future__ import annotations
|
2023-07-22 11:28:22 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
import asyncio
|
2023-10-31 15:23:37 +08:00
|
|
|
import json
|
2023-11-22 17:08:00 +08:00
|
|
|
import os.path
|
2023-11-29 10:14:04 +08:00
|
|
|
import uuid
|
2023-12-19 23:53:04 +08:00
|
|
|
from abc import ABC
|
2023-11-01 20:08:58 +08:00
|
|
|
from asyncio import Queue, QueueEmpty, wait_for
|
2023-10-31 15:23:37 +08:00
|
|
|
from json import JSONDecodeError
|
2023-11-28 18:16:50 +08:00
|
|
|
from pathlib import Path
|
2024-01-15 16:37:42 +08:00
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
2023-12-27 14:00:54 +08:00
|
|
|
|
|
|
|
|
from pydantic import (
|
|
|
|
|
BaseModel,
|
|
|
|
|
ConfigDict,
|
|
|
|
|
Field,
|
|
|
|
|
PrivateAttr,
|
|
|
|
|
field_serializer,
|
|
|
|
|
field_validator,
|
2024-01-08 22:15:56 +08:00
|
|
|
model_serializer,
|
|
|
|
|
model_validator,
|
2023-12-27 14:00:54 +08:00
|
|
|
)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
from metagpt.const import (
|
|
|
|
|
MESSAGE_ROUTE_CAUSE_BY,
|
|
|
|
|
MESSAGE_ROUTE_FROM,
|
|
|
|
|
MESSAGE_ROUTE_TO,
|
2023-11-06 22:38:43 +08:00
|
|
|
MESSAGE_ROUTE_TO_ALL,
|
2024-01-23 19:11:58 +08:00
|
|
|
PRDS_FILE_REPO,
|
2023-11-28 18:16:50 +08:00
|
|
|
SYSTEM_DESIGN_FILE_REPO,
|
|
|
|
|
TASK_FILE_REPO,
|
2023-11-01 20:08:58 +08:00
|
|
|
)
|
2023-07-22 11:28:22 +08:00
|
|
|
from metagpt.logs import logger
|
2024-01-26 19:39:06 +08:00
|
|
|
from metagpt.repo_parser import DotClassInfo
|
2023-12-20 10:54:49 +08:00
|
|
|
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
2023-12-19 16:16:52 +08:00
|
|
|
from metagpt.utils.exceptions import handle_exception
|
2023-12-21 10:48:46 +08:00
|
|
|
from metagpt.utils.serialize import (
|
|
|
|
|
actionoutout_schema_to_mapping,
|
|
|
|
|
actionoutput_mapping_to_str,
|
|
|
|
|
actionoutput_str_to_mapping,
|
|
|
|
|
)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
class SerializationMixin(BaseModel, extra="forbid"):
|
2024-01-02 15:26:23 +08:00
|
|
|
"""
|
|
|
|
|
PolyMorphic subclasses Serialization / Deserialization Mixin
|
|
|
|
|
- First of all, we need to know that pydantic is not designed for polymorphism.
|
|
|
|
|
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
|
|
|
|
|
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
|
|
|
|
|
|
|
|
|
|
More details:
|
|
|
|
|
- https://docs.pydantic.dev/latest/concepts/serialization/
|
|
|
|
|
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
|
|
|
|
|
"""
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
__is_polymorphic_base = False
|
|
|
|
|
__subclasses_map__ = {}
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
@model_serializer(mode="wrap")
|
|
|
|
|
def __serialize_with_class_type__(self, default_serializer) -> Any:
|
|
|
|
|
# default serializer, then append the `__module_class_name` field and return
|
|
|
|
|
ret = default_serializer(self)
|
|
|
|
|
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
2023-12-28 16:07:39 +08:00
|
|
|
return ret
|
|
|
|
|
|
2024-01-08 22:15:56 +08:00
|
|
|
@model_validator(mode="wrap")
|
2023-12-28 16:07:39 +08:00
|
|
|
@classmethod
|
2024-01-08 22:15:56 +08:00
|
|
|
def __convert_to_real_type__(cls, value: Any, handler):
|
|
|
|
|
if isinstance(value, dict) is False:
|
|
|
|
|
return handler(value)
|
|
|
|
|
|
|
|
|
|
# it is a dict so make sure to remove the __module_class_name
|
|
|
|
|
# because we don't allow extra keywords but want to ensure
|
|
|
|
|
# e.g Cat.model_validate(cat.model_dump()) works
|
|
|
|
|
class_full_name = value.pop("__module_class_name", None)
|
|
|
|
|
|
|
|
|
|
# if it's not the polymorphic base we construct via default handler
|
|
|
|
|
if not cls.__is_polymorphic_base:
|
|
|
|
|
if class_full_name is None:
|
|
|
|
|
return handler(value)
|
|
|
|
|
elif str(cls) == f"<class '{class_full_name}'>":
|
|
|
|
|
return handler(value)
|
|
|
|
|
else:
|
|
|
|
|
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# otherwise we lookup the correct polymorphic type and construct that
|
|
|
|
|
# instead
|
|
|
|
|
if class_full_name is None:
|
|
|
|
|
raise ValueError("Missing __module_class_name field")
|
|
|
|
|
|
|
|
|
|
class_type = cls.__subclasses_map__.get(class_full_name, None)
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
if class_type is None:
|
2024-01-08 22:15:56 +08:00
|
|
|
# TODO could try dynamic import
|
|
|
|
|
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
|
2023-12-28 16:07:39 +08:00
|
|
|
|
|
|
|
|
return class_type(**value)
|
|
|
|
|
|
|
|
|
|
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
|
|
|
|
cls.__is_polymorphic_base = is_polymorphic_base
|
|
|
|
|
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
|
|
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
2023-12-25 22:39:03 +08:00
|
|
|
class SimpleMessage(BaseModel):
|
2023-06-30 17:10:48 +08:00
|
|
|
content: str
|
|
|
|
|
role: str
|
|
|
|
|
|
|
|
|
|
|
2023-11-22 17:08:00 +08:00
|
|
|
class Document(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Represents a document.
|
|
|
|
|
"""
|
|
|
|
|
|
2023-12-04 23:04:07 +08:00
|
|
|
root_path: str = ""
|
|
|
|
|
filename: str = ""
|
|
|
|
|
content: str = ""
|
2023-11-22 17:08:00 +08:00
|
|
|
|
|
|
|
|
def get_meta(self) -> Document:
|
|
|
|
|
"""Get metadata of the document.
|
|
|
|
|
|
|
|
|
|
:return: A new Document instance with the same root path and filename.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return Document(root_path=self.root_path, filename=self.filename)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def root_relative_path(self):
|
|
|
|
|
"""Get relative path from root of git repository.
|
|
|
|
|
|
|
|
|
|
:return: relative path from root of git repository.
|
|
|
|
|
"""
|
|
|
|
|
return os.path.join(self.root_path, self.filename)
|
|
|
|
|
|
2023-12-15 00:37:10 +08:00
|
|
|
def __str__(self):
|
|
|
|
|
return self.content
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.content
|
|
|
|
|
|
2023-11-22 17:08:00 +08:00
|
|
|
|
|
|
|
|
class Documents(BaseModel):
|
|
|
|
|
"""A class representing a collection of documents.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
docs: Dict[str, Document] = Field(default_factory=dict)
|
|
|
|
|
|
2024-01-15 16:37:42 +08:00
|
|
|
@classmethod
|
|
|
|
|
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
|
|
|
|
|
"""Create a Documents instance from a list of Document instances.
|
|
|
|
|
|
|
|
|
|
:param documents: A list of Document instances.
|
|
|
|
|
:return: A Documents instance.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
docs = {doc.filename: doc for doc in documents}
|
|
|
|
|
return Documents(docs=docs)
|
|
|
|
|
|
2024-01-15 16:41:51 +08:00
|
|
|
def to_action_output(self) -> "ActionOutput":
|
2024-01-15 16:37:42 +08:00
|
|
|
"""Convert to action output string.
|
|
|
|
|
|
|
|
|
|
:return: A string representing action output.
|
|
|
|
|
"""
|
2024-01-15 16:41:51 +08:00
|
|
|
from metagpt.actions.action_output import ActionOutput
|
2024-01-15 16:37:42 +08:00
|
|
|
|
|
|
|
|
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
|
|
|
|
|
|
2023-11-22 17:08:00 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
class Message(BaseModel):
|
2023-06-30 17:10:48 +08:00
|
|
|
"""list[<role>: <content>]"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-12-27 14:00:54 +08:00
|
|
|
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
2023-06-30 17:10:48 +08:00
|
|
|
content: str
|
2023-12-27 14:00:54 +08:00
|
|
|
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
2023-11-08 20:27:18 +08:00
|
|
|
role: str = "user" # system / user / assistant
|
2023-12-27 14:00:54 +08:00
|
|
|
cause_by: str = Field(default="", validate_default=True)
|
|
|
|
|
sent_from: str = Field(default="", validate_default=True)
|
2023-12-29 04:27:44 +08:00
|
|
|
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
2023-12-27 14:00:54 +08:00
|
|
|
|
|
|
|
|
@field_validator("id", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_id(cls, id: str) -> str:
|
|
|
|
|
return id if id else uuid.uuid4().hex
|
|
|
|
|
|
|
|
|
|
@field_validator("instruct_content", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
2024-01-09 15:40:42 +08:00
|
|
|
if ic and isinstance(ic, dict) and "class" in ic:
|
|
|
|
|
if "mapping" in ic:
|
|
|
|
|
# compatible with custom-defined ActionOutput
|
|
|
|
|
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
|
|
|
|
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
|
|
|
|
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
|
|
|
|
elif "module" in ic:
|
2024-01-09 16:07:33 +08:00
|
|
|
# subclasses of BaseModel
|
2024-01-09 15:40:42 +08:00
|
|
|
ic_obj = import_class(ic["class"], ic["module"])
|
|
|
|
|
else:
|
|
|
|
|
raise KeyError("missing required key to init Message.instruct_content from dict")
|
2023-12-27 14:00:54 +08:00
|
|
|
ic = ic_obj(**ic["value"])
|
|
|
|
|
return ic
|
|
|
|
|
|
|
|
|
|
@field_validator("cause_by", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_cause_by(cls, cause_by: Any) -> str:
|
|
|
|
|
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
|
|
|
|
|
|
|
|
|
@field_validator("sent_from", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_sent_from(cls, sent_from: Any) -> str:
|
|
|
|
|
return any_to_str(sent_from if sent_from else "")
|
|
|
|
|
|
|
|
|
|
@field_validator("send_to", mode="before")
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_send_to(cls, send_to: Any) -> set:
|
|
|
|
|
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
|
|
|
|
|
2024-03-07 19:05:46 +08:00
|
|
|
@field_serializer("send_to", mode="plain")
|
|
|
|
|
def ser_send_to(self, send_to: set) -> list:
|
|
|
|
|
return list(send_to)
|
|
|
|
|
|
2023-12-27 14:00:54 +08:00
|
|
|
@field_serializer("instruct_content", mode="plain")
|
2024-01-15 20:10:39 +08:00
|
|
|
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
2023-12-27 14:00:54 +08:00
|
|
|
ic_dict = None
|
|
|
|
|
if ic:
|
|
|
|
|
# compatible with custom-defined ActionOutput
|
|
|
|
|
schema = ic.model_json_schema()
|
2024-01-09 15:40:42 +08:00
|
|
|
ic_type = str(type(ic))
|
|
|
|
|
if "<class 'metagpt.actions.action_node" in ic_type:
|
|
|
|
|
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
|
2023-12-27 14:00:54 +08:00
|
|
|
mapping = actionoutout_schema_to_mapping(schema)
|
|
|
|
|
mapping = actionoutput_mapping_to_str(mapping)
|
|
|
|
|
|
|
|
|
|
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
2024-01-09 15:40:42 +08:00
|
|
|
else:
|
|
|
|
|
# due to instruct_content can be assigned by subclasses of BaseModel
|
|
|
|
|
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
|
2023-12-27 14:00:54 +08:00
|
|
|
return ic_dict
|
|
|
|
|
|
|
|
|
|
def __init__(self, content: str = "", **data: Any):
|
|
|
|
|
data["content"] = data.get("content", content)
|
|
|
|
|
super().__init__(**data)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-11-04 16:20:47 +08:00
|
|
|
def __setattr__(self, key, val):
|
2023-11-08 20:27:18 +08:00
|
|
|
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
2023-11-04 16:20:47 +08:00
|
|
|
if key == MESSAGE_ROUTE_CAUSE_BY:
|
2023-11-08 20:27:18 +08:00
|
|
|
new_val = any_to_str(val)
|
|
|
|
|
elif key == MESSAGE_ROUTE_FROM:
|
|
|
|
|
new_val = any_to_str(val)
|
|
|
|
|
elif key == MESSAGE_ROUTE_TO:
|
|
|
|
|
new_val = any_to_str_set(val)
|
|
|
|
|
else:
|
|
|
|
|
new_val = val
|
|
|
|
|
super().__setattr__(key, new_val)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
# prefix = '-'.join([self.role, str(self.cause_by)])
|
2023-12-22 16:40:04 +08:00
|
|
|
if self.instruct_content:
|
2023-12-26 14:44:09 +08:00
|
|
|
return f"{self.role}: {self.instruct_content.model_dump()}"
|
2023-06-30 17:10:48 +08:00
|
|
|
return f"{self.role}: {self.content}"
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.__str__()
|
|
|
|
|
|
2024-03-07 19:05:46 +08:00
|
|
|
def rag_key(self) -> str:
|
|
|
|
|
"""For search"""
|
|
|
|
|
return self.content
|
|
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def to_dict(self) -> dict:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
2023-10-31 15:23:37 +08:00
|
|
|
return {"role": self.role, "content": self.content}
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
def dump(self) -> str:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the object to json string"""
|
2023-12-27 16:34:43 +08:00
|
|
|
return self.model_dump_json(exclude_none=True, warnings=False)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-19 16:16:52 +08:00
|
|
|
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
2023-11-08 20:27:18 +08:00
|
|
|
def load(val):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the json string to object."""
|
2023-12-22 16:40:04 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
try:
|
2023-12-19 10:44:06 +08:00
|
|
|
m = json.loads(val)
|
|
|
|
|
id = m.get("id")
|
|
|
|
|
if "id" in m:
|
|
|
|
|
del m["id"]
|
|
|
|
|
msg = Message(**m)
|
|
|
|
|
if id:
|
|
|
|
|
msg.id = id
|
|
|
|
|
return msg
|
2023-10-31 15:23:37 +08:00
|
|
|
except JSONDecodeError as err:
|
2023-11-08 20:27:18 +08:00
|
|
|
logger.error(f"parse json failed: {val}, error:{err}")
|
2023-10-31 15:23:37 +08:00
|
|
|
return None
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="user")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SystemMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="system")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="assistant")
|
|
|
|
|
|
|
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
class Task(BaseModel):
|
|
|
|
|
task_id: str = ""
|
2024-01-10 14:15:30 +08:00
|
|
|
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
|
2023-11-23 21:59:25 +08:00
|
|
|
instruction: str = ""
|
|
|
|
|
task_type: str = ""
|
|
|
|
|
code: str = ""
|
|
|
|
|
result: str = ""
|
2023-12-28 20:17:33 +08:00
|
|
|
is_success: bool = False
|
2023-11-23 21:59:25 +08:00
|
|
|
is_finished: bool = False
|
|
|
|
|
|
2024-02-01 20:07:44 +08:00
|
|
|
def reset(self):
|
|
|
|
|
self.code = ""
|
|
|
|
|
self.result = ""
|
|
|
|
|
self.is_success = False
|
|
|
|
|
self.is_finished = False
|
|
|
|
|
|
|
|
|
|
def update_task_result(self, task_result: TaskResult):
|
|
|
|
|
self.code = task_result.code
|
|
|
|
|
self.result = task_result.result
|
|
|
|
|
self.is_success = task_result.is_success
|
|
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
|
2024-01-09 16:54:36 +08:00
|
|
|
class TaskResult(BaseModel):
|
|
|
|
|
"""Result of taking a task, with result and is_success required to be filled"""
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2024-01-09 16:54:36 +08:00
|
|
|
code: str = ""
|
|
|
|
|
result: str
|
|
|
|
|
is_success: bool
|
|
|
|
|
|
|
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
class Plan(BaseModel):
|
2023-11-24 14:05:11 +08:00
|
|
|
goal: str
|
2023-12-01 00:44:47 +08:00
|
|
|
context: str = ""
|
2023-11-23 21:59:25 +08:00
|
|
|
tasks: list[Task] = []
|
|
|
|
|
task_map: dict[str, Task] = {}
|
2024-01-10 17:20:01 +08:00
|
|
|
current_task_id: str = ""
|
2023-11-23 21:59:25 +08:00
|
|
|
|
|
|
|
|
def _topological_sort(self, tasks: list[Task]):
|
|
|
|
|
task_map = {task.task_id: task for task in tasks}
|
|
|
|
|
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
|
|
|
|
|
sorted_tasks = []
|
|
|
|
|
visited = set()
|
|
|
|
|
|
|
|
|
|
def visit(task_id):
|
|
|
|
|
if task_id in visited:
|
|
|
|
|
return
|
|
|
|
|
visited.add(task_id)
|
|
|
|
|
for dependent_id in dependencies.get(task_id, []):
|
|
|
|
|
visit(dependent_id)
|
|
|
|
|
sorted_tasks.append(task_map[task_id])
|
|
|
|
|
|
|
|
|
|
for task in tasks:
|
|
|
|
|
visit(task.task_id)
|
|
|
|
|
|
|
|
|
|
return sorted_tasks
|
|
|
|
|
|
|
|
|
|
def add_tasks(self, tasks: list[Task]):
|
|
|
|
|
"""
|
|
|
|
|
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
This method performs two primary functions based on the current state of the task list:
|
2024-01-10 14:15:30 +08:00
|
|
|
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
|
2023-11-23 21:59:25 +08:00
|
|
|
correct execution order based on dependencies, and sets these as the current tasks.
|
2024-01-10 14:15:30 +08:00
|
|
|
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
|
|
|
|
|
any common prefix of tasks (based on task_id and instruction) and appends the remainder
|
2023-11-23 21:59:25 +08:00
|
|
|
of the new tasks. The current task is updated to the first unfinished task in this merged list.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None: The method updates the internal state of the plan but does not return anything.
|
|
|
|
|
"""
|
|
|
|
|
if not tasks:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Topologically sort the new tasks to ensure correct dependency order
|
|
|
|
|
new_tasks = self._topological_sort(tasks)
|
|
|
|
|
|
|
|
|
|
if not self.tasks:
|
|
|
|
|
# If there are no existing tasks, set the new tasks as the current tasks
|
|
|
|
|
self.tasks = new_tasks
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# Find the length of the common prefix between existing and new tasks
|
|
|
|
|
prefix_length = 0
|
|
|
|
|
for old_task, new_task in zip(self.tasks, new_tasks):
|
|
|
|
|
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
|
|
|
|
|
break
|
|
|
|
|
prefix_length += 1
|
|
|
|
|
|
|
|
|
|
# Combine the common prefix with the remainder of the new tasks
|
|
|
|
|
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
|
|
|
|
|
self.tasks = final_tasks
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
# Update current_task_id to the first unfinished task in the merged list
|
2023-12-11 16:13:34 +08:00
|
|
|
self._update_current_task()
|
2023-11-23 21:59:25 +08:00
|
|
|
|
|
|
|
|
# Update the task map for quick access to tasks by ID
|
|
|
|
|
self.task_map = {task.task_id: task for task in self.tasks}
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-12-02 01:34:22 +08:00
|
|
|
def reset_task(self, task_id: str):
|
|
|
|
|
"""
|
|
|
|
|
Clear code and result of the task based on task_id, and set the task as unfinished.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task_id (str): The ID of the task to be reset.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
if task_id in self.task_map:
|
|
|
|
|
task = self.task_map[task_id]
|
2024-02-01 20:07:44 +08:00
|
|
|
task.reset()
|
2023-12-02 01:34:22 +08:00
|
|
|
|
|
|
|
|
def replace_task(self, new_task: Task):
|
|
|
|
|
"""
|
|
|
|
|
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
new_task (Task): The new task that will replace an existing one.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
2023-12-28 20:17:33 +08:00
|
|
|
assert new_task.task_id in self.task_map
|
|
|
|
|
# Replace the task in the task map and the task list
|
|
|
|
|
self.task_map[new_task.task_id] = new_task
|
|
|
|
|
for i, task in enumerate(self.tasks):
|
|
|
|
|
if task.task_id == new_task.task_id:
|
|
|
|
|
self.tasks[i] = new_task
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Reset dependent tasks
|
|
|
|
|
for task in self.tasks:
|
|
|
|
|
if new_task.task_id in task.dependent_task_ids:
|
|
|
|
|
self.reset_task(task.task_id)
|
2023-11-23 21:59:25 +08:00
|
|
|
|
2023-12-11 16:13:34 +08:00
|
|
|
def append_task(self, new_task: Task):
|
|
|
|
|
"""
|
|
|
|
|
Append a new task to the end of existing task sequences
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
new_task (Task): The new task to be appended to the existing task sequence
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-12-11 16:13:34 +08:00
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
|
|
|
|
|
|
2024-01-10 14:15:30 +08:00
|
|
|
assert all(
|
|
|
|
|
[self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]
|
|
|
|
|
), "New task has unknown dependencies"
|
2023-12-11 16:13:34 +08:00
|
|
|
|
|
|
|
|
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
|
|
|
|
|
self.tasks.append(new_task)
|
|
|
|
|
self.task_map[new_task.task_id] = new_task
|
|
|
|
|
self._update_current_task()
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-12-02 01:34:22 +08:00
|
|
|
def has_task_id(self, task_id: str) -> bool:
|
|
|
|
|
return task_id in self.task_map
|
2023-12-11 16:13:34 +08:00
|
|
|
|
|
|
|
|
def _update_current_task(self):
|
|
|
|
|
current_task_id = ""
|
|
|
|
|
for task in self.tasks:
|
|
|
|
|
if not task.is_finished:
|
|
|
|
|
current_task_id = task.task_id
|
|
|
|
|
break
|
|
|
|
|
self.current_task_id = current_task_id # all tasks finished
|
2024-01-10 14:15:30 +08:00
|
|
|
|
2023-11-23 21:59:25 +08:00
|
|
|
@property
|
|
|
|
|
def current_task(self) -> Task:
|
|
|
|
|
"""Find current task to execute
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Task: the current task to be executed
|
|
|
|
|
"""
|
|
|
|
|
return self.task_map.get(self.current_task_id, None)
|
|
|
|
|
|
|
|
|
|
def finish_current_task(self):
|
2024-01-10 14:15:30 +08:00
|
|
|
"""Finish current task, set Task.is_finished=True, set current task to next task"""
|
2023-11-23 21:59:25 +08:00
|
|
|
if self.current_task_id:
|
2023-12-11 16:13:34 +08:00
|
|
|
self.current_task.is_finished = True
|
|
|
|
|
self._update_current_task() # set to next task
|
2023-11-23 21:59:25 +08:00
|
|
|
|
|
|
|
|
def get_finished_tasks(self) -> list[Task]:
|
|
|
|
|
"""return all finished tasks in correct linearized order
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list[Task]: list of finished tasks
|
|
|
|
|
"""
|
|
|
|
|
return [task for task in self.tasks if task.is_finished]
|
|
|
|
|
|
|
|
|
|
|
2023-12-19 14:22:52 +08:00
|
|
|
class MessageQueue(BaseModel):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Message queue which supports asynchronous updates."""
|
|
|
|
|
|
2023-12-26 14:44:09 +08:00
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
2023-12-19 14:22:52 +08:00
|
|
|
|
2023-12-26 14:44:09 +08:00
|
|
|
_queue: Queue = PrivateAttr(default_factory=Queue)
|
2023-12-19 14:22:52 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
def pop(self) -> Message | None:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop one message from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
|
|
|
|
item = self._queue.get_nowait()
|
|
|
|
|
if item:
|
|
|
|
|
self._queue.task_done()
|
|
|
|
|
return item
|
|
|
|
|
except QueueEmpty:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def pop_all(self) -> List[Message]:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop all messages from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
ret = []
|
|
|
|
|
while True:
|
|
|
|
|
msg = self.pop()
|
|
|
|
|
if not msg:
|
|
|
|
|
break
|
|
|
|
|
ret.append(msg)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def push(self, msg: Message):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Push a message into the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
self._queue.put_nowait(msg)
|
|
|
|
|
|
|
|
|
|
def empty(self):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Return true if the queue is empty."""
|
2023-11-01 20:08:58 +08:00
|
|
|
return self._queue.empty()
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
async def dump(self) -> str:
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the `MessageQueue` object to a json string."""
|
2023-11-01 20:08:58 +08:00
|
|
|
if self.empty():
|
|
|
|
|
return "[]"
|
|
|
|
|
|
|
|
|
|
lst = []
|
2023-12-29 14:52:21 +08:00
|
|
|
msgs = []
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
item = await wait_for(self._queue.get(), timeout=1.0)
|
|
|
|
|
if item is None:
|
|
|
|
|
break
|
2023-12-29 14:52:21 +08:00
|
|
|
msgs.append(item)
|
|
|
|
|
lst.append(item.dump())
|
2023-11-01 20:08:58 +08:00
|
|
|
self._queue.task_done()
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.debug("Queue is empty, exiting...")
|
2023-12-29 14:52:21 +08:00
|
|
|
finally:
|
|
|
|
|
for m in msgs:
|
|
|
|
|
self._queue.put_nowait(m)
|
|
|
|
|
return json.dumps(lst, ensure_ascii=False)
|
2023-11-01 20:08:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-19 17:55:34 +08:00
|
|
|
def load(data) -> "MessageQueue":
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the json string to the `MessageQueue` object."""
|
2023-12-19 16:16:52 +08:00
|
|
|
queue = MessageQueue()
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
2023-12-19 17:55:34 +08:00
|
|
|
lst = json.loads(data)
|
2023-11-01 20:08:58 +08:00
|
|
|
for i in lst:
|
2023-12-29 14:52:21 +08:00
|
|
|
msg = Message.load(i)
|
2023-12-19 16:16:52 +08:00
|
|
|
queue.push(msg)
|
2023-11-01 20:08:58 +08:00
|
|
|
except JSONDecodeError as e:
|
2023-12-19 17:55:34 +08:00
|
|
|
logger.warning(f"JSON load failed: {data}, error:{e}")
|
2023-11-01 20:08:58 +08:00
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
return queue
|
2023-11-23 17:49:38 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
# 定义一个泛型类型变量
|
|
|
|
|
T = TypeVar("T", bound="BaseModel")
|
|
|
|
|
|
|
|
|
|
|
2023-12-19 23:53:04 +08:00
|
|
|
class BaseContext(BaseModel, ABC):
|
2023-12-19 16:31:38 +08:00
|
|
|
@classmethod
|
2023-12-19 16:16:52 +08:00
|
|
|
@handle_exception
|
2023-12-19 16:31:38 +08:00
|
|
|
def loads(cls: Type[T], val: str) -> Optional[T]:
|
|
|
|
|
i = json.loads(val)
|
|
|
|
|
return cls(**i)
|
2023-11-23 17:49:38 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class CodingContext(BaseContext):
|
2023-11-23 17:49:38 +08:00
|
|
|
filename: str
|
2023-12-26 14:44:09 +08:00
|
|
|
design_doc: Optional[Document] = None
|
|
|
|
|
task_doc: Optional[Document] = None
|
|
|
|
|
code_doc: Optional[Document] = None
|
2024-02-04 17:23:00 +08:00
|
|
|
code_plan_and_change_doc: Optional[Document] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class TestingContext(BaseContext):
|
2023-11-23 22:41:44 +08:00
|
|
|
filename: str
|
|
|
|
|
code_doc: Document
|
2023-12-26 14:44:09 +08:00
|
|
|
test_doc: Optional[Document] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class RunCodeContext(BaseContext):
|
2023-11-23 22:41:44 +08:00
|
|
|
mode: str = "script"
|
2023-12-26 14:44:09 +08:00
|
|
|
code: Optional[str] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
code_filename: str = ""
|
2023-12-26 14:44:09 +08:00
|
|
|
test_code: Optional[str] = None
|
2023-11-23 22:41:44 +08:00
|
|
|
test_filename: str = ""
|
|
|
|
|
command: List[str] = Field(default_factory=list)
|
|
|
|
|
working_directory: str = ""
|
|
|
|
|
additional_python_paths: List[str] = Field(default_factory=list)
|
2023-12-26 14:44:09 +08:00
|
|
|
output_filename: Optional[str] = None
|
|
|
|
|
output: Optional[str] = None
|
2023-11-24 13:30:00 +08:00
|
|
|
|
2023-11-24 19:56:27 +08:00
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class RunCodeResult(BaseContext):
|
2023-11-24 19:56:27 +08:00
|
|
|
summary: str
|
|
|
|
|
stdout: str
|
|
|
|
|
stderr: str
|
|
|
|
|
|
2023-11-28 18:16:50 +08:00
|
|
|
|
|
|
|
|
class CodeSummarizeContext(BaseModel):
|
|
|
|
|
design_filename: str = ""
|
|
|
|
|
task_filename: str = ""
|
2023-12-04 23:04:07 +08:00
|
|
|
codes_filenames: List[str] = Field(default_factory=list)
|
|
|
|
|
reason: str = ""
|
2023-11-28 18:16:50 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2023-12-04 23:04:07 +08:00
|
|
|
def loads(filenames: List) -> CodeSummarizeContext:
|
2023-11-28 18:16:50 +08:00
|
|
|
ctx = CodeSummarizeContext()
|
|
|
|
|
for filename in filenames:
|
|
|
|
|
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
|
|
|
|
ctx.design_filename = str(filename)
|
|
|
|
|
continue
|
|
|
|
|
if Path(filename).is_relative_to(TASK_FILE_REPO):
|
|
|
|
|
ctx.task_filename = str(filename)
|
|
|
|
|
continue
|
|
|
|
|
return ctx
|
2023-12-04 23:04:07 +08:00
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash((self.design_filename, self.task_filename))
|
2023-12-12 21:32:03 +08:00
|
|
|
|
|
|
|
|
|
2023-12-19 16:16:52 +08:00
|
|
|
class BugFixContext(BaseContext):
|
2023-12-12 21:32:03 +08:00
|
|
|
filename: str = ""
|
2024-01-02 23:09:09 +08:00
|
|
|
|
|
|
|
|
|
2024-01-23 19:11:58 +08:00
|
|
|
class CodePlanAndChangeContext(BaseModel):
|
|
|
|
|
requirement: str = ""
|
2024-03-25 17:09:02 +08:00
|
|
|
issue: str = ""
|
2024-01-23 19:11:58 +08:00
|
|
|
prd_filename: str = ""
|
|
|
|
|
design_filename: str = ""
|
|
|
|
|
task_filename: str = ""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
|
2024-03-25 17:09:02 +08:00
|
|
|
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", ""))
|
2024-01-23 19:11:58 +08:00
|
|
|
for filename in filenames:
|
|
|
|
|
filename = Path(filename)
|
|
|
|
|
if filename.is_relative_to(PRDS_FILE_REPO):
|
|
|
|
|
ctx.prd_filename = filename.name
|
|
|
|
|
continue
|
|
|
|
|
if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
|
|
|
|
ctx.design_filename = filename.name
|
|
|
|
|
continue
|
|
|
|
|
if filename.is_relative_to(TASK_FILE_REPO):
|
|
|
|
|
ctx.task_filename = filename.name
|
|
|
|
|
continue
|
|
|
|
|
return ctx
|
2024-01-19 19:53:17 +08:00
|
|
|
|
|
|
|
|
|
2024-01-02 23:09:09 +08:00
|
|
|
# mermaid class view
|
2024-01-22 22:49:46 +08:00
|
|
|
class UMLClassMeta(BaseModel):
|
2024-01-02 23:09:09 +08:00
|
|
|
name: str = ""
|
|
|
|
|
visibility: str = ""
|
|
|
|
|
|
2024-01-22 22:49:46 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def name_to_visibility(name: str) -> str:
|
|
|
|
|
if name == "__init__":
|
|
|
|
|
return "+"
|
|
|
|
|
if name.startswith("__"):
|
|
|
|
|
return "-"
|
|
|
|
|
elif name.startswith("_"):
|
|
|
|
|
return "#"
|
|
|
|
|
return "+"
|
|
|
|
|
|
2024-01-02 23:09:09 +08:00
|
|
|
|
2024-01-22 22:49:46 +08:00
|
|
|
class UMLClassAttribute(UMLClassMeta):
|
2024-01-02 23:09:09 +08:00
|
|
|
value_type: str = ""
|
|
|
|
|
default_value: str = ""
|
|
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + self.visibility
|
|
|
|
|
if self.value_type:
|
2024-01-22 22:49:46 +08:00
|
|
|
content += self.value_type.replace(" ", "") + " "
|
|
|
|
|
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
|
|
|
|
content += name
|
2024-01-02 23:09:09 +08:00
|
|
|
if self.default_value:
|
|
|
|
|
content += "="
|
|
|
|
|
if self.value_type not in ["str", "string", "String"]:
|
|
|
|
|
content += self.default_value
|
|
|
|
|
else:
|
|
|
|
|
content += '"' + self.default_value.replace('"', "") + '"'
|
2024-01-22 22:49:46 +08:00
|
|
|
# if self.abstraction:
|
|
|
|
|
# content += "*"
|
|
|
|
|
# if self.static:
|
|
|
|
|
# content += "$"
|
2024-01-02 23:09:09 +08:00
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
2024-01-22 22:49:46 +08:00
|
|
|
class UMLClassMethod(UMLClassMeta):
|
|
|
|
|
args: List[UMLClassAttribute] = Field(default_factory=list)
|
2024-01-02 23:09:09 +08:00
|
|
|
return_type: str = ""
|
|
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + self.visibility
|
2024-01-22 22:49:46 +08:00
|
|
|
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
|
|
|
|
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
2024-01-02 23:09:09 +08:00
|
|
|
if self.return_type:
|
2024-01-22 22:49:46 +08:00
|
|
|
content += " " + self.return_type.replace(" ", "")
|
2024-02-01 20:19:52 +08:00
|
|
|
# if self.abstraction:
|
|
|
|
|
# content += "*"
|
|
|
|
|
# if self.static:
|
|
|
|
|
# content += "$"
|
2024-01-02 23:09:09 +08:00
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
2024-01-22 22:49:46 +08:00
|
|
|
class UMLClassView(UMLClassMeta):
|
|
|
|
|
attributes: List[UMLClassAttribute] = Field(default_factory=list)
|
|
|
|
|
methods: List[UMLClassMethod] = Field(default_factory=list)
|
2024-01-02 23:09:09 +08:00
|
|
|
|
|
|
|
|
def get_mermaid(self, align=1) -> str:
|
|
|
|
|
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
|
|
|
|
for v in self.attributes:
|
|
|
|
|
content += v.get_mermaid(align=align + 1) + "\n"
|
|
|
|
|
for v in self.methods:
|
|
|
|
|
content += v.get_mermaid(align=align + 1) + "\n"
|
|
|
|
|
content += "".join(["\t" for i in range(align)]) + "}\n"
|
|
|
|
|
return content
|
2024-01-26 19:39:06 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
|
|
|
|
|
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
|
|
|
|
|
class_view = cls(name=dot_class_info.name, visibility=visibility)
|
|
|
|
|
for i in dot_class_info.attributes.values():
|
|
|
|
|
visibility = UMLClassAttribute.name_to_visibility(i.name)
|
|
|
|
|
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
|
|
|
|
|
class_view.attributes.append(attr)
|
|
|
|
|
for i in dot_class_info.methods.values():
|
|
|
|
|
visibility = UMLClassMethod.name_to_visibility(i.name)
|
|
|
|
|
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
|
|
|
|
|
for j in i.args:
|
|
|
|
|
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
|
|
|
|
|
method.args.append(arg)
|
2024-02-19 13:08:14 +08:00
|
|
|
method.return_type = i.return_args.type_
|
|
|
|
|
class_view.methods.append(method)
|
2024-01-26 19:39:06 +08:00
|
|
|
return class_view
|