add docstring and others

This commit is contained in:
luxiangtao 2024-04-16 17:30:11 +08:00
parent fd6eb8882e
commit 258a0894b8
3 changed files with 116 additions and 101 deletions

View file

@ -39,109 +39,28 @@ TRAJECTORY_COLLECTION_NAME = "di_trajectory_0"
PERSIST_PATH = SERDESER_PATH / "data_interpreter/chroma"
class AddNewExperiences(Action):
name: str = "AddNewTaskExperiences"
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def _single_task_summary(self, trajectory_collection_name: str, experience_collection_name: str):
trajectory_engine = self._init_engine(collection_name=trajectory_collection_name)
experience_engine = self._init_engine(collection_name=experience_collection_name)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
collection = db.get_or_create_collection(trajectory_collection_name)
unused_ids = [
id
for id in collection.get()["ids"]
if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"] == False
]
trajectory_dicts = [json.loads(metadata["obj_json"]) for metadata in collection.get(unused_ids)["metadatas"]]
trajectories = []
experiences = []
for trajectory_dict in trajectory_dicts:
trajectory_dict["is_used"] = True
trajectory = Trajectory(**trajectory_dict)
trajectories.append(trajectory)
code_summary = await self.task_code_sumarization(trajectory)
experience = Experience(code_summary=code_summary, trajectory=trajectory)
experiences.append(experience)
collection.delete(unused_ids)
trajectory_engine.add_objs(trajectories)
experience_engine.add_objs(experiences)
async def task_code_sumarization(self, trajectory: Trajectory):
"""Summarize the task code
Args:
task: The task to be summarized.
Returns:
A summary of the task code.
"""
task = trajectory.task
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(
code_snippet=task.code, code_result=task.result, code_success="Success" if task.is_success else "Failure"
)
resp = await self._aask(prompt=prompt)
return resp
async def run(
self,
trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME,
experience_collection_name: str = EXPERIENCE_COLLECTION_NAME,
mode: str = "single_task_summary",
):
"""Initiate a collection and Add a new task experience to the collection
Args:
trajectory_collection_name(str): the trajectory collection_name to be used for geting experiences.
experience_collection_name(str): the experience collection_name to be used for saving experiences.
mode(str): how to generate experiences
"""
if mode == "single_task_summary":
await self._single_task_summary(
trajectory_collection_name=trajectory_collection_name,
experience_collection_name=experience_collection_name,
)
else:
pass # TODO:add other methods to generate experiences from trajectories
class AddNewTrajectories(Action):
"""Record the execution status of each task as a trajectory and store it."""
name: str = "AddNewTrajectories"
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences"""
"""Initialize a collection for storing code experiences."""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
chroma_collection = db.get_or_create_collection(collection_name) # get chromadb collection
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"}) # delete the irrelevant record
return engine
async def run(self, planner: Planner, trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME):
"""
Initiate a collection and add new trajectories to the collection
"""
"""Initiate a collection and add new trajectories to the collection."""
engine = self._init_engine(trajectory_collection_name)
if not planner.plan.tasks:
@ -157,14 +76,110 @@ class AddNewTrajectories(Action):
engine.add_objs(trajectories)
class AddNewExperiences(Action):
"""Retrieve the trajectories from the vector database where trajectories are stored,
compare and summarize them to form experiences, and then store these experiences in the vector database.
"""
name: str = "AddNewTaskExperiences"
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences."""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def _single_task_summary(self, trajectory_collection_name: str, experience_collection_name: str):
trajectory_engine = self._init_engine(collection_name=trajectory_collection_name)
experience_engine = self._init_engine(collection_name=experience_collection_name)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
collection = db.get_or_create_collection(trajectory_collection_name)
# get the ids of all trajectories where the is_used attribute is false.
unused_ids = [
id
for id in collection.get()["ids"] # collection.get()["ids"] will get all the ids in the collection
if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"]
== False # Check if the is_used attribute of the trajectory corresponding to the given id is false.
]
trajectory_dicts = [
json.loads(metadata["obj_json"]) for metadata in collection.get(unused_ids)["metadatas"]
] # get the trajectory in dictionary format
trajectories = []
experiences = []
for trajectory_dict in trajectory_dicts:
# set the is_used attribute of the trajectory to true and create a new trajectory (the old trajectory will be deleted below).
trajectory_dict["is_used"] = True
trajectory = Trajectory(**trajectory_dict)
trajectories.append(trajectory)
# summarize the trajectory using LLM and assemble it into a single experience
code_summary = await self.task_code_sumarization(trajectory)
experience = Experience(code_summary=code_summary, trajectory=trajectory)
experiences.append(experience)
collection.delete(unused_ids) # delete the old trajectories
trajectory_engine.add_objs(trajectories)
experience_engine.add_objs(experiences)
async def task_code_sumarization(self, trajectory: Trajectory):
"""use LLM to summarize the task code.
Args:
trajectory: The trajectory to be summarized.
Returns:
A summary of the trajectory's code.
"""
task = trajectory.task
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(
code_snippet=task.code, code_result=task.result, code_success="Success" if task.is_success else "Failure"
)
resp = await self._aask(prompt=prompt)
return resp
async def run(
self,
trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME,
experience_collection_name: str = EXPERIENCE_COLLECTION_NAME,
mode: str = "single_task_summary",
):
"""Initiate a collection and Add a new task experience to the collection.
Args:
trajectory_collection_name(str): the trajectory collection_name to be used for geting experiences.
experience_collection_name(str): the experience collection_name to be used for saving experiences.
mode(str): how to generate experiences.
"""
if mode == "single_task_summary":
await self._single_task_summary(
trajectory_collection_name=trajectory_collection_name,
experience_collection_name=experience_collection_name,
)
else:
pass # TODO:add other methods to generate experiences from trajectories.
class RetrieveExperiences(Action):
"""Retrieve the most relevant experience from the vector database based on the input task."""
name: str = "RetrieveExperiences"
def _init_engine(self, collection_name: str, top_k: int):
"""Initialize a SimpleEngine for retrieving experiences
"""Initialize a SimpleEngine for retrieving experiences.
Args:
query (str): The chromadb collectin_name
query (str): The chromadb collectin_name.
top_k (int): The number of eperiences to be retrieved.
"""
@ -176,7 +191,7 @@ class RetrieveExperiences(Action):
)
],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
@ -190,7 +205,7 @@ class RetrieveExperiences(Action):
Args:
query (str): The task instruction to be used for retrieval.
experience_collection_name(str): the collextion_name for retrieving experiences
experience_collection_name(str): the collextion_name for retrieving experiences.
top_k (int, optional): The number of experiences to be retrieved. Defaults to 5.
Returns:

View file

@ -7,6 +7,7 @@ from pydantic import Field, model_validator
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.use_experience import AddNewTrajectories, RetrieveExperiences
from metagpt.actions.di.write_analysis_code import CheckData, WriteAnalysisCode
from metagpt.logs import logger
from metagpt.prompts.di.write_analysis_code import DATA_INFO
@ -89,14 +90,19 @@ class DataInterpreter(Role):
async def _plan_and_act(self) -> Message:
try:
rsp = await super()._plan_and_act()
await AddNewTrajectories().run(
self.planner
) # extract trajectories based on the execution status of each task in the planner
await self.execute_code.terminate()
return rsp
except Exception as e:
await self.execute_code.terminate()
raise e
async def _act_on_task(self, current_task: Task, experiences: str) -> TaskResult:
async def _act_on_task(self, current_task: Task) -> TaskResult:
"""Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
# retrieve past tasks for this task
experiences = await RetrieveExperiences().run(query=current_task.instruction) if self.use_experience else ""
code, result, is_success = await self._write_and_exec_code(experiences=experiences)
task_result = TaskResult(code=code, result=result, is_success=is_success)
return task_result

View file

@ -30,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.di.use_experience import AddNewTrajectories, RetrieveExperiences
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
from metagpt.memory import Memory
@ -491,17 +490,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
task = self.planner.current_task
logger.info(f"ready to take on task {task}")
# retrieve past tasks for this task
experiences = await RetrieveExperiences().run(query=task.instruction) if self.use_experience else ""
# take on current task
task_result = await self._act_on_task(task, experiences)
task_result = await self._act_on_task(task)
# process the result, such as reviewing, confirming, plan updating
await self.planner.process_task_result(task_result)
await AddNewTrajectories().run(self.planner)
rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response
self.rc.memory.add(rsp) # add to persistent memory