Merge branch 'dev' into code_intepreter

This commit is contained in:
yzlin 2024-01-31 18:14:22 +08:00
commit f8d69ed01b
15 changed files with 494 additions and 94 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -8,7 +8,7 @@
from typing import List, Tuple
import pytest
from pydantic import ValidationError
from pydantic import BaseModel, Field, ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
@ -241,6 +241,47 @@ def test_create_model_class_with_mapping():
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
class ToolDef(BaseModel):
tool_name: str = Field(default="a", description="tool name", examples=[])
description: str = Field(default="b", description="tool description", examples=[])
class Task(BaseModel):
task_id: int = Field(default=1, description="task id", examples=[1, 2, 3])
name: str = Field(default="Get data from ...", description="task name", examples=[])
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
tool: ToolDef = Field(default=ToolDef(), description="tool use", examples=[])
class Tasks(BaseModel):
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
def test_action_node_from_pydantic_and_print_everything():
node = ActionNode.from_pydantic(Task)
print("1. Tasks")
print(Task().model_dump_json(indent=4))
print(Tasks.model_json_schema())
print("2. Task")
print(Task.model_json_schema())
print("3. ActionNode")
print(node)
print("4. node.compile prompt")
prompt = node.compile(context="")
assert "tool_name" in prompt, "tool_name should be in prompt"
print(prompt)
print("5. node.get_children_mapping")
print(node._get_children_mapping())
print("6. node.create_children_class")
children_class = node._create_children_class()
print(children_class)
import inspect
code = inspect.getsource(Tasks)
print(code)
assert "tasks" in code, "tasks should be in code"
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,27 +1,21 @@
import asyncio
import pytest
from pydantic import BaseModel
from metagpt.learn.google_search import google_search
from metagpt.tools import SearchEngineType
async def mock_google_search():
@pytest.mark.asyncio
async def test_google_search(search_engine_mocker):
class Input(BaseModel):
input: str
inputs = [{"input": "ai agent"}]
for i in inputs:
seed = Input(**i)
result = await google_search(seed.input)
result = await google_search(
seed.input,
engine=SearchEngineType.SERPER_GOOGLE,
serper_api_key="mock-serper-key",
)
assert result != ""
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_google_search())
loop.run_until_complete(task)
if __name__ == "__main__":
test_suite()

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/31 13:54
@Author : alexanderwu
@File : test_solver.py
"""
import pytest
from metagpt.actions.action_graph import ActionGraph
from metagpt.llm import LLM
from metagpt.strategy.search_space import SearchSpace
from metagpt.strategy.solver import NaiveSolver
@pytest.mark.asyncio
async def test_solver():
from metagpt.actions.write_prd_an import (
COMPETITIVE_ANALYSIS,
ISSUE_TYPE,
PRODUCT_GOALS,
REQUIREMENT_POOL,
)
graph = ActionGraph()
graph.add_node(ISSUE_TYPE)
graph.add_node(PRODUCT_GOALS)
graph.add_node(COMPETITIVE_ANALYSIS)
graph.add_node(REQUIREMENT_POOL)
graph.add_edge(ISSUE_TYPE, PRODUCT_GOALS)
graph.add_edge(PRODUCT_GOALS, COMPETITIVE_ANALYSIS)
graph.add_edge(PRODUCT_GOALS, REQUIREMENT_POOL)
graph.add_edge(COMPETITIVE_ANALYSIS, REQUIREMENT_POOL)
search_space = SearchSpace()
llm = LLM()
context = "Create a 2048 game"
solver = NaiveSolver(graph, search_space, llm, context)
await solver.solve()
print("## graph.nodes")
print(graph.nodes)
for k, v in graph.nodes.items():
print(f"{v.key} | prevs: {[i.key for i in v.prevs]} | nexts: {[i.key for i in v.nexts]}")
assert len(graph.nodes) == 4
assert len(graph.execution_order) == 4
assert graph.execution_order == [ISSUE_TYPE.key, PRODUCT_GOALS.key, COMPETITIVE_ANALYSIS.key, REQUIREMENT_POOL.key]

View file

@ -39,3 +39,7 @@ class MockAioResponse:
data = await self.response.json(*args, **kwargs)
self.rsp_cache[self.key] = data
return data
def raise_for_status(self):
if self.response:
self.response.raise_for_status()