feat: +software_development_intent_detect

This commit is contained in:
莘权 马 2024-03-29 21:46:46 +08:00
parent b029a1996a
commit a7b4af738a
7 changed files with 165 additions and 485 deletions

View file

@ -4,7 +4,8 @@ import json
import pytest
from metagpt.actions.intent_detect import IntentDetect
from metagpt.actions.intent_detect import IntentDetect, LightIntentDetect
from metagpt.logs import logger
from metagpt.schema import Message
DEMO_CONTENT = [
@ -56,6 +57,19 @@ async def test_intent_detect(content: str, context):
assert action._references
assert action._intent_to_sops
assert action.result
logger.info(action.result.model_dump_json())
@pytest.mark.asyncio
@pytest.mark.parametrize(
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT)],
)
async def test_light_intent_detect(content: str, context):
action = LightIntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
rsp = await action.run(messages)
assert isinstance(rsp, Message)
if __name__ == "__main__":

View file

@ -6,11 +6,12 @@ import pytest
from metagpt.actions.intent_detect import IntentDetectResult
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools.libs.dialog import intent_detect
from metagpt.tools.libs.dialog import intent_detect, software_development_intent_detect
from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT
@pytest.mark.asyncio
@pytest.mark.skip
async def test_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await intent_detect(messages)
@ -19,5 +20,14 @@ async def test_intent_detect():
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result.model_dump_json()}")
@pytest.mark.asyncio
async def test_software_develop_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await software_development_intent_detect(messages)
assert isinstance(result, list)
assert result
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result}")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs.shell import execute
from metagpt.tools.libs.shell import shell_execute
@pytest.mark.asyncio
@ -14,7 +14,7 @@ from metagpt.tools.libs.shell import execute
],
)
async def test_shell(command, expect_stdout, expect_stderr):
stdout, stderr = await execute(command)
stdout, stderr = await shell_execute(command)
assert expect_stdout in stdout
assert stderr == expect_stderr

View file

@ -62,6 +62,7 @@ async def test_js_parser():
@pytest.mark.asyncio
@pytest.mark.skip
async def test_codes():
path = DEFAULT_WORKSPACE_ROOT / "snake_game"
repo_parser = RepoParser(base_directory=path)
@ -81,5 +82,13 @@ async def test_codes():
print(data)
@pytest.mark.asyncio
async def test_graph_select():
gdb_path = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
gdb = await DiGraphRepository.load_from(gdb_path)
rows = await gdb.select()
assert rows
if __name__ == "__main__":
pytest.main([__file__, "-s"])