diff --git a/README.md b/README.md index 91a5483e0..396742077 100644 --- a/README.md +++ b/README.md @@ -76,20 +76,20 @@ # Step 3: Clone the repository to your local machine, and install it. **Note:** - If already have Chrome, Chromium, or MS Edge installed, you can skip downloading Chromium by setting the environment variable - `PUPPETEER_SKIP_CHROMIUM_DOWNLOAD` to `true`. +`PUPPETEER_SKIP_CHROMIUM_DOWNLOAD` to `true`. - Some people are [having issues](https://github.com/mermaidjs/mermaid.cli/issues/15) installing this tool globally. Installing it locally is an alternative solution, - ```bash - npm install @mermaid-js/mermaid-cli - ``` + ```bash + npm install @mermaid-js/mermaid-cli + ``` - don't forget to the configuration for mmdc in config.yml - ```yml - PUPPETEER_CONFIG: "./config/puppeteer-config.json" - MMDC: "./node_modules/.bin/mmdc" - ``` + ```yml + PUPPETEER_CONFIG: "./config/puppeteer-config.json" + MMDC: "./node_modules/.bin/mmdc" + ``` - if `pip install -e.` fails with error `[Errno 13] Permission denied: '/usr/local/lib/python3.11/dist-packages/test-easy-install-13129.write-test'`, try instead running `pip install -e. --user` @@ -224,12 +224,12 @@ # Run the script # Do not hire an engineer to implement the project python startup.py "Write a cli snake game" --implement False # Hire an engineer and perform code reviews -python startup.py "Write a cli snake game" --code_review True +python startup.py "Write a cli snake game" --code_review True ``` After running the script, you can find your new project in the `workspace/` directory. -### Preference of Platform or Tool +### Preference of Platform or Tool You can tell which platform or tool you want to use when stating your requirements. @@ -286,7 +286,7 @@ ### Code walkthrough ## QuickStart -It is difficult to install and configure the local environment for some users. The following tutorials will allow you to quickly experience the charm of MetaGPT. +It is difficult to install and configure the local environment for some users. The following tutorials will allow you to quickly experience the charm of MetaGPT. - [MetaGPT quickstart](https://deepwisdom.feishu.cn/wiki/CyY9wdJc4iNqArku3Lncl4v8n2b) @@ -299,7 +299,7 @@ ## Citation ```bibtex @misc{hong2023metagpt, - title={MetaGPT: Meta Programming for Multi-Agent Collaborative Framework}, + title={MetaGPT: Meta Programming for Multi-Agent Collaborative Framework}, author={Sirui Hong and Xiawu Zheng and Jonathan Chen and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu}, year={2023}, eprint={2308.00352}, diff --git a/docs/resources/gitattributes.txt b/docs/resources/gitattributes.txt new file mode 100644 index 000000000..bb940d6a1 --- /dev/null +++ b/docs/resources/gitattributes.txt @@ -0,0 +1 @@ +tasks.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/metagpt/actions/design.py b/metagpt/actions/design.py new file mode 100644 index 000000000..968640e64 --- /dev/null +++ b/metagpt/actions/design.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/17 13:43 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +# Standard library imports +from functools import wraps +from typing import Callable, Any, List, Optional, Tuple + +# Local library imports +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.utils.common import OutputParser +from metagpt.prompts.sd_design import ( + SD_PROMPT_KW_OPTIMIZE_TEMPLATE, + SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE, + FORMAT_INSTRUCTIONS, + PROMPT_OUTPUT_MAPPING +) +from metagpt.utils.resp_parse import flatten_json_structure, try_parse_json + +# A default template for the system primer. +SYSTEM_PRIMER_TEMPLATE = "Act like you are a terminal and always format your response as json. Always return exactly {answer_count} answers per question in English." + + +class Tool: + """Define a tool with its name, function and description.""" + + def __init__(self, name: str, func: Callable, description: str) -> None: + """Initialize tool.""" + self.name = name + self.func = func + self.description = description + + +# Decorator for the BaseModelAction to wrap it with system primer details +def system_primer_decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + system_primer = SYSTEM_PRIMER_TEMPLATE.format(answer_count=kwargs.get('answer_count', 1)) + logger.info(system_primer) + return await func(*args, system_primer=system_primer, **kwargs) + + return wrapper + + +class BaseModelAction(Action): + + def __init__(self, name: str = "", description: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.desc = description + + async def handle_response(self, resp: str) -> Any: + """Handle JSON response and extract value.""" + try: + resp_json = flatten_json_structure(try_parse_json(resp)) + logger.info(resp_json) + return resp_json + + except Exception as exp: + logger.error(f" JSON response {exp}") + return None + + @system_primer_decorator + async def run_optimize_or_improve(self, query: str, domain: str, template: str, answer_count: int = 1, + system_primer=None) -> List[str]: + """Run optimization or improvement based on the given template.""" + prompt = template.format(messages=query, domain=domain, answer_count=answer_count) + resp: str = await self._aask(prompt=prompt, system_msgs=[system_primer]) + result = await self.handle_response(resp) + return result or [query] + + +class SDPromptOptimize(BaseModelAction): + """ + Optimize graphical prompts based on keywords. + 扩充画图的提示词,根据keyword + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="PromptOptimize", *args, **kwargs) + + async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]: + """Run the optimization for the given query.""" + return await self.run_optimize_or_improve(query, domain, SD_PROMPT_KW_OPTIMIZE_TEMPLATE, + answer_count=answer_count) + + +class SDPromptImprove(BaseModelAction): + """Enhance the input prompt (recommended when the input prompt is long). + fixme: 接入提示词优化的FT模型 + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="PromptImprove", *args, **kwargs) + + async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]: + """Run the improvement for the given query.""" + return await self.run_optimize_or_improve(query, domain, SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE, + answer_count=answer_count) + + +class SDPromptExtend(BaseModelAction): + """Action class to extend the prompt.""" + + def __init__(self, name: str = "", tools: Optional[List[Tool]] = [], **kwargs): + super().__init__(name, description="Prompt Extend", **kwargs) + self.tools = tools + logger.info(self.tools) + + def _parse_tools(self) -> Tuple[str, str]: + """Parse tool names and descriptions.""" + tool_strings = [f"{tool.name}: {tool.description}" for tool in self.tools] + formatted_tools = "\n".join(tool_strings) + tool_names = ", ".join(tool.name for tool in self.tools) + return tool_names, formatted_tools + + async def run(self, query: str, answer_count: int = 1, domain: str = "realistic", + model_name="realisticVisionV30_v30VAE") -> str: + """Extend the prompt and get the "Final Action" from the output.""" + tool_names, formatted_tools = self._parse_tools() + msg = FORMAT_INSTRUCTIONS.format(query=query, tool_names=tool_names, + tool_description=formatted_tools, + model_name=model_name, domain=domain) + + resp = await self._aask(msg) + output_block = OutputParser.parse_data_with_mapping(resp, PROMPT_OUTPUT_MAPPING) + return output_block["Final Action"] + + diff --git a/metagpt/actions/minecraft/design_curriculumn.py b/metagpt/actions/minecraft/design_curriculumn.py index 28299c620..9d0daa72e 100644 --- a/metagpt/actions/minecraft/design_curriculumn.py +++ b/metagpt/actions/minecraft/design_curriculumn.py @@ -10,13 +10,11 @@ from langchain.vectorstores import Chroma from metagpt.document_store import FaissStore from metagpt.logs import logger -from metagpt.actions import Action +from metagpt.actions.minecraft.player_action import PlayerActions as Action from metagpt.utils.minecraft import load_prompt, fix_and_parse_json from metagpt.schema import HumanMessage, SystemMessage from metagpt.const import CKPT_DIR -# from metagpt.actions.minecraft import PlayerActions - class DesignTask(Action): """ @@ -63,11 +61,12 @@ class DesignTask(Action): response = self.parse_llm_response( curriculum ) # Task: Craft 4 wooden planks. + logger.info(f"Parsed Curriculum Agent response\n{response}") assert "next_task" in response return response["next_task"] except Exception as e: logger.info(f"Error parsing curriculum response: {e}. Trying again!") - return self.generate_task( + return await self.generate_task( human_msg=human_msg, system_msg=system_msg, max_retries=max_retries - 1, @@ -92,29 +91,6 @@ class DesignCurriculum(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) # voyager vectordb using - self.qa_cache = {} - self.qa_cache_questions_vectordb = Chroma( - collection_name="qa_cache_questions_vectordb", - embedding_function=OpenAIEmbeddings(), - persist_directory=f"{CKPT_DIR}/curriculum/vectordb", - ) - # TODO: change to FaissStore - # self.qa_cache_questions_vectordb = FaissStore( {CKPT_DIR}/ 'curriculum/vectordb') - # TODO: - # assert self.qa_cache_questions_vectordb._collection.count() == len( - # self.qa_cache - # ), ( - # f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n" - # f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb " - # f"but {len(self.qa_cache)} questions in qa_cache.json.\n" - # f"Did you set resume=False when initializing the agent?\n" - # f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n" - # ) - - @classmethod - def set_qa_cache(cls, qa_cache): - cls.qa_cache = qa_cache - # Check if qa_cache right using @classmethod def generate_qa(cls, events, chest_observation): @@ -232,7 +208,7 @@ class DesignCurriculum(Action): return context except Exception as e: logger.info(f"Error parsing curriculum response: {e}. Trying again!") - return self.generate_context( + return await self.generate_context( task=task, max_retries=max_retries - 1, ) diff --git a/metagpt/actions/minecraft/generate_actions.py b/metagpt/actions/minecraft/generate_actions.py index 8cc32ec08..65433f326 100644 --- a/metagpt/actions/minecraft/generate_actions.py +++ b/metagpt/actions/minecraft/generate_actions.py @@ -22,6 +22,7 @@ class GenerateActionCode(Action): Implement the logic for generating action code here. """ + # logger.info(f"human_msg {human_msg}, system_msg {system_msg}") rsp = await self._aask(prompt=human_msg, system_msgs=system_msg) parsed_result = parse_action_response(rsp) # logger.info(f"parsed_result is HERE: {parsed_result}") diff --git a/metagpt/actions/minecraft/manage_skills.py b/metagpt/actions/minecraft/manage_skills.py index bee726f15..caec6c560 100644 --- a/metagpt/actions/minecraft/manage_skills.py +++ b/metagpt/actions/minecraft/manage_skills.py @@ -5,11 +5,8 @@ import os import json -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.vectorstores import Chroma -from metagpt.document_store import FaissStore from metagpt.logs import logger -from metagpt.actions import Action +from metagpt.actions.minecraft.player_action import PlayerActions as Action from metagpt.const import CKPT_DIR @@ -21,21 +18,6 @@ class RetrieveSkills(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) - # TODO: mv to PlayerAction - self.retrieval_top_k = 5 - self.vectordb = Chroma( - collection_name="skill_vectordb", - embedding_function=OpenAIEmbeddings(), - persist_directory=f"{CKPT_DIR}/skill/vectordb", - ) - # Check if skills right using - # TODO: - # assert self.vectordb._collection.count() == len(self.skills), ( - # f"Skill Manager's vectordb is not synced with skills.json.\n" - # f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n" - # f"Did you set resume=False when initializing the manager?\n" - # f"You may need to manually delete the vectordb directory for running from scratch." - # ) async def run(self, query, skills, *args, **kwargs): # Implement the logic for retrieving skills here. @@ -62,22 +44,6 @@ class AddNewSkills(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) - # TODO: mv to PlayerAction - self.vectordb = Chroma( - collection_name="skill_vectordb", - embedding_function=OpenAIEmbeddings(), - persist_directory=f"{CKPT_DIR}/skill/vectordb", - ) - # TODO: change to FaissStore - # self.qa_cache_questions_vectordb = FaissStore( {CKPT_DIR}/ 'skill/vectordb') - # TODO: - # Check if skills right using - # assert self.vectordb._collection.count() == len(self.skills), ( - # f"Skill Manager's vectordb is not synced with skills.json.\n" - # f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n" - # f"Did you set resume=False when initializing the manager?\n" - # f"You may need to manually delete the vectordb directory for running from scratch." - # ) async def run( self, task, program_name, program_code, skills, skill_desp, *args, **kwargs diff --git a/metagpt/actions/minecraft/player_action.py b/metagpt/actions/minecraft/player_action.py index 6597fc9a1..d83deff8f 100644 --- a/metagpt/actions/minecraft/player_action.py +++ b/metagpt/actions/minecraft/player_action.py @@ -3,8 +3,57 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : from metagpt.actions import Action +from langchain.vectorstores import Chroma +from langchain.embeddings.openai import OpenAIEmbeddings +from metagpt.document_store import FaissStore +from metagpt.const import CKPT_DIR class PlayerActions(Action): + def __init__(self, name="", context=None, llm=None): + super().__init__(name, context, llm) + self.skills = {} + self.qa_cache = {} + self.retrieval_top_k = 5 + self.vectordb = Chroma( + collection_name="skill_vectordb", + embedding_function=OpenAIEmbeddings(), + persist_directory=f"{CKPT_DIR}/skill/vectordb", + ) + + self.qa_cache_questions_vectordb = Chroma( + collection_name="qa_cache_questions_vectordb", + embedding_function=OpenAIEmbeddings(), + persist_directory=f"{CKPT_DIR}/curriculum/vectordb", + ) + # TODO: change to FaissStore + # self.qa_cache_questions_vectordb = FaissStore( {CKPT_DIR}/ 'curriculum/vectordb' + + @classmethod + def set_skills(cls, skills): + cls.skills = skills + # Check if Skill Manager's vectordb right using + assert cls.vectordb._collection.count() == len(cls.skills), ( + f"Skill Manager's vectordb is not synced with skills.json.\n" + f"There are {cls.vectordb._collection.count()} skills in vectordb but {len(cls.skills)} skills in skills.json.\n" + f"Did you set resume=False when initializing the manager?\n" + f"You may need to manually delete the vectordb directory for running from scratch." + ) + + @classmethod + def set_qa_cache(cls, qa_cache): + cls.qa_cache = qa_cache + # Check if qa_cache right using + # Check if Skill Manager's vectordb right using + assert cls.qa_cache_questions_vectordb._collection.count() == len( + cls.qa_cache + ), ( + f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n" + f"There are {cls.qa_cache_questions_vectordb._collection.count()} questions in vectordb " + f"but {len(cls.qa_cache)} questions in qa_cache.json.\n" + f"Did you set resume=False when initializing the agent?\n" + f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n" + ) + """Minecraft player info without any implementation details""" async def run(self, *args, **kwargs): raise NotImplementedError \ No newline at end of file diff --git a/metagpt/actions/minecraft/review_task.py b/metagpt/actions/minecraft/review_task.py index b532fb370..3a46b9752 100644 --- a/metagpt/actions/minecraft/review_task.py +++ b/metagpt/actions/minecraft/review_task.py @@ -15,7 +15,6 @@ class VerifyTask(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) - self.vect_db = "" async def run(self,human_msg, system_msg, max_retries=5, *args, **kwargs): # Implement the logic to verify the task here. @@ -29,7 +28,8 @@ class VerifyTask(Action): logger.info(f"Failed to parse Critic Agent response. Consider updating your prompt.") return False, "" - if human_msg or system_msg is None: + if human_msg is None: + logger.warning(f"Failed to get human_msg or system_msg.") return False, "" critic = await self._aask(prompt=human_msg, system_msgs=system_msg) try: diff --git a/metagpt/actions/ui_design.py b/metagpt/actions/ui_design.py new file mode 100644 index 000000000..693ca904d --- /dev/null +++ b/metagpt/actions/ui_design.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/17 13:43 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +from typing import List, Union + +from metagpt.tools.sd_engine import SDEngine + +from metagpt.actions.design import BaseModelAction +from metagpt.prompts.sd_design import MODEL_SELECTION_PROMPT + + +class SDPromptRanker(BaseModelAction): + """ + Class responsible for ranking multiple prompts based on current requirements and + the underlying model to determine the most suitable prompt. + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Prompt ranker", *args, **kwargs) + + +class SDImgScorer(BaseModelAction): + """ + 根据多个SD的生成结果,进行美学评分,选出评分最高的图片 + Class responsible for aesthetically scoring multiple SD generated results and + selecting the highest scoring image. + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Image Scorer", *args, **kwargs) + + +class LoraSelection(BaseModelAction): + """ + Class responsible for selecting the most suitable Lora based on the + current model and requirements. + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + +class ModelSelection(BaseModelAction): + DEFAULT_MODEL_INFO = { + "realisticVisionV30_v30VAE": "Real Effects, Real Photo/Photography, v3.0", + "pixelmix_v10": "an anime model merge with finetuned lineart and eyes." + } + + def __init__(self, name="ModelSelection", *args, **kwargs): + super().__init__(name, description="Select models", *args, **kwargs) + + def add_models(self, model_name="", model_desc=""): + updated_info = {model_name: model_desc} if model_name else {} + return {**self.DEFAULT_MODEL_INFO, **updated_info} + + async def run(self, query: str, system_text: str = "model selection"): + prompt = MODEL_SELECTION_PROMPT.format(query=query, model_info=self.add_models()) + resp = await self._aask(prompt=prompt, system_msgs=[system_text]) + result = resp.split("||") + model_name = result[0].replace("Model:", "").strip() + domain = result[-1].replace("Domain:", "").strip() + return model_name, domain + + +class SDGeneration(BaseModelAction): + """Generates an image via the sd t2i API.""" + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Stable Diffusion Generator", *args, **kwargs) + self.engine = SDEngine() + self.negative_prompts = {"realisticVisionV30_v30VAE": "worst quality, low quality, easynegative", + "pixelmix_v10": ""} + + def _construct_prompt(self, query: str, model_name: str) -> str: + """Constructs a prompt for the provided query and model.""" + negative_prompt = self.negative_prompts.get(model_name, "") + return self.engine.construct_payload(query, negative_prompt=negative_prompt, sd_model=model_name) + + async def _generate_image(self, queries: List[str], model_name: str, img_name: str) -> None: + """Generates image(s) using the provided queries and model name.""" + prompts = [self._construct_prompt(query, model_name) for query in queries] + await self.engine.run_t2i(prompts, save_name=img_name) + + async def run(self, query: Union[str, List[str]], model_name: str, **kwargs) -> None: + """ + Generate image via sd t2i API. + """ + img_name = kwargs.get("image_name", "") + + queries = [query] if isinstance(query, str) else query + await self._generate_image(queries, model_name, img_name) diff --git a/metagpt/minecraft_team.py b/metagpt/minecraft_team.py index 5ead788ce..68e20ea89 100644 --- a/metagpt/minecraft_team.py +++ b/metagpt/minecraft_team.py @@ -25,7 +25,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): """ 游戏环境的记忆,用于多个agent进行信息的共享和缓存,而不需要重复在自己的角色内维护缓存 """ - + event: dict[str, Any] = Field(default_factory=dict) current_task: str = Field(default="Mine 1 wood log") task_execution_time: float = Field(default=float) @@ -35,23 +35,24 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): code: str = Field(default="") program_name: str = Field(default="") critique: str = Field(default="") - skills: dict = Field(default_factory=dict) # for skills.json - retrieve_skills: list[str] = Field(default_factory=list) + skills: dict = Field(default_factory=dict) # for skills.json + retrieve_skills: list[str] = Field(default_factory=list) event_summary: str = Field(default="") - + qa_cache: dict[str, str] = Field(default_factory=dict) completed_tasks: list[str] = Field(default_factory=list) # Critique things failed_tasks: list[str] = Field(default_factory=list) - + skill_desp: str = Field(default="") - + chest_memory: dict[str, Any] = Field( default_factory=dict ) # eg: {'(1344, 64, 1381)': 'Unknown'} chest_observation: str = Field(default="") # eg: "Chests: None\n\n" - + mf_instance: MineflayerEnv = Field(default_factory=MineflayerEnv) - + runtime_status: bool = False # equal to action execution status: success or failed + @property def progress(self): # return len(self.completed_tasks) + 10 # Test only @@ -61,30 +62,30 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): def programs(self): programs = "" if self.code == "": - return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager + return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager for skill_name, entry in self.skills.items(): programs += f"{entry['code']}\n\n" for primitives in load_skills_code(): programs += f"{primitives}\n\n" - return programs - + return programs + @property def warm_up(self): return self.mf_instance.warm_up - + @property def core_inv_items_regex(self): return self.mf_instance.core_inv_items_regex - + def set_mc_port(self, mc_port): self.mf_instance.set_mc_port(mc_port) - + def set_mc_resume(self, resume: bool = False): # TODO: mv to config if resume: logger.info(f"Loading Action Developer from {CKPT_DIR}/action") with open(f"{CKPT_DIR}/action/chest_memory.json", "r") as f: self.chest_memory = json.load(f) - + logger.info(f"Loading Curriculum Agent from {CKPT_DIR}/curriculum") with open(f"{CKPT_DIR}/curriculum/completed_tasks.json", "r") as f: self.completed_tasks = json.load(f) @@ -92,46 +93,46 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): self.failed_tasks = json.load(f) with open(f"{CKPT_DIR}/curriculum/qa_cache.json", "r") as f: self.qa_cache = json.load(f) - + logger.info(f"Loading Skill Manager from {CKPT_DIR}/skill\033[0m") with open(f"{CKPT_DIR}/skill/skills.json", "r") as f: self.skills = json.load(f) - + def register_roles(self, roles: Iterable[Minecraft]): for role in roles: role.set_memory(self) - + def update_event(self, event: Dict): if self.event == event: return self.event = event - self.update_chest_memory(event) - self.event_summary = self.summarize_chatlog(event) - + # self.update_chest_memory(event) + # self.event_summary = self.summarize_chatlog(event) + def update_task(self, task: str): self.current_task = task - + def update_context(self, context: str): self.context = context - + def update_code(self, code: str): self.code = code # action_developer.gen_action_code to HERE - + def update_program_name(self, program_name: str): self.program_name = program_name - + def update_critique(self, critique: str): self.critique = critique # critic_agent.check_task_success to HERE - + def append_skill(self, skill: dict): self.skills[self.program_name] = skill # skill_manager.retrieve_skills to HERE - + def update_retrieve_skills(self, retrieve_skills: list): self.retrieve_skills = retrieve_skills - + def update_skill_desp(self, skill_desp: str): self.skill_desp = skill_desp - + def update_chest_memory(self, events: Dict): """ Input: events: Dict @@ -151,13 +152,13 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): self.chest_memory[position] = chest with open(f"{CKPT_DIR}/action/chest_memory.json", "w") as f: json.dump(self.chest_memory, f) - + def update_chest_observation(self): """ update chest_memory to chest_observation. Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py """ - + chests = [] for chest_position, chest in self.chest_memory.items(): if isinstance(chest, dict) and len(chest) > 0: @@ -175,7 +176,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): self.chest_observation = f"Chests:\n{chests}\n\n" else: self.chest_observation = f"Chests: None\n\n" - + def summarize_chatlog(self, events): def filter_item(message: str): craft_pattern = r"I cannot make \w+ because I need: (.*)" @@ -184,22 +185,25 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): ) mine_pattern = r"I need at least a (.*) to mine \w+!" if re.match(craft_pattern, message): - return re.match(craft_pattern, message).groups()[0] + self.event_summary = re.match(craft_pattern, message).groups()[0] elif re.match(craft_pattern2, message): - return "a nearby crafting table" + self.event_summary = "a nearby crafting table" elif re.match(mine_pattern, message): - return re.match(mine_pattern, message).groups()[0] + self.event_summary = re.match(mine_pattern, message).groups()[0] else: - return "" - + self.event_summary = "" chatlog = set() for event_type, event in events: if event_type == "onChat": item = filter_item(event["onChat"]) if item: chatlog.add(item) - return "I also need " + ", ".join(chatlog) + "." if chatlog else "" - + self.event_summary = "I also need " + ", ".join(chatlog) + "." if chatlog else "" + + def reset_block_info(self): + # revert all the placing event in the last step + pass + def update_exploration_progress(self, success: bool): """ Split task into completed_tasks or failed_tasks @@ -209,6 +213,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): "conversations": self.conversations, } """ + self.runtime_status = success task = self.current_task if task.startswith("Deposit useless items into the chest at"): return @@ -218,26 +223,25 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): else: logger.info(f"Failed to complete task {task}. Skipping to next task.") self.failed_tasks.append(task) - # TODO: when not success, transform code below to update event!(isolate step soon!) - # if self.reset_placed_if_failed and not success: - # # revert all the placing event in the last step - # blocks = [] - # positions = [] - # for event_type, event in events: - # if event_type == "onSave" and event["onSave"].endswith("_placed"): - # block = event["onSave"].split("_placed")[0] - # position = event["status"]["position"] - # blocks.append(block) - # positions.append(position) - # new_events = self.env.step( - # f"await givePlacedItemBack(bot, {U.json_dumps(blocks)}, {U.json_dumps(positions)})", - # programs=self.skill_manager.programs, - # ) - # events[-1][1]["inventory"] = new_events[-1][1]["inventory"] - # events[-1][1]["voxels"] = new_events[-1][1]["voxels"] - + # when not success, below to update event! + # revert all the placing event in the last step + blocks = [] + positions = [] + for event_type, event in self.event: + if event_type == "onSave" and event["onSave"].endswith("_placed"): + block = event["onSave"].split("_placed")[0] + position = event["status"]["position"] + blocks.append(block) + positions.append(position) + new_events = self.mf_instance.step( + f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})", + programs=self.programs, + ) + self.event[-1][1]["inventory"] = new_events[-1][1]["inventory"] + self.event[-1][1]["voxels"] = new_events[-1][1]["voxels"] + self.save_sorted_tasks() - + def save_sorted_tasks(self): updated_completed_tasks = [] # record repeated failed tasks @@ -246,25 +250,55 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): for task in self.completed_tasks: if task not in updated_completed_tasks: updated_completed_tasks.append(task) - + # remove completed tasks from failed tasks for task in updated_completed_tasks: while task in updated_failed_tasks: updated_failed_tasks.remove(task) - + self.completed_tasks = updated_completed_tasks self.failed_tasks = updated_failed_tasks - + # dump to json with open(f"{CKPT_DIR}/curriculum/completed_tasks.json", "w") as f: json.dump(self.completed_tasks, f) with open(f"{CKPT_DIR}/curriculum/failed_tasks.json", "w") as f: json.dump(self.failed_tasks, f) - - async def on_event(self, *args): + + async def on_event_retrieve(self, *args): """ Retrieve Minecraft events. + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + self.mf_instance.reset( + options={ + "mode": "soft", + "wait_ticks": 20, + } + ) + difficulty = ( + "easy" if len(self.completed_tasks) > 15 else "peaceful" + ) + events = self.mf_instance.step( + "bot.chat(`/time set ${getNextTime()}`);\n" + + f"bot.chat('/difficulty {difficulty}');" + ) + self.update_event(events) + return events + except Exception as e: + logger.error(f"Failed to retrieve Minecraft events: {str(e)}") + raise {} + + async def on_event_execute(self, *args): + """ + Execute Minecraft events. + This function is used to obtain events from the Minecraft environment. Check the implementation in the 'voyager/env/bridge.py step()' function to capture events generated within the game. @@ -275,37 +309,14 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True): Exception: If there is an issue retrieving events. """ try: - if not self.mf_instance.has_reset: - # TODO Modify - logger.info("Environment has not been reset yet, is resetting") - self.mf_instance.reset( - options={ - "mode": "soft", - "wait_ticks": 20, - } - ) - # raise {} - self.mf_instance.check_process() - self.mf_instance.unpause() - data = { - "code": self.code, - "programs": self.programs, - } - res = requests.post( - f"{self.mf_instance.server}/step", - json=data, - timeout=self.mf_instance.request_timeout, + events = self.mf_instance.step( + code = self.code, + programs=self.programs, ) - if res.status_code != 200: - logger.error("Failed to step Minecraft server") - raise {} - returned_data = res.json() - self.mf_instance.pause() - events = json.loads(returned_data) - logger.info(f"Get Current Event: {events}") + self.update_event(events) return events except Exception as e: - logger.error(f"Failed to retrieve Minecraft events: {str(e)}") + logger.error(f"Failed to execute Minecraft events: {str(e)}") raise {} @@ -314,16 +325,16 @@ class MinecraftPlayer(SoftwareCompany): Software Company: Possesses a team, SOP (Standard Operating Procedures), and a platform for instant messaging, dedicated to writing executable code. """ - + environment: Environment = Field(default_factory=Environment) game_memory: GameEnvironment = Field(default_factory=GameEnvironment) investment: float = Field(default=50.0) task: str = Field(default="") game_info: dict = Field(default={}) - + def set_port(self, mc_port): self.game_memory.set_mc_port(mc_port) - + def set_resume(self, resume: bool = False): self.game_memory.set_mc_resume(resume=resume) @@ -332,9 +343,9 @@ class MinecraftPlayer(SoftwareCompany): for role in self.environment.roles.values(): status = role.finish_step complete_round.append(status) - #if not status: + # if not status: # return complete_round - #complete_round = True + # complete_round = True complete_round_tag = all(complete_round) logger.info(f"complete_round {complete_round}") return complete_round_tag @@ -342,14 +353,14 @@ class MinecraftPlayer(SoftwareCompany): def update_round(self): for role in self.environment.roles.values(): role.finish_step = False - role.round_id+=1 + role.round_id += 1 role._rc.todo = None logger.info(f"round_id:{role.round_id}") - + def hire(self, roles: list[Role]): self.environment.add_roles(roles) self.game_memory.register_roles(roles) - + def start(self, task, round=0): """Start a project from publishing boss requirement.""" self.task = task @@ -357,30 +368,42 @@ class MinecraftPlayer(SoftwareCompany): Message(role="Player", content=task, cause_by=PlayerActions, round_id=round) ) logger.info(self.game_info) - + def _save(self): logger.info(self.json()) def _reset(self): for role_profile, role in self.environment.roles.items(): role.reset_state() - + async def run(self, n_round=3): """Run company until target round or no money""" - round_id=0 + round_id = 0 + self.game_memory.mf_instance.reset( + options={ + "mode": "soft", + "wait_ticks": 20, + } + ) + events = self.game_memory.mf_instance.step( + code="", + programs="", + ) + self.game_memory.update_event(events) + while n_round > 0: # self._save() if self.check_complete_round(): n_round -= 1 self.update_round() - round_id+=1 + round_id += 1 # add new task into env and continue - #fixme: update self.task + # fixme: update self.task self.start(task=self.task, round=round_id) - + logger.info(f"{n_round=}") self._check_balance() await self.environment.run() - #self.environment.memory.clear() - #self._reset() + # self.environment.memory.clear() + # self._reset() return self.environment.history diff --git a/metagpt/mineflayer_environment.py b/metagpt/mineflayer_environment.py index f01e10e50..6bec20fe7 100644 --- a/metagpt/mineflayer_environment.py +++ b/metagpt/mineflayer_environment.py @@ -136,6 +136,24 @@ class MineflayerEnv: self.pause() return json.loads(returned_data) + def step(self, code: str, programs: str = ""): + if not self.has_reset: + raise RuntimeError("Environment has not been reset yet") + self.check_process() + self.unpause() + data = { + "code": code, + "programs": programs, + } + res = requests.post( + f"{self.server}/step", json=data, timeout=self.request_timeout + ) + if res.status_code != 200: + raise RuntimeError("Failed to step Minecraft server") + returned_data = res.json() + self.pause() + return json.loads(returned_data) + def close(self): self.unpause() if self.connected: diff --git a/metagpt/prompts/sd_design.py b/metagpt/prompts/sd_design.py new file mode 100644 index 000000000..8719f1a5a --- /dev/null +++ b/metagpt/prompts/sd_design.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/18 09:51 +@Author : stellahong +@File : __init__.py +""" + +MODEL_SELECTION_PROMPT = """Please help me find a suitable model for painting in this scene. +Model list will be given in the format like: +''' +model_name: model desc, +''' + +you should select the model and tell me the model name. answer it in the form like Model: model_name || Domain:xxx + +### +Model List: +{model_info} + +My scene is: {query} +""" + +DOMAIN_JUDGEMENT_TEMPLATE = ''' +use model {model_name}, decide the domain, answer it in the form like Domain: xxx + +### +Model Information: +{model_info} + +''' + +MODEL_SELECTION_OUTPUT_MAPPING = { + "Model:": (str, ...), } + +SD_PROMPT_KW_OPTIMIZE_TEMPLATE = ''' +I want you to act as a prompt generator. Compose each answer as a visual sentence. Do not write explanations on replies. Format the answers as javascript json arrays with a single string per answer. Return exactly {answer_count} to my question. Answer the questions exactly, in the form like responses:xxx. Answer the following question: + +Find 3 keywords related to the prompt "{messages}" that are not found in the prompt. The keywords should be related to each other. Each keyword is a single word. + +''' + +SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE = ''' +I want you to act as a prompt generator. Compose each answer as a visual sentence. Do not write explanations on replies. Format the answers as javascript json arrays with a single string per answer. Return exactly {answer_count} to my question. Answer the questions exactly, in the form like responses:xxx. Answer the following question: + +domain is {domain} + +if domain is anime or game like, Take the prompt "{messages}, Cute kawaii sticker , white background, vector, pastel colors" and improve it. + +if domain is realistic like, Take the prompt "{messages}" and improve it. + +''' +# Die-cut sticker, illustration minimalism, + +FORMAT_INSTRUCTIONS = """The problem is to make the user input a better text2image prompt, the input is {query}" + + Let's first understand the problem and devise a plan to solve the problem. + + Based on the text2image model selected {model_name} and domain {domain} + You have access to the following tools: + + {tool_names} + {tool_description} + + Use a json blob to specify a tool by providing an action key (tool name) and an Observation (tool description). + + Valid "action" values: {tool_names} + + Provide only ONE action per $JSON_BLOB, as shown: + + ``` + {{{{ + "action": $TOOL_NAME, + "Observation": $TOOL_DESCRIPTION + }}}} + ``` + + Follow this format: + + ## Think Chain + ``` + Question: input question to answer + Thought: select a better method for the input by go through these two tools and its observations respectively + Action1: + ``` + $JSON_BLOB + ``` + Action2: + ``` + $JSON_BLOB + ``` + + Thought:When evaluating a prompt's richness, I need to specify which tool to use and I can only select one tool . To finish this selection, in the form: + ## Final Action: + TOOL_NAME + + """ + +PROMPT_OUTPUT_MAPPING = { + "Final Action:": (str, ...), +} diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 7e865f288..303b1bbf7 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -187,7 +187,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, - "temperature": 0.3, + "temperature": 0.0, "timeout": 3, } if CONFIG.openai_api_type == "azure": diff --git a/metagpt/roles/minecraft/action_developer.py b/metagpt/roles/minecraft/action_developer.py index 9171e455b..4f585ea26 100644 --- a/metagpt/roles/minecraft/action_developer.py +++ b/metagpt/roles/minecraft/action_developer.py @@ -2,6 +2,8 @@ # @Date : 2023/9/23 12:45 # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : +import copy + from metagpt.logs import logger from metagpt.roles.minecraft.minecraft_base import Minecraft as Base from metagpt.schema import Message, HumanMessage, SystemMessage @@ -12,11 +14,14 @@ from metagpt.actions.minecraft.manage_skills import ( RetrieveSkills, AddNewSkills, ) +from metagpt.actions.minecraft.review_task import VerifyTask import metagpt.utils.minecraft as utils from metagpt.config import CONFIG from metagpt.actions.minecraft.control_primitives_context import ( load_skills_code_context, ) +from metagpt.utils.minecraft import fix_and_parse_json +from metagpt.roles.minecraft.critic_agent import CriticReviewer @agent_registry.register("action_developer") @@ -25,28 +30,32 @@ class ActionDeveloper(Base): iterative prompting mechanism in paper. generate action code based on environment observation and plan, as well as skills retrieval results """ - + def __init__( - self, - name: str = "Bob", - profile: str = "Generate code for specified tasks", - goal: str = "Produce accurate and efficient code solutions in Python and JavaScript", - constraints: str = "Adhere to coding best practices and style guidelines", + self, + name: str = "Bob", + profile: str = "Generate code for specified tasks", + goal: str = "Produce accurate and efficient code solutions in Python and JavaScript", + constraints: str = "Adhere to coding best practices and style guidelines", ) -> None: super().__init__(name, profile, goal, constraints) # Initialize actions specific to the Action role self._init_actions([GenerateActionCode]) - + # Set events or actions the ActionAgent should watch or be aware of # 需要根据events进行自己chest_observation的更新 self._watch([RetrieveSkills]) - + self.rollout_num_iter = 0 + self.task_max_retries = 4 + self.critic_reviewer = None # self._rc.env.roles["Task Reviewer"] + logger.info(self.critic_reviewer) + def render_system_message(self, skills=[], *args, **kwargs): """ According to basic skills context files to genenarate js skill codes. Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py """ - + action_template = utils.load_prompt("action_template") base_skills = [ "exploreUntil", @@ -69,21 +78,21 @@ class ActionDeveloper(Base): system_action_message = SystemMessage(content=system_action_prompt) assert isinstance(system_action_message, SystemMessage) return system_action_message - + def render_human_message( - self, events, code="", task="", context="", critique="", *args, **kwargs + self, events, code="", task="", context="", critique="", *args, **kwargs ): """ Integrate observation about the environment(especially events), add to HumanMessage. Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py """ - + # Deal with events info chat_messages = [] error_messages = [] # damage_messages = [] # TODO: try to add damage_messages into prompt later assert events[-1][0] == "observe", "Last event must be observe" - + for i, (event_type, event) in enumerate(events): if event_type == "onChat": chat_messages.append(event["onChat"]) @@ -101,30 +110,30 @@ class ActionDeveloper(Base): inventory_used = event["status"]["inventoryUsed"] inventory = event["inventory"] assert i == len(events) - 1, "observe must be the last event" - + # Collect all the environment information into a str: observation observation = "" - + observation = ( f"Code from the last round:\n{code or 'No code in the first round'}\n\n" ) - + if error_messages: error = "\n".join(error_messages) observation += f"Execution error:\n{error}\n\n" else: observation += f"Execution error: No error\n\n" - + if chat_messages: chat_log = "\n".join(chat_messages) observation += f"Chat log: {chat_log}\n\n" else: observation += f"Chat log: None\n\n" - + observation += f"Biome: {biome}\n\n" observation += f"Time: {time_of_day}\n\n" observation += f"Nearby blocks: {', '.join(voxels) if voxels else 'None'}\n\n" - + if entities: nearby_entities = [ k for k, v in sorted(entities.items(), key=lambda x: x[1]) @@ -132,35 +141,35 @@ class ActionDeveloper(Base): observation += f"Nearby entities (nearest to farthest): {', '.join(nearby_entities)}\n\n" else: observation += f"Nearby entities (nearest to farthest): None\n\n" - + observation += f"Health: {health:.1f}/20\n\n" observation += f"Hunger: {hunger:.1f}/20\n\n" observation += f"Position: x={position['x']:.1f}, y={position['y']:.1f}, z={position['z']:.1f}\n\n" observation += f"Equipment: {equipment}\n\n" observation += f"Inventory ({inventory_used}/36): {'Empty' if not inventory else ', '.join(inventory)}\n\n" - + if not ( - task == "Place and deposit useless items into a chest" - or task.startswith("Deposit useless items into the chest at") + task == "Place and deposit useless items into a chest" + or task.startswith("Deposit useless items into the chest at") ): observation += self.game_memory.chest_observation - + observation += f"Task: {task}\n\n" observation += f"Context: {context or 'None'}\n\n" observation += f"Critique: {critique or 'None'}\n\n" - + return HumanMessage(content=observation) - + def encapsule_message( - self, - events, - code="", - task="", - context="", - critique="", - skills=[], - *args, - **kwargs, + self, + events, + code="", + task="", + context="", + critique="", + skills=[], + *args, + **kwargs, ): system_message = self.render_system_message(skills=skills) human_message = self.render_human_message( @@ -170,7 +179,7 @@ class ActionDeveloper(Base): "system_msg": [system_message.content], "human_msg": human_message.content, } - + async def _observe(self) -> int: await super()._observe() for msg in self._rc.news: @@ -180,7 +189,117 @@ class ActionDeveloper(Base): ] # only relevant msgs count as observed news logger.info(len(self._rc.news)) return len(self._rc.news) + + async def run_step(self, human_msg, system_msg, *args, **kwargs): + while True: + logger.info(f"self.rollout_num_iter {self.rollout_num_iter}") + system_msg, human_msg, reward, done, info = await self.runcode_and_evaluate(human_msg, system_msg, *args, + **kwargs) + if done: + break + # return [system_msg, human_msg], reward, done, info + return Message( + content=f"{info}", + instruct_content="generate_action_code", + role=self.profile, + ) + + async def handle_add_new_skills( + self, task, program_name, program_code, skills, *args, **kwargs + ): + skill_desp = self.game_memory.skill_desp + new_skills_info = await AddNewSkills().run( + task, program_name, program_code, skills, skill_desp + ) + # update skills in game memory + self.perform_game_info_callback(new_skills_info, self.game_memory.append_skill) + + async def retrieve_skills(self, query, skills, *args, **kwargs): + retrieve_skills = await RetrieveSkills().run(query, skills) + logger.info(f"Render Action Agent system message with {len(retrieve_skills)} skills") + self.perform_game_info_callback(retrieve_skills, self.game_memory.update_retrieve_skills) + # return Message(content=f"{retrieve_skills}", instruct_content="retrieve_skills", + # role=self.profile, send_to=agent_registry.entries["action_developer"]()._setting.name) + + async def runcode_and_evaluate(self, human_msg, system_msg, *args, **kwargs): + """ + equal to step() in voyager + """ + task = self.game_memory.current_task + context = self.game_memory.context + + # 更新生成的代码和对应程序名称 + code, program_name = await GenerateActionCode().run( + human_msg, system_msg, *args, **kwargs + ) + # logger.warning(type(code)) + # logger.info(f"Code is Here:{code}") + + if code is not None: + # fixme:若有独立的mc code执行入口函数,使用独立的函数 + events = await self._execute_events() + # 注意:这里的events对应是执行了新的action函数之后的events信息 + # 更新了评估结果, 回调了最新的环境信息到ga + self.critic_reviewer = self._rc.env.roles["Task Reviewer"] + await self.critic_reviewer._act() # todo: critic act内的update event放在这里似乎更合理? + + critique = self.game_memory.critique + self.perform_game_info_callback(self.game_memory.event, self.game_memory.summarize_chatlog) + + event_summary = self.game_memory.event_summary + skills = self.game_memory.skills + + if not self.game_memory.runtime_status: + # todo: callback game memory reset block info + logger.info("Not success, reset block info !") + logger.info( + f"\033[32m****Action Agent human message****\n{human_msg}\033[0m" + ) + + # add new skills no matter success or not + # add_new_skills_message = { + # "task": task, + # "program_name": program_name, + # "program_code": code, + # "skills": self.game_memory.skills, + # } + new_skill_info = {"query": context + "\n\n" + event_summary, "skills": skills} + + # await self.handle_add_new_skills(**add_new_skills_message) + await self.retrieve_skills(**new_skill_info) + retrieve_skills = self.game_memory.retrieve_skills + + message = self.encapsule_message( + events=events, + code=code, + task=task, + context=context, + critique=critique, + skills=retrieve_skills, + ) + + system_msg = message["system_msg"] + human_msg = message["human_msg"] + else: + self.critic_reviewer.maintain_actions(VerifyTask()) + logger.info(f"system msg is {system_msg}, \n human_msg is {human_msg}") + logger.info(f"\033[34m Trying again!\033[0m") + + self.rollout_num_iter += 1 + done = (self.rollout_num_iter >= self.task_max_retries or self.game_memory.runtime_status) + info = { + "task": self.game_memory.current_task, + "success": self.game_memory.runtime_status, + } + logger.info(f"info is {info}") + self.perform_game_info_callback(code, self.game_memory.update_code) + self.perform_game_info_callback( + program_name, self.game_memory.update_program_name + ) + + return system_msg, human_msg, 0, done, info + async def generate_action_code(self, human_msg, system_msg, *args, **kwargs): code, program_name = await GenerateActionCode().run( human_msg, system_msg, *args, **kwargs @@ -196,22 +315,27 @@ class ActionDeveloper(Base): instruct_content="generate_action_code", role=self.profile, ) - # logger.info(msg) + return msg - + async def _act(self) -> Message: todo = self._rc.todo logger.debug(f"Todo is {todo}") self.maintain_actions(todo) + # 获取最新的游戏周边信息 - events = await self._obtain_events() - self.perform_game_info_callback(events, self.game_memory.update_event) + # events = await self._obtain_events() + events = self.game_memory.event + logger.info(events) + # self.perform_game_info_callback(events, self.game_memory.update_event) + logger.info(self.game_memory.event_summary) context = self.game_memory.context task = self.game_memory.current_task code = self.game_memory.code critique = self.game_memory.critique retrieve_skills = self.game_memory.retrieve_skills - + + # 对自己所需的环境信息进行处理 message = self.encapsule_message( events=events, code=code, @@ -222,17 +346,18 @@ class ActionDeveloper(Base): ) logger.info(todo) handler_map = { - GenerateActionCode: self.generate_action_code, + GenerateActionCode: self.run_step # self.generate_action_code, } handler = handler_map.get(type(todo)) logger.info(handler) - + if handler: msg = await handler(**message) - msg.cause_by = type(todo) + msg.cause_by = GenerateActionCode msg.round_id = self.round_id logger.info(msg.send_to) + self.rollout_num_iter = 0 self._publish_message(msg) return msg - + raise ValueError(f"Unknown todo type: {type(todo)}") diff --git a/metagpt/roles/minecraft/critic_agent.py b/metagpt/roles/minecraft/critic_agent.py index 7bb90767a..3bf632909 100644 --- a/metagpt/roles/minecraft/critic_agent.py +++ b/metagpt/roles/minecraft/critic_agent.py @@ -28,11 +28,27 @@ class CriticReviewer(Base): ) -> None: super().__init__(name, profile, goal, constraints) # Initialize actions specific to the CriticReviewer role + # self._init_actions([VerifyTask]) self._init_actions([VerifyTask]) # Set events or actions the CriticReviewer should watch or be aware of # 需要获取最新的events来进行评估 - self._watch([GenerateActionCode, AddNewSkills]) + self._watch([]) + + async def run(self, message=None): + """Observe, only get the observation""" + if message: + if isinstance(message, str): + message = Message(message) + if isinstance(message, Message): + self.recv(message) + if isinstance(message, list): + self.recv(Message("\n".join(message))) + elif not await self._observe(): + # If there is no new information, suspend and wait + logger.info(f"{self._setting}: no news. waiting.") + return + self._rc.todo = VerifyTask def render_system_message(self): system_message = SystemMessage(content=load_prompt("critic")) @@ -119,6 +135,9 @@ class CriticReviewer(Base): self.perform_game_info_callback( success, self.game_memory.update_exploration_progress ) + self.perform_game_info_callback( + critique, self.game_memory.update_critique + ) return Message( content=f"{critique}", instruct_content="verify_task", @@ -126,16 +145,19 @@ class CriticReviewer(Base): send_to=agent_registry.entries["skill_manager"]()._setting.name, ) # addnewskill # TODO:if not success + async def _act(self) -> Message: + self._rc.todo = VerifyTask() todo = self._rc.todo + logger.debug(f"Todo is {todo}") + self.maintain_actions(todo) # 获取最新的游戏周边信息 - events = await self._obtain_events() - self.perform_game_info_callback( - events, self.game_memory.update_event - ) # update chest_memory / chest observation + events = await self._execute_events() + self.perform_game_info_callback(events, self.game_memory.update_chest_memory) + logger.info(f"Execute return event is {self.game_memory.event}") context = self.game_memory.context task = self.game_memory.current_task chest_observation = self.game_memory.chest_observation diff --git a/metagpt/roles/minecraft/curriculum_agent.py b/metagpt/roles/minecraft/curriculum_agent.py index 602792d06..68e394786 100644 --- a/metagpt/roles/minecraft/curriculum_agent.py +++ b/metagpt/roles/minecraft/curriculum_agent.py @@ -313,11 +313,11 @@ class CurriculumDesigner(Base): logger.debug(f"Todo is {todo}") self.maintain_actions(todo) # 获取最新的游戏周边环境信息 - events = await self._obtain_events() - self.perform_game_info_callback(events, self.game_memory.update_event) + # events = await self._obtain_events() + events = self.game_memory.event chest_observation = self.game_memory.chest_observation - DesignCurriculum.set_qa_cache(self.game_memory.qa_cache) + # DesignCurriculum.set_qa_cache(self.game_memory.qa_cache) # msg = self._rc.memory.get(k=1)[0] # query = msg.content @@ -335,7 +335,7 @@ class CurriculumDesigner(Base): } handler = handler_map.get(type(todo)) if handler: - if type(todo) == "DesignTask": + if type(todo) == DesignTask: msg = await handler(**design_task_message) else: msg = await handler(**design_curriculum_message) diff --git a/metagpt/roles/minecraft/minecraft_base.py b/metagpt/roles/minecraft/minecraft_base.py index 47852b4fb..dbc3c10a9 100644 --- a/metagpt/roles/minecraft/minecraft_base.py +++ b/metagpt/roles/minecraft/minecraft_base.py @@ -103,7 +103,10 @@ class Minecraft(Role): self._rc.todo = None async def _obtain_events(self): - return await self.game_memory.on_event() + return await self.game_memory.on_event_retrieve() + + async def _execute_events(self): + return await self.game_memory.on_event_execute() def set_memory(self, shared_memory: 'GameEnviroment'): self.game_memory = shared_memory @@ -116,7 +119,7 @@ class Minecraft(Role): @staticmethod def perform_game_info_callback(info: object, callback: object) -> object: - logger.info(info) + # logger.info(info) callback(info) def encapsule_message(self, msg, *args, **kwargs): @@ -130,5 +133,5 @@ agent_registry = Registry(name="Minecraft") if __name__ == "__main__": mc = Minecraft() result = "Async operation result" - # 调用回调函数,并传递结果 + # ûصݽ # mc.perform_memory_callback(mc.my_callback) diff --git a/metagpt/roles/minecraft/skill_manager.py b/metagpt/roles/minecraft/skill_manager.py index 4dddf0ab1..161ec08ae 100644 --- a/metagpt/roles/minecraft/skill_manager.py +++ b/metagpt/roles/minecraft/skill_manager.py @@ -28,7 +28,7 @@ class SkillManager(Base): super().__init__(name, profile, goal, constraints) # Initialize actions specific to the SkillManager role - self._init_actions([RetrieveSkills, GenerateSkillDescription]) #AddNewSkills])#先去掉add + self._init_actions([RetrieveSkills, GenerateSkillDescription, AddNewSkills]) #AddNewSkills])#先去掉add # Set events or actions the SkillManager should watch or be aware of self._watch( @@ -36,8 +36,8 @@ class SkillManager(Base): ) def encapsule_message(self, program_code, program_name, *args, **kwargs): - human_msg = self.render_system_message(load_prompt("skill")) - system_msg = self.render_human_message( + system_msg = self.render_system_message(load_prompt("skill")) + human_msg = self.render_human_message( program_code + "\n\n" + f"The main function is `{program_name}`." ) return {"system_msg": [system_msg.content], "human_msg": human_msg.content} @@ -65,6 +65,13 @@ class SkillManager(Base): async def handle_add_new_skills( self, task, program_name, program_code, skills, *args, **kwargs ): + if not self.game_memory.runtime_status: + return Message( + content="", + instruct_content="handle_add_new_skills", + role=self.profile, + ) + skill_desp = self.game_memory.skill_desp new_skills_info = await AddNewSkills().run( task, program_name, program_code, skills, skill_desp @@ -83,8 +90,10 @@ class SkillManager(Base): # 获取最新的游戏周边信息 context = self.game_memory.context task = self.game_memory.current_task - event_summary = self.game_memory.event_summary + code = self.game_memory.code + self.perform_game_info_callback(self.game_memory.event, self.game_memory.summarize_chatlog) + event_summary = self.game_memory.event_summary try: program_code = code["program_code"] # TODO: Handle code is None, cuz first round DesignCurriculum(code is None) trigger this except (KeyError, TypeError): @@ -96,7 +105,9 @@ class SkillManager(Base): # msg = self._rc.memory.get(k=1)[0] retrieve_skills_message_step1 = {"query": context, "skills": skills} - + logger.info(f"check query {context}") + logger.info(f"check event summary {event_summary}") + retrieve_skills_message_step2 = {"query": context + "\n\n" + event_summary, "skills": skills} generate_skill_message = self.encapsule_message(program_code, program_name) diff --git a/metagpt/roles/ui_designer.py b/metagpt/roles/ui_designer.py new file mode 100644 index 000000000..05b906c16 --- /dev/null +++ b/metagpt/roles/ui_designer.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/16 13:58 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +from functools import wraps +import json5 + +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.schema import Message + +from metagpt.actions.design import Tool, SDPromptExtend, SDPromptOptimize, SDPromptImprove +from metagpt.actions.ui_design import ModelSelection, SDGeneration + + +def retrieve(func): + @wraps(func) + def wrapper(*args, **kwargs): + content, keyword = func(*args, **kwargs) + info = content.replace(keyword, "") + return info + + return wrapper + + +class Designer(Role): + """Class representing the UI designer Role.""" + + def __init__( + self, + name="Catherine", + profile="UI Design", + goal="Generate UI icon", + constraints="Give clear icon description and generate images to finish the design", + actions=[ModelSelection, SDPromptExtend, SDGeneration]): + super().__init__(name, profile, goal, constraints) + + self._init_actions(actions) + + @property + def memory_model_name(self): + return "MODEL_NAME: " + + @property + def memory_user_input(self): + return "User Input: " + + @property + def memory_domain(self): + return "Domain: " + + def memory_property(self, memory_keyword: str, memory_content: str): + self._rc.memory.add(Message(f"{memory_keyword}{memory_content}", role=self.profile)) + + @retrieve + def get_important_memory(self, keyword: str): + query_memory = self._rc.memory.get_by_content(keyword)[0] + return query_memory.content, keyword + + async def _plan_and_select(self): + """ + 这里实现的是二选一的option,action在这里进行了选择 + 理论上应该可以实现4种选择 (&:表示串行顺序),目前只选择了前2种 + 1) action1 + 2) action2 + 3) action1 & action2 + 4) action2 & action1 + """ + msg = self._rc.memory.get(k=1)[0] + query = msg.content + logger.info(query) + if query == "PromptImprove": + self._actions.insert(self._rc.state + 1, SDPromptImprove()) + elif query == "PromptOptimize": + self._actions.insert(self._rc.state + 1, SDPromptOptimize()) + return self._rc.state + + async def _think(self) -> None: + logger.info(self._rc.state) + if self._rc.todo is None: + self._set_state(0) + return + + if self._rc.state == 1: + await self._plan_and_select() + self._set_state(self._rc.state + 1) + + elif self._rc.state + 1 < len(self._actions): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + + async def handle_model_selection(self, query, **kwargs): + ms = ModelSelection() + model_name, domain = await ms.run(query) + logger.info(f"{model_name}, {domain}") + + self.memory_property(self.memory_user_input, query) + self.memory_property(self.memory_model_name, model_name) + self.memory_property(self.memory_domain, domain) + return f"{model_name}||{domain}" + + async def handle_sd_prompt_extend(self, *args, **kwargs): + tools = [ + Tool(name="PromptOptimize", + func=SDPromptOptimize().run, + description="Find 3 keywords related to the prompt that are not found in the prompt. The keywords should be related to each other. Each keyword is a single word. useful for when you need to add extra keywords for input prompt, specially for long enough input"), + + Tool(name="PromptImprove", + func=SDPromptImprove().run, + description="Take the prompt and improve it. useful for when you need to add improve and extend the prompt for input prompt, specially for short input"), + + ] + + query = self.get_important_memory(self.memory_user_input) + domain = self.get_important_memory(self.memory_domain) + sd_exd = SDPromptExtend(tools=tools) + resp = await sd_exd.run(query=query, domain=domain, answer_count=1) + return resp + + async def handle_sd_prompt_improve(self, *args, **kwargs): + query = self.get_important_memory(self.memory_user_input) + domain = self.get_important_memory(self.memory_domain) + sd_pi = SDPromptImprove() + resp = await sd_pi.run(query=query, domain=domain, answer_count=1) + return resp + + async def handle_sd_prompt_optimize(self, *args, **kwargs): + query = self.get_important_memory(self.memory_user_input) + domain = self.get_important_memory(self.memory_domain) + sd_op = SDPromptOptimize() + resp = await sd_op.run(query=query, domain=domain, answer_count=1) + return resp + + async def handle_sd_generation(self, *args, **kwargs): + msg = self._rc.memory.get_by_action(SDPromptImprove)[0] + image_name = self.get_important_memory(self.memory_user_input) + logger.info(type(msg.content)) + logger.info(msg.content) + resp = json5.loads(msg.content) + logger.info(resp) + model_name = self.get_important_memory(self.memory_model_name) + await SDGeneration().run(query=resp, model_name=model_name, **{"image_name":image_name}) + return resp + + async def _act(self) -> Message: + logger.info(f"{self._setting}: ready to {self._rc.todo}") + todo = self._rc.todo + msg = self._rc.memory.get(k=1)[0] + query = msg.content + logger.info(msg.cause_by) + logger.info(query) + logger.info(todo) + + handler_map = { + ModelSelection: self.handle_model_selection, + SDPromptExtend: self.handle_sd_prompt_extend, + + SDPromptImprove: self.handle_sd_prompt_improve, + SDPromptOptimize: self.handle_sd_prompt_optimize, + + SDGeneration: self.handle_sd_generation, + } + + handler = handler_map.get(type(todo)) + if handler: + resp = await handler(query) + if type(todo) in [SDPromptImprove, SDPromptOptimize]: + ret = Message(f"{resp}", role=self.profile, cause_by=SDPromptImprove) + else: + ret = Message(f"{resp}", role=self.profile, cause_by=type(todo)) + self._rc.memory.add(ret) + return ret + + raise ValueError(f"Unknown todo type: {type(todo)}") + + async def _react(self) -> Message: + while True: + await self._think() + if self._rc.todo is None: + break + + msg = await self._act() + return msg + + +if __name__ == "__main__": + import asyncio + import platform + test_queries = ["Flappy Bird", + "Clash of Clans", + "Subway Surfers", + "Pokémon Go", + "Super Mario", + "Tetris", + "Call of Duty" + ] + + for prompt in test_queries: + + designer = Designer() + if platform.system() == "Windows": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(designer.run(prompt)) + \ No newline at end of file diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index 1d9cd0b2a..2e34be9de 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -27,7 +27,7 @@ payload = { "batch_size": 1, "n_iter": 1, "steps": 20, - "cfg_scale": 7, + "cfg_scale": 9, "width": 512, "height": 768, "restore_faces": False, @@ -62,52 +62,54 @@ class SDEngine: # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - + def construct_payload( - self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", + self, + prompt, + negative_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", + **kwargs ): # Configure the payload with provided inputs self.payload["prompt"] = prompt - self.payload["negtive_prompt"] = negtive_prompt + self.payload["negative_prompt"] = negative_prompt self.payload["width"] = width self.payload["height"] = height self.payload["override_settings"]["sd_model_checkpoint"] = sd_model + self.payload.update(**kwargs) logger.info(f"call sd payload is {self.payload}") return self.payload - + def _save(self, imgs, save_name=""): save_dir = WORKSPACE_ROOT / "resources" / "SD_Output" - if not os.path.exists(save_dir): + if not save_dir.exists(): os.makedirs(save_dir, exist_ok=True) batch_decode_base64_to_image(imgs, save_dir, save_name=save_name) - - async def run_t2i(self, prompts: List): + + async def run_t2i(self, prompts: List, save_name=""): # Asynchronously run the SD API for multiple prompts session = ClientSession() for payload_idx, payload in enumerate(prompts): results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) - self._save(results, save_name=f"output_{payload_idx}") + self._save(results, save_name=f"{save_name}_output_{payload_idx}") await session.close() - + async def run(self, url, payload, session): # Perform the HTTP POST request to the SD API async with session.post(url, json=payload, timeout=600) as rsp: data = await rsp.read() - + rsp_json = json.loads(data) imgs = rsp_json["images"] logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - + async def run_i2i(self): # todo: 添加图生图接口调用 raise NotImplementedError - + async def run_sam(self): # todo:添加SAM接口调用 raise NotImplementedError @@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): if __name__ == "__main__": engine = SDEngine() prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - + engine.construct_payload(prompt) - + event_loop = asyncio.get_event_loop() event_loop.run_until_complete(engine.run_t2i(prompt)) diff --git a/metagpt/utils/__init__.py b/metagpt/utils/__init__.py index f13175cf8..a7535383a 100644 --- a/metagpt/utils/__init__.py +++ b/metagpt/utils/__init__.py @@ -6,7 +6,7 @@ @File : __init__.py """ -from metagpt.utils.read_document import read_docx +#from metagpt.utils.read_document import read_docx from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -16,7 +16,7 @@ from metagpt.utils.token_counter import ( __all__ = [ - "read_docx", +# "read_docx", "Singleton", "TOKEN_COSTS", "count_message_tokens", diff --git a/metagpt/utils/resp_parse.py b/metagpt/utils/resp_parse.py new file mode 100644 index 000000000..d87b87883 --- /dev/null +++ b/metagpt/utils/resp_parse.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/22 22:18 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import json5 +import re + + +def flatten_json_structure(json_array): + if (isinstance(json_array, list) and len(json_array) == 1 and not isinstance(json_array[0], str)): + return flatten_json_structure(json_array[0]) + + if (isinstance(json_array, dict) and len(json_array.values()) == 1 and not isinstance(list(json_array.values())[0], + str)): + return flatten_json_structure(list(json_array.values())[0]) + + flattened_json_array = [] + + if (isinstance(json_array, dict)): + json_array = json_array.values() + + for json_object in json_array: + flattened_dict = flatten_json_object(json_object) + flattened_values = ", ".join(str(v) for v in flattened_dict.values()) + flattened_json_array.append(flattened_values) + + return flattened_json_array + + +def flatten_json_object(obj, parent_key='', sep=', '): + if isinstance(obj, str): + return dict([("value", obj)]) + + if isinstance(obj, list): + return dict([("value", sep.join(str(v) for v in obj))]) + + items = [] + for key, value in obj.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + items.extend(flatten_json_object(value, new_key, sep=sep).items()) + elif isinstance(value, list) or isinstance(value, tuple): + items.append((new_key, sep.join(str(v) for v in value))) + else: + items.append((new_key, value)) + return dict(items) + + +def try_parse_json(input_text): + input_text.index + start_index_brackets = input_text.find('[') + end_index_brackets = input_text.rfind(']') + start_index_curly = input_text.find('{') + end_index_curly = input_text.rfind('}') + + start_index = start_index_brackets + end_index = end_index_brackets + + if (start_index_curly != -1 and (start_index_curly < start_index_brackets or start_index_brackets < 0)): + start_index = start_index_curly + end_index = end_index_curly + + if start_index >= 0 and end_index > 0: + json_string = input_text[start_index:end_index + 1] + json_string = re.sub(r'\}[\s]*\{', '}, {', json_string) + json_string = re.sub(r'\][\s]*\[', '], [', json_string) + json_string = re.sub(r'\"[\s]*\"', '", "', json_string) + + try: + json_object = json5.loads(json_string) + except ValueError: + json_object = json5.loads(f"[{json_string}]") + + return json_object + + raise Exception("No JSON object found in input text.") diff --git a/minecraft_run.py b/minecraft_run.py index d7d2cf7c2..ed9276265 100644 --- a/minecraft_run.py +++ b/minecraft_run.py @@ -13,7 +13,7 @@ from metagpt.minecraft_team import MinecraftPlayer async def learn(task="Start", investment: float = 50.0, n_round: int = 3): mc_player = MinecraftPlayer() - mc_player.set_port(1077) # Modify this to your Minecraft LAN port + mc_player.set_port(33141) # Modify this to your Minecraft LAN port # mc_player.set_resume(True) # If load json from ckpt dir(include chest_memory, skills, ...) mc_player.hire( [ diff --git a/tests/metagpt/actions/test_sd_design.py b/tests/metagpt/actions/test_sd_design.py new file mode 100644 index 000000000..0e321b342 --- /dev/null +++ b/tests/metagpt/actions/test_sd_design.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/7/22 02:40 +# @Author : stellahong (stellahong@fuzhi.ai) +# + +import pytest +from typing import List +from metagpt.actions.design import SDPromptOptimize, SDPromptImprove +from metagpt.actions.ui_design import ModelSelection + + +@pytest.mark.asyncio +async def test_ui_model_selection(): + ms = ModelSelection() + model_name, domain = await ms.run("Pokémon Go") + assert model_name == "pixelmix_v10" + + +@pytest.mark.asyncio +async def test_ui_sd_generation(): + pass + + +@pytest.mark.asyncio +async def test_ui_sd_prompt_optimize(): + sd_po = SDPromptOptimize() + resp = await sd_po.run(query="Pokémon Go", domain="Anime", answer_count=1) + assert type(resp) == List + assert len(resp) == 1 + + +@pytest.mark.asyncio +async def test_ui_sd_optimize_answer_count(): + sd_po = SDPromptOptimize() + answer_count = 2 + resp = await sd_po.run(query="Pokémon Go", domain="Anime", answer_count=2) + assert type(resp) == List + assert len(resp) == answer_count + +@pytest.mark.asyncio +async def test_ui_sd_improve_answer_count(): + sd_pi = SDPromptImprove() + answer_count = 2 + resp = await sd_pi.run(query="Pokémon Go", domain="Anime", answer_count=2) + assert type(resp) == List + assert len(resp) == answer_count + + +@pytest.mark.asyncio +async def test_ui_sd_prompt_improve(): + sd_pi = SDPromptImprove() + resp = await sd_pi.run(query="Pokémon Go", domain="Anime", answer_count=1) + assert type(resp) == List + assert len(resp) == 1 diff --git a/tests/metagpt/utils/test_flatten_json_object.py b/tests/metagpt/utils/test_flatten_json_object.py new file mode 100644 index 000000000..25d3ddb1e --- /dev/null +++ b/tests/metagpt/utils/test_flatten_json_object.py @@ -0,0 +1,23 @@ +import unittest +import json5 + +from metagpt.utils.resp_parse import flatten_json_object + +class TestFlattenJsonObject(unittest.TestCase): + def test_flatten_json_object(self): + json_obj = json5.loads('{"a": 1, "b": {"c": 2, "d": {"e": 3, "f": 4}}, "g": [5, 6, 7]}') + expected_result = {'a': 1, 'b, c': 2, 'b, d, e': 3, 'b, d, f': 4, 'g': '5, 6, 7'} + self.assertEqual(flatten_json_object(json_obj), expected_result) + + def test_flatten_json_object_with_string(self): + json_obj = json5.loads('{"a": "hello"}') + expected_result = {'a': 'hello'} + self.assertEqual(flatten_json_object(json_obj), expected_result) + + def test_flatten_json_object_with_list(self): + json_obj = json5.loads('{"a": [1, 2, 3]}') + expected_result = {'a': '1, 2, 3'} + self.assertEqual(flatten_json_object(json_obj), expected_result) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/metagpt/utils/test_flatten_json_structure_json.py b/tests/metagpt/utils/test_flatten_json_structure_json.py new file mode 100644 index 000000000..6d2732714 --- /dev/null +++ b/tests/metagpt/utils/test_flatten_json_structure_json.py @@ -0,0 +1,45 @@ +import unittest + +from metagpt.utils.resp_parse import flatten_json_structure + +class TestFlattenJson(unittest.TestCase): + def test_flatten_json_structure(self): + input_json = [ + { + "name": "John", + "age": 30, + "city": "New York" + }, + { + "name": "Jane", + "age": 25, + "city": "Chicago" + } + ] + expected_output = ["John, 30, New York", "Jane, 25, Chicago"] + self.assertEqual(flatten_json_structure(input_json), expected_output) + + def test_flatten_json_structure_with_nested_json(self): + input_json = [ + { + "name": "John", + "age": 30, + "address": { + "city": "New York", + "state": "NY" + } + }, + { + "name": "Jane", + "age": 25, + "address": { + "city": "Chicago", + "state": "IL" + } + } + ] + expected_output = ["John, 30, New York, NY", "Jane, 25, Chicago, IL"] + self.assertEqual(flatten_json_structure(input_json), expected_output) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/metagpt/utils/test_try_parse_json.py b/tests/metagpt/utils/test_try_parse_json.py new file mode 100644 index 000000000..54919a31d --- /dev/null +++ b/tests/metagpt/utils/test_try_parse_json.py @@ -0,0 +1,30 @@ +import unittest + +from metagpt.utils.resp_parse import try_parse_json + + +class TestTryParseJson(unittest.TestCase): + def test_valid_json(self): + input_text = '{"name": "John", "age": 30, "city": "New York"}' + expected_output = {"name": "John", "age": 30, "city": "New York"} + self.assertEqual(try_parse_json(input_text), expected_output) + + def test_invalid_json(self): + input_text = 'This is not a JSON string' + with self.assertRaises(Exception) as context: + try_parse_json(input_text) + self.assertTrue('No JSON object found in input text.' in str(context.exception)) + + def test_empty_json(self): + input_text = '{}' + expected_output = {} + self.assertEqual(try_parse_json(input_text), expected_output) + + def test_nested_json(self): + input_text = '{"name": "John", "age": 30, "city": "New York", "friends": ["Mike", "Anna"]}' + expected_output = {"name": "John", "age": 30, "city": "New York", "friends": ["Mike", "Anna"]} + self.assertEqual(try_parse_json(input_text), expected_output) + +if __name__ == '__main__': + unittest.main() + try_parse_json('{"a": [ jjj}') \ No newline at end of file