From df36fcc9297400ad840ba40c70d36de300f07276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sun, 28 Apr 2024 22:49:43 +0800 Subject: [PATCH] feat: `Message` add the ability to generate key-values based on `content` --- metagpt/schema.py | 34 +++++++++++++++++++++++++++++ tests/metagpt/test_schema.py | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/metagpt/schema.py b/metagpt/schema.py index d0396ec26..5a7f1cbab 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import json import os.path +import re import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for @@ -186,6 +187,14 @@ class Documents(BaseModel): return ActionOutput(content=self.model_dump_json(), instruct_content=self) +class Resource(BaseModel): + """Used by `Message`.`parse_resources`""" + + resource_type: str # the type of resource + value: str # a string type of resource content + description: str # explanation + + class Message(BaseModel): """list[: ]""" @@ -311,6 +320,31 @@ class Message(BaseModel): logger.error(f"parse json failed: {val}, error:{err}") return None + async def parse_resources(self, llm: "BaseLLM", key_descriptions: Dict[str, str] = None) -> Dict: + if not self.content: + return {} + content = f"## Original Requirement\n```text\n{self.content}\n```\n" + return_format = ( + "Return a markdown JSON object with:\n" + '- a "resources" key contain a list of objects. Each object with:\n' + ' - a "resource_type" key explain the type of resource;\n' + ' - a "value" key containing a string type of resource content;\n' + ' - a "description" key explaining why;\n' + ) + key_descriptions = key_descriptions or {} + for k, v in key_descriptions.items(): + return_format += f'- a "{k}" key containing {v};\n' + return_format += '- a "reason" key explaining why;\n' + instructions = ['Lists all the resources contained in the "Original Requirement".', return_format] + rsp = await llm.aask(msg=content, system_msgs=instructions) + pattern = r"```json\s*({[\s\S]*?})\s*```" + matches = re.findall(pattern, rsp) + if not matches: + return {} + m = json.loads(matches[0]) + m["resources"] = [Resource(**i) for i in m.get("resources", [])] + return m + class UserMessage(Message): """便于支持OpenAI的消息 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 22f6ae9fb..6f54b062d 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -350,5 +350,47 @@ class TestPlan: assert plan.current_task_id == "2" +@pytest.mark.parametrize( + ("content", "key_descriptions"), + [ + ( + """ +Traceback (most recent call last): + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/main.py", line 38, in + Main().main() + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/main.py", line 28, in main + self.user_interface.draw() + File "/Users/iorishinier/github/MetaGPT/workspace/game_2048_1/game_2048/user_interface.py", line 16, in draw + if grid[i][j] != 0: +TypeError: 'Grid' object is not subscriptable + """, + { + "filename": "the string type of the path name of the source code where the bug resides", + "line": "the integer type of the line error occurs", + "function_name": "the string type of the function name the error occurs in", + "code": "the string type of the codes where the error occurs at", + "info": "the string type of the error information", + }, + ), + ( + "将代码提交到github上的iorisa/repo1的branch1分支,发起pull request ,合并到master分支。", + { + "repo_name": "the string type of github repo to create pull", + "head": "the string type of github branch to be pushed", + "base": "the string type of github branch to merge the changes into", + }, + ), + ], +) +async def test_parse_resources(context, content: str, key_descriptions): + msg = Message(content=content) + llm = context.llm_with_cost_manager_from_llm_config(context.config.llm) + result = await msg.parse_resources(llm=llm, key_descriptions=key_descriptions) + assert result + assert result.get("resources") + for k in key_descriptions.keys(): + assert k in result + + if __name__ == "__main__": pytest.main([__file__, "-s"])