Merge pull request #400 from stellaHSR/minecraft

Minecraft: bug fix for skill_manager and new round
This commit is contained in:
Sirui Hong 2023-10-08 22:49:18 +08:00 committed by GitHub
commit 7989e9fb2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 58 additions and 16 deletions

View file

@ -24,6 +24,7 @@ class DesignTask(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def decompose_task(self, query, events):
system_msgs = SystemMessage(

View file

@ -15,6 +15,7 @@ class GenerateActionCode(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-4"
async def generate_code(self, human_msg, system_msg=[]):
"""

View file

@ -18,6 +18,7 @@ class RetrieveSkills(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def run(self, query, skills, *args, **kwargs):
# Implement the logic for retrieving skills here.
@ -44,18 +45,22 @@ class AddNewSkills(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def run(
self, task, program_name, program_code, skills, skill_desp, *args, **kwargs
):
# Implement the logic for adding new skills here.
# TODO: Fix this
logger.info(f"check task {task}")
if task.startswith("Deposit useless items into the chest at"):
# No need to reuse the deposit skill
return {}
logger.info(
f"Skill Manager generated description for {program_name}:\n{skill_desp}\033[0m"
)
logger.info(f"check skills {skills}")
if program_name in skills:
logger.info(f"Skill {program_name} already exists. Rewriting!")
self.vectordb._collection.delete(ids=[program_name])
@ -97,6 +102,7 @@ class GenerateSkillDescription(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def run(self, program_name, human_message, system_message, *args, **kwargs):
# Implement the logic for generating skill descriptions here.

View file

@ -15,6 +15,7 @@ class VerifyTask(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def run(self,human_msg, system_msg, max_retries=5, *args, **kwargs):
# Implement the logic to verify the task here.

View file

@ -355,6 +355,7 @@ class MinecraftPlayer(SoftwareCompany):
role.finish_step = False
role.round_id += 1
role._rc.todo = None
role.finish_state = len(role._actions)
logger.info(f"round_id:{role.round_id}")
def hire(self, roles: list[Role]):
@ -394,6 +395,7 @@ class MinecraftPlayer(SoftwareCompany):
while n_round > 0:
# self._save()
if self.check_complete_round():
n_round -= 1
self.update_round()
round_id += 1

View file

@ -142,10 +142,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
Check https://platform.openai.com/examples for examples
"""
def __init__(self):
self.__init_openai(CONFIG)
def __init__(self, conf=CONFIG, **kwargs):
self.__init_openai(conf)
self.llm = openai
self.model = CONFIG.openai_api_model
self.model = conf.openai_api_model
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)

View file

@ -20,8 +20,6 @@ from metagpt.config import CONFIG
from metagpt.actions.minecraft.control_primitives_context import (
load_skills_code_context,
)
from metagpt.utils.minecraft import fix_and_parse_json
from metagpt.roles.minecraft.critic_agent import CriticReviewer
@agent_registry.register("action_developer")
@ -42,13 +40,14 @@ class ActionDeveloper(Base):
# Initialize actions specific to the Action role
self._init_actions([GenerateActionCode])
# Set events or actions the ActionAgent should watch or be aware of
# 需要根据events进行自己chest_observation的更新
self._watch([RetrieveSkills])
self.rollout_num_iter = 0
self.task_max_retries = 4
self.finish_state = len(self._actions)
self.critic_reviewer = None # self._rc.env.roles["Task Reviewer"]
logger.info(self.critic_reviewer)
def render_system_message(self, skills=[], *args, **kwargs):
"""
@ -198,6 +197,8 @@ class ActionDeveloper(Base):
if done:
break
# return [system_msg, human_msg], reward, done, info
# 结束前将critic_reviewer 轮次状态更新,以便进入下一轮
self.critic_reviewer.finish_step = True
return Message(
content=f"{info}",
instruct_content="generate_action_code",
@ -282,8 +283,12 @@ class ActionDeveloper(Base):
system_msg = message["system_msg"]
human_msg = message["human_msg"]
else:
self.perform_game_info_callback(
False, self.game_memory.update_exploration_progress
)
logger.info(f"Code is None. Update runtime_status failed!")
self.critic_reviewer.maintain_actions(VerifyTask())
logger.info(f"system msg is {system_msg}, \n human_msg is {human_msg}")
# logger.info(f"system msg is {system_msg}, \n human_msg is {human_msg}")
logger.info(f"\033[34m Trying again!\033[0m")
self.rollout_num_iter += 1
@ -326,7 +331,7 @@ class ActionDeveloper(Base):
# 获取最新的游戏周边信息
# events = await self._obtain_events()
events = self.game_memory.event
logger.info(events)
# logger.info(events)
# self.perform_game_info_callback(events, self.game_memory.update_event)
logger.info(self.game_memory.event_summary)
context = self.game_memory.context

View file

@ -34,6 +34,7 @@ class CriticReviewer(Base):
# Set events or actions the CriticReviewer should watch or be aware of
# 需要获取最新的events来进行评估
self._watch([])
self.finish_state = len(self._actions)
async def run(self, message=None):
"""Observe, only get the observation"""
@ -157,7 +158,7 @@ class CriticReviewer(Base):
# 获取最新的游戏周边信息
events = await self._execute_events()
self.perform_game_info_callback(events, self.game_memory.update_chest_memory)
logger.info(f"Execute return event is {self.game_memory.event}")
# logger.info(f"Execute return event is {self.game_memory.event}")
context = self.game_memory.context
task = self.game_memory.current_task
chest_observation = self.game_memory.chest_observation
@ -173,7 +174,7 @@ class CriticReviewer(Base):
VerifyTask: self.verify_task,
}
handler = handler_map.get(type(todo))
logger.info(handler)
# logger.info(handler)
if handler:
msg = await handler(**message)
msg.cause_by = type(todo)

View file

@ -32,6 +32,8 @@ class CurriculumDesigner(Base):
# Set events or actions the ActionAgent should watch or be aware of
self._watch([PlayerActions, DesignTask])
logger.info(self._actions)
self.finish_state = len(self._actions)
def render_curriculum_observation(self, *, events, chest_observation):
"""

View file

@ -52,26 +52,28 @@ class Minecraft(Role):
self.finish_step = False
def maintain_actions(self, todo):
logger.info(f"{self._setting.name}:{self.finish_state}")
if todo in self._actions:
self.finish_state-=1
if self.finish_state<=0:
self.finish_step = True
logger.info(f"{self._setting.name}:{self.finish_state}")
async def _observe(self) -> int:
await super()._observe()
for msg in self._rc.news:
logger.info(f"check msg round :{msg.round_id}")
logger.info(msg.round_id == self.round_id)
# logger.info(msg.round_id == self.round_id)
self._rc.news = [
msg for msg in self._rc.news if msg.round_id == self.round_id
] # only relevant msgs count as observed news
logger.info(len(self._rc.news))
# logger.info(len(self._rc.news))
return len(self._rc.news)
async def _think(self) -> None:
logger.info(self._actions)
logger.info(self._rc.state)
# logger.info(self._rc.state)
if len(self._actions) == 1:
# If there is only one action, then only this one can be performed
self._set_state(0)
@ -133,5 +135,5 @@ agent_registry = Registry(name="Minecraft")
if __name__ == "__main__":
mc = Minecraft()
result = "Async operation result"
# 调用回调函数,并传递结果
# mc.perform_memory_callback(mc.my_callback)

View file

@ -11,7 +11,7 @@ from metagpt.actions.minecraft.manage_skills import (
RetrieveSkills,
AddNewSkills,
)
from metagpt.actions.minecraft.review_task import VerifyTask
from metagpt.actions.minecraft import GenerateActionCode
from metagpt.actions.minecraft.design_curriculumn import DesignCurriculum
from metagpt.utils.minecraft import load_prompt
@ -32,8 +32,10 @@ class SkillManager(Base):
# Set events or actions the SkillManager should watch or be aware of
self._watch(
[DesignCurriculum, VerifyTask, RetrieveSkills, GenerateSkillDescription]
[DesignCurriculum, GenerateActionCode, RetrieveSkills, GenerateSkillDescription]
)
self.finish_state = len(self._actions)
def encapsule_message(self, program_code, program_name, *args, **kwargs):
system_msg = self.render_system_message(load_prompt("skill"))
@ -128,8 +130,10 @@ class SkillManager(Base):
handler = handler_map.get(type(todo))
if handler:
if type(todo) == DesignCurriculum:
logger.info(retrieve_skills_message_step1)
msg = await handler(**retrieve_skills_message_step1)
elif type(todo) == RetrieveSkills:
logger.info(retrieve_skills_message_step2)
msg = await handler(**retrieve_skills_message_step2)
elif type(todo) == GenerateSkillDescription:
msg = await handler(**generate_skill_message)

View file

@ -108,10 +108,12 @@ class Role:
def _init_actions(self, actions):
self._reset()
for idx, action in enumerate(actions):
if not isinstance(action, Action):
i = action("")
else:
i = action
i.set_prefix(self._get_prefix(), self.profile)
self._actions.append(i)
self._states.append(f"{idx}. {action}")

View file

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# @Date : 2023/10/7 16:32
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import yaml
from metagpt.const import PROJECT_ROOT
def load_extra_conf(yaml_file=PROJECT_ROOT / "config/add_config.yaml"):
with open(yaml_file, "r", encoding="utf-8") as file:
yaml_data = yaml.safe_load(file)
return yaml_data