mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-04 21:32:38 +02:00
Add files via upload
This commit is contained in:
parent
26112dd111
commit
d5d45114a4
1 changed files with 746 additions and 168 deletions
|
|
@ -1,209 +1,787 @@
|
|||
"""RAG schemas."""
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/8 22:12
|
||||
@Author : alexanderwu
|
||||
@File : schema.py
|
||||
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
||||
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
||||
@Modified By: mashenquan, 2023/11/22.
|
||||
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
|
||||
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
|
||||
between actions.
|
||||
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
PRDS_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import DotClassInfo
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
actionoutput_mapping_to_str,
|
||||
actionoutput_str_to_mapping,
|
||||
)
|
||||
|
||||
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
class SerializationMixin(BaseModel, extra="forbid"):
|
||||
"""
|
||||
PolyMorphic subclasses Serialization / Deserialization Mixin
|
||||
- First of all, we need to know that pydantic is not designed for polymorphism.
|
||||
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
|
||||
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
|
||||
More details:
|
||||
- https://docs.pydantic.dev/latest/concepts/serialization/
|
||||
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
|
||||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def __serialize_with_class_type__(self, default_serializer) -> Any:
|
||||
# default serializer, then append the `__module_class_name` field and return
|
||||
ret = default_serializer(self)
|
||||
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
return ret
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def __convert_to_real_type__(cls, value: Any, handler):
|
||||
if isinstance(value, dict) is False:
|
||||
return handler(value)
|
||||
|
||||
# it is a dict so make sure to remove the __module_class_name
|
||||
# because we don't allow extra keywords but want to ensure
|
||||
# e.g Cat.model_validate(cat.model_dump()) works
|
||||
class_full_name = value.pop("__module_class_name", None)
|
||||
|
||||
# if it's not the polymorphic base we construct via default handler
|
||||
if not cls.__is_polymorphic_base:
|
||||
if class_full_name is None:
|
||||
return handler(value)
|
||||
elif str(cls) == f"<class '{class_full_name}'>":
|
||||
return handler(value)
|
||||
else:
|
||||
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
|
||||
pass
|
||||
|
||||
# otherwise we lookup the correct polymorphic type and construct that
|
||||
# instead
|
||||
if class_full_name is None:
|
||||
raise ValueError("Missing __module_class_name field")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(class_full_name, None)
|
||||
|
||||
if class_type is None:
|
||||
# TODO could try dynamic import
|
||||
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
||||
cls.__is_polymorphic_base = is_polymorphic_base
|
||||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class IndexRetrieverConfig(BaseRetrieverConfig):
|
||||
"""Config for Index-basd retrievers."""
|
||||
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver.")
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
|
||||
es_url: str = Field(default=None, description="Elasticsearch URL.")
|
||||
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
|
||||
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
|
||||
es_user: str = Field(default=None, description="Elasticsearch username.")
|
||||
es_password: str = Field(default=None, description="Elasticsearch password.")
|
||||
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
|
||||
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
|
||||
|
||||
|
||||
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
vector_store_query_mode: VectorStoreQueryMode = Field(
|
||||
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support text only."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
|
||||
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
|
||||
)
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
|
||||
class Document(BaseModel):
|
||||
"""
|
||||
Represents a document.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
root_path: str = ""
|
||||
filename: str = ""
|
||||
content: str = ""
|
||||
|
||||
def get_meta(self) -> Document:
|
||||
"""Get metadata of the document.
|
||||
|
||||
:return: A new Document instance with the same root path and filename.
|
||||
"""
|
||||
|
||||
return Document(root_path=self.root_path, filename=self.filename)
|
||||
|
||||
@property
|
||||
def root_relative_path(self):
|
||||
"""Get relative path from root of git repository.
|
||||
|
||||
:return: relative path from root of git repository.
|
||||
"""
|
||||
return os.path.join(self.root_path, self.filename)
|
||||
|
||||
def __str__(self):
|
||||
return self.content
|
||||
|
||||
def __repr__(self):
|
||||
return self.content
|
||||
|
||||
|
||||
class LLMRankerConfig(BaseRankerConfig):
|
||||
"""Config for LLM-based rankers."""
|
||||
class Documents(BaseModel):
|
||||
"""A class representing a collection of documents.
|
||||
|
||||
llm: Any = Field(
|
||||
default=None,
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
|
||||
)
|
||||
|
||||
|
||||
class ColbertRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
|
||||
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
|
||||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="rerank-english-v3.0")
|
||||
api_key: str = Field(default="YOUR_COHERE_API")
|
||||
|
||||
|
||||
class BGERerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
|
||||
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
Attributes:
|
||||
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
docs: Dict[str, Document] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
|
||||
"""Create a Documents instance from a list of Document instances.
|
||||
|
||||
:param documents: A list of Document instances.
|
||||
:return: A Documents instance.
|
||||
"""
|
||||
|
||||
docs = {doc.filename: doc for doc in documents}
|
||||
return Documents(docs=docs)
|
||||
|
||||
def to_action_output(self) -> "ActionOutput":
|
||||
"""Convert to action output string.
|
||||
|
||||
:return: A string representing action output.
|
||||
"""
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
|
||||
|
||||
|
||||
class VectorIndexConfig(BaseIndexConfig):
|
||||
"""Config for vector-based index."""
|
||||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
|
||||
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
||||
content: str
|
||||
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = Field(default="", validate_default=True)
|
||||
sent_from: str = Field(default="", validate_default=True)
|
||||
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def check_id(cls, id: str) -> str:
|
||||
return id if id else uuid.uuid4().hex
|
||||
|
||||
class FAISSIndexConfig(VectorIndexConfig):
|
||||
"""Config for faiss-based index."""
|
||||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and isinstance(ic, dict) and "class" in ic:
|
||||
if "mapping" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
elif "module" in ic:
|
||||
# subclasses of BaseModel
|
||||
ic_obj = import_class(ic["class"], ic["module"])
|
||||
else:
|
||||
raise KeyError("missing required key to init Message.instruct_content from dict")
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
@field_validator("cause_by", mode="before")
|
||||
@classmethod
|
||||
def check_cause_by(cls, cause_by: Any) -> str:
|
||||
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
|
||||
class ChromaIndexConfig(VectorIndexConfig):
|
||||
"""Config for chroma-based index."""
|
||||
@field_validator("sent_from", mode="before")
|
||||
@classmethod
|
||||
def check_sent_from(cls, sent_from: Any) -> str:
|
||||
return any_to_str(sent_from if sent_from else "")
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
@field_validator("send_to", mode="before")
|
||||
@classmethod
|
||||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("send_to", mode="plain")
|
||||
def ser_send_to(self, send_to: set) -> list:
|
||||
return list(send_to)
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
||||
ic_dict = None
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
ic_type = str(type(ic))
|
||||
if "<class 'metagpt.actions.action_node" in ic_type:
|
||||
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
else:
|
||||
# due to instruct_content can be assigned by subclasses of BaseModel
|
||||
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
data["content"] = data.get("content", content)
|
||||
super().__init__(**data)
|
||||
|
||||
class ElasticsearchIndexConfig(VectorIndexConfig):
|
||||
"""Config for es-based index."""
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
if key == MESSAGE_ROUTE_CAUSE_BY:
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_FROM:
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_TO:
|
||||
new_val = any_to_str_set(val)
|
||||
else:
|
||||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
persist_path: Union[str, Path] = ""
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
if self.instruct_content:
|
||||
return f"{self.role}: {self.instruct_content.model_dump()}"
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
|
||||
"""Config for es-based index. no embedding."""
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.content
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
|
||||
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
||||
class ObjectNode(TextNode):
|
||||
"""RAG add object."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
|
||||
def dump(self) -> str:
|
||||
"""Convert the object to json string"""
|
||||
return self.model_dump_json(exclude_none=True, warnings=False)
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
|
||||
return metadata.model_dump()
|
||||
try:
|
||||
m = json.loads(val)
|
||||
id = m.get("id")
|
||||
if "id" in m:
|
||||
del m["id"]
|
||||
msg = Message(**m)
|
||||
if id:
|
||||
msg.id = id
|
||||
return msg
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {val}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content=content, role="user")
|
||||
|
||||
|
||||
class SystemMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content=content, role="system")
|
||||
|
||||
|
||||
class AIMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content=content, role="assistant")
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: str = ""
|
||||
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
|
||||
instruction: str = ""
|
||||
task_type: str = ""
|
||||
code: str = ""
|
||||
result: str = ""
|
||||
is_success: bool = False
|
||||
is_finished: bool = False
|
||||
|
||||
def reset(self):
|
||||
self.code = ""
|
||||
self.result = ""
|
||||
self.is_success = False
|
||||
self.is_finished = False
|
||||
|
||||
def update_task_result(self, task_result: TaskResult):
|
||||
self.code = task_result.code
|
||||
self.result = task_result.result
|
||||
self.is_success = task_result.is_success
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of taking a task, with result and is_success required to be filled"""
|
||||
|
||||
code: str = ""
|
||||
result: str
|
||||
is_success: bool
|
||||
|
||||
|
||||
class Plan(BaseModel):
|
||||
goal: str
|
||||
context: str = ""
|
||||
tasks: list[Task] = []
|
||||
task_map: dict[str, Task] = {}
|
||||
current_task_id: str = ""
|
||||
|
||||
def _topological_sort(self, tasks: list[Task]):
|
||||
task_map = {task.task_id: task for task in tasks}
|
||||
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
|
||||
sorted_tasks = []
|
||||
visited = set()
|
||||
|
||||
def visit(task_id):
|
||||
if task_id in visited:
|
||||
return
|
||||
visited.add(task_id)
|
||||
for dependent_id in dependencies.get(task_id, []):
|
||||
visit(dependent_id)
|
||||
sorted_tasks.append(task_map[task_id])
|
||||
|
||||
for task in tasks:
|
||||
visit(task.task_id)
|
||||
|
||||
return sorted_tasks
|
||||
|
||||
def add_tasks(self, tasks: list[Task]):
|
||||
"""
|
||||
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
|
||||
|
||||
This method performs two primary functions based on the current state of the task list:
|
||||
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
|
||||
correct execution order based on dependencies, and sets these as the current tasks.
|
||||
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
|
||||
any common prefix of tasks (based on task_id and instruction) and appends the remainder
|
||||
of the new tasks. The current task is updated to the first unfinished task in this merged list.
|
||||
|
||||
Args:
|
||||
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
|
||||
|
||||
Returns:
|
||||
None: The method updates the internal state of the plan but does not return anything.
|
||||
"""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
# Topologically sort the new tasks to ensure correct dependency order
|
||||
new_tasks = self._topological_sort(tasks)
|
||||
|
||||
if not self.tasks:
|
||||
# If there are no existing tasks, set the new tasks as the current tasks
|
||||
self.tasks = new_tasks
|
||||
|
||||
else:
|
||||
# Find the length of the common prefix between existing and new tasks
|
||||
prefix_length = 0
|
||||
for old_task, new_task in zip(self.tasks, new_tasks):
|
||||
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
|
||||
break
|
||||
prefix_length += 1
|
||||
|
||||
# Combine the common prefix with the remainder of the new tasks
|
||||
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
|
||||
self.tasks = final_tasks
|
||||
|
||||
# Update current_task_id to the first unfinished task in the merged list
|
||||
self._update_current_task()
|
||||
|
||||
# Update the task map for quick access to tasks by ID
|
||||
self.task_map = {task.task_id: task for task in self.tasks}
|
||||
|
||||
def reset_task(self, task_id: str):
|
||||
"""
|
||||
Clear code and result of the task based on task_id, and set the task as unfinished.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to be reset.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if task_id in self.task_map:
|
||||
task = self.task_map[task_id]
|
||||
task.reset()
|
||||
|
||||
def replace_task(self, new_task: Task):
|
||||
"""
|
||||
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
|
||||
|
||||
Args:
|
||||
new_task (Task): The new task that will replace an existing one.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert new_task.task_id in self.task_map
|
||||
# Replace the task in the task map and the task list
|
||||
self.task_map[new_task.task_id] = new_task
|
||||
for i, task in enumerate(self.tasks):
|
||||
if task.task_id == new_task.task_id:
|
||||
self.tasks[i] = new_task
|
||||
break
|
||||
|
||||
# Reset dependent tasks
|
||||
for task in self.tasks:
|
||||
if new_task.task_id in task.dependent_task_ids:
|
||||
self.reset_task(task.task_id)
|
||||
|
||||
def append_task(self, new_task: Task):
|
||||
"""
|
||||
Append a new task to the end of existing task sequences
|
||||
|
||||
Args:
|
||||
new_task (Task): The new task to be appended to the existing task sequence
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
|
||||
|
||||
assert all(
|
||||
[self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]
|
||||
), "New task has unknown dependencies"
|
||||
|
||||
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
|
||||
self.tasks.append(new_task)
|
||||
self.task_map[new_task.task_id] = new_task
|
||||
self._update_current_task()
|
||||
|
||||
def has_task_id(self, task_id: str) -> bool:
|
||||
return task_id in self.task_map
|
||||
|
||||
def _update_current_task(self):
|
||||
current_task_id = ""
|
||||
for task in self.tasks:
|
||||
if not task.is_finished:
|
||||
current_task_id = task.task_id
|
||||
break
|
||||
self.current_task_id = current_task_id # all tasks finished
|
||||
|
||||
@property
|
||||
def current_task(self) -> Task:
|
||||
"""Find current task to execute
|
||||
|
||||
Returns:
|
||||
Task: the current task to be executed
|
||||
"""
|
||||
return self.task_map.get(self.current_task_id, None)
|
||||
|
||||
def finish_current_task(self):
|
||||
"""Finish current task, set Task.is_finished=True, set current task to next task"""
|
||||
if self.current_task_id:
|
||||
self.current_task.is_finished = True
|
||||
self._update_current_task() # set to next task
|
||||
|
||||
def get_finished_tasks(self) -> list[Task]:
|
||||
"""return all finished tasks in correct linearized order
|
||||
|
||||
Returns:
|
||||
list[Task]: list of finished tasks
|
||||
"""
|
||||
return [task for task in self.tasks if task.is_finished]
|
||||
|
||||
|
||||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_queue: Queue = PrivateAttr(default_factory=Queue)
|
||||
|
||||
def pop(self) -> Message | None:
|
||||
"""Pop one message from the queue."""
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
if item:
|
||||
self._queue.task_done()
|
||||
return item
|
||||
except QueueEmpty:
|
||||
return None
|
||||
|
||||
def pop_all(self) -> List[Message]:
|
||||
"""Pop all messages from the queue."""
|
||||
ret = []
|
||||
while True:
|
||||
msg = self.pop()
|
||||
if not msg:
|
||||
break
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
def push(self, msg: Message):
|
||||
"""Push a message into the queue."""
|
||||
self._queue.put_nowait(msg)
|
||||
|
||||
def empty(self):
|
||||
"""Return true if the queue is empty."""
|
||||
return self._queue.empty()
|
||||
|
||||
async def dump(self) -> str:
|
||||
"""Convert the `MessageQueue` object to a json string."""
|
||||
if self.empty():
|
||||
return "[]"
|
||||
|
||||
lst = []
|
||||
msgs = []
|
||||
try:
|
||||
while True:
|
||||
item = await wait_for(self._queue.get(), timeout=1.0)
|
||||
if item is None:
|
||||
break
|
||||
msgs.append(item)
|
||||
lst.append(item.dump())
|
||||
self._queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Queue is empty, exiting...")
|
||||
finally:
|
||||
for m in msgs:
|
||||
self._queue.put_nowait(m)
|
||||
return json.dumps(lst, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> "MessageQueue":
|
||||
"""Convert the json string to the `MessageQueue` object."""
|
||||
queue = MessageQueue()
|
||||
try:
|
||||
lst = json.loads(data)
|
||||
for i in lst:
|
||||
msg = Message.load(i)
|
||||
queue.push(msg)
|
||||
except JSONDecodeError as e:
|
||||
logger.warning(f"JSON load failed: {data}, error:{e}")
|
||||
|
||||
return queue
|
||||
|
||||
|
||||
# 定义一个泛型类型变量
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseContext(BaseModel, ABC):
|
||||
@classmethod
|
||||
@handle_exception
|
||||
def loads(cls: Type[T], val: str) -> Optional[T]:
|
||||
i = json.loads(val)
|
||||
return cls(**i)
|
||||
|
||||
|
||||
class CodingContext(BaseContext):
|
||||
filename: str
|
||||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
code_plan_and_change_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
filename: str
|
||||
code_doc: Document
|
||||
test_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class RunCodeContext(BaseContext):
|
||||
mode: str = "script"
|
||||
code: Optional[str] = None
|
||||
code_filename: str = ""
|
||||
test_code: Optional[str] = None
|
||||
test_filename: str = ""
|
||||
command: List[str] = Field(default_factory=list)
|
||||
working_directory: str = ""
|
||||
additional_python_paths: List[str] = Field(default_factory=list)
|
||||
output_filename: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RunCodeResult(BaseContext):
|
||||
summary: str
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
class CodeSummarizeContext(BaseModel):
|
||||
design_filename: str = ""
|
||||
task_filename: str = ""
|
||||
codes_filenames: List[str] = Field(default_factory=list)
|
||||
reason: str = ""
|
||||
|
||||
@staticmethod
|
||||
def loads(filenames: List) -> CodeSummarizeContext:
|
||||
ctx = CodeSummarizeContext()
|
||||
for filename in filenames:
|
||||
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
||||
ctx.design_filename = str(filename)
|
||||
continue
|
||||
if Path(filename).is_relative_to(TASK_FILE_REPO):
|
||||
ctx.task_filename = str(filename)
|
||||
continue
|
||||
return ctx
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.design_filename, self.task_filename))
|
||||
|
||||
|
||||
class BugFixContext(BaseContext):
|
||||
filename: str = ""
|
||||
|
||||
|
||||
class CodePlanAndChangeContext(BaseModel):
|
||||
requirement: str = ""
|
||||
issue: str = ""
|
||||
prd_filename: str = ""
|
||||
design_filename: str = ""
|
||||
task_filename: str = ""
|
||||
|
||||
@staticmethod
|
||||
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
|
||||
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", ""))
|
||||
for filename in filenames:
|
||||
filename = Path(filename)
|
||||
if filename.is_relative_to(PRDS_FILE_REPO):
|
||||
ctx.prd_filename = filename.name
|
||||
continue
|
||||
if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
||||
ctx.design_filename = filename.name
|
||||
continue
|
||||
if filename.is_relative_to(TASK_FILE_REPO):
|
||||
ctx.task_filename = filename.name
|
||||
continue
|
||||
return ctx
|
||||
|
||||
|
||||
# mermaid class view
|
||||
class UMLClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
visibility: str = ""
|
||||
|
||||
@staticmethod
|
||||
def name_to_visibility(name: str) -> str:
|
||||
if name == "__init__":
|
||||
return "+"
|
||||
if name.startswith("__"):
|
||||
return "-"
|
||||
elif name.startswith("_"):
|
||||
return "#"
|
||||
return "+"
|
||||
|
||||
|
||||
class UMLClassAttribute(UMLClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type.replace(" ", "") + " "
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class UMLClassMethod(UMLClassMeta):
|
||||
args: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += " " + self.return_type.replace(" ", "")
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class UMLClassView(UMLClassMeta):
|
||||
attributes: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
methods: List[UMLClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
for v in self.attributes:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
for v in self.methods:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
content += "".join(["\t" for i in range(align)]) + "}\n"
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
|
||||
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
|
||||
class_view = cls(name=dot_class_info.name, visibility=visibility)
|
||||
for i in dot_class_info.attributes.values():
|
||||
visibility = UMLClassAttribute.name_to_visibility(i.name)
|
||||
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
|
||||
class_view.attributes.append(attr)
|
||||
for i in dot_class_info.methods.values():
|
||||
visibility = UMLClassMethod.name_to_visibility(i.name)
|
||||
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
|
||||
for j in i.args:
|
||||
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
|
||||
method.args.append(arg)
|
||||
method.return_type = i.return_args.type_
|
||||
class_view.methods.append(method)
|
||||
return class_view
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue