feat: +software_development_intent_detect

This commit is contained in:
莘权 马 2024-03-29 21:46:46 +08:00
parent b029a1996a
commit a7b4af738a
7 changed files with 165 additions and 485 deletions

View file

@ -157,15 +157,7 @@ class IntentDetect(Action):
intentions = await self._get_intentions(msg_markdown)
await self._get_references(msg_markdown, intentions)
await self._get_sops()
self.result = IntentDetectResult(clarifications=self._dialog_intentions.clarifications)
sops = {i.description: i for i in SOP_CONFIG}
intent_to_sops = {i.intent: i.sop for i in self._intent_to_sops if i.sop != ""}
for i in self._references.intentions:
item = IntentDetectIntentionSOP(intention=i)
key = intent_to_sops.get(i.intent)
if key:
item.sop = sops.get(key)
self.result.intentions.append(item)
await self._merge()
return Message(
content=self.result.model_dump_json(), role="assistant", cause_by=self, instruct_content=self.result
@ -251,6 +243,45 @@ class IntentDetect(Action):
vv = json.loads(json_blocks[0])
self._intent_to_sops = [self.IntentSOP.model_validate(i) for i in vv]
async def _merge(self):
self.result = IntentDetectResult(clarifications=self._dialog_intentions.clarifications)
sops = {i.description: i for i in SOP_CONFIG}
intent_to_sops = {i.intent: i.sop for i in self._intent_to_sops if i.sop != ""}
distinct = {}
for i in self._intent_to_sops:
if i.sop_index == 0: # 1-based index
ref = self._get_intent_ref(i.intent)
item = IntentDetectIntentionSOP(intention=ref)
self.result.intentions.append(item)
continue
distinct[i.sop_index] = [i.intent] + distinct.get(i.sop_index, [])
merge_intents = {}
for sop_index, intents in distinct.items():
if len(intents) > 1:
merge_intents[sop_index] = intents
continue
intent_ref = self._get_intent_ref(intents[0])
item = IntentDetectIntentionSOP(intention=intent_ref)
key = intent_to_sops.get(intents[0])
if key:
item.sop = sops.get(key)
self.result.intentions.append(item)
for sop_index, intents in merge_intents.items():
intent_ref = IntentDetectIntentionRef(intent="\n".join(intents), refs=[])
for i in intents:
ref = self._get_intent_ref(i)
intent_ref.refs.extend(ref.refs)
item = IntentDetectIntentionSOP(intention=intent_ref)
item.sop = SOP_CONFIG[sop_index - 1] # 1-based index
self.result.intentions.append(item)
def _get_intent_ref(self, intent: str):
mappings = {i.intent: i for i in self._references.intentions}
return mappings[intent]
@staticmethod
def _message_to_markdown(messages) -> str:
markdown = ""
@ -258,3 +289,50 @@ class IntentDetect(Action):
content = i.content.replace("\n", " ")
markdown += f"> {i.role}: {content}\n>\n"
return markdown
class LightIntentDetect(IntentDetect):
result: List[SOPItem] = None
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
"""
Runs the intention detection action.
Args:
with_messages (List[Message]): List of messages representing the conversation content.
**kwargs: Additional keyword arguments.
"""
msg_markdown = self._message_to_markdown(with_messages)
await self._get_intentions(msg_markdown)
await self._get_sops()
distinct = {i.sop_index - 1: SOP_CONFIG[i.sop_index - 1] for i in self._intent_to_sops if i.sop_index > 0}
self.result = list(distinct.values())
return Message(content="", role="assistant", cause_by=self)
async def _get_sops(self):
intention_list = ""
for i, v in enumerate(self._dialog_intentions.intentions):
intention_list += f"{i + 1}. intent: {v.intent}\n - ref: {v.ref}\n"
sop_list = ""
for i, v in enumerate(SOP_CONFIG):
sop_list += f"{i + 1}. {v.description}\n"
prompt = f"## Intentions\n{intention_list}\n---\n## SOPs\n{sop_list}\n"
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a tool that matches user intentions with Standard Operating Procedures (SOPs).",
'You search for matching SOPs under "SOPs" based on user intentions in "Intentions" and their related original descriptions.',
'Inspect each intention in "Intentions".',
"Return a markdown JSON list of objects, where each object contains:\n"
'- an "intent" key containing the intention from the "Intentions" section;\n'
'- a "sop" key containing the SOP description from the "SOPs" section; filled with an empty string if no match.\n'
'- a "sop_index" key containing the int type index of SOP description from the "SOPs" section; filled with 0 if no match.\n'
'- a "reason" key explaining why it is matching/mismatching.\n',
],
stream=False,
)
logger.debug(rsp)
json_blocks = parse_json_code_block(rsp)
vv = json.loads(json_blocks[0])
self._intent_to_sops = [self.IntentSOP.model_validate(i) for i in vv]

View file

@ -6,7 +6,11 @@ This script defines tools for dialog.
from typing import List
from metagpt.actions.intent_detect import IntentDetect, IntentDetectResult
from metagpt.actions.intent_detect import (
IntentDetect,
IntentDetectResult,
LightIntentDetect,
)
from metagpt.context import Context
from metagpt.schema import Message
from metagpt.tools.tool_registry import register_tool
@ -74,3 +78,39 @@ async def intent_detect(messages: List[Message]) -> IntentDetectResult:
action = IntentDetect(context=ctx)
await action.run(messages)
return action.result
@register_tool(tags=["dialog", "software development intent detect"])
async def software_development_intent_detect(messages: List[Message]) -> List[str]:
"""Detects software development intent from a list of dialog messages.
Args:
messages (List[Message]): A list of dialog messages.
Returns:
IntentDetectResult: The result of intent detection.
Example:
>>> # Create messages
>>> dialog = [
>>> {"role":"user", "content":"user queries ..."},
>>> {"role":"assistant", "content": "assistant answers ..."},
>>> ...
>>> ]
>>> from metagpt.schema import Message
>>> messages = [Message.model_validate(i) for i in dialog]
>>> result = await software_development_intent_detect(messages)
>>> print(result)
[
"Writes a PRD based on software requirements.",
"Writes a design to the project repository, based on the PRD of the project.",
"Writes a project plan to the project repository, based on the design of the project.",
"Writes codes to the project repository, based on the project plan of the project.",
"Run QA test on the project repository.",
"Stage and commit changes for the project repository using Git."
]
"""
ctx = Context()
action = LightIntentDetect(context=ctx)
await action.run(messages)
return action.result[0].sop if action.result else []

File diff suppressed because one or more lines are too long

View file

@ -4,7 +4,8 @@ import json
import pytest
from metagpt.actions.intent_detect import IntentDetect
from metagpt.actions.intent_detect import IntentDetect, LightIntentDetect
from metagpt.logs import logger
from metagpt.schema import Message
DEMO_CONTENT = [
@ -56,6 +57,19 @@ async def test_intent_detect(content: str, context):
assert action._references
assert action._intent_to_sops
assert action.result
logger.info(action.result.model_dump_json())
@pytest.mark.asyncio
@pytest.mark.parametrize(
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT)],
)
async def test_light_intent_detect(content: str, context):
action = LightIntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
rsp = await action.run(messages)
assert isinstance(rsp, Message)
if __name__ == "__main__":

View file

@ -6,11 +6,12 @@ import pytest
from metagpt.actions.intent_detect import IntentDetectResult
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools.libs.dialog import intent_detect
from metagpt.tools.libs.dialog import intent_detect, software_development_intent_detect
from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT
@pytest.mark.asyncio
@pytest.mark.skip
async def test_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await intent_detect(messages)
@ -19,5 +20,14 @@ async def test_intent_detect():
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result.model_dump_json()}")
@pytest.mark.asyncio
async def test_software_develop_intent_detect():
messages = [Message.model_validate(i) for i in DEMO_CONTENT]
result = await software_development_intent_detect(messages)
assert isinstance(result, list)
assert result
logger.info(f"dialog:{DEMO_CONTENT}\nresult:{result}")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs.shell import execute
from metagpt.tools.libs.shell import shell_execute
@pytest.mark.asyncio
@ -14,7 +14,7 @@ from metagpt.tools.libs.shell import execute
],
)
async def test_shell(command, expect_stdout, expect_stderr):
stdout, stderr = await execute(command)
stdout, stderr = await shell_execute(command)
assert expect_stdout in stdout
assert stderr == expect_stderr

View file

@ -62,6 +62,7 @@ async def test_js_parser():
@pytest.mark.asyncio
@pytest.mark.skip
async def test_codes():
path = DEFAULT_WORKSPACE_ROOT / "snake_game"
repo_parser = RepoParser(base_directory=path)
@ -81,5 +82,13 @@ async def test_codes():
print(data)
@pytest.mark.asyncio
async def test_graph_select():
gdb_path = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
gdb = await DiGraphRepository.load_from(gdb_path)
rows = await gdb.select()
assert rows
if __name__ == "__main__":
pytest.main([__file__, "-s"])