mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
Merge branch 'dev' into dev
This commit is contained in:
commit
539e1c7dce
81 changed files with 1402 additions and 649 deletions
|
|
@ -13,7 +13,7 @@ from metagpt.actions.add_requirement import UserRequirement
|
|||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.actions.design_api_review import DesignReview
|
||||
from metagpt.actions.project_management import AssignTasks, WriteTasks
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.actions.search_and_summarize import SearchAndSummarize
|
||||
|
|
@ -38,7 +38,6 @@ class ActionType(Enum):
|
|||
RUN_CODE = RunCode
|
||||
DEBUG_ERROR = DebugError
|
||||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
|
|
|
|||
|
|
@ -352,17 +352,3 @@ class ActionNode:
|
|||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
|
||||
def action_node_example():
|
||||
node = ActionNode(key="key-0", expected_type=str, instruction="instruction-a", example="example-b")
|
||||
|
||||
logger.info(node.compile(context="123", schema="raw", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="json", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="markdown", mode="auto"))
|
||||
logger.info(node.to_dict())
|
||||
logger.info(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
action_node_example()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,3 @@ from metagpt.actions import Action
|
|||
|
||||
class UserRequirement(Action):
|
||||
"""User Requirement without any implementation details"""
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -63,12 +62,3 @@ NODES = [
|
|||
]
|
||||
|
||||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class PrepareDocuments(Action):
|
|||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.project_path = path
|
||||
CONFIG.project_name = path.name
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
|
|
|
|||
|
|
@ -123,9 +123,3 @@ class WriteTasks(Action):
|
|||
@staticmethod
|
||||
async def _save_pdf(task_doc):
|
||||
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
|
||||
|
||||
|
||||
class AssignTasks(Action):
|
||||
async def run(self, *args, **kwargs):
|
||||
# Here you should implement the actual action
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class CollectLinks(Action):
|
|||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
rank_func: Union[Callable[[list[str]], None], None] = None
|
||||
rank_func: Optional[Callable[[list[str]], None]] = None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -130,7 +130,8 @@ class CollectLinks(Action):
|
|||
if len(remove) == 0:
|
||||
break
|
||||
|
||||
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
|
||||
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
|
|
@ -181,18 +182,18 @@ class WebBrowseAndSummarize(Action):
|
|||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: WebBrowserEngine = Field(
|
||||
default_factory=lambda: WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
|
||||
run_func=WebBrowseAndSummarize.browse_func,
|
||||
)
|
||||
)
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if CONFIG.model_for_researcher_summary:
|
||||
self.llm.model = CONFIG.model_for_researcher_summary
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
url: str,
|
||||
|
|
|
|||
|
|
@ -82,11 +82,13 @@ class RunCode(Action):
|
|||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def run_text(cls, code) -> Tuple[str, str]:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
try:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
except Exception as e:
|
||||
return "", str(e)
|
||||
return namespace.get("result", ""), ""
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -21,7 +21,10 @@ Example:
|
|||
This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using
|
||||
the specified docstring style and adds them to the code.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -29,7 +32,7 @@ from pydantic import Field
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.common import OutputParser, aread, awrite
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
PYTHON_DOCSTRING_SYSTEM = """### Requirements
|
||||
|
|
@ -187,6 +190,16 @@ class WriteDocstring(Action):
|
|||
documented_code = OutputParser.parse_python_code(documented_code)
|
||||
return merge_docstring(code, documented_code)
|
||||
|
||||
@staticmethod
|
||||
async def write_docstring(
|
||||
filename: str | Path, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"
|
||||
) -> str:
|
||||
data = await aread(str(filename))
|
||||
code = await WriteDocstring().run(data, style=style)
|
||||
if overwrite:
|
||||
await awrite(filename, code)
|
||||
return code
|
||||
|
||||
|
||||
def _simplify_python_code(code: str) -> None:
|
||||
"""Simplifies the given Python code by removing expressions and the last if statement.
|
||||
|
|
@ -207,13 +220,4 @@ def _simplify_python_code(code: str) -> None:
|
|||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"):
|
||||
with open(filename) as f:
|
||||
code = f.read()
|
||||
code = await WriteDocstring().run(code, style=style)
|
||||
if overwrite:
|
||||
with open(filename, "w") as f:
|
||||
f.write(code)
|
||||
return code
|
||||
|
||||
fire.Fire(run)
|
||||
fire.Fire(WriteDocstring.write_docstring)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WritePRD(Action):
|
||||
name: str = ""
|
||||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
|
|
@ -181,18 +181,13 @@ class WritePRD(Action):
|
|||
|
||||
@staticmethod
|
||||
async def _rename_workspace(prd):
|
||||
if CONFIG.project_path: # Updating on the old version has already been specified if it's valid. According to
|
||||
# Section 2.2.3.10 of RFC 135
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
return
|
||||
|
||||
if not CONFIG.project_name:
|
||||
if isinstance(prd, (ActionOutput, ActionNode)):
|
||||
ws_name = prd.instruct_content.model_dump()["Project Name"]
|
||||
else:
|
||||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
CONFIG.project_name = ws_name
|
||||
if ws_name:
|
||||
CONFIG.project_name = ws_name
|
||||
CONFIG.git_repo.rename_root(CONFIG.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ class Config(metaclass=Singleton):
|
|||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
self.git_reinit = False
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
|
|
@ -110,11 +111,7 @@ class Config(metaclass=Singleton):
|
|||
|
||||
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
model_name = model_mappings.get(provider)
|
||||
model_name = self.get_model_name(provider=provider)
|
||||
if model_name:
|
||||
logger.info(f"{provider} Model: {model_name}")
|
||||
if provider:
|
||||
|
|
@ -122,6 +119,14 @@ class Config(metaclass=Singleton):
|
|||
return provider
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
||||
def get_model_name(self, provider=None) -> str:
|
||||
provider = provider or self.get_default_llm_provider_enum()
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
return model_mappings.get(provider, "")
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k: str) -> bool:
|
||||
return bool(k and k != "YOUR_API_KEY")
|
||||
|
|
@ -142,7 +147,7 @@ class Config(metaclass=Singleton):
|
|||
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
# self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
|||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/28 00:00
|
||||
@Author : alexanderwu
|
||||
@File : milvus_store.py
|
||||
"""
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
|
||||
|
||||
|
||||
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
|
||||
"""Assume the structure of columns is str: regular type"""
|
||||
fields = []
|
||||
for col, ctype in columns.items():
|
||||
if ctype == str:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], max_length=100)
|
||||
elif ctype == np.ndarray:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], dim=2)
|
||||
else:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], is_primary=(col == primary_col_name))
|
||||
fields.append(mcol)
|
||||
schema = CollectionSchema(fields, description=desc)
|
||||
return schema
|
||||
|
||||
|
||||
class MilvusConnection(TypedDict):
|
||||
alias: str
|
||||
host: str
|
||||
port: str
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
connections.connect(**connection)
|
||||
self.collection = None
|
||||
|
||||
def _create_collection(self, name, schema):
|
||||
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
|
||||
return collection
|
||||
|
||||
def create_collection(self, name, columns):
|
||||
schema = columns_to_milvus_schema(columns, "idx")
|
||||
self.collection = self._create_collection(name, schema)
|
||||
return self.collection
|
||||
|
||||
def drop(self, name):
|
||||
Collection(name).drop()
|
||||
|
||||
def load_collection(self):
|
||||
self.collection.load()
|
||||
|
||||
def build_index(self, field="emb"):
|
||||
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
|
||||
|
||||
def search(self, query: list[list[float]], *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/search.md
|
||||
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
|
||||
Note the above description, is this logic serious? This should take a long time, right?
|
||||
"""
|
||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
||||
results = self.collection.search(
|
||||
data=query,
|
||||
anns_field=kwargs.get("field", "emb"),
|
||||
param=search_params,
|
||||
limit=10,
|
||||
expr=None,
|
||||
consistency_level="Strong",
|
||||
)
|
||||
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
|
||||
return results
|
||||
|
||||
def write(self, name, schema, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, data, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/insert_data.md
|
||||
import random
|
||||
data = [
|
||||
[i for i in range(2000)],
|
||||
[i for i in range(10000, 12000)],
|
||||
[[random.random() for _ in range(2)] for _ in range(2000)],
|
||||
]
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.collection.insert(data)
|
||||
|
|
@ -28,7 +28,7 @@ class SkillManager:
|
|||
:return:
|
||||
"""
|
||||
self._skills[skill.name] = skill
|
||||
self._store.add(skill.desc, {}, skill.name)
|
||||
self._store.add(skill.desc, {"name": skill.name, "desc": skill.desc}, skill.name)
|
||||
|
||||
def del_skill(self, skill_name: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ class BrainMemory(BaseModel):
|
|||
return "\n".join(texts)
|
||||
|
||||
@staticmethod
|
||||
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
|
||||
redis = Redis(conf=redis_conf)
|
||||
if not redis.is_valid() or not redis_key:
|
||||
async def loads(redis_key: str) -> "BrainMemory":
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
return BrainMemory()
|
||||
v = await redis.get(key=redis_key)
|
||||
logger.debug(f"REDIS GET {redis_key} {v}")
|
||||
|
|
@ -67,11 +67,11 @@ class BrainMemory(BaseModel):
|
|||
return bm
|
||||
return BrainMemory()
|
||||
|
||||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
|
||||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
|
||||
if not self.is_dirty:
|
||||
return
|
||||
redis = Redis(conf=redis_conf)
|
||||
if not redis.is_valid() or not redis_key:
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
return False
|
||||
v = self.model_dump_json()
|
||||
if self.cacheable:
|
||||
|
|
@ -86,26 +86,27 @@ class BrainMemory(BaseModel):
|
|||
async def set_history_summary(self, history_summary, redis_key, redis_conf):
|
||||
if self.historical_summary == history_summary:
|
||||
if self.is_dirty:
|
||||
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
|
||||
await self.dumps(redis_key=redis_key)
|
||||
self.is_dirty = False
|
||||
return
|
||||
|
||||
self.historical_summary = history_summary
|
||||
self.history = []
|
||||
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
|
||||
await self.dumps(redis_key=redis_key)
|
||||
self.is_dirty = False
|
||||
|
||||
def add_history(self, msg: Message):
|
||||
if msg.id:
|
||||
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
|
||||
return
|
||||
self.history.append(msg.model_dump())
|
||||
|
||||
self.history.append(msg)
|
||||
self.last_history_id = str(msg.id)
|
||||
self.is_dirty = True
|
||||
|
||||
def exists(self, text) -> bool:
|
||||
for m in reversed(self.history):
|
||||
if m.get("content") == text:
|
||||
if m.content == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -163,7 +164,7 @@ class BrainMemory(BaseModel):
|
|||
msgs.reverse()
|
||||
self.history = msgs
|
||||
self.is_dirty = True
|
||||
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
|
||||
await self.dumps(redis_key=CONFIG.REDIS_KEY)
|
||||
self.is_dirty = False
|
||||
|
||||
return BrainMemory.to_metagpt_history_format(self.history)
|
||||
|
|
@ -217,7 +218,7 @@ class BrainMemory(BaseModel):
|
|||
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
|
||||
@staticmethod
|
||||
async def _metagpt_rewrite(sentence: str):
|
||||
async def _metagpt_rewrite(sentence: str, **kwargs):
|
||||
return sentence
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class GeminiLLM(BaseLLM):
|
|||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
# Not to change BaseGPTAPI default functions but update with Gemini's conversation format.
|
||||
# Not to change BaseLLM default functions but update with Gemini's conversation format.
|
||||
# You should follow the format.
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class OpenAILLM(BaseLLM):
|
|||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
|
||||
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
if proxy_params := self._get_proxy_params():
|
||||
|
|
@ -81,8 +81,8 @@ class OpenAILLM(BaseLLM):
|
|||
params = {}
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.OPENAI_BASE_URL:
|
||||
params["base_url"] = self.config.OPENAI_BASE_URL
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
|||
|
|
@ -40,10 +40,11 @@ class ProductManager(Role):
|
|||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo:
|
||||
if CONFIG.git_repo and not CONFIG.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
CONFIG.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self.rc.todo)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -107,9 +108,11 @@ class Researcher(Role):
|
|||
return msg
|
||||
|
||||
def write_report(self, topic: str, content: str):
|
||||
filename = re.sub(r'[\\/:"*?<>|]+', " ", topic)
|
||||
filename = filename.replace("\n", "")
|
||||
if not RESEARCH_PATH.exists():
|
||||
RESEARCH_PATH.mkdir(parents=True)
|
||||
filepath = RESEARCH_PATH / f"{topic}.md"
|
||||
filepath = RESEARCH_PATH / f"{filename}.md"
|
||||
filepath.write_text(content)
|
||||
|
||||
|
||||
|
|
|
|||
4
metagpt/strategy/__init__.py
Normal file
4
metagpt/strategy/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
108
metagpt/strategy/base.py
Normal file
108
metagpt/strategy/base.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:16 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from typing import List
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseParser(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, input: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseEvaluator(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def status_verify(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ThoughtNode(Node):
|
||||
"""A node representing a thought in the thought tree."""
|
||||
|
||||
name: str = ""
|
||||
value: int = 0
|
||||
id: int = 0
|
||||
valid_status: bool = True
|
||||
|
||||
def update_value(self, value) -> None:
|
||||
"""Update the value of the thought node."""
|
||||
self.value = value
|
||||
|
||||
def update_valid_status(self, status) -> None:
|
||||
"""Update the validity status of the thought node."""
|
||||
self.valid_status = status
|
||||
|
||||
|
||||
class ThoughtTree(RenderTree):
|
||||
"""A tree structure to represent thoughts."""
|
||||
|
||||
@property
|
||||
def all_nodes(self) -> List[ThoughtNode]:
|
||||
"""
|
||||
Get a list of all nodes in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list containing all nodes in the thought tree.
|
||||
"""
|
||||
all_nodes = [node for _, _, node in self]
|
||||
return all_nodes
|
||||
|
||||
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Update the tree with new thoughts.
|
||||
|
||||
Args:
|
||||
thought (List[dict]): A list of dictionaries representing thought information.
|
||||
current_node (ThoughtNode): The current node under which new thoughts will be added.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
|
||||
"""
|
||||
nodes = []
|
||||
for node_info in thought:
|
||||
node = ThoughtNode(
|
||||
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
|
||||
)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def parse_node_path(self, node) -> List[str]:
|
||||
"""
|
||||
Parse and retrieve the hierarchical path of the given thought node.
|
||||
|
||||
This method traverses the parent nodes of the provided 'node' and constructs
|
||||
the full path from the root node to the given node.
|
||||
|
||||
Args:
|
||||
node: The thought node for which the hierarchical path needs to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list representing the full hierarchical path of the given thought node.
|
||||
The list is ordered from the root node to the provided node.
|
||||
"""
|
||||
full_node_path = []
|
||||
while node is not None:
|
||||
full_node_path.append(node.name)
|
||||
node = node.parent
|
||||
full_node_path.reverse()
|
||||
return full_node_path
|
||||
|
||||
def show(self) -> None:
|
||||
"""Print the updated tree."""
|
||||
print("\nUpdated Tree:")
|
||||
for pre, _, node in self:
|
||||
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
|
||||
4
metagpt/strategy/examples/__init__.py
Normal file
4
metagpt/strategy/examples/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/26/2023 3:32 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
73
metagpt/strategy/examples/creative_writing.py
Normal file
73
metagpt/strategy/examples/creative_writing.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 1:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
|
||||
from metagpt.strategy.prompt_templates.creative_writing import cot_prompt, vote_prompt
|
||||
from metagpt.strategy.tot import TreeofThought
|
||||
from metagpt.strategy.tot_schema import (
|
||||
BaseEvaluator,
|
||||
BaseParser,
|
||||
Strategy,
|
||||
ThoughtSolverConfig,
|
||||
)
|
||||
|
||||
|
||||
class TextGenParser(BaseParser):
|
||||
propose_prompt: str = cot_prompt
|
||||
value_prompt: str = vote_prompt
|
||||
|
||||
def __call__(self, input_text: str) -> str:
|
||||
return input_text
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
return self.propose_prompt.format(input=current_state, **kwargs)
|
||||
|
||||
def value(self, input: str = "", **kwargs) -> str:
|
||||
# node_result = self(input)
|
||||
id = kwargs.get("node_id", "0")
|
||||
return self.value_prompt + f"Choice {id}:\n{input}\n"
|
||||
|
||||
|
||||
class TextGenEvaluator(BaseEvaluator):
|
||||
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
|
||||
status_map = {val: key for key, val in value_map.items()}
|
||||
|
||||
def __call__(self, evaluation: str, **kwargs) -> float:
|
||||
try:
|
||||
value = 0
|
||||
node_id = kwargs.get("node_id", "0")
|
||||
pattern = r".*best choice is .*(\d+).*"
|
||||
match = re.match(pattern, evaluation, re.DOTALL)
|
||||
|
||||
if match:
|
||||
vote = int(match.groups()[0])
|
||||
print(vote)
|
||||
if vote == int(node_id):
|
||||
value = 1
|
||||
except:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
def status_verify(self, value):
|
||||
status = False
|
||||
if value in self.status_map:
|
||||
status_value = self.status_map[value]
|
||||
if status_value != "impossible":
|
||||
status = True
|
||||
return status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
|
||||
|
||||
parser = TextGenParser()
|
||||
evaluator = TextGenEvaluator()
|
||||
|
||||
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
|
||||
|
||||
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
|
||||
asyncio.run(tot_base.solve(init_prompt=initial_prompt))
|
||||
64
metagpt/strategy/examples/game24.py
Normal file
64
metagpt/strategy/examples/game24.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 1:36 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
|
||||
from metagpt.strategy.prompt_templates.game24 import propose_prompt, value_prompt
|
||||
from metagpt.strategy.tot import TreeofThought
|
||||
from metagpt.strategy.tot_schema import (
|
||||
BaseEvaluator,
|
||||
BaseParser,
|
||||
Strategy,
|
||||
ThoughtSolverConfig,
|
||||
)
|
||||
|
||||
|
||||
class Game24Parser(BaseParser):
|
||||
propose_prompt: str = propose_prompt
|
||||
value_prompt: str = value_prompt
|
||||
|
||||
def __call__(self, input_text: str) -> str:
|
||||
last_line = input_text.strip().split("\n")[-1]
|
||||
return last_line.split("left: ")[-1].split(")")[0]
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
return self.propose_prompt.format(input=current_state, **kwargs)
|
||||
|
||||
def value(self, input: str = "", **kwargs) -> str:
|
||||
node_result = self(input)
|
||||
return self.value_prompt.format(input=node_result)
|
||||
|
||||
|
||||
class Game24Evaluator(BaseEvaluator):
|
||||
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
|
||||
status_map = {val: key for key, val in value_map.items()}
|
||||
|
||||
def __call__(self, evaluation: str, **kwargs) -> float:
|
||||
try:
|
||||
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
|
||||
value = self.value_map[matches[0]]
|
||||
except:
|
||||
value = 0.001
|
||||
return value
|
||||
|
||||
def status_verify(self, value):
|
||||
status = False
|
||||
if value in self.status_map:
|
||||
status_value = self.status_map[value]
|
||||
if status_value != "impossible":
|
||||
status = True
|
||||
return status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
initial_prompt = """4 5 6 10"""
|
||||
parser = Game24Parser()
|
||||
evaluator = Game24Evaluator()
|
||||
|
||||
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
|
||||
|
||||
tot = TreeofThought(strategy=Strategy.BFS, config=config)
|
||||
asyncio.run(tot.solve(init_prompt=initial_prompt))
|
||||
4
metagpt/strategy/prompt_templates/__init__.py
Normal file
4
metagpt/strategy/prompt_templates/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 5:21 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
25
metagpt/strategy/prompt_templates/creative_writing.py
Normal file
25
metagpt/strategy/prompt_templates/creative_writing.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
standard_prompt = """
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
"""
|
||||
|
||||
cot_prompt = """
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
|
||||
Make a plan then write. Your output should be of the following format:
|
||||
|
||||
Plan:
|
||||
Your plan here.
|
||||
|
||||
Passage:
|
||||
Your passage here.
|
||||
"""
|
||||
|
||||
|
||||
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
|
||||
"""
|
||||
|
||||
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
|
||||
"""
|
||||
|
||||
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
|
||||
"""
|
||||
139
metagpt/strategy/prompt_templates/game24.py
Normal file
139
metagpt/strategy/prompt_templates/game24.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# 5-shot
|
||||
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Input: 1 4 8 8
|
||||
Answer: (8 / 4 + 1) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Answer: 5 + 5 + 5 + 9 = 24
|
||||
Input: {input}
|
||||
"""
|
||||
|
||||
# 5-shot
|
||||
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
|
||||
Input: 4 4 6 8
|
||||
Steps:
|
||||
4 + 8 = 12 (left: 4 6 12)
|
||||
6 - 4 = 2 (left: 2 12)
|
||||
2 * 12 = 24 (left: 24)
|
||||
Answer: (6 - 4) * (4 + 8) = 24
|
||||
Input: 2 9 10 12
|
||||
Steps:
|
||||
12 * 2 = 24 (left: 9 10 24)
|
||||
10 - 9 = 1 (left: 1 24)
|
||||
24 * 1 = 24 (left: 24)
|
||||
Answer: (12 * 2) * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Steps:
|
||||
13 - 10 = 3 (left: 3 4 9)
|
||||
9 - 3 = 6 (left: 4 6)
|
||||
4 * 6 = 24 (left: 24)
|
||||
Answer: 4 * (9 - (13 - 10)) = 24
|
||||
Input: 1 4 8 8
|
||||
Steps:
|
||||
8 / 4 = 2 (left: 1 2 8)
|
||||
1 + 2 = 3 (left: 3 8)
|
||||
3 * 8 = 24 (left: 24)
|
||||
Answer: (1 + 8 / 4) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Steps:
|
||||
5 + 5 = 10 (left: 5 9 10)
|
||||
10 + 5 = 15 (left: 9 15)
|
||||
15 + 9 = 24 (left: 24)
|
||||
Answer: ((5 + 5) + 5) + 9 = 24
|
||||
Input: {input}
|
||||
"""
|
||||
|
||||
# 1-shot
|
||||
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
|
||||
Input: 2 8 8 14
|
||||
Possible next steps:
|
||||
2 + 8 = 10 (left: 8 10 14)
|
||||
8 / 2 = 4 (left: 4 8 14)
|
||||
14 + 2 = 16 (left: 8 8 16)
|
||||
2 * 8 = 16 (left: 8 14 16)
|
||||
8 - 2 = 6 (left: 6 8 14)
|
||||
14 - 8 = 6 (left: 2 6 8)
|
||||
14 / 2 = 7 (left: 7 8 8)
|
||||
14 - 2 = 12 (left: 8 8 12)
|
||||
|
||||
Here is my task for 1 input and {n_generate_sample} possible thoughts:
|
||||
Input: {input}
|
||||
Possible next steps:
|
||||
|
||||
|
||||
"""
|
||||
|
||||
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
|
||||
10 14
|
||||
10 + 14 = 24
|
||||
sure
|
||||
11 12
|
||||
11 + 12 = 23
|
||||
12 - 11 = 1
|
||||
11 * 12 = 132
|
||||
11 / 12 = 0.91
|
||||
impossible
|
||||
4 4 10
|
||||
4 + 4 + 10 = 8 + 10 = 18
|
||||
4 * 10 - 4 = 40 - 4 = 36
|
||||
(10 - 4) * 4 = 6 * 4 = 24
|
||||
sure
|
||||
4 9 11
|
||||
9 + 11 + 4 = 20 + 4 = 24
|
||||
sure
|
||||
5 7 8
|
||||
5 + 7 + 8 = 12 + 8 = 20
|
||||
(8 - 5) * 7 = 3 * 7 = 21
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
5 6 6
|
||||
5 + 6 + 6 = 17
|
||||
(6 - 5) * 6 = 1 * 6 = 6
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
10 10 11
|
||||
10 + 10 + 11 = 31
|
||||
(11 - 10) * 10 = 10
|
||||
10 10 10 are all too big
|
||||
impossible
|
||||
1 3 3
|
||||
1 * 3 * 3 = 9
|
||||
(1 + 3) * 3 = 12
|
||||
1 3 3 are all too small
|
||||
impossible
|
||||
{input}
|
||||
"""
|
||||
|
||||
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) + 1 = 25
|
||||
Judge:
|
||||
impossible
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * (12 - 10) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 4) * (10 - 9) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: {input}
|
||||
Answer: {answer}
|
||||
Judge:"""
|
||||
272
metagpt/strategy/tot.py
Normal file
272
metagpt/strategy/tot.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.strategy.base import ThoughtNode, ThoughtTree
|
||||
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
OUTPUT_FORMAT = """
|
||||
Output a list of jsons following the format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"node_id": str = "unique identifier for a solution, can be an ordinal",
|
||||
"node_state_instruction": "specified sample of solution",
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class ThoughtSolverBase(BaseModel):
|
||||
thought_tree: str = ""
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.llm.use_system_prompt = False
|
||||
|
||||
async def solve(self, init_prompt):
|
||||
"""
|
||||
Solve method for subclasses to implement.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the solve method")
|
||||
|
||||
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Generate children thoughts based on the current state.
|
||||
|
||||
Args:
|
||||
current_state (str): The current state for which thoughts are generated.
|
||||
current_node (ThoughtNode): The current node in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of nodes representing the generated thoughts.
|
||||
"""
|
||||
state_prompt = self.config.parser.propose(
|
||||
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
|
||||
)
|
||||
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
|
||||
thoughts = CodeParser.parse_code(block=None, text=rsp)
|
||||
thoughts = eval(thoughts)
|
||||
# fixme 避免不跟随,生成过多nodes
|
||||
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
|
||||
return self.thought_tree.update_node(thoughts, current_node=current_node)
|
||||
|
||||
async def evaluate_node(self, node, parent_value) -> None:
|
||||
"""
|
||||
Evaluate a node and update its status and value.
|
||||
|
||||
Args:
|
||||
node (ThoughtNode): The node to be evaluated.
|
||||
parent_value (float): The parent node's value.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
|
||||
evaluation = await self.llm.aask(msg=eval_prompt)
|
||||
|
||||
value = self.config.evaluator(evaluation, **{"node_id": node.id})
|
||||
status = self.config.evaluator.status_verify(value)
|
||||
|
||||
node.update_valid_status(status=status)
|
||||
# 累计分数
|
||||
node.update_value(parent_value + value)
|
||||
|
||||
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
|
||||
"""
|
||||
Select nodes based on the configured selection method.
|
||||
|
||||
Args:
|
||||
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of selected nodes.
|
||||
"""
|
||||
# selection
|
||||
if self.config.method_select == MethodSelect.SAMPLE:
|
||||
raise NotImplementedError
|
||||
elif self.config.method_select == MethodSelect.GREEDY:
|
||||
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
|
||||
for node in thought_nodes:
|
||||
if node not in select_nodes:
|
||||
node.parent = None # 从树中删除节点
|
||||
return select_nodes
|
||||
|
||||
def update_solution(self):
|
||||
"""
|
||||
Select the result with the highest score.
|
||||
|
||||
Returns:
|
||||
- List[ThoughtNode]: List of nodes representing the best solution.
|
||||
- List[str]: List of node names forming the best solution path.
|
||||
"""
|
||||
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
|
||||
best_solution_path = self.thought_tree.parse_node_path(best_node)
|
||||
return [best_node], best_solution_path
|
||||
|
||||
|
||||
class BFSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through BFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
current_nodes = [root]
|
||||
for step in range(self.config.max_steps):
|
||||
solutions = await self._bfs_build(current_nodes)
|
||||
|
||||
selected_nodes = self.select_nodes(solutions)
|
||||
current_nodes = selected_nodes
|
||||
|
||||
self.thought_tree.show()
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
async def _bfs_build(self, current_nodes):
|
||||
"""
|
||||
Build the thought tree using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
current_nodes (List[ThoughtNode]): Current nodes to expand.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
|
||||
"""
|
||||
tasks = []
|
||||
for node in current_nodes:
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
|
||||
|
||||
thought_nodes_list = await asyncio.gather(*tasks)
|
||||
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
|
||||
return solutions
|
||||
|
||||
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await asyncio.gather(
|
||||
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
|
||||
)
|
||||
return thought_nodes
|
||||
|
||||
|
||||
class DFSSolver(ThoughtSolverBase):
|
||||
async def _dfs(self, root_node):
|
||||
"""
|
||||
Perform Depth-First Search (DFS) on the thought tree.
|
||||
|
||||
Args:
|
||||
root_node (ThoughtNode): The root node of the thought tree.
|
||||
|
||||
Returns:
|
||||
List[str]: The solution path obtained through DFS.
|
||||
"""
|
||||
impossible_state_cnt = 0
|
||||
node = root_node
|
||||
for step in range(self.max_steps):
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
|
||||
if thought_nodes[0].valid_status is False:
|
||||
impossible_state_cnt += 1
|
||||
if impossible_state_cnt >= 2:
|
||||
logger.info("impossible state reached, break")
|
||||
break
|
||||
node = thought_nodes[0]
|
||||
_solution_path = self.thought_tree.parse_node_path(node)
|
||||
self.thought_tree.show()
|
||||
|
||||
return _solution_path
|
||||
|
||||
async def solve(self, init_prompt="", root=ThoughtNode("")):
|
||||
"""
|
||||
Solve the problem using Depth-First Search (DFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through DFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
for n in range(self.config.n_solution_sample):
|
||||
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
|
||||
await self._dfs(root)
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
|
||||
class MCTSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TreeofThought(BaseModel):
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
|
||||
strategy: Strategy = Field(default=Strategy.BFS)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._initialize_solver(self.strategy)
|
||||
|
||||
def _initialize_solver(self, strategy):
|
||||
"""
|
||||
Initialize the solver based on the chosen strategy.
|
||||
|
||||
Args:
|
||||
strategy (Strategy): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
ThoughtSolverBase: An instance of the appropriate solver.
|
||||
"""
|
||||
if strategy == Strategy.BFS:
|
||||
self.solver = BFSSolver(config=self.config)
|
||||
elif strategy == Strategy.DFS:
|
||||
self.solver = DFSSolver(config=self.config)
|
||||
elif strategy == Strategy.MCTS:
|
||||
self.solver = MCTSSolver(config=self.config)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
|
||||
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using the specified strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
strategy (str): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
Any: The solution obtained using the selected strategy.
|
||||
"""
|
||||
await self.solver.solve(init_prompt)
|
||||
30
metagpt/strategy/tot_schema.py
Normal file
30
metagpt/strategy/tot_schema.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:14 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.strategy.base import BaseEvaluator, BaseParser
|
||||
|
||||
|
||||
class MethodSelect(Enum):
|
||||
SAMPLE = "sample"
|
||||
GREEDY = "greedy"
|
||||
|
||||
|
||||
class Strategy(Enum):
|
||||
BFS = "BFS"
|
||||
DFS = "DFS"
|
||||
MCTS = "MCTS"
|
||||
|
||||
|
||||
class ThoughtSolverConfig(BaseModel):
|
||||
max_steps: int = 3
|
||||
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
|
||||
n_generate_sample: int = 5 # per node
|
||||
n_select_sample: int = 3 # per path
|
||||
n_solution_sample: int = 5 # only for dfs
|
||||
parser: BaseParser = Field(default_factory=BaseParser)
|
||||
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)
|
||||
|
|
@ -95,4 +95,4 @@ class SearchEngine:
|
|||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
return await self.run_func(query, max_results, as_string)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ class SerpAPIWrapper(BaseModel):
|
|||
|
||||
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
result = await self.results(query, max_results)
|
||||
return self._process_response(result, as_string=as_string)
|
||||
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from typing import Literal
|
|||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from webdriver_manager.core.download_manager import WDMDownloadManager
|
||||
from webdriver_manager.core.http import WDMHttpClient
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
|
@ -93,6 +95,13 @@ _webdriver_manager_types = {
|
|||
}
|
||||
|
||||
|
||||
class WDMHttpProxyClient(WDMHttpClient):
|
||||
def get(self, url, **kwargs):
|
||||
if "proxies" not in kwargs and CONFIG.global_proxy:
|
||||
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
|
||||
return super().get(url, **kwargs)
|
||||
|
||||
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
||||
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
|
||||
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
|
||||
|
|
@ -101,7 +110,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
if not executable_path:
|
||||
module_name, type_name = _webdriver_manager_types[browser_type]
|
||||
DriverManager = getattr(importlib.import_module(module_name), type_name)
|
||||
driver_manager = DriverManager()
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
|
||||
# driver_manager.driver_cache.find_driver(driver_manager.driver))
|
||||
executable_path = driver_manager.install()
|
||||
|
||||
|
|
|
|||
|
|
@ -131,13 +131,11 @@ class OutputParser:
|
|||
try:
|
||||
content = cls.parse_code(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 尝试解析list
|
||||
try:
|
||||
content = cls.parse_file_list(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
# 尝试解析list
|
||||
try:
|
||||
content = cls.parse_file_list(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
parsed_data[block] = content
|
||||
return parsed_data
|
||||
|
||||
|
|
|
|||
|
|
@ -63,5 +63,5 @@ class Redis:
|
|||
self._client = None
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return bool(self._client)
|
||||
def is_valid(self) -> bool:
|
||||
return self._client is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue