Merge branch 'dev' into code_intepreter

This commit is contained in:
yzlin 2024-01-31 18:14:22 +08:00
commit f8d69ed01b
15 changed files with 494 additions and 94 deletions

View file

@ -8,7 +8,7 @@
from typing import List, Tuple
import pytest
from pydantic import ValidationError
from pydantic import BaseModel, Field, ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
@ -241,6 +241,47 @@ def test_create_model_class_with_mapping():
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
class ToolDef(BaseModel):
tool_name: str = Field(default="a", description="tool name", examples=[])
description: str = Field(default="b", description="tool description", examples=[])
class Task(BaseModel):
task_id: int = Field(default=1, description="task id", examples=[1, 2, 3])
name: str = Field(default="Get data from ...", description="task name", examples=[])
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
tool: ToolDef = Field(default=ToolDef(), description="tool use", examples=[])
class Tasks(BaseModel):
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
def test_action_node_from_pydantic_and_print_everything():
node = ActionNode.from_pydantic(Task)
print("1. Tasks")
print(Task().model_dump_json(indent=4))
print(Tasks.model_json_schema())
print("2. Task")
print(Task.model_json_schema())
print("3. ActionNode")
print(node)
print("4. node.compile prompt")
prompt = node.compile(context="")
assert "tool_name" in prompt, "tool_name should be in prompt"
print(prompt)
print("5. node.get_children_mapping")
print(node._get_children_mapping())
print("6. node.create_children_class")
children_class = node._create_children_class()
print(children_class)
import inspect
code = inspect.getsource(Tasks)
print(code)
assert "tasks" in code, "tasks should be in code"
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,27 +1,21 @@
import asyncio
import pytest
from pydantic import BaseModel
from metagpt.learn.google_search import google_search
from metagpt.tools import SearchEngineType
async def mock_google_search():
@pytest.mark.asyncio
async def test_google_search(search_engine_mocker):
class Input(BaseModel):
input: str
inputs = [{"input": "ai agent"}]
for i in inputs:
seed = Input(**i)
result = await google_search(seed.input)
result = await google_search(
seed.input,
engine=SearchEngineType.SERPER_GOOGLE,
serper_api_key="mock-serper-key",
)
assert result != ""
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_google_search())
loop.run_until_complete(task)
if __name__ == "__main__":
test_suite()

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/31 13:54
@Author : alexanderwu
@File : test_solver.py
"""
import pytest
from metagpt.actions.action_graph import ActionGraph
from metagpt.llm import LLM
from metagpt.strategy.search_space import SearchSpace
from metagpt.strategy.solver import NaiveSolver
@pytest.mark.asyncio
async def test_solver():
from metagpt.actions.write_prd_an import (
COMPETITIVE_ANALYSIS,
ISSUE_TYPE,
PRODUCT_GOALS,
REQUIREMENT_POOL,
)
graph = ActionGraph()
graph.add_node(ISSUE_TYPE)
graph.add_node(PRODUCT_GOALS)
graph.add_node(COMPETITIVE_ANALYSIS)
graph.add_node(REQUIREMENT_POOL)
graph.add_edge(ISSUE_TYPE, PRODUCT_GOALS)
graph.add_edge(PRODUCT_GOALS, COMPETITIVE_ANALYSIS)
graph.add_edge(PRODUCT_GOALS, REQUIREMENT_POOL)
graph.add_edge(COMPETITIVE_ANALYSIS, REQUIREMENT_POOL)
search_space = SearchSpace()
llm = LLM()
context = "Create a 2048 game"
solver = NaiveSolver(graph, search_space, llm, context)
await solver.solve()
print("## graph.nodes")
print(graph.nodes)
for k, v in graph.nodes.items():
print(f"{v.key} | prevs: {[i.key for i in v.prevs]} | nexts: {[i.key for i in v.nexts]}")
assert len(graph.nodes) == 4
assert len(graph.execution_order) == 4
assert graph.execution_order == [ISSUE_TYPE.key, PRODUCT_GOALS.key, COMPETITIVE_ANALYSIS.key, REQUIREMENT_POOL.key]