update code: 更新断后恢复,支持vdb的更新和处理

This commit is contained in:
stellahsr 2023-10-12 11:57:43 +08:00
parent b55a080c83
commit bb5f3f8ecf
5 changed files with 120 additions and 63 deletions

View file

@ -91,8 +91,9 @@ class DesignCurriculum(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)
self.llm.model = "gpt-3.5-turbo"
async def generate_qa(self, events, qa_cache, qa_cache_questions_vectordb, human_msg, system_msg):
async def generate_qa(self, events, qa_cache, qa_cache_questions_vectordb, game_memory, human_msg, system_msg):
"""
Generate qa for DesignTask's HumanMessage
"""
@ -100,7 +101,7 @@ class DesignCurriculum(Action):
events=events, human_msg=human_msg, system_msg=system_msg
)
logger.debug(f"Generate_qa_step1 result list is HERE: {questions_new}")
questions = []
answers = []
for question in questions_new:
@ -128,6 +129,10 @@ class DesignCurriculum(Action):
qa_cache_questions_vectordb.persist()
questions.append(question)
answers.append(answer)
game_memory.qa_cache = qa_cache
assert len(questions_new) == len(questions) == len(answers)
logger.info(f"Curriculum Agent generate_qa Questions: {questions}")
logger.info(f"Curriculum Agent generate_qa Answers: {answers}")
@ -170,7 +175,7 @@ class DesignCurriculum(Action):
# logger.info(f"Curriculum Agent generate_qa_step2 answer: {answer}")
return answer
async def get_context_from_task(self, task, qa_cache, qa_cache_questions_vectordb):
async def get_context_from_task(self, task, qa_cache, qa_cache_questions_vectordb, game_memory):
"""
Args: task
Returns: context: "Question: {question}\n{answer}"
@ -192,10 +197,12 @@ class DesignCurriculum(Action):
with open(f"{CKPT_DIR}/curriculum/qa_cache.json", "w") as f:
json.dump(qa_cache, f)
qa_cache_questions_vectordb.persist()
game_memory.qa_cache = qa_cache
context = f"Question: {question}\n{answer}"
return context
async def generate_context(self, task, qa_cache, qa_cache_questions_vectordb, max_retries=5):
async def generate_context(self, task, qa_cache, qa_cache_questions_vectordb, game_memory, max_retries=5):
"""
Refer to the code in the voyager/agents/curriculum.py propose_next_ai_task() for implementation details.
Returns: context
@ -206,7 +213,8 @@ class DesignCurriculum(Action):
raise RuntimeError("Max retries reached, failed to propose context.")
try:
context = await self.get_context_from_task(
task=task, qa_cache=qa_cache, qa_cache_questions_vectordb=qa_cache_questions_vectordb
task=task, qa_cache=qa_cache, qa_cache_questions_vectordb=qa_cache_questions_vectordb,
game_memory=game_memory,
) # Curriculum Agent Question: How to craft 4 wooden planks in Minecraft? & Curriculum Agent Answer: ...
return context
except Exception as e:
@ -215,14 +223,17 @@ class DesignCurriculum(Action):
task=task,
qa_cache=qa_cache,
qa_cache_questions_vectordb=qa_cache_questions_vectordb,
game_memory=game_memory,
max_retries=max_retries - 1,
)
async def run(self, task, qa_cache, qa_cache_questions_vectordb, human_msg, system_msg, *args, **kwargs):
async def run(self, task, qa_cache, qa_cache_questions_vectordb, game_memory, human_msg, system_msg, *args,
**kwargs):
logger.info(f"run {self.__repr__()}")
# Generate curriculum-related questions and answers.
# curriculum_qustion = await self.generate_qa_step1(events, human_msg, system_msg)
curriculum_context = await self.generate_context(task, qa_cache, qa_cache_questions_vectordb)
curriculum_context = await self.generate_context(task, qa_cache, qa_cache_questions_vectordb,
game_memory=game_memory)
# Return the generated questions and answers.
return curriculum_context

View file

@ -75,6 +75,12 @@ class AddNewSkills(Action):
ids=[program_name],
metadatas=[{"name": program_name}],
)
skills[program_name] = {
"code": program_code,
"description": skill_desp,
}
logger.debug(f"ADD_CHECK: There are {vectordb._collection.count()} skills in vectordb")
with open(f"{CKPT_DIR}/skill/code/{dumped_program_name}.js", "w") as f:

View file

@ -60,6 +60,10 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
mf_instance: MineflayerEnv = Field(default_factory=MineflayerEnv)
runtime_status: bool = False # equal to action execution status: success or failed
vectordb: Chroma = Field(default_factory=Chroma)
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
@property
def progress(self):
# return len(self.completed_tasks) + 10 # Test only
@ -84,27 +88,40 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
def core_inv_items_regex(self):
return self.mf_instance.core_inv_items_regex
@property
def qa_cache_questions_vectordb(self):
return Chroma(
collection_name="qa_cache_questions_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{CKPT_DIR}/curriculum/vectordb",
)
# @property
# def qa_cache_questions_vectordb(self):
# return Chroma(
# collection_name="qa_cache_questions_vectordb",
# embedding_function=OpenAIEmbeddings(),
# persist_directory=f"{CKPT_DIR}/curriculum/vectordb",
# )
@property
def vectordb(self):
return Chroma(
collection_name="skill_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{CKPT_DIR}/skill/vectordb",
)
# @property
# def vectordb(self):
# return Chroma(
# collection_name="skill_vectordb",
# embedding_function=OpenAIEmbeddings(),
# persist_directory=f"{CKPT_DIR}/skill/vectordb",
# )
def set_mc_port(self, mc_port):
self.mf_instance.set_mc_port(mc_port)
self.set_mc_resume()
def set_mc_resume(self):
self.qa_cache_questions_vectordb = Chroma(
collection_name="qa_cache_questions_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{CKPT_DIR}/curriculum/vectordb",
)
self.vectordb = Chroma(
collection_name="skill_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{CKPT_DIR}/skill/vectordb",
)
if CONFIG.resume:
logger.info(f"Loading Action Developer from {CKPT_DIR}/action")
with open(f"{CKPT_DIR}/action/chest_memory.json", "r") as f:
@ -125,6 +142,7 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
self.qa_cache = json.load(f)
if self.vectordb._collection.count() == 0:
logger.info(self.vectordb._collection.count())
# Set vdvs for skills & qa_cache
skill_desps = [
skill["description"] for program_name, skill in self.skills.items()
@ -140,36 +158,41 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
metadatas=metadatas,
)
self.vectordb.persist()
logger.info(self.qa_cache_questions_vectordb._collection.count())
if self.qa_cache_questions_vectordb._collection.count() == 0:
questions = [question for question, answer in self.qa_cache.items()]
self.qa_cache_questions_vectordb.add_texts(texts=questions)
self.qa_cache_questions_vectordb.persist()
logger.info(
f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json."
)
# Check if Skill Manager's vectordb right using
assert self.vectordb._collection.count() >= len(self.skills), (
f"Skill Manager's vectordb is not synced with skills.json.\n"
f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n"
f"Did you set resume=False when initializing the manager?\n"
f"You may need to manually delete the vectordb directory for running from scratch."
)
logger.info(
f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json."
)
# Check if Skill Manager's vectordb right using
assert self.vectordb._collection.count() >= len(self.skills), (
f"Skill Manager's vectordb is not synced with skills.json.\n"
f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n"
f"Did you set resume=False when initializing the manager?\n"
f"You may need to manually delete the vectordb directory for running from scratch."
)
logger.info(
f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json."
)
assert self.qa_cache_questions_vectordb._collection.count() >= len(
self.qa_cache
), (
f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n"
f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb "
f"but {len(self.qa_cache)} questions in qa_cache.json.\n"
f"Did you set resume=False when initializing the agent?\n"
f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n"
)
logger.info(
f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json."
)
assert self.qa_cache_questions_vectordb._collection.count() >= len(
self.qa_cache
), (
f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n"
f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb "
f"but {len(self.qa_cache)} questions in qa_cache.json.\n"
f"Did you set resume=False when initializing the agent?\n"
f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n"
)
def register_roles(self, roles: Iterable[Minecraft]):
for role in roles:
@ -209,6 +232,9 @@ class GameEnvironment(BaseModel, arbitrary_types_allowed=True):
def update_skill_desp(self, skill_desp: str):
self.skill_desp = skill_desp
async def update_qa_cache(self, qa_cache: dict):
self.qa_cache = qa_cache
def update_chest_memory(self, events: Dict):
"""
Input: events: Dict
@ -517,3 +543,9 @@ class MinecraftPlayer(SoftwareCompany):
# self.environment.memory.clear()
# self._reset()
return self.environment.history
if __name__ == "__main__":
A = GameEnvironment()
a = A.qa_cache_questions_vectordb
print(a._persist_directory)

View file

@ -216,7 +216,7 @@ class ActionDeveloper(Base):
self.perform_game_info_callback(new_skills_info, self.game_memory.append_skill)
async def retrieve_skills(self, query, skills, *args, **kwargs):
retrieve_skills = await RetrieveSkills().run(query, skills)
retrieve_skills = await RetrieveSkills().run(query, skills, vectordb=self.game_memory.vectordb)
logger.info(f"Render Action Agent system message with {len(retrieve_skills)} skills")
self.perform_game_info_callback(retrieve_skills, self.game_memory.update_retrieve_skills)
# return Message(content=f"{retrieve_skills}", instruct_content="retrieve_skills",

View file

@ -20,11 +20,11 @@ class CurriculumDesigner(Base):
"""
def __init__(
self,
name: str = "David",
profile: str = "Expertise in minecraft task design and curriculum development.",
goal: str = " Collect and integrate learner feedback to improve and refine educational content and pathways",
constraints: str = "Limited budget and resources for the development of educational content and technology tools.",
self,
name: str = "David",
profile: str = "Expertise in minecraft task design and curriculum development.",
goal: str = " Collect and integrate learner feedback to improve and refine educational content and pathways",
constraints: str = "Limited budget and resources for the development of educational content and technology tools.",
) -> None:
super().__init__(name, profile, goal, constraints)
# Initialize actions specific to the Action role
@ -56,12 +56,12 @@ class CurriculumDesigner(Base):
inventory = event["inventory"]
if not any(
"dirt" in block
or "log" in block
or "grass" in block
or "sand" in block
or "snow" in block
for block in voxels
"dirt" in block
or "log" in block
or "grass" in block
or "sand" in block
or "snow" in block
for block in voxels
):
biome = "underground"
@ -92,8 +92,8 @@ class CurriculumDesigner(Base):
# filter out optional inventory items if required
if (
self.game_memory.progress
< self.game_memory.warm_up["optional_inventory_items"]
self.game_memory.progress
< self.game_memory.warm_up["optional_inventory_items"]
):
inventory = {
k: v
@ -121,7 +121,7 @@ class CurriculumDesigner(Base):
# --------------------------------Design Task Prepare---------------------------------------
async def render_design_task_human_message(
self, events, chest_observation, *args, **kwargs
self, events, chest_observation, *args, **kwargs
):
"""
Returns: observation for curriculum
@ -135,15 +135,19 @@ class CurriculumDesigner(Base):
events=events, chest_observation=chest_observation
)
if self.game_memory.progress >= warm_up["context"]:
# if self.game_memory.progress >= 0: # TEST ONLY
# if self.game_memory.progress >= 0: # TEST ONLY
human_msg = self.render_design_curriculum_human_message(
events=events, chest_observation=chest_observation
).content
system_msg = [self.render_design_curriculum_system_message().content]
questions, answers = await DesignCurriculum().generate_qa(
events=events, qa_cache=qa_cache, human_msg=human_msg, system_msg=system_msg
events=events, qa_cache=qa_cache,
qa_cache_questions_vectordb=self.game_memory.qa_cache_questions_vectordb,
game_memory=self.game_memory,
human_msg=human_msg, system_msg=system_msg
)
logger.debug(f"Generate_qa result is HERE: Ques: {questions}, Ans: {answers}")
i = 1
for question, answer in zip(questions, answers):
if "Answer: Unknown" in answer or "language model" in answer:
@ -202,7 +206,7 @@ class CurriculumDesigner(Base):
return SystemMessage(content=load_prompt("curriculum_qa_step1_ask_questions"))
def render_design_curriculum_human_message(
self, events, chest_observation, *args, **kwargs
self, events, chest_observation, *args, **kwargs
):
observation = self.render_curriculum_observation(
events=events, chest_observation=chest_observation
@ -213,7 +217,7 @@ class CurriculumDesigner(Base):
return HumanMessage(content=content)
def encapsule_design_curriculum_message(
self, events, chest_observation, *args, **kwargs
self, events, chest_observation, *args, **kwargs
):
human_msg = self.render_design_curriculum_human_message(
events=events, chest_observation=chest_observation, *args, **kwargs
@ -309,8 +313,12 @@ class CurriculumDesigner(Base):
self, events=events, chest_observation=chest_observation
)
else:
logger.info(self.game_memory.qa_cache_questions_vectordb._collection.count())
logger.info(self.game_memory.vectordb._collection.count())
context = await DesignCurriculum().run(
task, qa_cache, qa_cache_questions_vectordb, human_msg, system_msg, *args, **kwargs
task, qa_cache, qa_cache_questions_vectordb, game_memory=self.game_memory,
human_msg=human_msg,
system_msg=system_msg, *args, **kwargs
)
self.perform_game_info_callback(context, self.game_memory.update_context)
return Message(