mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
feat: +SentenceIntentDetect
This commit is contained in:
parent
66b68399eb
commit
edb39a50d3
4 changed files with 69 additions and 15 deletions
|
|
@ -13,6 +13,7 @@ Dependencies:
|
|||
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -395,3 +396,41 @@ class LightIntentDetect(IntentDetect):
|
|||
if i.intent == intent:
|
||||
refs.append(i.ref)
|
||||
return "\n".join(refs)
|
||||
|
||||
|
||||
class SentenceIntentDetect(IntentDetect):
|
||||
sop: List[str] = None
|
||||
|
||||
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
|
||||
"""
|
||||
Runs the intention detection action.
|
||||
|
||||
Args:
|
||||
with_messages (List[Message]): List of messages representing the conversation content.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
msg_markdown = self._message_to_markdown(with_messages)
|
||||
self.sop = await self._get_intentions(msg_markdown)
|
||||
return Message(content="", role="assistant", cause_by=self)
|
||||
|
||||
async def _get_intentions(self, msg_markdown: str) -> List[str]:
|
||||
prompt = f"## Dialog\n{msg_markdown}\n"
|
||||
prompt += "## Intentions\n"
|
||||
for i, v in enumerate(SOP_CONFIG):
|
||||
prompt += f"{i + 1}. {v.description}\n"
|
||||
prompt += f"{len(SOP_CONFIG) + 1}. Others"
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
'You are a tool for selecting a suitable intention from the "Intentions" section.',
|
||||
'Select the intention that matches the conversation in the "Dialog" section from the "Intentions" section.',
|
||||
'If no matching intention is found, choose "Others".',
|
||||
"Return the integer index of your choice.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(rsp)
|
||||
idx = int(re.findall(r"\b\d+\b", rsp)[0]) - 1
|
||||
if idx < len(SOP_CONFIG):
|
||||
return SOP_CONFIG[idx].sop
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import asyncio
|
||||
from typing import Dict, List
|
||||
|
||||
from metagpt.actions.intent_detect import LightIntentDetect
|
||||
from metagpt.actions.intent_detect import SentenceIntentDetect
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -15,21 +15,16 @@ class MGX(DataInterpreter):
|
|||
intents: Dict = {}
|
||||
|
||||
async def _intent_detect(self, user_msgs: List[Message] = None, **kwargs):
|
||||
todo = LightIntentDetect(context=self.context)
|
||||
todo = SentenceIntentDetect(context=self.context)
|
||||
await todo.run(user_msgs)
|
||||
logger.info(f"intent_desp is {todo.result.model_dump_json()}")
|
||||
logger.info(f"intent_desp is {todo.sop}")
|
||||
|
||||
# Extract intent and sop prompt
|
||||
intention_ref = ""
|
||||
for i in todo.result.intentions:
|
||||
if not intention_ref:
|
||||
intention_ref = "\n".join(i.intention.refs)
|
||||
if not i.sop:
|
||||
continue
|
||||
|
||||
self.intents[intention_ref] = i.sop.sop
|
||||
logger.debug(f"refs: {intention_ref}, sop: {i.sop.sop}")
|
||||
sop_str = "\n".join([f"- {i}" for i in i.sop.sop])
|
||||
intention_ref = "\n".join([i.content for i in user_msgs])
|
||||
if todo.sop:
|
||||
self.intents[intention_ref] = todo.sop
|
||||
logger.debug(f"refs: {intention_ref}, sop: {todo.sop}")
|
||||
sop_str = "\n".join([f"- {i}" for i in todo.sop])
|
||||
markdown = (
|
||||
f"### User Requirement Detail\n```text\n{intention_ref}\n````\n"
|
||||
f"### Knowledge\nTo meet user requirements, the following standard operating procedure(SOP) must be"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ import json
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.intent_detect import IntentDetect, LightIntentDetect
|
||||
from metagpt.actions.intent_detect import (
|
||||
IntentDetect,
|
||||
LightIntentDetect,
|
||||
SentenceIntentDetect,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -140,6 +144,7 @@ DEMO3_CONTENT = [
|
|||
"content",
|
||||
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
|
||||
)
|
||||
# @pytest.mark.skip
|
||||
async def test_intent_detect(content: str, context):
|
||||
action = IntentDetect(context=context)
|
||||
messages = [Message.model_validate(i) for i in json.loads(content)]
|
||||
|
|
@ -157,6 +162,7 @@ async def test_intent_detect(content: str, context):
|
|||
"content",
|
||||
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
|
||||
)
|
||||
# @pytest.mark.skip
|
||||
async def test_light_intent_detect(content: str, context):
|
||||
action = LightIntentDetect(context=context)
|
||||
messages = [Message.model_validate(i) for i in json.loads(content)]
|
||||
|
|
@ -167,5 +173,18 @@ async def test_light_intent_detect(content: str, context):
|
|||
assert action.result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[json.dumps(DEMO1_CONTENT), json.dumps(DEMO_CONTENT), json.dumps(DEMO2_CONTENT), json.dumps(DEMO3_CONTENT)],
|
||||
)
|
||||
# @pytest.mark.skip
|
||||
async def test_sentence_intent(content: str, context):
|
||||
action = SentenceIntentDetect(context=context)
|
||||
messages = [Message.model_validate(i) for i in json.loads(content)]
|
||||
await action.run(messages)
|
||||
assert action.sop is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ async def test_mgx_fixbug(user_message: Message, history_messages: List[Message]
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("user_message", [Message.model_validate(i) for i in DEMO3_CONTENT if i.role == "user"])
|
||||
@pytest.mark.parametrize("user_message", [Message.model_validate(i) for i in DEMO3_CONTENT if i["role"] == "user"])
|
||||
@pytest.mark.skip
|
||||
async def test_git_import(user_message, context):
|
||||
mgx = MGX(context=context, tools=["<all>"])
|
||||
await mgx.run(user_message)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue