mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 21:02:38 +02:00
Merge branch 'geekan/dev' into feature/rebuild
This commit is contained in:
commit
5f88e12a7d
62 changed files with 1109 additions and 565 deletions
49
metagpt/actions/action_graph.py
Normal file
49
metagpt/actions/action_graph.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 13:52
|
||||
@Author : alexanderwu
|
||||
@File : action_graph.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# from metagpt.actions.action_node import ActionNode
|
||||
|
||||
|
||||
class ActionGraph:
|
||||
"""ActionGraph: a directed graph to represent the dependency between actions."""
|
||||
|
||||
def __init__(self):
|
||||
self.nodes = {}
|
||||
self.edges = {}
|
||||
self.execution_order = []
|
||||
|
||||
def add_node(self, node):
|
||||
"""Add a node to the graph"""
|
||||
self.nodes[node.key] = node
|
||||
|
||||
def add_edge(self, from_node: "ActionNode", to_node: "ActionNode"):
|
||||
"""Add an edge to the graph"""
|
||||
if from_node.key not in self.edges:
|
||||
self.edges[from_node.key] = []
|
||||
self.edges[from_node.key].append(to_node.key)
|
||||
from_node.add_next(to_node)
|
||||
to_node.add_prev(from_node)
|
||||
|
||||
def topological_sort(self):
|
||||
"""Topological sort the graph"""
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
def visit(k):
|
||||
if k not in visited:
|
||||
visited.add(k)
|
||||
if k in self.edges:
|
||||
for next_node in self.edges[k]:
|
||||
visit(next_node)
|
||||
stack.insert(0, k)
|
||||
|
||||
for key in self.nodes:
|
||||
visit(key)
|
||||
|
||||
self.execution_order = stack
|
||||
|
|
@ -9,6 +9,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
|
||||
"""
|
||||
import json
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
|
|
@ -39,7 +40,6 @@ TAG = "CONTENT"
|
|||
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
|
||||
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
|
||||
|
||||
|
||||
SIMPLE_TEMPLATE = """
|
||||
## context
|
||||
{context}
|
||||
|
|
@ -131,6 +131,8 @@ class ActionNode:
|
|||
|
||||
# Action Input
|
||||
key: str # Product Requirement / File list / Code
|
||||
func: typing.Callable # 与节点相关联的函数或LLM调用
|
||||
params: Dict[str, Type] # 输入参数的字典,键为参数名,值为参数类型
|
||||
expected_type: Type # such as str / int / float etc.
|
||||
# context: str # everything in the history.
|
||||
instruction: str # the instructions should be followed.
|
||||
|
|
@ -140,6 +142,10 @@ class ActionNode:
|
|||
content: str
|
||||
instruct_content: BaseModel
|
||||
|
||||
# For ActionGraph
|
||||
prevs: List["ActionNode"] # previous nodes
|
||||
nexts: List["ActionNode"] # next nodes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
|
@ -157,6 +163,8 @@ class ActionNode:
|
|||
self.content = content
|
||||
self.children = children if children is not None else {}
|
||||
self.schema = schema
|
||||
self.prevs = []
|
||||
self.nexts = []
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
|
|
@ -167,6 +175,14 @@ class ActionNode:
|
|||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def add_prev(self, node: "ActionNode"):
|
||||
"""增加前置ActionNode"""
|
||||
self.prevs.append(node)
|
||||
|
||||
def add_next(self, node: "ActionNode"):
|
||||
"""增加后置ActionNode"""
|
||||
self.nexts.append(node)
|
||||
|
||||
def add_child(self, node: "ActionNode"):
|
||||
"""增加子ActionNode"""
|
||||
self.children[node.key] = node
|
||||
|
|
@ -186,41 +202,38 @@ class ActionNode:
|
|||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping_old(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
def _get_children_mapping(self, exclude=None) -> Dict[str, Any]:
|
||||
"""获得子ActionNode的字典,以key索引,支持多级结构。"""
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引,支持多级结构"""
|
||||
exclude = exclude or []
|
||||
mapping = {}
|
||||
|
||||
def _get_mapping(node: "ActionNode", prefix: str = ""):
|
||||
def _get_mapping(node: "ActionNode") -> Dict[str, Any]:
|
||||
mapping = {}
|
||||
for key, child in node.children.items():
|
||||
if key in exclude:
|
||||
continue
|
||||
full_key = f"{prefix}{key}"
|
||||
mapping[full_key] = (child.expected_type, ...)
|
||||
_get_mapping(child, prefix=f"{full_key}.")
|
||||
# 对于嵌套的子节点,递归调用 _get_mapping
|
||||
if child.children:
|
||||
mapping[key] = _get_mapping(child)
|
||||
else:
|
||||
mapping[key] = (child.expected_type, Field(default=child.example, description=child.instruction))
|
||||
return mapping
|
||||
|
||||
_get_mapping(self)
|
||||
return mapping
|
||||
return _get_mapping(self)
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
def _get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
||||
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get key: type mapping under mode"""
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
return self.get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
return self._get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self._get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
@register_action_outcls
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
"""基于pydantic v2的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
|
|
@ -235,7 +248,17 @@ class ActionNode:
|
|||
|
||||
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
new_fields = {}
|
||||
for field_name, field_value in mapping.items():
|
||||
if isinstance(field_value, dict):
|
||||
# 对于嵌套结构,递归创建模型类
|
||||
nested_class_name = f"{class_name}_{field_name}"
|
||||
nested_class = cls.create_model_class(nested_class_name, field_value)
|
||||
new_fields[field_name] = (nested_class, ...)
|
||||
else:
|
||||
new_fields[field_name] = field_value
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **new_fields)
|
||||
return new_class
|
||||
|
||||
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
|
||||
|
|
@ -243,39 +266,48 @@ class ActionNode:
|
|||
mapping = self.get_mapping(mode=mode, exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
def _create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
mapping = self.get_children_mapping(exclude=exclude)
|
||||
mapping = self._get_children_mapping(exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
nodes = self._to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
if not isinstance(nodes, dict):
|
||||
nodes = {self.key: nodes}
|
||||
return nodes
|
||||
|
||||
# 如果没有提供格式化函数,使用默认的格式化方式
|
||||
def _to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
|
||||
# 如果没有提供格式化函数,则使用默认的格式化函数
|
||||
if format_func is None:
|
||||
format_func = lambda node: f"{node.instruction}"
|
||||
format_func = lambda node: node.instruction
|
||||
|
||||
# 使用提供的格式化函数来格式化当前节点的值
|
||||
formatted_value = format_func(self)
|
||||
|
||||
# 创建当前节点的键值对
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
node_dict = {}
|
||||
if (mode == "children" or mode == "auto") and self.children:
|
||||
node_value = {}
|
||||
else:
|
||||
node_dict = {self.key: formatted_value}
|
||||
node_value = formatted_value
|
||||
|
||||
if mode == "root":
|
||||
return node_dict
|
||||
return {self.key: node_value}
|
||||
|
||||
# 遍历子节点并递归调用 to_dict 方法
|
||||
# 递归处理子节点
|
||||
exclude = exclude or []
|
||||
for _, child_node in self.children.items():
|
||||
if child_node.key in exclude:
|
||||
for child_key, child_node in self.children.items():
|
||||
if child_key in exclude:
|
||||
continue
|
||||
node_dict.update(child_node.to_dict(format_func))
|
||||
# 递归调用 to_dict 方法并更新节点字典
|
||||
child_dict = child_node._to_dict(format_func, mode, exclude)
|
||||
node_value[child_key] = child_dict
|
||||
|
||||
return node_dict
|
||||
return node_value
|
||||
|
||||
def update_instruct_content(self, incre_data: dict[str, Any]):
|
||||
assert self.instruct_content
|
||||
|
|
@ -344,6 +376,17 @@ class ActionNode:
|
|||
if schema == "raw":
|
||||
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
|
||||
|
||||
### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON
|
||||
# child_class = self._create_children_class()
|
||||
# node_schema = child_class.model_json_schema()
|
||||
# defaults = {
|
||||
# k: str(v)
|
||||
# for k, v in child_class.model_fields.items()
|
||||
# if k not in exclude
|
||||
# }
|
||||
# instruction = node_schema
|
||||
# example = json.dumps(defaults, indent=4)
|
||||
|
||||
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
|
||||
# compile example暂时不支持markdown
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
|
||||
|
|
@ -454,7 +497,7 @@ class ActionNode:
|
|||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.model_dump())
|
||||
cls = self.create_children_class()
|
||||
cls = self._create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
|
|
@ -645,49 +688,19 @@ class ActionNode:
|
|||
ActionNode: The root node of the created ActionNode tree.
|
||||
"""
|
||||
key = key or model.__name__
|
||||
root_node = cls(key=model.__name__, expected_type=Type[model], instruction="", example="")
|
||||
root_node = cls(key=key, expected_type=Type[model], instruction="", example="")
|
||||
|
||||
for field_name, field_model in model.model_fields.items():
|
||||
# Extracting field details
|
||||
expected_type = field_model.annotation
|
||||
instruction = field_model.description or ""
|
||||
example = field_model.default
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
field_type = field_info.annotation
|
||||
description = field_info.description
|
||||
default = field_info.default
|
||||
|
||||
# Check if the field is a Pydantic model itself.
|
||||
# Use isinstance to avoid typing.List, typing.Dict, etc. (they are instances of type, not subclasses)
|
||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||
# Recursively process the nested model
|
||||
child_node = cls.from_pydantic(expected_type, key=field_name)
|
||||
# Recursively handle nested models if needed
|
||||
if not isinstance(field_type, typing._GenericAlias) and issubclass(field_type, BaseModel):
|
||||
child_node = cls.from_pydantic(field_type, key=field_name)
|
||||
else:
|
||||
child_node = cls(key=field_name, expected_type=expected_type, instruction=instruction, example=example)
|
||||
child_node = cls(key=field_name, expected_type=field_type, instruction=description, example=default)
|
||||
|
||||
root_node.add_child(child_node)
|
||||
|
||||
return root_node
|
||||
|
||||
|
||||
class ToolUse(BaseModel):
|
||||
tool_name: str = Field(default="a", description="tool name", examples=[])
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: int = Field(default="1", description="task id", examples=[1, 2, 3])
|
||||
name: str = Field(default="Get data from ...", description="task name", examples=[])
|
||||
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
|
||||
tool: ToolUse = Field(default=ToolUse(), description="tool use", examples=[])
|
||||
|
||||
|
||||
class Tasks(BaseModel):
|
||||
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
node = ActionNode.from_pydantic(Tasks)
|
||||
print("Tasks")
|
||||
print(Tasks.model_json_schema())
|
||||
print("Task")
|
||||
print(Task.model_json_schema())
|
||||
print(node)
|
||||
prompt = node.compile(context="")
|
||||
node.create_children_class()
|
||||
print(prompt)
|
||||
|
|
|
|||
|
|
@ -117,4 +117,4 @@ class WriteDesign(Action):
|
|||
|
||||
async def _save_mermaid_file(self, data: str, pathname: Path):
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(self.config.mermaid_engine, data, pathname)
|
||||
await mermaid_to_file(self.config.mermaid.engine, data, pathname)
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ Determine the ONE file to rewrite in order to fix the error, for example, xyz.py
|
|||
Determine if all of the code works fine, if so write PASS, else FAIL,
|
||||
WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION
|
||||
## Send To:
|
||||
Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors,
|
||||
WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
|
||||
Please write NoOne if there are no errors, Engineer if the errors are due to problematic development codes, else QaEngineer,
|
||||
WRITE ONLY ONE WORD, NoOne OR Engineer OR QaEngineer, IN THIS SECTION.
|
||||
---
|
||||
You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ Options:
|
|||
Default: 'google'
|
||||
|
||||
Example:
|
||||
python3 -m metagpt.actions.write_docstring ./metagpt/startup.py --overwrite False --style=numpy
|
||||
python3 -m metagpt.actions.write_docstring ./metagpt/software_company.py --overwrite False --style=numpy
|
||||
|
||||
This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using
|
||||
the specified docstring style and adds them to the code.
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class WritePRD(Action):
|
|||
return
|
||||
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
|
||||
await mermaid_to_file(self.config.mermaid.engine, quadrant_chart, pathname)
|
||||
|
||||
async def _rename_workspace(self, prd):
|
||||
if not self.project_name:
|
||||
|
|
|
|||
|
|
@ -67,24 +67,18 @@ class Config(CLIParams, YamlModel):
|
|||
code_review_k_times: int = 2
|
||||
|
||||
# Will be removed in the future
|
||||
llm_for_researcher_summary: str = "gpt3"
|
||||
llm_for_researcher_report: str = "gpt3"
|
||||
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
|
||||
metagpt_tti_url: str = ""
|
||||
language: str = "English"
|
||||
redis_key: str = "placeholder"
|
||||
mmdc: str = "mmdc"
|
||||
puppeteer_config: str = ""
|
||||
pyppeteer_executable_path: str = ""
|
||||
IFLYTEK_APP_ID: str = ""
|
||||
IFLYTEK_API_SECRET: str = ""
|
||||
IFLYTEK_API_KEY: str = ""
|
||||
AZURE_TTS_SUBSCRIPTION_KEY: str = ""
|
||||
AZURE_TTS_REGION: str = ""
|
||||
mermaid_engine: str = "nodejs"
|
||||
iflytek_app_id: str = ""
|
||||
iflytek_api_secret: str = ""
|
||||
iflytek_api_key: str = ""
|
||||
azure_tts_subscription_key: str = ""
|
||||
azure_tts_region: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
"""Load config from ~/.metagpt/config.yaml"""
|
||||
"""Load config from ~/.metagpt/config2.yaml"""
|
||||
pathname = CONFIG_ROOT / path
|
||||
if not pathname.exists():
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -74,5 +74,5 @@ class LLMConfig(YamlModel):
|
|||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
if v in ["", None, "YOUR_API_KEY"]:
|
||||
raise ValueError("Please set your API key in config.yaml")
|
||||
raise ValueError("Please set your API key in config2.yaml")
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -14,5 +14,6 @@ class MermaidConfig(YamlModel):
|
|||
"""Config for Mermaid"""
|
||||
|
||||
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
|
||||
path: str = ""
|
||||
puppeteer_config: str = "" # Only for nodejs engine
|
||||
path: str = "mmdc" # mmdc
|
||||
puppeteer_config: str = ""
|
||||
pyppeteer_path: str = "/usr/bin/google-chrome-stable"
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ async def google_search(query: str, max_results: int = 6, **kwargs):
|
|||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
results = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
results = await SearchEngine(**kwargs).run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(results, 1))
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ async def text_to_image(text, size_type: str = "512x512", config: Config = metag
|
|||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
|
||||
model_url = config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
model_url = config.metagpt_tti_url
|
||||
if model_url:
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif config.get_openai_llm():
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ async def text_to_speech(
|
|||
|
||||
"""
|
||||
|
||||
subscription_key = config.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
region = config.AZURE_TTS_REGION
|
||||
subscription_key = config.azure_tts_subscription_key
|
||||
region = config.azure_tts_region
|
||||
if subscription_key and region:
|
||||
audio_declaration = "data:audio/wav;base64,"
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
|
|
@ -50,9 +50,9 @@ async def text_to_speech(
|
|||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
iflytek_app_id = config.IFLYTEK_APP_ID
|
||||
iflytek_api_key = config.IFLYTEK_API_KEY
|
||||
iflytek_api_secret = config.IFLYTEK_API_SECRET
|
||||
iflytek_app_id = config.iflytek_app_id
|
||||
iflytek_api_key = config.iflytek_api_key
|
||||
iflytek_api_secret = config.iflytek_api_secret
|
||||
if iflytek_app_id and iflytek_api_key and iflytek_api_secret:
|
||||
audio_declaration = "data:audio/mp3;base64,"
|
||||
base64_data = await oas3_iflytek_tts(
|
||||
|
|
@ -65,5 +65,5 @@ async def text_to_speech(
|
|||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
raise ValueError(
|
||||
"AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error"
|
||||
"azure_tts_subscription_key, azure_tts_region, iflytek_app_id, iflytek_api_key, iflytek_api_secret error"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class BaseLLM(ABC):
|
|||
|
||||
def get_choice_delta_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of stream choice"""
|
||||
return rsp.get("choices")[0]["delta"]["content"]
|
||||
return rsp.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
|
||||
def get_choice_function(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function of choice
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
|
@ -9,24 +10,25 @@ import typer
|
|||
from metagpt.config2 import config
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
def generate_repo(
|
||||
idea,
|
||||
investment,
|
||||
n_round,
|
||||
code_review,
|
||||
run_tests,
|
||||
implement,
|
||||
project_name,
|
||||
inc,
|
||||
project_path,
|
||||
reqa_file,
|
||||
max_auto_summarize_code,
|
||||
recover_path,
|
||||
):
|
||||
investment=3.0,
|
||||
n_round=5,
|
||||
code_review=True,
|
||||
run_tests=False,
|
||||
implement=True,
|
||||
project_name="",
|
||||
inc=False,
|
||||
project_path="",
|
||||
reqa_file="",
|
||||
max_auto_summarize_code=0,
|
||||
recover_path=None,
|
||||
) -> ProjectRepo:
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
|
|
@ -67,6 +69,8 @@ def generate_repo(
|
|||
company.run_project(idea)
|
||||
asyncio.run(company.run(n_round=n_round))
|
||||
|
||||
return ctx.repo
|
||||
|
||||
|
||||
@app.command("", help="Start a new project.")
|
||||
def startup(
|
||||
20
metagpt/strategy/search_space.py
Normal file
20
metagpt/strategy/search_space.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 17:15
|
||||
@Author : alexanderwu
|
||||
@File : search_space.py
|
||||
"""
|
||||
|
||||
|
||||
class SearchSpace:
|
||||
"""SearchSpace: 用于定义一个搜索空间,搜索空间中的节点是 ActionNode 类。"""
|
||||
|
||||
def __init__(self):
|
||||
self.search_space = {}
|
||||
|
||||
def add_node(self, node):
|
||||
self.search_space[node.key] = node
|
||||
|
||||
def get_node(self, key):
|
||||
return self.search_space[key]
|
||||
77
metagpt/strategy/solver.py
Normal file
77
metagpt/strategy/solver.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 17:13
|
||||
@Author : alexanderwu
|
||||
@File : solver.py
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
|
||||
from metagpt.actions.action_graph import ActionGraph
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.strategy.search_space import SearchSpace
|
||||
|
||||
|
||||
class BaseSolver:
|
||||
"""AbstractSolver: defines the interface of a solver."""
|
||||
|
||||
def __init__(self, graph: ActionGraph, search_space: SearchSpace, llm: BaseLLM, context):
|
||||
"""
|
||||
:param graph: ActionGraph
|
||||
:param search_space: SearchSpace
|
||||
:param llm: BaseLLM
|
||||
:param context: Context
|
||||
"""
|
||||
self.graph = graph
|
||||
self.search_space = search_space
|
||||
self.llm = llm
|
||||
self.context = context
|
||||
|
||||
@abstractmethod
|
||||
async def solve(self):
|
||||
"""abstract method to solve the problem."""
|
||||
|
||||
|
||||
class NaiveSolver(BaseSolver):
|
||||
"""NaiveSolver: Iterate all the nodes in the graph and execute them one by one."""
|
||||
|
||||
async def solve(self):
|
||||
self.graph.topological_sort()
|
||||
for key in self.graph.execution_order:
|
||||
op = self.graph.nodes[key]
|
||||
await op.fill(self.context, self.llm, mode="root")
|
||||
|
||||
|
||||
class TOTSolver(BaseSolver):
|
||||
"""TOTSolver: Tree of Thought"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeInterpreterSolver(BaseSolver):
|
||||
"""CodeInterpreterSolver: Write&Run code in the graph"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReActSolver(BaseSolver):
|
||||
"""ReActSolver: ReAct algorithm"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IOSolver(BaseSolver):
|
||||
"""IOSolver: use LLM directly to solve the problem"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class COTSolver(BaseSolver):
|
||||
"""COTSolver: Chain of Thought"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
|
@ -61,9 +61,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -55,9 +55,11 @@ class SerperWrapper(BaseModel):
|
|||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=payloads, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
# tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
|
||||
if engine == "nodejs":
|
||||
if check_cmd_exists(config.mmdc) != 0:
|
||||
if check_cmd_exists(config.mermaid.path) != 0:
|
||||
logger.warning(
|
||||
"RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc,"
|
||||
"or consider changing MERMAID_ENGINE to `playwright`, `pyppeteer`, or `ink`."
|
||||
"or consider changing engine to `playwright`, `pyppeteer`, or `ink`."
|
||||
)
|
||||
return -1
|
||||
|
||||
|
|
@ -47,11 +47,11 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
# Call the `mmdc` command to convert the Mermaid code to a PNG
|
||||
logger.info(f"Generating {output_file}..")
|
||||
|
||||
if config.puppeteer_config:
|
||||
if config.mermaid.puppeteer_config:
|
||||
commands = [
|
||||
config.mmdc,
|
||||
config.mermaid.path,
|
||||
"-p",
|
||||
config.puppeteer_config,
|
||||
config.mermaid.puppeteer_config,
|
||||
"-i",
|
||||
str(tmp),
|
||||
"-o",
|
||||
|
|
@ -62,7 +62,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
str(height),
|
||||
]
|
||||
else:
|
||||
commands = [config.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]
|
||||
commands = [config.mermaid.path, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
" ".join(commands), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,14 +30,14 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if config.pyppeteer_executable_path:
|
||||
if config.mermaid.pyppeteer_path:
|
||||
browser = await launch(
|
||||
headless=True,
|
||||
executablePath=config.pyppeteer_executable_path,
|
||||
executablePath=config.mermaid.pyppeteer_path,
|
||||
args=["--disable-extensions", "--no-sandbox"],
|
||||
)
|
||||
else:
|
||||
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
|
||||
logger.error("Please set the var mermaid.pyppeteer_path in the config2.yaml.")
|
||||
return -1
|
||||
page = await browser.newPage()
|
||||
device_scale_factor = 1.0
|
||||
|
|
|
|||
|
|
@ -102,6 +102,13 @@ class ProjectRepo(FileRepository):
|
|||
self.tests = self._git_repo.new_file_repository(relative_path=TEST_CODES_FILE_REPO)
|
||||
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
|
||||
self._srcs_path = None
|
||||
self.code_files_exists()
|
||||
|
||||
def __str__(self):
|
||||
repo_str = f"ProjectRepo({self._git_repo.workdir})"
|
||||
docs_str = f"Docs({self.docs.all_files})"
|
||||
srcs_str = f"Srcs({self.srcs.all_files})"
|
||||
return f"{repo_str}\n{docs_str}\n{srcs_str}"
|
||||
|
||||
@property
|
||||
async def requirement(self):
|
||||
|
|
|
|||
|
|
@ -119,15 +119,22 @@ def repair_json_format(output: str) -> str:
|
|||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove `#` in output json str, usually appeared in `glm-4`
|
||||
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
for line in arr:
|
||||
idx = line.find("#")
|
||||
if idx >= 0:
|
||||
line = line[:idx]
|
||||
new_arr.append(line)
|
||||
for json_line in arr:
|
||||
# look for # or // comments and make sure they are not inside the string value
|
||||
comment_index = -1
|
||||
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", json_line):
|
||||
if match.group(1): # if the string value
|
||||
continue
|
||||
if match.group(2): # if comments
|
||||
comment_index = match.start(2)
|
||||
break
|
||||
# if comments, then delete them
|
||||
if comment_index != -1:
|
||||
json_line = json_line[:comment_index].rstrip()
|
||||
new_arr.append(json_line)
|
||||
output = "\n".join(new_arr)
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class YamlModelWithoutDefault(YamlModel):
|
|||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_not_default_config(cls, values):
|
||||
"""Check if there is any default config in config.yaml"""
|
||||
"""Check if there is any default config in config2.yaml"""
|
||||
if any(["YOUR" in v for v in values]):
|
||||
raise ValueError("Please set your config in config.yaml")
|
||||
raise ValueError("Please set your config in config2.yaml")
|
||||
return values
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue