mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 02:23:52 +02:00
Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev
This commit is contained in:
commit
e0191c9b3d
16 changed files with 499 additions and 95 deletions
|
|
@ -1,3 +1,3 @@
|
|||
llm:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
|
|
@ -2,7 +2,7 @@ llm:
|
|||
api_type: "openai"
|
||||
base_url: "YOUR_BASE_URL"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106" # or gpt-4-1106-preview
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
|
||||
proxy: "YOUR_PROXY"
|
||||
|
||||
|
|
|
|||
49
metagpt/actions/action_graph.py
Normal file
49
metagpt/actions/action_graph.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 13:52
|
||||
@Author : alexanderwu
|
||||
@File : action_graph.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# from metagpt.actions.action_node import ActionNode
|
||||
|
||||
|
||||
class ActionGraph:
|
||||
"""ActionGraph: a directed graph to represent the dependency between actions."""
|
||||
|
||||
def __init__(self):
|
||||
self.nodes = {}
|
||||
self.edges = {}
|
||||
self.execution_order = []
|
||||
|
||||
def add_node(self, node):
|
||||
"""Add a node to the graph"""
|
||||
self.nodes[node.key] = node
|
||||
|
||||
def add_edge(self, from_node: "ActionNode", to_node: "ActionNode"):
|
||||
"""Add an edge to the graph"""
|
||||
if from_node.key not in self.edges:
|
||||
self.edges[from_node.key] = []
|
||||
self.edges[from_node.key].append(to_node.key)
|
||||
from_node.add_next(to_node)
|
||||
to_node.add_prev(from_node)
|
||||
|
||||
def topological_sort(self):
|
||||
"""Topological sort the graph"""
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
def visit(k):
|
||||
if k not in visited:
|
||||
visited.add(k)
|
||||
if k in self.edges:
|
||||
for next_node in self.edges[k]:
|
||||
visit(next_node)
|
||||
stack.insert(0, k)
|
||||
|
||||
for key in self.nodes:
|
||||
visit(key)
|
||||
|
||||
self.execution_order = stack
|
||||
|
|
@ -9,6 +9,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
|
||||
"""
|
||||
import json
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
|
|
@ -39,7 +40,6 @@ TAG = "CONTENT"
|
|||
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
|
||||
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
|
||||
|
||||
|
||||
SIMPLE_TEMPLATE = """
|
||||
## context
|
||||
{context}
|
||||
|
|
@ -131,6 +131,8 @@ class ActionNode:
|
|||
|
||||
# Action Input
|
||||
key: str # Product Requirement / File list / Code
|
||||
func: typing.Callable # 与节点相关联的函数或LLM调用
|
||||
params: Dict[str, Type] # 输入参数的字典,键为参数名,值为参数类型
|
||||
expected_type: Type # such as str / int / float etc.
|
||||
# context: str # everything in the history.
|
||||
instruction: str # the instructions should be followed.
|
||||
|
|
@ -140,6 +142,10 @@ class ActionNode:
|
|||
content: str
|
||||
instruct_content: BaseModel
|
||||
|
||||
# For ActionGraph
|
||||
prevs: List["ActionNode"] # previous nodes
|
||||
nexts: List["ActionNode"] # next nodes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
|
@ -157,6 +163,8 @@ class ActionNode:
|
|||
self.content = content
|
||||
self.children = children if children is not None else {}
|
||||
self.schema = schema
|
||||
self.prevs = []
|
||||
self.nexts = []
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
|
|
@ -167,6 +175,14 @@ class ActionNode:
|
|||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def add_prev(self, node: "ActionNode"):
|
||||
"""增加前置ActionNode"""
|
||||
self.prevs.append(node)
|
||||
|
||||
def add_next(self, node: "ActionNode"):
|
||||
"""增加后置ActionNode"""
|
||||
self.nexts.append(node)
|
||||
|
||||
def add_child(self, node: "ActionNode"):
|
||||
"""增加子ActionNode"""
|
||||
self.children[node.key] = node
|
||||
|
|
@ -186,41 +202,38 @@ class ActionNode:
|
|||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping_old(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
def _get_children_mapping(self, exclude=None) -> Dict[str, Any]:
|
||||
"""获得子ActionNode的字典,以key索引,支持多级结构。"""
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引,支持多级结构"""
|
||||
exclude = exclude or []
|
||||
mapping = {}
|
||||
|
||||
def _get_mapping(node: "ActionNode", prefix: str = ""):
|
||||
def _get_mapping(node: "ActionNode") -> Dict[str, Any]:
|
||||
mapping = {}
|
||||
for key, child in node.children.items():
|
||||
if key in exclude:
|
||||
continue
|
||||
full_key = f"{prefix}{key}"
|
||||
mapping[full_key] = (child.expected_type, ...)
|
||||
_get_mapping(child, prefix=f"{full_key}.")
|
||||
# 对于嵌套的子节点,递归调用 _get_mapping
|
||||
if child.children:
|
||||
mapping[key] = _get_mapping(child)
|
||||
else:
|
||||
mapping[key] = (child.expected_type, Field(default=child.example, description=child.instruction))
|
||||
return mapping
|
||||
|
||||
_get_mapping(self)
|
||||
return mapping
|
||||
return _get_mapping(self)
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
def _get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
||||
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get key: type mapping under mode"""
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
return self.get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
return self._get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self._get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
@register_action_outcls
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
"""基于pydantic v2的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
|
|
@ -235,7 +248,17 @@ class ActionNode:
|
|||
|
||||
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
new_fields = {}
|
||||
for field_name, field_value in mapping.items():
|
||||
if isinstance(field_value, dict):
|
||||
# 对于嵌套结构,递归创建模型类
|
||||
nested_class_name = f"{class_name}_{field_name}"
|
||||
nested_class = cls.create_model_class(nested_class_name, field_value)
|
||||
new_fields[field_name] = (nested_class, ...)
|
||||
else:
|
||||
new_fields[field_name] = field_value
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **new_fields)
|
||||
return new_class
|
||||
|
||||
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
|
||||
|
|
@ -243,39 +266,48 @@ class ActionNode:
|
|||
mapping = self.get_mapping(mode=mode, exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
def _create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
mapping = self.get_children_mapping(exclude=exclude)
|
||||
mapping = self._get_children_mapping(exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
nodes = self._to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
if not isinstance(nodes, dict):
|
||||
nodes = {self.key: nodes}
|
||||
return nodes
|
||||
|
||||
# 如果没有提供格式化函数,使用默认的格式化方式
|
||||
def _to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
|
||||
# 如果没有提供格式化函数,则使用默认的格式化函数
|
||||
if format_func is None:
|
||||
format_func = lambda node: f"{node.instruction}"
|
||||
format_func = lambda node: node.instruction
|
||||
|
||||
# 使用提供的格式化函数来格式化当前节点的值
|
||||
formatted_value = format_func(self)
|
||||
|
||||
# 创建当前节点的键值对
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
node_dict = {}
|
||||
if (mode == "children" or mode == "auto") and self.children:
|
||||
node_value = {}
|
||||
else:
|
||||
node_dict = {self.key: formatted_value}
|
||||
node_value = formatted_value
|
||||
|
||||
if mode == "root":
|
||||
return node_dict
|
||||
return {self.key: node_value}
|
||||
|
||||
# 遍历子节点并递归调用 to_dict 方法
|
||||
# 递归处理子节点
|
||||
exclude = exclude or []
|
||||
for _, child_node in self.children.items():
|
||||
if child_node.key in exclude:
|
||||
for child_key, child_node in self.children.items():
|
||||
if child_key in exclude:
|
||||
continue
|
||||
node_dict.update(child_node.to_dict(format_func))
|
||||
# 递归调用 to_dict 方法并更新节点字典
|
||||
child_dict = child_node._to_dict(format_func, mode, exclude)
|
||||
node_value[child_key] = child_dict
|
||||
|
||||
return node_dict
|
||||
return node_value
|
||||
|
||||
def update_instruct_content(self, incre_data: dict[str, Any]):
|
||||
assert self.instruct_content
|
||||
|
|
@ -344,6 +376,17 @@ class ActionNode:
|
|||
if schema == "raw":
|
||||
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
|
||||
|
||||
### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON
|
||||
# child_class = self._create_children_class()
|
||||
# node_schema = child_class.model_json_schema()
|
||||
# defaults = {
|
||||
# k: str(v)
|
||||
# for k, v in child_class.model_fields.items()
|
||||
# if k not in exclude
|
||||
# }
|
||||
# instruction = node_schema
|
||||
# example = json.dumps(defaults, indent=4)
|
||||
|
||||
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
|
||||
# compile example暂时不支持markdown
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
|
||||
|
|
@ -454,7 +497,7 @@ class ActionNode:
|
|||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.model_dump())
|
||||
cls = self.create_children_class()
|
||||
cls = self._create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
|
|
@ -645,49 +688,19 @@ class ActionNode:
|
|||
ActionNode: The root node of the created ActionNode tree.
|
||||
"""
|
||||
key = key or model.__name__
|
||||
root_node = cls(key=model.__name__, expected_type=Type[model], instruction="", example="")
|
||||
root_node = cls(key=key, expected_type=Type[model], instruction="", example="")
|
||||
|
||||
for field_name, field_model in model.model_fields.items():
|
||||
# Extracting field details
|
||||
expected_type = field_model.annotation
|
||||
instruction = field_model.description or ""
|
||||
example = field_model.default
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
field_type = field_info.annotation
|
||||
description = field_info.description
|
||||
default = field_info.default
|
||||
|
||||
# Check if the field is a Pydantic model itself.
|
||||
# Use isinstance to avoid typing.List, typing.Dict, etc. (they are instances of type, not subclasses)
|
||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||
# Recursively process the nested model
|
||||
child_node = cls.from_pydantic(expected_type, key=field_name)
|
||||
# Recursively handle nested models if needed
|
||||
if not isinstance(field_type, typing._GenericAlias) and issubclass(field_type, BaseModel):
|
||||
child_node = cls.from_pydantic(field_type, key=field_name)
|
||||
else:
|
||||
child_node = cls(key=field_name, expected_type=expected_type, instruction=instruction, example=example)
|
||||
child_node = cls(key=field_name, expected_type=field_type, instruction=description, example=default)
|
||||
|
||||
root_node.add_child(child_node)
|
||||
|
||||
return root_node
|
||||
|
||||
|
||||
class ToolUse(BaseModel):
|
||||
tool_name: str = Field(default="a", description="tool name", examples=[])
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: int = Field(default="1", description="task id", examples=[1, 2, 3])
|
||||
name: str = Field(default="Get data from ...", description="task name", examples=[])
|
||||
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
|
||||
tool: ToolUse = Field(default=ToolUse(), description="tool use", examples=[])
|
||||
|
||||
|
||||
class Tasks(BaseModel):
|
||||
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
node = ActionNode.from_pydantic(Tasks)
|
||||
print("Tasks")
|
||||
print(Tasks.model_json_schema())
|
||||
print("Task")
|
||||
print(Task.model_json_schema())
|
||||
print(node)
|
||||
prompt = node.compile(context="")
|
||||
node.create_children_class()
|
||||
print(prompt)
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ async def google_search(query: str, max_results: int = 6, **kwargs):
|
|||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
results = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
results = await SearchEngine(**kwargs).run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(results, 1))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
|
@ -9,6 +10,7 @@ import typer
|
|||
from metagpt.config2 import config
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
||||
|
|
@ -26,7 +28,7 @@ def generate_repo(
|
|||
reqa_file,
|
||||
max_auto_summarize_code,
|
||||
recover_path,
|
||||
):
|
||||
) -> ProjectRepo:
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
|
|
@ -67,6 +69,8 @@ def generate_repo(
|
|||
company.run_project(idea)
|
||||
asyncio.run(company.run(n_round=n_round))
|
||||
|
||||
return ctx.repo
|
||||
|
||||
|
||||
@app.command("", help="Start a new project.")
|
||||
def startup(
|
||||
|
|
|
|||
20
metagpt/strategy/search_space.py
Normal file
20
metagpt/strategy/search_space.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 17:15
|
||||
@Author : alexanderwu
|
||||
@File : search_space.py
|
||||
"""
|
||||
|
||||
|
||||
class SearchSpace:
|
||||
"""SearchSpace: 用于定义一个搜索空间,搜索空间中的节点是 ActionNode 类。"""
|
||||
|
||||
def __init__(self):
|
||||
self.search_space = {}
|
||||
|
||||
def add_node(self, node):
|
||||
self.search_space[node.key] = node
|
||||
|
||||
def get_node(self, key):
|
||||
return self.search_space[key]
|
||||
77
metagpt/strategy/solver.py
Normal file
77
metagpt/strategy/solver.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/30 17:13
|
||||
@Author : alexanderwu
|
||||
@File : solver.py
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
|
||||
from metagpt.actions.action_graph import ActionGraph
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.strategy.search_space import SearchSpace
|
||||
|
||||
|
||||
class BaseSolver:
|
||||
"""AbstractSolver: defines the interface of a solver."""
|
||||
|
||||
def __init__(self, graph: ActionGraph, search_space: SearchSpace, llm: BaseLLM, context):
|
||||
"""
|
||||
:param graph: ActionGraph
|
||||
:param search_space: SearchSpace
|
||||
:param llm: BaseLLM
|
||||
:param context: Context
|
||||
"""
|
||||
self.graph = graph
|
||||
self.search_space = search_space
|
||||
self.llm = llm
|
||||
self.context = context
|
||||
|
||||
@abstractmethod
|
||||
async def solve(self):
|
||||
"""abstract method to solve the problem."""
|
||||
|
||||
|
||||
class NaiveSolver(BaseSolver):
|
||||
"""NaiveSolver: Iterate all the nodes in the graph and execute them one by one."""
|
||||
|
||||
async def solve(self):
|
||||
self.graph.topological_sort()
|
||||
for key in self.graph.execution_order:
|
||||
op = self.graph.nodes[key]
|
||||
await op.fill(self.context, self.llm, mode="root")
|
||||
|
||||
|
||||
class TOTSolver(BaseSolver):
|
||||
"""TOTSolver: Tree of Thought"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeInterpreterSolver(BaseSolver):
|
||||
"""CodeInterpreterSolver: Write&Run code in the graph"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReActSolver(BaseSolver):
|
||||
"""ReActSolver: ReAct algorithm"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IOSolver(BaseSolver):
|
||||
"""IOSolver: use LLM directly to solve the problem"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class COTSolver(BaseSolver):
|
||||
"""COTSolver: Chain of Thought"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
|
@ -61,9 +61,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -55,9 +55,11 @@ class SerperWrapper(BaseModel):
|
|||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=payloads, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
return res
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -8,7 +8,7 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
|
||||
|
|
@ -241,6 +241,47 @@ def test_create_model_class_with_mapping():
|
|||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
class ToolDef(BaseModel):
|
||||
tool_name: str = Field(default="a", description="tool name", examples=[])
|
||||
description: str = Field(default="b", description="tool description", examples=[])
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: int = Field(default=1, description="task id", examples=[1, 2, 3])
|
||||
name: str = Field(default="Get data from ...", description="task name", examples=[])
|
||||
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
|
||||
tool: ToolDef = Field(default=ToolDef(), description="tool use", examples=[])
|
||||
|
||||
|
||||
class Tasks(BaseModel):
|
||||
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
|
||||
|
||||
|
||||
def test_action_node_from_pydantic_and_print_everything():
|
||||
node = ActionNode.from_pydantic(Task)
|
||||
print("1. Tasks")
|
||||
print(Task().model_dump_json(indent=4))
|
||||
print(Tasks.model_json_schema())
|
||||
print("2. Task")
|
||||
print(Task.model_json_schema())
|
||||
print("3. ActionNode")
|
||||
print(node)
|
||||
print("4. node.compile prompt")
|
||||
prompt = node.compile(context="")
|
||||
assert "tool_name" in prompt, "tool_name should be in prompt"
|
||||
print(prompt)
|
||||
print("5. node.get_children_mapping")
|
||||
print(node._get_children_mapping())
|
||||
print("6. node.create_children_class")
|
||||
children_class = node._create_children_class()
|
||||
print(children_class)
|
||||
import inspect
|
||||
|
||||
code = inspect.getsource(Tasks)
|
||||
print(code)
|
||||
assert "tasks" in code, "tasks should be in code"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -1,27 +1,21 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.google_search import google_search
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
async def mock_google_search():
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_search(search_engine_mocker):
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "ai agent"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = await google_search(seed.input)
|
||||
result = await google_search(
|
||||
seed.input,
|
||||
engine=SearchEngineType.SERPER_GOOGLE,
|
||||
serper_api_key="mock-serper-key",
|
||||
)
|
||||
assert result != ""
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_google_search())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
|
|
|
|||
47
tests/metagpt/strategy/test_solver.py
Normal file
47
tests/metagpt/strategy/test_solver.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/31 13:54
|
||||
@Author : alexanderwu
|
||||
@File : test_solver.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action_graph import ActionGraph
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.strategy.search_space import SearchSpace
|
||||
from metagpt.strategy.solver import NaiveSolver
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_solver():
|
||||
from metagpt.actions.write_prd_an import (
|
||||
COMPETITIVE_ANALYSIS,
|
||||
ISSUE_TYPE,
|
||||
PRODUCT_GOALS,
|
||||
REQUIREMENT_POOL,
|
||||
)
|
||||
|
||||
graph = ActionGraph()
|
||||
graph.add_node(ISSUE_TYPE)
|
||||
graph.add_node(PRODUCT_GOALS)
|
||||
graph.add_node(COMPETITIVE_ANALYSIS)
|
||||
graph.add_node(REQUIREMENT_POOL)
|
||||
graph.add_edge(ISSUE_TYPE, PRODUCT_GOALS)
|
||||
graph.add_edge(PRODUCT_GOALS, COMPETITIVE_ANALYSIS)
|
||||
graph.add_edge(PRODUCT_GOALS, REQUIREMENT_POOL)
|
||||
graph.add_edge(COMPETITIVE_ANALYSIS, REQUIREMENT_POOL)
|
||||
search_space = SearchSpace()
|
||||
llm = LLM()
|
||||
context = "Create a 2048 game"
|
||||
solver = NaiveSolver(graph, search_space, llm, context)
|
||||
await solver.solve()
|
||||
|
||||
print("## graph.nodes")
|
||||
print(graph.nodes)
|
||||
for k, v in graph.nodes.items():
|
||||
print(f"{v.key} | prevs: {[i.key for i in v.prevs]} | nexts: {[i.key for i in v.nexts]}")
|
||||
|
||||
assert len(graph.nodes) == 4
|
||||
assert len(graph.execution_order) == 4
|
||||
assert graph.execution_order == [ISSUE_TYPE.key, PRODUCT_GOALS.key, COMPETITIVE_ANALYSIS.key, REQUIREMENT_POOL.key]
|
||||
|
|
@ -39,3 +39,7 @@ class MockAioResponse:
|
|||
data = await self.response.json(*args, **kwargs)
|
||||
self.rsp_cache[self.key] = data
|
||||
return data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.response:
|
||||
self.response.raise_for_status()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue