From 26112dd1116cd59fee714bfd961e3ff6ea9cc019 Mon Sep 17 00:00:00 2001 From: YangQianli92 <108046369+YangQianli92@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:42:43 +0800 Subject: [PATCH] Add files via upload --- metagpt/schema.py | 958 +++++++++------------------------------------- 1 file changed, 190 insertions(+), 768 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index 071518d62..e7b2e5ce9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -1,787 +1,209 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/8 22:12 -@Author : alexanderwu -@File : schema.py -@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: - Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. -@Modified By: mashenquan, 2023/11/22. - 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. - 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing - between actions. - 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. -""" +"""RAG schemas.""" -from __future__ import annotations - -import asyncio -import json -import os.path -import uuid -from abc import ABC -from asyncio import Queue, QueueEmpty, wait_for -from json import JSONDecodeError from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Literal, Optional, Union -from pydantic import ( - BaseModel, - ConfigDict, - Field, - PrivateAttr, - field_serializer, - field_validator, - model_serializer, - model_validator, -) +from chromadb.api.types import CollectionMetadata +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.indices.base import BaseIndex +from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQueryMode +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator -from metagpt.const import ( - MESSAGE_ROUTE_CAUSE_BY, - MESSAGE_ROUTE_FROM, - MESSAGE_ROUTE_TO, - MESSAGE_ROUTE_TO_ALL, - PRDS_FILE_REPO, - SYSTEM_DESIGN_FILE_REPO, - TASK_FILE_REPO, -) -from metagpt.logs import logger -from metagpt.repo_parser import DotClassInfo -from metagpt.utils.common import any_to_str, any_to_str_set, import_class -from metagpt.utils.exceptions import handle_exception -from metagpt.utils.serialize import ( - actionoutout_schema_to_mapping, - actionoutput_mapping_to_str, - actionoutput_str_to_mapping, -) +from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType +from metagpt.rag.interface import RAGObject -class SerializationMixin(BaseModel, extra="forbid"): +class BaseRetrieverConfig(BaseModel): + """Common config for retrievers. + + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever. """ - PolyMorphic subclasses Serialization / Deserialization Mixin - - First of all, we need to know that pydantic is not designed for polymorphism. - - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need - to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. - - More details: - - https://docs.pydantic.dev/latest/concepts/serialization/ - - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` - """ - - __is_polymorphic_base = False - __subclasses_map__ = {} - - @model_serializer(mode="wrap") - def __serialize_with_class_type__(self, default_serializer) -> Any: - # default serializer, then append the `__module_class_name` field and return - ret = default_serializer(self) - ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - return ret - - @model_validator(mode="wrap") - @classmethod - def __convert_to_real_type__(cls, value: Any, handler): - if isinstance(value, dict) is False: - return handler(value) - - # it is a dict so make sure to remove the __module_class_name - # because we don't allow extra keywords but want to ensure - # e.g Cat.model_validate(cat.model_dump()) works - class_full_name = value.pop("__module_class_name", None) - - # if it's not the polymorphic base we construct via default handler - if not cls.__is_polymorphic_base: - if class_full_name is None: - return handler(value) - elif str(cls) == f"": - return handler(value) - else: - # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") - pass - - # otherwise we lookup the correct polymorphic type and construct that - # instead - if class_full_name is None: - raise ValueError("Missing __module_class_name field") - - class_type = cls.__subclasses_map__.get(class_full_name, None) - - if class_type is None: - # TODO could try dynamic import - raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") - - return class_type(**value) - - def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): - cls.__is_polymorphic_base = is_polymorphic_base - cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls - super().__init_subclass__(**kwargs) - - -class SimpleMessage(BaseModel): - content: str - role: str - - -class Document(BaseModel): - """ - Represents a document. - """ - - root_path: str = "" - filename: str = "" - content: str = "" - - def get_meta(self) -> Document: - """Get metadata of the document. - - :return: A new Document instance with the same root path and filename. - """ - - return Document(root_path=self.root_path, filename=self.filename) - - @property - def root_relative_path(self): - """Get relative path from root of git repository. - - :return: relative path from root of git repository. - """ - return os.path.join(self.root_path, self.filename) - - def __str__(self): - return self.content - - def __repr__(self): - return self.content - - -class Documents(BaseModel): - """A class representing a collection of documents. - - Attributes: - docs (Dict[str, Document]): A dictionary mapping document names to Document instances. - """ - - docs: Dict[str, Document] = Field(default_factory=dict) - - @classmethod - def from_iterable(cls, documents: Iterable[Document]) -> Documents: - """Create a Documents instance from a list of Document instances. - - :param documents: A list of Document instances. - :return: A Documents instance. - """ - - docs = {doc.filename: doc for doc in documents} - return Documents(docs=docs) - - def to_action_output(self) -> "ActionOutput": - """Convert to action output string. - - :return: A string representing action output. - """ - from metagpt.actions.action_output import ActionOutput - - return ActionOutput(content=self.model_dump_json(), instruct_content=self) - - -class Message(BaseModel): - """list[: ]""" - - id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 - content: str - instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) - role: str = "user" # system / user / assistant - cause_by: str = Field(default="", validate_default=True) - sent_from: str = Field(default="", validate_default=True) - send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) - - @field_validator("id", mode="before") - @classmethod - def check_id(cls, id: str) -> str: - return id if id else uuid.uuid4().hex - - @field_validator("instruct_content", mode="before") - @classmethod - def check_instruct_content(cls, ic: Any) -> BaseModel: - if ic and isinstance(ic, dict) and "class" in ic: - if "mapping" in ic: - # compatible with custom-defined ActionOutput - mapping = actionoutput_str_to_mapping(ic["mapping"]) - actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import - ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) - elif "module" in ic: - # subclasses of BaseModel - ic_obj = import_class(ic["class"], ic["module"]) - else: - raise KeyError("missing required key to init Message.instruct_content from dict") - ic = ic_obj(**ic["value"]) - return ic - - @field_validator("cause_by", mode="before") - @classmethod - def check_cause_by(cls, cause_by: Any) -> str: - return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement")) - - @field_validator("sent_from", mode="before") - @classmethod - def check_sent_from(cls, sent_from: Any) -> str: - return any_to_str(sent_from if sent_from else "") - - @field_validator("send_to", mode="before") - @classmethod - def check_send_to(cls, send_to: Any) -> set: - return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) - - @field_serializer("send_to", mode="plain") - def ser_send_to(self, send_to: set) -> list: - return list(send_to) - - @field_serializer("instruct_content", mode="plain") - def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]: - ic_dict = None - if ic: - # compatible with custom-defined ActionOutput - schema = ic.model_json_schema() - ic_type = str(type(ic)) - if " str: - """For search""" - return self.content - - def to_dict(self) -> dict: - """Return a dict containing `role` and `content` for the LLM call.l""" - return {"role": self.role, "content": self.content} - - def dump(self) -> str: - """Convert the object to json string""" - return self.model_dump_json(exclude_none=True, warnings=False) - - @staticmethod - @handle_exception(exception_type=JSONDecodeError, default_return=None) - def load(val): - """Convert the json string to object.""" - - try: - m = json.loads(val) - id = m.get("id") - if "id" in m: - del m["id"] - msg = Message(**m) - if id: - msg.id = id - return msg - except JSONDecodeError as err: - logger.error(f"parse json failed: {val}, error:{err}") - return None - - -class UserMessage(Message): - """便于支持OpenAI的消息 - Facilitate support for OpenAI messages - """ - - def __init__(self, content: str): - super().__init__(content=content, role="user") - - -class SystemMessage(Message): - """便于支持OpenAI的消息 - Facilitate support for OpenAI messages - """ - - def __init__(self, content: str): - super().__init__(content=content, role="system") - - -class AIMessage(Message): - """便于支持OpenAI的消息 - Facilitate support for OpenAI messages - """ - - def __init__(self, content: str): - super().__init__(content=content, role="assistant") - - -class Task(BaseModel): - task_id: str = "" - dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task - instruction: str = "" - task_type: str = "" - code: str = "" - result: str = "" - is_success: bool = False - is_finished: bool = False - - def reset(self): - self.code = "" - self.result = "" - self.is_success = False - self.is_finished = False - - def update_task_result(self, task_result: TaskResult): - self.code = task_result.code - self.result = task_result.result - self.is_success = task_result.is_success - - -class TaskResult(BaseModel): - """Result of taking a task, with result and is_success required to be filled""" - - code: str = "" - result: str - is_success: bool - - -class Plan(BaseModel): - goal: str - context: str = "" - tasks: list[Task] = [] - task_map: dict[str, Task] = {} - current_task_id: str = "" - - def _topological_sort(self, tasks: list[Task]): - task_map = {task.task_id: task for task in tasks} - dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} - sorted_tasks = [] - visited = set() - - def visit(task_id): - if task_id in visited: - return - visited.add(task_id) - for dependent_id in dependencies.get(task_id, []): - visit(dependent_id) - sorted_tasks.append(task_map[task_id]) - - for task in tasks: - visit(task.task_id) - - return sorted_tasks - - def add_tasks(self, tasks: list[Task]): - """ - Integrates new tasks into the existing plan, ensuring dependency order is maintained. - - This method performs two primary functions based on the current state of the task list: - 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure - correct execution order based on dependencies, and sets these as the current tasks. - 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains - any common prefix of tasks (based on task_id and instruction) and appends the remainder - of the new tasks. The current task is updated to the first unfinished task in this merged list. - - Args: - tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. - - Returns: - None: The method updates the internal state of the plan but does not return anything. - """ - if not tasks: - return - - # Topologically sort the new tasks to ensure correct dependency order - new_tasks = self._topological_sort(tasks) - - if not self.tasks: - # If there are no existing tasks, set the new tasks as the current tasks - self.tasks = new_tasks - - else: - # Find the length of the common prefix between existing and new tasks - prefix_length = 0 - for old_task, new_task in zip(self.tasks, new_tasks): - if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: - break - prefix_length += 1 - - # Combine the common prefix with the remainder of the new tasks - final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] - self.tasks = final_tasks - - # Update current_task_id to the first unfinished task in the merged list - self._update_current_task() - - # Update the task map for quick access to tasks by ID - self.task_map = {task.task_id: task for task in self.tasks} - - def reset_task(self, task_id: str): - """ - Clear code and result of the task based on task_id, and set the task as unfinished. - - Args: - task_id (str): The ID of the task to be reset. - - Returns: - None - """ - if task_id in self.task_map: - task = self.task_map[task_id] - task.reset() - - def replace_task(self, new_task: Task): - """ - Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. - - Args: - new_task (Task): The new task that will replace an existing one. - - Returns: - None - """ - assert new_task.task_id in self.task_map - # Replace the task in the task map and the task list - self.task_map[new_task.task_id] = new_task - for i, task in enumerate(self.tasks): - if task.task_id == new_task.task_id: - self.tasks[i] = new_task - break - - # Reset dependent tasks - for task in self.tasks: - if new_task.task_id in task.dependent_task_ids: - self.reset_task(task.task_id) - - def append_task(self, new_task: Task): - """ - Append a new task to the end of existing task sequences - - Args: - new_task (Task): The new task to be appended to the existing task sequence - - Returns: - None - """ - assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" - - assert all( - [self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids] - ), "New task has unknown dependencies" - - # Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence - self.tasks.append(new_task) - self.task_map[new_task.task_id] = new_task - self._update_current_task() - - def has_task_id(self, task_id: str) -> bool: - return task_id in self.task_map - - def _update_current_task(self): - current_task_id = "" - for task in self.tasks: - if not task.is_finished: - current_task_id = task.task_id - break - self.current_task_id = current_task_id # all tasks finished - - @property - def current_task(self) -> Task: - """Find current task to execute - - Returns: - Task: the current task to be executed - """ - return self.task_map.get(self.current_task_id, None) - - def finish_current_task(self): - """Finish current task, set Task.is_finished=True, set current task to next task""" - if self.current_task_id: - self.current_task.is_finished = True - self._update_current_task() # set to next task - - def get_finished_tasks(self) -> list[Task]: - """return all finished tasks in correct linearized order - - Returns: - list[Task]: list of finished tasks - """ - return [task for task in self.tasks if task.is_finished] - - -class MessageQueue(BaseModel): - """Message queue which supports asynchronous updates.""" model_config = ConfigDict(arbitrary_types_allowed=True) + similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.") - _queue: Queue = PrivateAttr(default_factory=Queue) - def pop(self) -> Message | None: - """Pop one message from the queue.""" - try: - item = self._queue.get_nowait() - if item: - self._queue.task_done() - return item - except QueueEmpty: - return None +class IndexRetrieverConfig(BaseRetrieverConfig): + """Config for Index-basd retrievers.""" - def pop_all(self) -> List[Message]: - """Pop all messages from the queue.""" - ret = [] - while True: - msg = self.pop() - if not msg: - break - ret.append(msg) - return ret + index: BaseIndex = Field(default=None, description="Index for retriver.") - def push(self, msg: Message): - """Push a message into the queue.""" - self._queue.put_nowait(msg) - def empty(self): - """Return true if the queue is empty.""" - return self._queue.empty() +class FAISSRetrieverConfig(IndexRetrieverConfig): + """Config for FAISS-based retrievers.""" - async def dump(self) -> str: - """Convert the `MessageQueue` object to a json string.""" - if self.empty(): - return "[]" + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") - lst = [] - msgs = [] - try: - while True: - item = await wait_for(self._queue.get(), timeout=1.0) - if item is None: - break - msgs.append(item) - lst.append(item.dump()) - self._queue.task_done() - except asyncio.TimeoutError: - logger.debug("Queue is empty, exiting...") - finally: - for m in msgs: - self._queue.put_nowait(m) - return json.dumps(lst, ensure_ascii=False) + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + + return self + + +class BM25RetrieverConfig(IndexRetrieverConfig): + """Config for BM25-based retrievers.""" + + _no_embedding: bool = PrivateAttr(default=True) + + +class ChromaRetrieverConfig(IndexRetrieverConfig): + """Config for Chroma-based retrievers.""" + + persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") + collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) + + +class ElasticsearchStoreConfig(BaseModel): + index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") + es_url: str = Field(default=None, description="Elasticsearch URL.") + es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") + es_api_key: str = Field(default=None, description="Elasticsearch API key.") + es_user: str = Field(default=None, description="Elasticsearch username.") + es_password: str = Field(default=None, description="Elasticsearch password.") + batch_size: int = Field(default=200, description="Batch size for bulk indexing.") + distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") + + +class ElasticsearchRetrieverConfig(IndexRetrieverConfig): + """Config for Elasticsearch-based retrievers. Support both vector and text.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + vector_store_query_mode: VectorStoreQueryMode = Field( + default=VectorStoreQueryMode.DEFAULT, description="default is vector query." + ) + + +class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): + """Config for Elasticsearch-based retrievers. Support text only.""" + + _no_embedding: bool = PrivateAttr(default=True) + vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field( + default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only." + ) + + +class BaseRankerConfig(BaseModel): + """Common config for rankers. + + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + top_n: int = Field(default=5, description="The number of top results to return.") + + +class LLMRankerConfig(BaseRankerConfig): + """Config for LLM-based rankers.""" + + llm: Any = Field( + default=None, + description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", + ) + + +class ColbertRerankConfig(BaseRankerConfig): + model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") + device: str = Field(default="cpu", description="Device to use for sentence transformer.") + keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") + + +class CohereRerankConfig(BaseRankerConfig): + model: str = Field(default="rerank-english-v3.0") + api_key: str = Field(default="YOUR_COHERE_API") + + +class BGERerankConfig(BaseRankerConfig): + model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") + use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") + + +class ObjectRankerConfig(BaseRankerConfig): + field_name: str = Field(..., description="field name of the object, field's value must can be compared.") + order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") + + +class BaseIndexConfig(BaseModel): + """Common config for index. + + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + persist_path: Union[str, Path] = Field(description="The directory of saved data.") + + +class VectorIndexConfig(BaseIndexConfig): + """Config for vector-based index.""" + + embed_model: BaseEmbedding = Field(default=None, description="Embed model.") + + +class FAISSIndexConfig(VectorIndexConfig): + """Config for faiss-based index.""" + + +class ChromaIndexConfig(VectorIndexConfig): + """Config for chroma-based index.""" + + collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) + + +class BM25IndexConfig(BaseIndexConfig): + """Config for bm25-based index.""" + + _no_embedding: bool = PrivateAttr(default=True) + + +class ElasticsearchIndexConfig(VectorIndexConfig): + """Config for es-based index.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + persist_path: Union[str, Path] = "" + + +class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): + """Config for es-based index. no embedding.""" + + _no_embedding: bool = PrivateAttr(default=True) + + +class ObjectNodeMetadata(BaseModel): + """Metadata of ObjectNode.""" + + is_obj: bool = Field(default=True) + obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json") + obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()") + obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__") + obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__") + + +class ObjectNode(TextNode): + """RAG add object.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys()) + self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys @staticmethod - def load(data) -> "MessageQueue": - """Convert the json string to the `MessageQueue` object.""" - queue = MessageQueue() - try: - lst = json.loads(data) - for i in lst: - msg = Message.load(i) - queue.push(msg) - except JSONDecodeError as e: - logger.warning(f"JSON load failed: {data}, error:{e}") + def get_obj_metadata(obj: RAGObject) -> dict: + metadata = ObjectNodeMetadata( + obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ + ) - return queue - - -# 定义一个泛型类型变量 -T = TypeVar("T", bound="BaseModel") - - -class BaseContext(BaseModel, ABC): - @classmethod - @handle_exception - def loads(cls: Type[T], val: str) -> Optional[T]: - i = json.loads(val) - return cls(**i) - - -class CodingContext(BaseContext): - filename: str - design_doc: Optional[Document] = None - task_doc: Optional[Document] = None - code_doc: Optional[Document] = None - code_plan_and_change_doc: Optional[Document] = None - - -class TestingContext(BaseContext): - filename: str - code_doc: Document - test_doc: Optional[Document] = None - - -class RunCodeContext(BaseContext): - mode: str = "script" - code: Optional[str] = None - code_filename: str = "" - test_code: Optional[str] = None - test_filename: str = "" - command: List[str] = Field(default_factory=list) - working_directory: str = "" - additional_python_paths: List[str] = Field(default_factory=list) - output_filename: Optional[str] = None - output: Optional[str] = None - - -class RunCodeResult(BaseContext): - summary: str - stdout: str - stderr: str - - -class CodeSummarizeContext(BaseModel): - design_filename: str = "" - task_filename: str = "" - codes_filenames: List[str] = Field(default_factory=list) - reason: str = "" - - @staticmethod - def loads(filenames: List) -> CodeSummarizeContext: - ctx = CodeSummarizeContext() - for filename in filenames: - if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): - ctx.design_filename = str(filename) - continue - if Path(filename).is_relative_to(TASK_FILE_REPO): - ctx.task_filename = str(filename) - continue - return ctx - - def __hash__(self): - return hash((self.design_filename, self.task_filename)) - - -class BugFixContext(BaseContext): - filename: str = "" - - -class CodePlanAndChangeContext(BaseModel): - requirement: str = "" - issue: str = "" - prd_filename: str = "" - design_filename: str = "" - task_filename: str = "" - - @staticmethod - def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: - ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) - for filename in filenames: - filename = Path(filename) - if filename.is_relative_to(PRDS_FILE_REPO): - ctx.prd_filename = filename.name - continue - if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): - ctx.design_filename = filename.name - continue - if filename.is_relative_to(TASK_FILE_REPO): - ctx.task_filename = filename.name - continue - return ctx - - -# mermaid class view -class UMLClassMeta(BaseModel): - name: str = "" - visibility: str = "" - - @staticmethod - def name_to_visibility(name: str) -> str: - if name == "__init__": - return "+" - if name.startswith("__"): - return "-" - elif name.startswith("_"): - return "#" - return "+" - - -class UMLClassAttribute(UMLClassMeta): - value_type: str = "" - default_value: str = "" - - def get_mermaid(self, align=1) -> str: - content = "".join(["\t" for i in range(align)]) + self.visibility - if self.value_type: - content += self.value_type.replace(" ", "") + " " - name = self.name.split(":", 1)[1] if ":" in self.name else self.name - content += name - if self.default_value: - content += "=" - if self.value_type not in ["str", "string", "String"]: - content += self.default_value - else: - content += '"' + self.default_value.replace('"', "") + '"' - # if self.abstraction: - # content += "*" - # if self.static: - # content += "$" - return content - - -class UMLClassMethod(UMLClassMeta): - args: List[UMLClassAttribute] = Field(default_factory=list) - return_type: str = "" - - def get_mermaid(self, align=1) -> str: - content = "".join(["\t" for i in range(align)]) + self.visibility - name = self.name.split(":", 1)[1] if ":" in self.name else self.name - content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" - if self.return_type: - content += " " + self.return_type.replace(" ", "") - # if self.abstraction: - # content += "*" - # if self.static: - # content += "$" - return content - - -class UMLClassView(UMLClassMeta): - attributes: List[UMLClassAttribute] = Field(default_factory=list) - methods: List[UMLClassMethod] = Field(default_factory=list) - - def get_mermaid(self, align=1) -> str: - content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" - for v in self.attributes: - content += v.get_mermaid(align=align + 1) + "\n" - for v in self.methods: - content += v.get_mermaid(align=align + 1) + "\n" - content += "".join(["\t" for i in range(align)]) + "}\n" - return content - - @classmethod - def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView: - visibility = UMLClassView.name_to_visibility(dot_class_info.name) - class_view = cls(name=dot_class_info.name, visibility=visibility) - for i in dot_class_info.attributes.values(): - visibility = UMLClassAttribute.name_to_visibility(i.name) - attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_) - class_view.attributes.append(attr) - for i in dot_class_info.methods.values(): - visibility = UMLClassMethod.name_to_visibility(i.name) - method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_) - for j in i.args: - arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_) - method.args.append(arg) - method.return_type = i.return_args.type_ - class_view.methods.append(method) - return class_view + return metadata.model_dump()