diff --git a/metagpt/schema.py b/metagpt/schema.py index e7b2e5ce9..071518d62 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -1,209 +1,787 @@ -"""RAG schemas.""" +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/8 22:12 +@Author : alexanderwu +@File : schema.py +@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: + Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. +@Modified By: mashenquan, 2023/11/22. + 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. + 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing + between actions. + 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. +""" +from __future__ import annotations + +import asyncio +import json +import os.path +import uuid +from abc import ABC +from asyncio import Queue, QueueEmpty, wait_for +from json import JSONDecodeError from pathlib import Path -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union -from chromadb.api.types import CollectionMetadata -from llama_index.core.embeddings import BaseEmbedding -from llama_index.core.indices.base import BaseIndex -from llama_index.core.schema import TextNode -from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + field_serializer, + field_validator, + model_serializer, + model_validator, +) -from metagpt.config2 import config -from metagpt.configs.embedding_config import EmbeddingType -from metagpt.rag.interface import RAGObject +from metagpt.const import ( + MESSAGE_ROUTE_CAUSE_BY, + MESSAGE_ROUTE_FROM, + MESSAGE_ROUTE_TO, + MESSAGE_ROUTE_TO_ALL, + PRDS_FILE_REPO, + SYSTEM_DESIGN_FILE_REPO, + TASK_FILE_REPO, +) +from metagpt.logs import logger +from metagpt.repo_parser import DotClassInfo +from metagpt.utils.common import any_to_str, any_to_str_set, import_class +from metagpt.utils.exceptions import handle_exception +from metagpt.utils.serialize import ( + actionoutout_schema_to_mapping, + actionoutput_mapping_to_str, + actionoutput_str_to_mapping, +) -class BaseRetrieverConfig(BaseModel): - """Common config for retrievers. +class SerializationMixin(BaseModel, extra="forbid"): + """ + PolyMorphic subclasses Serialization / Deserialization Mixin + - First of all, we need to know that pydantic is not designed for polymorphism. + - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need + to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. - If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever. + More details: + - https://docs.pydantic.dev/latest/concepts/serialization/ + - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` """ - model_config = ConfigDict(arbitrary_types_allowed=True) - similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.") + __is_polymorphic_base = False + __subclasses_map__ = {} + + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return ret + + @model_validator(mode="wrap") + @classmethod + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) + + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) + + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) + + if class_type is None: + # TODO could try dynamic import + raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) -class IndexRetrieverConfig(BaseRetrieverConfig): - """Config for Index-basd retrievers.""" - - index: BaseIndex = Field(default=None, description="Index for retriver.") +class SimpleMessage(BaseModel): + content: str + role: str -class FAISSRetrieverConfig(IndexRetrieverConfig): - """Config for FAISS-based retrievers.""" - - dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") - - _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { - EmbeddingType.GEMINI: 768, - EmbeddingType.OLLAMA: 4096, - } - - @model_validator(mode="after") - def check_dimensions(self): - if self.dimensions == 0: - self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) - - return self - - -class BM25RetrieverConfig(IndexRetrieverConfig): - """Config for BM25-based retrievers.""" - - _no_embedding: bool = PrivateAttr(default=True) - - -class ChromaRetrieverConfig(IndexRetrieverConfig): - """Config for Chroma-based retrievers.""" - - persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") - collection_name: str = Field(default="metagpt", description="The name of the collection.") - metadata: Optional[CollectionMetadata] = Field( - default=None, description="Optional metadata to associate with the collection" - ) - - -class ElasticsearchStoreConfig(BaseModel): - index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") - es_url: str = Field(default=None, description="Elasticsearch URL.") - es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") - es_api_key: str = Field(default=None, description="Elasticsearch API key.") - es_user: str = Field(default=None, description="Elasticsearch username.") - es_password: str = Field(default=None, description="Elasticsearch password.") - batch_size: int = Field(default=200, description="Batch size for bulk indexing.") - distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") - - -class ElasticsearchRetrieverConfig(IndexRetrieverConfig): - """Config for Elasticsearch-based retrievers. Support both vector and text.""" - - store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") - vector_store_query_mode: VectorStoreQueryMode = Field( - default=VectorStoreQueryMode.DEFAULT, description="default is vector query." - ) - - -class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): - """Config for Elasticsearch-based retrievers. Support text only.""" - - _no_embedding: bool = PrivateAttr(default=True) - vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field( - default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only." - ) - - -class BaseRankerConfig(BaseModel): - """Common config for rankers. - - If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker. +class Document(BaseModel): + """ + Represents a document. """ - model_config = ConfigDict(arbitrary_types_allowed=True) - top_n: int = Field(default=5, description="The number of top results to return.") + root_path: str = "" + filename: str = "" + content: str = "" + + def get_meta(self) -> Document: + """Get metadata of the document. + + :return: A new Document instance with the same root path and filename. + """ + + return Document(root_path=self.root_path, filename=self.filename) + + @property + def root_relative_path(self): + """Get relative path from root of git repository. + + :return: relative path from root of git repository. + """ + return os.path.join(self.root_path, self.filename) + + def __str__(self): + return self.content + + def __repr__(self): + return self.content -class LLMRankerConfig(BaseRankerConfig): - """Config for LLM-based rankers.""" +class Documents(BaseModel): + """A class representing a collection of documents. - llm: Any = Field( - default=None, - description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", - ) - - -class ColbertRerankConfig(BaseRankerConfig): - model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") - device: str = Field(default="cpu", description="Device to use for sentence transformer.") - keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") - - -class CohereRerankConfig(BaseRankerConfig): - model: str = Field(default="rerank-english-v3.0") - api_key: str = Field(default="YOUR_COHERE_API") - - -class BGERerankConfig(BaseRankerConfig): - model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") - use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") - - -class ObjectRankerConfig(BaseRankerConfig): - field_name: str = Field(..., description="field name of the object, field's value must can be compared.") - order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") - - -class BaseIndexConfig(BaseModel): - """Common config for index. - - If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. + Attributes: + docs (Dict[str, Document]): A dictionary mapping document names to Document instances. """ - model_config = ConfigDict(arbitrary_types_allowed=True) - persist_path: Union[str, Path] = Field(description="The directory of saved data.") + docs: Dict[str, Document] = Field(default_factory=dict) + + @classmethod + def from_iterable(cls, documents: Iterable[Document]) -> Documents: + """Create a Documents instance from a list of Document instances. + + :param documents: A list of Document instances. + :return: A Documents instance. + """ + + docs = {doc.filename: doc for doc in documents} + return Documents(docs=docs) + + def to_action_output(self) -> "ActionOutput": + """Convert to action output string. + + :return: A string representing action output. + """ + from metagpt.actions.action_output import ActionOutput + + return ActionOutput(content=self.model_dump_json(), instruct_content=self) -class VectorIndexConfig(BaseIndexConfig): - """Config for vector-based index.""" +class Message(BaseModel): + """list[: ]""" - embed_model: BaseEmbedding = Field(default=None, description="Embed model.") + id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 + content: str + instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) + role: str = "user" # system / user / assistant + cause_by: str = Field(default="", validate_default=True) + sent_from: str = Field(default="", validate_default=True) + send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) + @field_validator("id", mode="before") + @classmethod + def check_id(cls, id: str) -> str: + return id if id else uuid.uuid4().hex -class FAISSIndexConfig(VectorIndexConfig): - """Config for faiss-based index.""" + @field_validator("instruct_content", mode="before") + @classmethod + def check_instruct_content(cls, ic: Any) -> BaseModel: + if ic and isinstance(ic, dict) and "class" in ic: + if "mapping" in ic: + # compatible with custom-defined ActionOutput + mapping = actionoutput_str_to_mapping(ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + elif "module" in ic: + # subclasses of BaseModel + ic_obj = import_class(ic["class"], ic["module"]) + else: + raise KeyError("missing required key to init Message.instruct_content from dict") + ic = ic_obj(**ic["value"]) + return ic + @field_validator("cause_by", mode="before") + @classmethod + def check_cause_by(cls, cause_by: Any) -> str: + return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement")) -class ChromaIndexConfig(VectorIndexConfig): - """Config for chroma-based index.""" + @field_validator("sent_from", mode="before") + @classmethod + def check_sent_from(cls, sent_from: Any) -> str: + return any_to_str(sent_from if sent_from else "") - collection_name: str = Field(default="metagpt", description="The name of the collection.") - metadata: Optional[CollectionMetadata] = Field( - default=None, description="Optional metadata to associate with the collection" - ) + @field_validator("send_to", mode="before") + @classmethod + def check_send_to(cls, send_to: Any) -> set: + return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) + @field_serializer("send_to", mode="plain") + def ser_send_to(self, send_to: set) -> list: + return list(send_to) -class BM25IndexConfig(BaseIndexConfig): - """Config for bm25-based index.""" + @field_serializer("instruct_content", mode="plain") + def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]: + ic_dict = None + if ic: + # compatible with custom-defined ActionOutput + schema = ic.model_json_schema() + ic_type = str(type(ic)) + if " str: + """For search""" + return self.content - _no_embedding: bool = PrivateAttr(default=True) + def to_dict(self) -> dict: + """Return a dict containing `role` and `content` for the LLM call.l""" + return {"role": self.role, "content": self.content} - -class ObjectNodeMetadata(BaseModel): - """Metadata of ObjectNode.""" - - is_obj: bool = Field(default=True) - obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json") - obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()") - obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__") - obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__") - - -class ObjectNode(TextNode): - """RAG add object.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys()) - self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys + def dump(self) -> str: + """Convert the object to json string""" + return self.model_dump_json(exclude_none=True, warnings=False) @staticmethod - def get_obj_metadata(obj: RAGObject) -> dict: - metadata = ObjectNodeMetadata( - obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ - ) + @handle_exception(exception_type=JSONDecodeError, default_return=None) + def load(val): + """Convert the json string to object.""" - return metadata.model_dump() + try: + m = json.loads(val) + id = m.get("id") + if "id" in m: + del m["id"] + msg = Message(**m) + if id: + msg.id = id + return msg + except JSONDecodeError as err: + logger.error(f"parse json failed: {val}, error:{err}") + return None + + +class UserMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="user") + + +class SystemMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="system") + + +class AIMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="assistant") + + +class Task(BaseModel): + task_id: str = "" + dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task + instruction: str = "" + task_type: str = "" + code: str = "" + result: str = "" + is_success: bool = False + is_finished: bool = False + + def reset(self): + self.code = "" + self.result = "" + self.is_success = False + self.is_finished = False + + def update_task_result(self, task_result: TaskResult): + self.code = task_result.code + self.result = task_result.result + self.is_success = task_result.is_success + + +class TaskResult(BaseModel): + """Result of taking a task, with result and is_success required to be filled""" + + code: str = "" + result: str + is_success: bool + + +class Plan(BaseModel): + goal: str + context: str = "" + tasks: list[Task] = [] + task_map: dict[str, Task] = {} + current_task_id: str = "" + + def _topological_sort(self, tasks: list[Task]): + task_map = {task.task_id: task for task in tasks} + dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} + sorted_tasks = [] + visited = set() + + def visit(task_id): + if task_id in visited: + return + visited.add(task_id) + for dependent_id in dependencies.get(task_id, []): + visit(dependent_id) + sorted_tasks.append(task_map[task_id]) + + for task in tasks: + visit(task.task_id) + + return sorted_tasks + + def add_tasks(self, tasks: list[Task]): + """ + Integrates new tasks into the existing plan, ensuring dependency order is maintained. + + This method performs two primary functions based on the current state of the task list: + 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure + correct execution order based on dependencies, and sets these as the current tasks. + 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains + any common prefix of tasks (based on task_id and instruction) and appends the remainder + of the new tasks. The current task is updated to the first unfinished task in this merged list. + + Args: + tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. + + Returns: + None: The method updates the internal state of the plan but does not return anything. + """ + if not tasks: + return + + # Topologically sort the new tasks to ensure correct dependency order + new_tasks = self._topological_sort(tasks) + + if not self.tasks: + # If there are no existing tasks, set the new tasks as the current tasks + self.tasks = new_tasks + + else: + # Find the length of the common prefix between existing and new tasks + prefix_length = 0 + for old_task, new_task in zip(self.tasks, new_tasks): + if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: + break + prefix_length += 1 + + # Combine the common prefix with the remainder of the new tasks + final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] + self.tasks = final_tasks + + # Update current_task_id to the first unfinished task in the merged list + self._update_current_task() + + # Update the task map for quick access to tasks by ID + self.task_map = {task.task_id: task for task in self.tasks} + + def reset_task(self, task_id: str): + """ + Clear code and result of the task based on task_id, and set the task as unfinished. + + Args: + task_id (str): The ID of the task to be reset. + + Returns: + None + """ + if task_id in self.task_map: + task = self.task_map[task_id] + task.reset() + + def replace_task(self, new_task: Task): + """ + Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. + + Args: + new_task (Task): The new task that will replace an existing one. + + Returns: + None + """ + assert new_task.task_id in self.task_map + # Replace the task in the task map and the task list + self.task_map[new_task.task_id] = new_task + for i, task in enumerate(self.tasks): + if task.task_id == new_task.task_id: + self.tasks[i] = new_task + break + + # Reset dependent tasks + for task in self.tasks: + if new_task.task_id in task.dependent_task_ids: + self.reset_task(task.task_id) + + def append_task(self, new_task: Task): + """ + Append a new task to the end of existing task sequences + + Args: + new_task (Task): The new task to be appended to the existing task sequence + + Returns: + None + """ + assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" + + assert all( + [self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids] + ), "New task has unknown dependencies" + + # Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence + self.tasks.append(new_task) + self.task_map[new_task.task_id] = new_task + self._update_current_task() + + def has_task_id(self, task_id: str) -> bool: + return task_id in self.task_map + + def _update_current_task(self): + current_task_id = "" + for task in self.tasks: + if not task.is_finished: + current_task_id = task.task_id + break + self.current_task_id = current_task_id # all tasks finished + + @property + def current_task(self) -> Task: + """Find current task to execute + + Returns: + Task: the current task to be executed + """ + return self.task_map.get(self.current_task_id, None) + + def finish_current_task(self): + """Finish current task, set Task.is_finished=True, set current task to next task""" + if self.current_task_id: + self.current_task.is_finished = True + self._update_current_task() # set to next task + + def get_finished_tasks(self) -> list[Task]: + """return all finished tasks in correct linearized order + + Returns: + list[Task]: list of finished tasks + """ + return [task for task in self.tasks if task.is_finished] + + +class MessageQueue(BaseModel): + """Message queue which supports asynchronous updates.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + _queue: Queue = PrivateAttr(default_factory=Queue) + + def pop(self) -> Message | None: + """Pop one message from the queue.""" + try: + item = self._queue.get_nowait() + if item: + self._queue.task_done() + return item + except QueueEmpty: + return None + + def pop_all(self) -> List[Message]: + """Pop all messages from the queue.""" + ret = [] + while True: + msg = self.pop() + if not msg: + break + ret.append(msg) + return ret + + def push(self, msg: Message): + """Push a message into the queue.""" + self._queue.put_nowait(msg) + + def empty(self): + """Return true if the queue is empty.""" + return self._queue.empty() + + async def dump(self) -> str: + """Convert the `MessageQueue` object to a json string.""" + if self.empty(): + return "[]" + + lst = [] + msgs = [] + try: + while True: + item = await wait_for(self._queue.get(), timeout=1.0) + if item is None: + break + msgs.append(item) + lst.append(item.dump()) + self._queue.task_done() + except asyncio.TimeoutError: + logger.debug("Queue is empty, exiting...") + finally: + for m in msgs: + self._queue.put_nowait(m) + return json.dumps(lst, ensure_ascii=False) + + @staticmethod + def load(data) -> "MessageQueue": + """Convert the json string to the `MessageQueue` object.""" + queue = MessageQueue() + try: + lst = json.loads(data) + for i in lst: + msg = Message.load(i) + queue.push(msg) + except JSONDecodeError as e: + logger.warning(f"JSON load failed: {data}, error:{e}") + + return queue + + +# 定义一个泛型类型变量 +T = TypeVar("T", bound="BaseModel") + + +class BaseContext(BaseModel, ABC): + @classmethod + @handle_exception + def loads(cls: Type[T], val: str) -> Optional[T]: + i = json.loads(val) + return cls(**i) + + +class CodingContext(BaseContext): + filename: str + design_doc: Optional[Document] = None + task_doc: Optional[Document] = None + code_doc: Optional[Document] = None + code_plan_and_change_doc: Optional[Document] = None + + +class TestingContext(BaseContext): + filename: str + code_doc: Document + test_doc: Optional[Document] = None + + +class RunCodeContext(BaseContext): + mode: str = "script" + code: Optional[str] = None + code_filename: str = "" + test_code: Optional[str] = None + test_filename: str = "" + command: List[str] = Field(default_factory=list) + working_directory: str = "" + additional_python_paths: List[str] = Field(default_factory=list) + output_filename: Optional[str] = None + output: Optional[str] = None + + +class RunCodeResult(BaseContext): + summary: str + stdout: str + stderr: str + + +class CodeSummarizeContext(BaseModel): + design_filename: str = "" + task_filename: str = "" + codes_filenames: List[str] = Field(default_factory=list) + reason: str = "" + + @staticmethod + def loads(filenames: List) -> CodeSummarizeContext: + ctx = CodeSummarizeContext() + for filename in filenames: + if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): + ctx.design_filename = str(filename) + continue + if Path(filename).is_relative_to(TASK_FILE_REPO): + ctx.task_filename = str(filename) + continue + return ctx + + def __hash__(self): + return hash((self.design_filename, self.task_filename)) + + +class BugFixContext(BaseContext): + filename: str = "" + + +class CodePlanAndChangeContext(BaseModel): + requirement: str = "" + issue: str = "" + prd_filename: str = "" + design_filename: str = "" + task_filename: str = "" + + @staticmethod + def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: + ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) + for filename in filenames: + filename = Path(filename) + if filename.is_relative_to(PRDS_FILE_REPO): + ctx.prd_filename = filename.name + continue + if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): + ctx.design_filename = filename.name + continue + if filename.is_relative_to(TASK_FILE_REPO): + ctx.task_filename = filename.name + continue + return ctx + + +# mermaid class view +class UMLClassMeta(BaseModel): + name: str = "" + visibility: str = "" + + @staticmethod + def name_to_visibility(name: str) -> str: + if name == "__init__": + return "+" + if name.startswith("__"): + return "-" + elif name.startswith("_"): + return "#" + return "+" + + +class UMLClassAttribute(UMLClassMeta): + value_type: str = "" + default_value: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + if self.value_type: + content += self.value_type.replace(" ", "") + " " + name = self.name.split(":", 1)[1] if ":" in self.name else self.name + content += name + if self.default_value: + content += "=" + if self.value_type not in ["str", "string", "String"]: + content += self.default_value + else: + content += '"' + self.default_value.replace('"', "") + '"' + # if self.abstraction: + # content += "*" + # if self.static: + # content += "$" + return content + + +class UMLClassMethod(UMLClassMeta): + args: List[UMLClassAttribute] = Field(default_factory=list) + return_type: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + name = self.name.split(":", 1)[1] if ":" in self.name else self.name + content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" + if self.return_type: + content += " " + self.return_type.replace(" ", "") + # if self.abstraction: + # content += "*" + # if self.static: + # content += "$" + return content + + +class UMLClassView(UMLClassMeta): + attributes: List[UMLClassAttribute] = Field(default_factory=list) + methods: List[UMLClassMethod] = Field(default_factory=list) + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" + for v in self.attributes: + content += v.get_mermaid(align=align + 1) + "\n" + for v in self.methods: + content += v.get_mermaid(align=align + 1) + "\n" + content += "".join(["\t" for i in range(align)]) + "}\n" + return content + + @classmethod + def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView: + visibility = UMLClassView.name_to_visibility(dot_class_info.name) + class_view = cls(name=dot_class_info.name, visibility=visibility) + for i in dot_class_info.attributes.values(): + visibility = UMLClassAttribute.name_to_visibility(i.name) + attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_) + class_view.attributes.append(attr) + for i in dot_class_info.methods.values(): + visibility = UMLClassMethod.name_to_visibility(i.name) + method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_) + for j in i.args: + arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_) + method.args.append(arg) + method.return_type = i.return_args.type_ + class_view.methods.append(method) + return class_view