feat: merge feature/intent_detect

This commit is contained in:
莘权 马 2024-03-30 14:14:01 +08:00
commit de77fccc9a
38 changed files with 1748 additions and 103 deletions

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
import pytest
from metagpt.actions.extract_readme import ExtractReadMe
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_learn_readme(context):
action = ExtractReadMe(
name="RedBean",
i_context=str(Path(__file__).parent.parent.parent.parent),
llm=LLM(),
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.actions.import_repo import ImportRepo
from metagpt.context import Context
from metagpt.utils.common import list_files
@pytest.mark.asyncio
@pytest.mark.parametrize(
"repo_path",
[
"https://github.com/spec-first/connexion.git",
# "https://github.com/geekan/MetaGPT.git"
],
)
@pytest.mark.skip
async def test_import_repo(repo_path):
context = Context()
action = ImportRepo(repo_path=repo_path, context=context)
await action.run()
assert context.repo
prd = list_files(context.repo.docs.prd.workdir)
assert prd
design = list_files(context.repo.docs.system_design.workdir)
assert design
assert prd[0].stem == design[0].stem
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import pytest
from metagpt.actions.intent_detect import IntentDetect, LightIntentDetect
from metagpt.logs import logger
from metagpt.schema import Message
DEMO_CONTENT = [
{
"role": "user",
"content": "Can you build TextToSummarize which is a SMS number that I can text and it will scrape a website "
"and summarize it with ChatGPT",
},
{
"role": "assistant",
"content": "Absolutely, I can build a service like TextToSummarize for you. The process will involve setting up"
" an SMS service that can receive your texts, scraping the website content you send, and then using"
" an AI to summarize the content.\nTo get started, I'll need to set up a number for receiving SMS"
", then I'll work on the script to scrape the website content from the URLs you text, and finally."
" I'll integrate it with an AI service to Generalte the summaries.\n I'll keep you updated on my"
" progress!",
},
{"role": "user", "content": "What exactly do we need the web app for?"},
{
"role": "assistant",
"content": "The web app will service as the interface between the SMS service and the AI summarization service"
". When you send a text with a URL to the SMS number, teh web app will process that, scrape the"
" content from the URL, and then use the AI to summarize it. The summary will then be sent back to"
" you. It's the central hub that ties all the components of the service togather.",
},
]
DEMO1_CONTENT = [
{
"role": "user",
"content": "Extract all of the blog posts from `https://stripe.com/blog/page/1` and return a CSV with the"
" columns `date`, `article_text`, `author` and `summary`. Generate a summary for each article"
" yourself.",
}
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT)],
)
async def test_intent_detect(content: str, context):
action = IntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
rsp = await action.run(messages)
assert isinstance(rsp, Message)
assert action._dialog_intentions
assert action._references
assert action._intent_to_sops
assert action.result
logger.info(action.result.model_dump_json())
@pytest.mark.asyncio
@pytest.mark.parametrize(
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT)],
)
async def test_light_intent_detect(content: str, context):
action = LightIntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
rsp = await action.run(messages)
assert isinstance(rsp, Message)
assert action._dialog_intentions
assert action._intent_to_sops
assert action.result
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -61,7 +61,7 @@ async def test_rebuild(context, mocker):
],
)
def test_get_full_filename(root, pathname, want):
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname)
assert res == want

View file

@ -8,6 +8,7 @@ from pathlib import Path
import pytest
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, ProjectManager
from metagpt.team import Team
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
await new_company.run(n_round=4)
@pytest.mark.asyncio
async def test_context(context):
context.kwargs.set("a", "a")
context.cost_manager.max_budget = 9
company = Team(context=context)
save_to = context.repo.workdir / "serial"
company.serialize(save_to)
company.deserialize(save_to, Context())
assert company.env.context.repo
assert company.env.context.repo.workdir == context.repo.workdir
assert company.env.context.kwargs.a == "a"
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.actions.intent_detect import IntentDetectResult
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools.libs.dialog import intent_detect, software_development_intent_detect
from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT
@pytest.mark.asyncio
@pytest.mark.skip
async def test_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await intent_detect(messages)
assert isinstance(result, IntentDetectResult)
assert result
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result.model_dump_json()}")
@pytest.mark.asyncio
async def test_software_develop_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await software_development_intent_detect(messages)
assert isinstance(result, IntentDetectResult)
assert result
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result.model_dump_json()}")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from pydantic import BaseModel
from metagpt.tools.libs.git import git_checkout, git_clone
from metagpt.utils.git_repository import GitRepository
class SWEBenchItem(BaseModel):
base_commit: str
repo: str
@pytest.mark.asyncio
@pytest.mark.parametrize(
["url", "commit_id"], [("https://github.com/sqlfluff/sqlfluff.git", "d19de0ecd16d298f9e3bfb91da122734c40c01e5")]
)
async def test_git(url: str, commit_id: str):
repo_dir = await git_clone(url)
assert repo_dir
await git_checkout(repo_dir, commit_id)
repo = GitRepository(repo_dir, auto_init=False)
repo.delete_repository()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs.shell import shell_execute
@pytest.mark.asyncio
@pytest.mark.parametrize(
["command", "expect_stdout", "expect_stderr"],
[
(["file", f"{__file__}"], "Python script text executable, ASCII text", ""),
(f"file {__file__}", "Python script text executable, ASCII text", ""),
],
)
async def test_shell(command, expect_stdout, expect_stderr):
stdout, stderr = await shell_execute(command)
assert expect_stdout in stdout
assert stderr == expect_stderr
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs import (
fix_bug,
git_archive,
run_qa_test,
write_codes,
write_design,
write_prd,
write_project_plan,
)
from metagpt.tools.libs.software_development import import_git_repo
@pytest.mark.asyncio
async def test_software_team():
path = await write_prd("snake game")
assert path
path = await write_design(path)
assert path
path = await write_project_plan(path)
assert path
path = await write_codes(path)
assert path
path = await run_qa_test(path)
assert path
issue = """
pygame 2.0.1 (SDL 2.0.14, Python 3.9.17)
Hello from the pygame community. https://www.pygame.org/contribute.html
Traceback (most recent call last):
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 10, in <module>
main()
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 7, in main
game.start_game()
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/game.py", line 81, in start_game
x
NameError: name 'x' is not defined
"""
path = await fix_bug(path, issue)
assert path
new_path = await write_prd("snake game with moving enemy", path)
assert new_path == path
git_log = await git_archive(new_path)
assert git_log
@pytest.mark.asyncio
async def test_import_repo():
url = "https://github.com/spec-first/connexion.git"
path = await import_git_repo(url)
assert path
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -62,6 +62,7 @@ async def test_js_parser():
@pytest.mark.asyncio
@pytest.mark.skip
async def test_codes():
path = DEFAULT_WORKSPACE_ROOT / "snake_game"
repo_parser = RepoParser(base_directory=path)
@ -81,5 +82,13 @@ async def test_codes():
print(data)
@pytest.mark.asyncio
async def test_graph_select():
gdb_path = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
gdb = await DiGraphRepository.load_from(gdb_path)
rows = await gdb.select()
assert rows
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,7 +10,12 @@ from metagpt.utils.repo_to_markdown import repo_to_markdown
@pytest.mark.parametrize(
["repo_path", "output"],
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
[
(
Path(__file__).parent.parent.parent,
Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}.md",
),
],
)
@pytest.mark.asyncio
async def test_repo_to_markdown(repo_path: Path, output: Path):

View file

@ -3,6 +3,7 @@ from typing import Optional, Union
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
@ -22,7 +23,7 @@ class MockLLM(OriginalLLM):
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str:
"""Overwrite original acompletion_text to cancel retry"""
if stream:
resp = await self._achat_completion_stream(messages, timeout=timeout)
@ -37,7 +38,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
@ -56,7 +57,7 @@ class MockLLM(OriginalLLM):
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
async def original_aask_batch(self, msgs: list, timeout=3) -> str:
async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
"""A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked"""
context = []
for msg in msgs:
@ -83,7 +84,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
# used to identify it a message has been called before
@ -98,7 +99,7 @@ class MockLLM(OriginalLLM):
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp