feat: +SentenceIntentDetect

This commit is contained in:
莘权 马 2024-04-02 15:58:09 +08:00
parent 66b68399eb
commit edb39a50d3
4 changed files with 69 additions and 15 deletions

View file

@ -13,6 +13,7 @@ Dependencies:
"""
import json
import re
from typing import List
from pydantic import BaseModel, Field
@ -395,3 +396,41 @@ class LightIntentDetect(IntentDetect):
if i.intent == intent:
refs.append(i.ref)
return "\n".join(refs)
class SentenceIntentDetect(IntentDetect):
sop: List[str] = None
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
"""
Runs the intention detection action.
Args:
with_messages (List[Message]): List of messages representing the conversation content.
**kwargs: Additional keyword arguments.
"""
msg_markdown = self._message_to_markdown(with_messages)
self.sop = await self._get_intentions(msg_markdown)
return Message(content="", role="assistant", cause_by=self)
async def _get_intentions(self, msg_markdown: str) -> List[str]:
prompt = f"## Dialog\n{msg_markdown}\n"
prompt += "## Intentions\n"
for i, v in enumerate(SOP_CONFIG):
prompt += f"{i + 1}. {v.description}\n"
prompt += f"{len(SOP_CONFIG) + 1}. Others"
rsp = await self.llm.aask(
prompt,
system_msgs=[
'You are a tool for selecting a suitable intention from the "Intentions" section.',
'Select the intention that matches the conversation in the "Dialog" section from the "Intentions" section.',
'If no matching intention is found, choose "Others".',
"Return the integer index of your choice.",
],
stream=False,
)
logger.debug(rsp)
idx = int(re.findall(r"\b\d+\b", rsp)[0]) - 1
if idx < len(SOP_CONFIG):
return SOP_CONFIG[idx].sop
return []

View file

@ -4,7 +4,7 @@
import asyncio
from typing import Dict, List
from metagpt.actions.intent_detect import LightIntentDetect
from metagpt.actions.intent_detect import SentenceIntentDetect
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message
@ -15,21 +15,16 @@ class MGX(DataInterpreter):
intents: Dict = {}
async def _intent_detect(self, user_msgs: List[Message] = None, **kwargs):
todo = LightIntentDetect(context=self.context)
todo = SentenceIntentDetect(context=self.context)
await todo.run(user_msgs)
logger.info(f"intent_desp is {todo.result.model_dump_json()}")
logger.info(f"intent_desp is {todo.sop}")
# Extract intent and sop prompt
intention_ref = ""
for i in todo.result.intentions:
if not intention_ref:
intention_ref = "\n".join(i.intention.refs)
if not i.sop:
continue
self.intents[intention_ref] = i.sop.sop
logger.debug(f"refs: {intention_ref}, sop: {i.sop.sop}")
sop_str = "\n".join([f"- {i}" for i in i.sop.sop])
intention_ref = "\n".join([i.content for i in user_msgs])
if todo.sop:
self.intents[intention_ref] = todo.sop
logger.debug(f"refs: {intention_ref}, sop: {todo.sop}")
sop_str = "\n".join([f"- {i}" for i in todo.sop])
markdown = (
f"### User Requirement Detail\n```text\n{intention_ref}\n````\n"
f"### Knowledge\nTo meet user requirements, the following standard operating procedure(SOP) must be"

View file

@ -4,7 +4,11 @@ import json
import pytest
from metagpt.actions.intent_detect import IntentDetect, LightIntentDetect
from metagpt.actions.intent_detect import (
IntentDetect,
LightIntentDetect,
SentenceIntentDetect,
)
from metagpt.logs import logger
from metagpt.schema import Message
@ -140,6 +144,7 @@ DEMO3_CONTENT = [
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
)
# @pytest.mark.skip
async def test_intent_detect(content: str, context):
action = IntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
@ -157,6 +162,7 @@ async def test_intent_detect(content: str, context):
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
)
# @pytest.mark.skip
async def test_light_intent_detect(content: str, context):
action = LightIntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
@ -167,5 +173,18 @@ async def test_light_intent_detect(content: str, context):
assert action.result
@pytest.mark.asyncio
@pytest.mark.parametrize(
"content",
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
)
# @pytest.mark.skip
async def test_sentence_intent(content: str, context):
action = SentenceIntentDetect(context=context)
messages = [Message.model_validate(i) for i in json.loads(content)]
await action.run(messages)
assert action.sop is not None
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -48,7 +48,8 @@ async def test_mgx_fixbug(user_message: Message, history_messages: List[Message]
@pytest.mark.asyncio
@pytest.mark.parametrize("user_message", [Message.model_validate(i) for i in DEMO3_CONTENT if i.role == "user"])
@pytest.mark.parametrize("user_message", [Message.model_validate(i) for i in DEMO3_CONTENT if i["role"] == "user"])
@pytest.mark.skip
async def test_git_import(user_message, context):
mgx = MGX(context=context, tools=["<all>"])
await mgx.run(user_message)