From 270c1e477037ebcbc6a0ad2395fad447db54f3d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 27 Mar 2024 12:48:40 +0800 Subject: [PATCH] refactor: private classes --- metagpt/actions/intent_detect.py | 103 +++++++++---------------------- tests/mock/mock_llm.py | 11 ++-- 2 files changed, 34 insertions(+), 80 deletions(-) diff --git a/metagpt/actions/intent_detect.py b/metagpt/actions/intent_detect.py index 7a538cb43..faf7c6d86 100644 --- a/metagpt/actions/intent_detect.py +++ b/metagpt/actions/intent_detect.py @@ -51,19 +51,6 @@ SOP_CONFIG = [ ] -class _IntentDetectIntention(BaseModel): - """ - Represents detected intentions. - - Attributes: - ref (str): The reference to the original words. - intent (str): The detected intention of the referenced words. - """ - - ref: str - intent: str - - class IntentDetectClarification(BaseModel): """ Represents clarifications for unclear intentions. @@ -77,19 +64,6 @@ class IntentDetectClarification(BaseModel): clarification: str -class _IntentDetectDialogIntentions(BaseModel): - """ - Represents dialog intentions. - - Attributes: - intentions (List[IntentDetectIntention]): List of detected intentions. - clarifications (List[IntentDetectClarification]): List of clarifications for unclear intentions. - """ - - intentions: List[_IntentDetectIntention] - clarifications: List[IntentDetectClarification] - - class IntentDetectIntentionRef(BaseModel): """ Represents intentions along with their references. @@ -103,49 +77,6 @@ class IntentDetectIntentionRef(BaseModel): refs: List[str] -class _IntentDetectUnrefs(BaseModel): - """ - Represents unreferenced content along with reasons. - - Attributes: - ref (str): The unreferenced original text. - reason (str): Explanation for why it is unreferenced. - """ - - ref: str - reason: str - - -class _IntentSOP(BaseModel): - """ - Represents a mapping between an intention and a Standard Operating Procedure (SOP). - - Attributes: - intent (str): The intention related to the SOP. - sop (str): The description of the Standard Operating Procedure. - sop_index (int): The index of the description of the Standard Operating Procedure. - reason (str): Explanation for why the intention is unreferenced. - """ - - intent: str - sop: str - sop_index: int - reason: str - - -class _IntentDetectReferences(BaseModel): - """ - Represents references to intentions and unreferenced content. - - Attributes: - intentions (List[IntentDetectIntentionRef]): List of intentions with their references. - unrefs (List[IntentDetectUnrefs]): List of unreferenced content with reasons. - """ - - intentions: List[IntentDetectIntentionRef] - unrefs: List[_IntentDetectUnrefs] - - class IntentDetectIntentionSOP(BaseModel): """ Represents an intention mapped to a Standard Operating Procedure (SOP). @@ -187,9 +118,31 @@ class IntentDetect(Action): Result object containing the outcome of intention detection. """ - _dialog_intentions: _IntentDetectDialogIntentions = None - _references: _IntentDetectReferences = None - _intent_to_sops: List[_IntentSOP] = None + class IntentDetectDialogIntentions(BaseModel): + class IntentDetectIntention(BaseModel): + ref: str + intent: str + + intentions: List[IntentDetectIntention] + clarifications: List[IntentDetectClarification] + + class IntentDetectReferences(BaseModel): + class IntentDetectUnrefs(BaseModel): + ref: str + reason: str + + intentions: List[IntentDetectIntentionRef] + unrefs: List[IntentDetectUnrefs] + + class IntentSOP(BaseModel): + intent: str + sop: str + sop_index: int + reason: str + + _dialog_intentions: IntentDetectDialogIntentions = None + _references: IntentDetectReferences = None + _intent_to_sops: List[IntentSOP] = None result: IntentDetectResult = None async def run(self, with_messages: List[Message] = None, **kwargs) -> Message: @@ -240,7 +193,7 @@ class IntentDetect(Action): json_blocks = parse_json_code_block(rsp) if not json_blocks: return [] - self._dialog_intentions = _IntentDetectDialogIntentions.model_validate_json(json_blocks[0]) + self._dialog_intentions = self.IntentDetectDialogIntentions.model_validate_json(json_blocks[0]) return [i.intent for i in self._dialog_intentions.intentions] async def _get_references(self, msg_markdown: str, intentions: List[str]): @@ -267,7 +220,7 @@ class IntentDetect(Action): json_blocks = parse_json_code_block(rsp) if not json_blocks: return [] - self._references = _IntentDetectReferences.model_validate_json(json_blocks[0]) + self._references = self.IntentDetectReferences.model_validate_json(json_blocks[0]) async def _get_sops(self): intention_list = "" @@ -296,7 +249,7 @@ class IntentDetect(Action): logger.debug(rsp) json_blocks = parse_json_code_block(rsp) vv = json.loads(json_blocks[0]) - self._intent_to_sops = [_IntentSOP.model_validate(i) for i in vv] + self._intent_to_sops = [self.IntentSOP.model_validate(i) for i in vv] @staticmethod def _message_to_markdown(messages) -> str: diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index c4262e080..f6c206d5e 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -3,6 +3,7 @@ from typing import Optional, Union from metagpt.config2 import config from metagpt.configs.llm_config import LLMType +from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA @@ -22,7 +23,7 @@ class MockLLM(OriginalLLM): self.rsp_cache: dict = {} self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str: """Overwrite original acompletion_text to cancel retry""" if stream: resp = await self._achat_completion_stream(messages, timeout=timeout) @@ -37,7 +38,7 @@ class MockLLM(OriginalLLM): system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, + timeout=LLM_API_TIMEOUT, stream=True, ) -> str: if system_msgs: @@ -56,7 +57,7 @@ class MockLLM(OriginalLLM): rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp - async def original_aask_batch(self, msgs: list, timeout=3) -> str: + async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: """A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked""" context = [] for msg in msgs: @@ -83,7 +84,7 @@ class MockLLM(OriginalLLM): system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, + timeout=LLM_API_TIMEOUT, stream=True, ) -> str: # used to identify it a message has been called before @@ -98,7 +99,7 @@ class MockLLM(OriginalLLM): rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp - async def aask_batch(self, msgs: list, timeout=3) -> str: + async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs]) rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout) return rsp