pre-commit

This commit is contained in:
luxiangtao 2024-04-15 17:50:11 +08:00
parent e9603c6f03
commit fd6eb8882e
4 changed files with 67 additions and 61 deletions

View file

@ -1,16 +1,17 @@
import json
import chromadb
import chromadb
from pydantic import BaseModel
from examples.rag_pipeline import TRAVEL_DOC_PATH
from metagpt.actions import Action
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
from metagpt.prompts.di.get_task_summary import TASK_CODE_DESCRIPTION_PROMPT
from metagpt.schema import Task
from metagpt.strategy.planner import Planner
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig
from examples.rag_pipeline import TRAVEL_DOC_PATH
from metagpt.schema import Task
from metagpt.strategy.planner import Planner
class Trajectory(BaseModel):
@ -22,7 +23,8 @@ class Trajectory(BaseModel):
def rag_key(self) -> str:
"""For search"""
return self.task.instruction
class Experience(BaseModel):
code_summary: str = ""
trajectory: Trajectory = None
@ -31,6 +33,7 @@ class Experience(BaseModel):
"""For search"""
return self.code_summary
EXPERIENCE_COLLECTION_NAME = "di_experience_0"
TRAJECTORY_COLLECTION_NAME = "di_trajectory_0"
PERSIST_PATH = SERDESER_PATH / "data_interpreter/chroma"
@ -40,14 +43,11 @@ class AddNewExperiences(Action):
name: str = "AddNewTaskExperiences"
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences
"""
"""Initialize a collection for storing code experiences"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name)],
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
@ -55,20 +55,23 @@ class AddNewExperiences(Action):
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def _single_task_summary(self,trajectory_collection_name: str,experience_collection_name: str):
async def _single_task_summary(self, trajectory_collection_name: str, experience_collection_name: str):
trajectory_engine = self._init_engine(collection_name=trajectory_collection_name)
experience_engine = self._init_engine(collection_name=experience_collection_name)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
collection = db.get_or_create_collection(trajectory_collection_name)
unused_ids=[id for id in collection.get()["ids"] if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"]==False]
unused_ids = [
id
for id in collection.get()["ids"]
if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"] == False
]
trajectory_dicts = [json.loads(metadata["obj_json"]) for metadata in collection.get(unused_ids)["metadatas"]]
trajectories = []
experiences = []
for trajectory_dict in trajectory_dicts:
trajectory_dict["is_used"] = True
trajectory = Trajectory(**trajectory_dict)
trajectories.append(trajectory)
@ -76,12 +79,11 @@ class AddNewExperiences(Action):
code_summary = await self.task_code_sumarization(trajectory)
experience = Experience(code_summary=code_summary, trajectory=trajectory)
experiences.append(experience)
collection.delete(unused_ids)
trajectory_engine.add_objs(trajectories)
experience_engine.add_objs(experiences)
async def task_code_sumarization(self, trajectory: Trajectory):
"""Summarize the task code
Args:
@ -90,16 +92,18 @@ class AddNewExperiences(Action):
A summary of the task code.
"""
task = trajectory.task
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(code_snippet=task.code, code_result=task.result,
code_success="Success" if task.is_success else "Failure")
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(
code_snippet=task.code, code_result=task.result, code_success="Success" if task.is_success else "Failure"
)
resp = await self._aask(prompt=prompt)
return resp
async def run(self,
trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME,
experience_collection_name: str=EXPERIENCE_COLLECTION_NAME,
mode :str = "single_task_summary"
):
async def run(
self,
trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME,
experience_collection_name: str = EXPERIENCE_COLLECTION_NAME,
mode: str = "single_task_summary",
):
"""Initiate a collection and Add a new task experience to the collection
Args:
@ -109,22 +113,23 @@ class AddNewExperiences(Action):
"""
if mode == "single_task_summary":
await self._single_task_summary(trajectory_collection_name=trajectory_collection_name,experience_collection_name=experience_collection_name)
await self._single_task_summary(
trajectory_collection_name=trajectory_collection_name,
experience_collection_name=experience_collection_name,
)
else:
pass # TODO:add other methods to generate experiences from trajectories
pass # TODO:add other methods to generate experiences from trajectories
class AddNewTrajectories(Action):
name: str = "AddNewTrajectories"
def _init_engine(self,collection_name: str):
"""Initialize a collection for storing code experiences
"""
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name)],
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
@ -132,8 +137,8 @@ class AddNewTrajectories(Action):
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def run(self, planner: Planner, trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME):
async def run(self, planner: Planner, trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME):
"""
Initiate a collection and add new trajectories to the collection
"""
@ -141,19 +146,21 @@ class AddNewTrajectories(Action):
if not planner.plan.tasks:
return
user_requirement = planner.plan.goal
task_map = planner.plan.task_map
trajectories = [Trajectory(user_requirement=user_requirement, task_map=task_map, task=task, is_used=False) for task in planner.plan.tasks]
engine.add_objs(trajectories)
trajectories = [
Trajectory(user_requirement=user_requirement, task_map=task_map, task=task, is_used=False)
for task in planner.plan.tasks
]
engine.add_objs(trajectories)
class RetrieveExperiences(Action):
name: str = "RetrieveExperiences"
def _init_engine(self,collection_name: str,top_k: int):
def _init_engine(self, collection_name: str, top_k: int):
"""Initialize a SimpleEngine for retrieving experiences
Args:
@ -163,10 +170,11 @@ class RetrieveExperiences(Action):
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(
persist_path=PERSIST_PATH,
collection_name=collection_name,
similarity_top_k=top_k)],
retriever_configs=[
ChromaRetrieverConfig(
persist_path=PERSIST_PATH, collection_name=collection_name, similarity_top_k=top_k
)
],
)
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
@ -175,7 +183,9 @@ class RetrieveExperiences(Action):
return engine
async def run(self, query: str,experience_collection_name: str=EXPERIENCE_COLLECTION_NAME, top_k: int = 5) -> str:
async def run(
self, query: str, experience_collection_name: str = EXPERIENCE_COLLECTION_NAME, top_k: int = 5
) -> str:
"""Retrieve past attempted tasks
Args:
@ -186,7 +196,7 @@ class RetrieveExperiences(Action):
Returns:
_type_: _description_
"""
engine = self._init_engine(collection_name=experience_collection_name,top_k=top_k)
engine = self._init_engine(collection_name=experience_collection_name, top_k=top_k)
if len(query) <= 2: # not "" or not '""'
return ""
@ -208,7 +218,7 @@ class RetrieveExperiences(Action):
"Code summary": code_summary,
"Task result": trajectory.task.result,
"Task outcome": "Success" if trajectory.task.is_success else "Failure",
"Task ownership's requirement": "This task is part of " + trajectory.user_requirement
"Task ownership's requirement": "This task is part of " + trajectory.user_requirement,
}
# Replace the placeholder in the keys

View file

@ -7,4 +7,4 @@ Code Execution Result:
{code_result}
Code Success or Failure:
{code_success}
"""
"""

View file

@ -123,7 +123,9 @@ class DataInterpreter(Role):
while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info, experiences = experiences if counter == 0 else "")
code, cause_by = await self._write_code(
counter, plan_status, tool_info, experiences=experiences if counter == 0 else ""
)
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
@ -144,13 +146,7 @@ class DataInterpreter(Role):
return code, result, success
async def _write_code(
self,
counter: int,
plan_status: str = "",
tool_info: str = "",
experiences: str = ""
):
async def _write_code(self, counter: int, plan_status: str = "", tool_info: str = "", experiences: str = ""):
todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial
@ -163,7 +159,7 @@ class DataInterpreter(Role):
tool_info=tool_info,
working_memory=self.working_memory.get(),
use_reflection=use_reflection,
experiences = experiences
experiences=experiences,
)
return code, todo

View file

@ -30,12 +30,12 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.di.use_experience import RetrieveExperiences, AddNewTrajectories
from metagpt.actions.di.use_experience import AddNewTrajectories, RetrieveExperiences
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.schema import Message, MessageQueue, SerializationMixin, TaskResult, Task
from metagpt.schema import Message, MessageQueue, SerializationMixin, Task, TaskResult
from metagpt.strategy.planner import Planner
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.project_repo import ProjectRepo
@ -493,13 +493,13 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
# retrieve past tasks for this task
experiences = await RetrieveExperiences().run(query=task.instruction) if self.use_experience else ""
# take on current task
task_result = await self._act_on_task(task, experiences)
# process the result, such as reviewing, confirming, plan updating
await self.planner.process_task_result(task_result)
await AddNewTrajectories().run(self.planner)
rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response