add experience

This commit is contained in:
luxiangtao 2024-04-15 17:46:14 +08:00
parent bf77d2a339
commit e9603c6f03
6 changed files with 250 additions and 6 deletions

View file

@ -0,0 +1,219 @@
import json
import chromadb
from pydantic import BaseModel
from metagpt.actions import Action
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
from metagpt.prompts.di.get_task_summary import TASK_CODE_DESCRIPTION_PROMPT
from metagpt.schema import Task
from metagpt.strategy.planner import Planner
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig
from examples.rag_pipeline import TRAVEL_DOC_PATH
class Trajectory(BaseModel):
user_requirement: str = ""
task_map: dict[str, Task] = {}
task: Task = None
is_used: bool = False
def rag_key(self) -> str:
"""For search"""
return self.task.instruction
class Experience(BaseModel):
code_summary: str = ""
trajectory: Trajectory = None
def rag_key(self) -> str:
"""For search"""
return self.code_summary
EXPERIENCE_COLLECTION_NAME = "di_experience_0"
TRAJECTORY_COLLECTION_NAME = "di_trajectory_0"
PERSIST_PATH = SERDESER_PATH / "data_interpreter/chroma"
class AddNewExperiences(Action):
name: str = "AddNewTaskExperiences"
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences
"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def _single_task_summary(self,trajectory_collection_name: str,experience_collection_name: str):
trajectory_engine = self._init_engine(collection_name=trajectory_collection_name)
experience_engine = self._init_engine(collection_name=experience_collection_name)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
collection = db.get_or_create_collection(trajectory_collection_name)
unused_ids=[id for id in collection.get()["ids"] if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"]==False]
trajectory_dicts = [json.loads(metadata["obj_json"]) for metadata in collection.get(unused_ids)["metadatas"]]
trajectories = []
experiences = []
for trajectory_dict in trajectory_dicts:
trajectory_dict["is_used"] = True
trajectory = Trajectory(**trajectory_dict)
trajectories.append(trajectory)
code_summary = await self.task_code_sumarization(trajectory)
experience = Experience(code_summary=code_summary, trajectory=trajectory)
experiences.append(experience)
collection.delete(unused_ids)
trajectory_engine.add_objs(trajectories)
experience_engine.add_objs(experiences)
async def task_code_sumarization(self, trajectory: Trajectory):
"""Summarize the task code
Args:
task: The task to be summarized.
Returns:
A summary of the task code.
"""
task = trajectory.task
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(code_snippet=task.code, code_result=task.result,
code_success="Success" if task.is_success else "Failure")
resp = await self._aask(prompt=prompt)
return resp
async def run(self,
trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME,
experience_collection_name: str=EXPERIENCE_COLLECTION_NAME,
mode :str = "single_task_summary"
):
"""Initiate a collection and Add a new task experience to the collection
Args:
trajectory_collection_name(str): the trajectory collection_name to be used for geting experiences.
experience_collection_name(str): the experience collection_name to be used for saving experiences.
mode(str): how to generate experiences
"""
if mode == "single_task_summary":
await self._single_task_summary(trajectory_collection_name=trajectory_collection_name,experience_collection_name=experience_collection_name)
else:
pass # TODO:add other methods to generate experiences from trajectories
class AddNewTrajectories(Action):
name: str = "AddNewTrajectories"
def _init_engine(self,collection_name: str):
"""Initialize a collection for storing code experiences
"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def run(self, planner: Planner, trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME):
"""
Initiate a collection and add new trajectories to the collection
"""
engine = self._init_engine(trajectory_collection_name)
if not planner.plan.tasks:
return
user_requirement = planner.plan.goal
task_map = planner.plan.task_map
trajectories = [Trajectory(user_requirement=user_requirement, task_map=task_map, task=task, is_used=False) for task in planner.plan.tasks]
engine.add_objs(trajectories)
class RetrieveExperiences(Action):
name: str = "RetrieveExperiences"
def _init_engine(self,collection_name: str,top_k: int):
"""Initialize a SimpleEngine for retrieving experiences
Args:
query (str): The chromadb collectin_name
top_k (int): The number of eperiences to be retrieved.
"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name,
similarity_top_k=top_k)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def run(self, query: str,experience_collection_name: str=EXPERIENCE_COLLECTION_NAME, top_k: int = 5) -> str:
"""Retrieve past attempted tasks
Args:
query (str): The task instruction to be used for retrieval.
experience_collection_name(str): the collextion_name for retrieving experiences
top_k (int, optional): The number of experiences to be retrieved. Defaults to 5.
Returns:
_type_: _description_
"""
engine = self._init_engine(collection_name=experience_collection_name,top_k=top_k)
if len(query) <= 2: # not "" or not '""'
return ""
nodes = await engine.aretrieve(query)
new_experiences = []
for i, node in enumerate(nodes):
try:
code_summary = node.node.metadata["obj"].code_summary
trajectory = node.node.metadata["obj"].trajectory
except:
continue
# Create the experience dictionary with placeholder keys
experience = {
"Reference __i__": trajectory.task.instruction,
"Task code": trajectory.task.code,
"Code summary": code_summary,
"Task result": trajectory.task.result,
"Task outcome": "Success" if trajectory.task.is_success else "Failure",
"Task ownership's requirement": "This task is part of " + trajectory.user_requirement
}
# Replace the placeholder in the keys
experience = {k.replace("__i__", str(i)): v for k, v in experience.items()}
new_experiences.append(experience)
logger.info("retrieval done")
return json.dumps(new_experiences, indent=4)

View file

@ -41,11 +41,13 @@ class WriteAnalysisCode(Action):
tool_info: str = "",
working_memory: list[Message] = None,
use_reflection: bool = False,
experiences: str = "",
**kwargs,
) -> str:
structual_prompt = STRUCTUAL_PROMPT.format(
user_requirement=user_requirement,
plan_status=plan_status,
experiences=experiences,
tool_info=tool_info,
)

View file

@ -0,0 +1,10 @@
TASK_CODE_DESCRIPTION_PROMPT = """
Please explain in a paragraph what the following code snippet does. Only the function of the code snippet needs to be explained, no variable names need to be explained.
Code snippet:
{code_snippet}
Code Execution Result:
{code_result}
Code Success or Failure:
{code_success}
"""

View file

@ -7,6 +7,10 @@ STRUCTUAL_PROMPT = """
# Plan Status
{plan_status}
# Reference experience (can be empty):
This is some previous coding experience that is similar to the current task. You can learn from the successful code and avoid the mistakes from the failed code. If there are other codes you don't know about in the experience, please don't refer to it.
{experiences}
# Tool Info
{tool_info}

View file

@ -38,6 +38,7 @@ class DataInterpreter(Role):
auto_run: bool = True
use_plan: bool = True
use_reflection: bool = False
use_experience: bool = False
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
@ -94,13 +95,13 @@ class DataInterpreter(Role):
await self.execute_code.terminate()
raise e
async def _act_on_task(self, current_task: Task) -> TaskResult:
async def _act_on_task(self, current_task: Task, experiences: str) -> TaskResult:
"""Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
code, result, is_success = await self._write_and_exec_code()
code, result, is_success = await self._write_and_exec_code(experiences=experiences)
task_result = TaskResult(code=code, result=result, is_success=is_success)
return task_result
async def _write_and_exec_code(self, max_retry: int = 3):
async def _write_and_exec_code(self, max_retry: int = 3, experiences: str = ""):
counter = 0
success = False
@ -122,7 +123,7 @@ class DataInterpreter(Role):
while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info)
code, cause_by = await self._write_code(counter, plan_status, tool_info, experiences = experiences if counter == 0 else "")
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
@ -148,6 +149,7 @@ class DataInterpreter(Role):
counter: int,
plan_status: str = "",
tool_info: str = "",
experiences: str = ""
):
todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
@ -161,6 +163,7 @@ class DataInterpreter(Role):
tool_info=tool_info,
working_memory=self.working_memory.get(),
use_reflection=use_reflection,
experiences = experiences
)
return code, todo

View file

@ -30,11 +30,12 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.di.use_experience import RetrieveExperiences, AddNewTrajectories
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.schema import Message, MessageQueue, SerializationMixin, TaskResult, Task
from metagpt.strategy.planner import Planner
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.project_repo import ProjectRepo
@ -490,11 +491,16 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
task = self.planner.current_task
logger.info(f"ready to take on task {task}")
# retrieve past tasks for this task
experiences = await RetrieveExperiences().run(query=task.instruction) if self.use_experience else ""
# take on current task
task_result = await self._act_on_task(task)
task_result = await self._act_on_task(task, experiences)
# process the result, such as reviewing, confirming, plan updating
await self.planner.process_task_result(task_result)
await AddNewTrajectories().run(self.planner)
rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response