Merge pull request #431 from stellaHSR/minecraft

Minecraft:reduce top-k
This commit is contained in:
Sirui Hong 2023-10-13 23:53:49 +08:00 committed by GitHub
commit d7f8a861de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 39 deletions

View file

@ -26,7 +26,13 @@ class RetrieveSkills(Action):
if k == 0:
return []
logger.info(f"Skill Manager retrieving for {k} skills")
docs_and_scores = vectordb.similarity_search_with_score(query, k=k)
try:
docs_and_scores = vectordb.similarity_search_with_score(query, k=k)
except Exception as e:
docs_and_scores = vectordb.similarity_search_with_score(query, k=1)
logger.debug(f"{e}")
logger.info(
f"Skill Manager retrieved skills: "
f"{', '.join([doc.metadata['name'] for doc, _ in docs_and_scores])}"

View file

@ -136,7 +136,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
]
metadatas = [{"name": program_name} for program_name in program_names]
# add vectordb from file
ids = self.vectordb.add_texts(
self.vectordb.add_texts(
texts=skill_desps,
ids=program_names,
metadatas=metadatas,
@ -155,7 +155,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json."
)
# Check if Skill Manager's vectordb right using
assert self.vectordb._collection.count() >= len(self.skills), (
assert self.vectordb._collection.count() == len(self.skills), (
f"Skill Manager's vectordb is not synced with skills.json.\n"
f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n"
f"Did you set resume=False when initializing the manager?\n"
@ -165,7 +165,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
logger.info(
f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json."
)
assert self.qa_cache_questions_vectordb._collection.count() >= len(
assert self.qa_cache_questions_vectordb._collection.count() == len(
self.qa_cache
), (
f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n"
@ -183,8 +183,8 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
if self.event == event:
return
self.event = event
self.update_chest_memory(event)
self.update_chest_observation()
# self.update_chest_memory(event)
# self.event_summary = self.summarize_chatlog(event)
def update_task(self, task: str):

View file

@ -71,6 +71,7 @@ Inventory (xx/36): ...
Chests: You can ask me to deposit or take items from these chests. There also might be some unknown chest, you should ask me to open and check items inside the unknown chest.
Completed tasks so far: ...
Failed tasks that are too hard: ...
Last Task: ...
You must follow the following criteria:
1) You should act as a mentor and guide me to the next task based on my current learning progress.
@ -85,6 +86,7 @@ You must follow the following criteria:
You should only respond in the format as described below:
RESPONSE FORMAT:
Reasoning: Based on the information I listed above, do reasoning about what the next task should be.
Task: The next task.
Here's an example response:

View file

@ -224,11 +224,12 @@ class ActionDeveloper(Base):
self.perform_game_info_callback(new_skills_info, self.game_memory.append_skill)
async def retrieve_skills(self, query, skills, *args, **kwargs):
retrieve_skills = await RetrieveSkills().run(query, skills, vectordb=self.game_memory.vectordb)
skill_retrieve = RetrieveSkills()
skill_retrieve.retrieval_top_k = max(1, skill_retrieve.retrieval_top_k - int(self.round_id // 50))
retrieve_skills = await skill_retrieve.run(query, skills, vectordb=self.game_memory.vectordb)
logger.info(f"Render Action Agent system message with {len(retrieve_skills)} skills")
self.perform_game_info_callback(retrieve_skills, self.game_memory.update_retrieve_skills)
# return Message(content=f"{retrieve_skills}", instruct_content="retrieve_skills",
# role=self.profile, send_to=agent_registry.entries["action_developer"]()._setting.name)
async def runcode_and_evaluate(self, human_msg, system_msg, *args, **kwargs):
"""

View file

@ -134,6 +134,9 @@ class CurriculumDesigner(Base):
observation = self.render_curriculum_observation(
events=events, chest_observation=chest_observation
)
# Add last task
observation["last_task"] = f"Last Task: {self.game_memory.current_task}\n\n"
if self.game_memory.progress >= warm_up["context"]:
# if self.game_memory.progress >= 0: # TEST ONLY
human_msg = self.render_design_curriculum_human_message(

View file

@ -19,42 +19,41 @@ from metagpt.utils.minecraft import load_prompt
@agent_registry.register("skill_manager")
class SkillManager(Base):
def __init__(
self,
name: str = "John",
profile: str = "Skills Management Specialist",
goal: str = "To oversee and optimize the acquisition, development, and utilization of skills within the organization, ensuring workforce competence and efficiency.",
constraints: str = "Resource allocation, training budgets, and alignment with organizational goals.",
self,
name: str = "John",
profile: str = "Skills Management Specialist",
goal: str = "To oversee and optimize the acquisition, development, and utilization of skills within the organization, ensuring workforce competence and efficiency.",
constraints: str = "Resource allocation, training budgets, and alignment with organizational goals.",
) -> None:
super().__init__(name, profile, goal, constraints)
# Initialize actions specific to the SkillManager role
self._init_actions([RetrieveSkills, GenerateSkillDescription, AddNewSkills])
# Set events or actions the SkillManager should watch or be aware of
self._watch(
[DesignCurriculum, GenerateActionCode, RetrieveSkills, GenerateSkillDescription]
)
self.finish_state = len(self._actions)
def encapsule_message(self, program_code, program_name, *args, **kwargs):
system_msg = self.render_system_message(load_prompt("skill"))
human_msg = self.render_human_message(
program_code + "\n\n" + f"The main function is `{program_name}`."
)
return {"system_msg": [system_msg.content], "human_msg": human_msg.content}
async def retrieve_skills(self, query, skills, *args, **kwargs):
vectordb = self.game_memory.vectordb
retrieve_skills = await RetrieveSkills().run(query, skills, vectordb)
skill_retrieve = RetrieveSkills()
skill_retrieve.retrieval_top_k = max(1, skill_retrieve.retrieval_top_k - int(self.round_id // 50))
retrieve_skills = await skill_retrieve.run(query, skills, vectordb)
logger.info(f"Render Action Agent system message with {len(retrieve_skills)} skills")
self.perform_game_info_callback(retrieve_skills, self.game_memory.update_retrieve_skills)
return Message(content=f"{retrieve_skills}", instruct_content="retrieve_skills",
return Message(content=f"{retrieve_skills}", instruct_content="retrieve_skills",
role=self.profile, send_to=agent_registry.entries["action_developer"]()._setting.name)
# return Message(
# content=f"{skills}", instruct_content="retrieve_skills", role=self.profile
# ) # Unit test only
async def generate_skill_descp(self, human_msg, system_msg, *args, **kwargs):
program_name = self.game_memory.program_name
desp = await GenerateSkillDescription().run(program_name, human_msg, system_msg)
@ -64,16 +63,16 @@ class SkillManager(Base):
instruct_content="generate_skill_descp",
role=self.profile,
)
async def handle_add_new_skills(
self, task, program_name, program_code, skills, *args, **kwargs
self, task, program_name, program_code, skills, *args, **kwargs
):
if not self.game_memory.runtime_status:
return Message(
content="",
instruct_content="handle_add_new_skills",
role=self.profile,
)
content="",
instruct_content="handle_add_new_skills",
role=self.profile,
)
skill_desp = self.game_memory.skill_desp
vectordb = self.game_memory.vectordb
@ -86,12 +85,12 @@ class SkillManager(Base):
instruct_content="handle_add_new_skills",
role=self.profile,
)
async def _act(self) -> Message:
todo = self._rc.todo
logger.debug(f"Todo is {todo}")
self.maintain_actions(todo)
# 获取最新的游戏周边信息
context = self.game_memory.context
task = self.game_memory.current_task
@ -100,27 +99,27 @@ class SkillManager(Base):
self.perform_game_info_callback(self.game_memory.event, self.game_memory.summarize_chatlog)
event_summary = self.game_memory.event_summary
program_code = self.game_memory.program_code
program_name = self.game_memory.program_name
skills = self.game_memory.skills
# msg = self._rc.memory.get(k=1)[0]
retrieve_skills_message_step1 = {"query": context, "skills": skills}
logger.info(f"check query {context}")
logger.info(f"check event summary {event_summary}")
retrieve_skills_message_step2 = {"query": context + "\n\n" + event_summary, "skills": skills}
generate_skill_message = self.encapsule_message(program_code, program_name)
add_new_skills_message = {
"task": task,
"program_name": program_name,
"program_code": program_code,
"skills": skills,
}
handler_map = {
DesignCurriculum: self.retrieve_skills,
RetrieveSkills: self.retrieve_skills,
@ -139,10 +138,10 @@ class SkillManager(Base):
msg = await handler(**generate_skill_message)
else:
msg = await handler(**add_new_skills_message)
msg.cause_by = type(todo)
msg.round_id = self.round_id
self._publish_message(msg)
return msg
raise ValueError(f"Unknown todo type: {type(todo)}")