diff --git a/metagpt/actions/intent_detect.py b/metagpt/actions/intent_detect.py index 0104d079d..7831624c6 100644 --- a/metagpt/actions/intent_detect.py +++ b/metagpt/actions/intent_detect.py @@ -245,42 +245,43 @@ class IntentDetect(Action): async def _merge(self): self.result = IntentDetectResult(clarifications=self._dialog_intentions.clarifications) - sops = {i.description: i for i in SOP_CONFIG} - intent_to_sops = {i.intent: i.sop for i in self._intent_to_sops if i.sop != ""} - distinct = {} for i in self._intent_to_sops: if i.sop_index == 0: # 1-based index - ref = self._get_intent_ref(i.intent) - item = IntentDetectIntentionSOP(intention=ref) + refs = self._get_intent_ref(i.intent) + item = IntentDetectIntentionSOP(intention=IntentDetectIntentionRef(intent=i.intent, refs=refs)) self.result.intentions.append(item) continue distinct[i.sop_index] = [i.intent] + distinct.get(i.sop_index, []) merge_intents = {} + intent_to_sops = {i.intent: i.sop_index for i in self._intent_to_sops if i.sop_index != 0} for sop_index, intents in distinct.items(): if len(intents) > 1: merge_intents[sop_index] = intents continue - intent_ref = self._get_intent_ref(intents[0]) - item = IntentDetectIntentionSOP(intention=intent_ref) - key = intent_to_sops.get(intents[0]) - if key: - item.sop = sops.get(key) + refs = self._get_intent_ref(intents[0]) + item = IntentDetectIntentionSOP(intention=IntentDetectIntentionRef(intent=intents[0], refs=refs)) + sop_index = intent_to_sops.get(intents[0]) + item.sop = SOP_CONFIG[sop_index - 1] # 1-based index self.result.intentions.append(item) for sop_index, intents in merge_intents.items(): intent_ref = IntentDetectIntentionRef(intent="\n".join(intents), refs=[]) for i in intents: - ref = self._get_intent_ref(i) - intent_ref.refs.extend(ref.refs) + refs = self._get_intent_ref(i) + intent_ref.refs.extend(refs) + intent_ref.refs = list(set(intent_ref.refs)) item = IntentDetectIntentionSOP(intention=intent_ref) item.sop = SOP_CONFIG[sop_index - 1] # 1-based index self.result.intentions.append(item) - def _get_intent_ref(self, intent: str): - mappings = {i.intent: i for i in self._references.intentions} - return mappings[intent] + def _get_intent_ref(self, intent: str) -> List[str]: + refs = [] + for i in self._references.intentions: + if i.intent == intent: + refs.extend(i.refs) + return refs @staticmethod def _message_to_markdown(messages) -> str: @@ -336,9 +337,6 @@ class LightIntentDetect(IntentDetect): async def _merge(self): self.result = IntentDetectResult(clarifications=[]) - sops = {i.description: i for i in SOP_CONFIG} - intent_to_sops = {i.intent: i.sop for i in self._intent_to_sops if i.sop != ""} - distinct = {} for i in self._intent_to_sops: if i.sop_index == 0: # 1-based index @@ -349,15 +347,16 @@ class LightIntentDetect(IntentDetect): distinct[i.sop_index] = [i.intent] + distinct.get(i.sop_index, []) merge_intents = {} + intent_to_sops = {i.intent: i.sop_index for i in self._intent_to_sops if i.sop_index != 0} for sop_index, intents in distinct.items(): if len(intents) > 1: merge_intents[sop_index] = intents continue ref = self._get_intent_ref(intents[0]) item = IntentDetectIntentionSOP(intention=IntentDetectIntentionRef(intent=intents[0], refs=[ref])) - key = intent_to_sops.get(intents[0]) - if key: - item.sop = sops.get(key) + sop_index = intent_to_sops.get(intents[0]) # 1-based + if sop_index: + item.sop = SOP_CONFIG[sop_index - 1] # 1-based index self.result.intentions.append(item) for sop_index, intents in merge_intents.items(): @@ -365,10 +364,14 @@ class LightIntentDetect(IntentDetect): for i in intents: ref = self._get_intent_ref(i) intent_ref.refs.append(ref) + intent_ref.refs = list(set(intent_ref.refs)) item = IntentDetectIntentionSOP(intention=intent_ref) item.sop = SOP_CONFIG[sop_index - 1] # 1-based index self.result.intentions.append(item) def _get_intent_ref(self, intent: str) -> str: - mappings = {i.intent: i.ref for i in self._dialog_intentions.intentions} - return mappings[intent] + refs = [] + for i in self._dialog_intentions.intentions: + if i.intent == intent: + refs.append(i.ref) + return "\n".join(refs) diff --git a/metagpt/roles/di/mgx.py b/metagpt/roles/di/mgx.py index 85fd44894..64b06b702 100644 --- a/metagpt/roles/di/mgx.py +++ b/metagpt/roles/di/mgx.py @@ -2,7 +2,6 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : import asyncio -import json from typing import Dict, List from metagpt.actions.intent_detect import IntentDetect @@ -17,20 +16,20 @@ class MGX(DataInterpreter): async def _intent_detect(self, user_msgs: List[Message] = None, **kwargs): todo = IntentDetect(context=self.context) - intent_desp = await todo.run(user_msgs) - intent_desp = json.loads(intent_desp.content) - logger.info(f"intent_desp is {intent_desp}") + await todo.run(user_msgs) + logger.info(f"intent_desp is {todo.result.model_dump_json()}") # Extract intent and sop prompt - intents = intent_desp.get("intentions", [{}])[0] - # Optional: handle the case where intentions might be empty or malformatted - intention_ref = intents.get("intention", {}).get("refs", [None])[0] - sop = intents.get("sop", {}).get("sop", None) - self.intents.update({intention_ref: sop}) - - if sop is None: - return intention_ref - return intention_ref + "\n---" + "\n".join(intents["sop"]["sop"]) + intention_ref = "" + for i in todo.result.intentions: + if not i.sop: + continue + intention_ref = "\n".join(i.intention.refs) + self.intents[intention_ref] = i.sop.sop + logger.debug(f"refs: {intention_ref}, sop: {i.sop.sop}") + sop_str = "\n".join(i.sop.sop) + return f"{intention_ref}\n---\n{sop_str}" + return intention_ref async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" diff --git a/tests/metagpt/roles/di/test_mgx.py b/tests/metagpt/roles/di/test_mgx.py new file mode 100644 index 000000000..59392e8d7 --- /dev/null +++ b/tests/metagpt/roles/di/test_mgx.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from typing import List + +import pytest + +from metagpt.context import Context +from metagpt.roles.di.mgx import MGX +from metagpt.schema import Message +from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT + + +@pytest.mark.asyncio +@pytest.mark.parametrize("user_messages", [[Message.model_validate(i) for i in DEMO_CONTENT if i["role"] == "user"]]) +async def test_mgx(user_messages: List[Message]): + ctx = Context() + mgx = MGX(context=ctx) + + for i in user_messages: + await mgx.run(i) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])