mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 04:42:38 +02:00
pre-commit
This commit is contained in:
parent
e9603c6f03
commit
fd6eb8882e
4 changed files with 67 additions and 61 deletions
|
|
@ -1,16 +1,17 @@
|
|||
import json
|
||||
import chromadb
|
||||
|
||||
import chromadb
|
||||
from pydantic import BaseModel
|
||||
|
||||
from examples.rag_pipeline import TRAVEL_DOC_PATH
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.di.get_task_summary import TASK_CODE_DESCRIPTION_PROMPT
|
||||
from metagpt.schema import Task
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig
|
||||
from examples.rag_pipeline import TRAVEL_DOC_PATH
|
||||
from metagpt.schema import Task
|
||||
from metagpt.strategy.planner import Planner
|
||||
|
||||
|
||||
class Trajectory(BaseModel):
|
||||
|
|
@ -22,7 +23,8 @@ class Trajectory(BaseModel):
|
|||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.task.instruction
|
||||
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
code_summary: str = ""
|
||||
trajectory: Trajectory = None
|
||||
|
|
@ -31,6 +33,7 @@ class Experience(BaseModel):
|
|||
"""For search"""
|
||||
return self.code_summary
|
||||
|
||||
|
||||
EXPERIENCE_COLLECTION_NAME = "di_experience_0"
|
||||
TRAJECTORY_COLLECTION_NAME = "di_trajectory_0"
|
||||
PERSIST_PATH = SERDESER_PATH / "data_interpreter/chroma"
|
||||
|
|
@ -40,14 +43,11 @@ class AddNewExperiences(Action):
|
|||
name: str = "AddNewTaskExperiences"
|
||||
|
||||
def _init_engine(self, collection_name: str):
|
||||
"""Initialize a collection for storing code experiences
|
||||
"""
|
||||
"""Initialize a collection for storing code experiences"""
|
||||
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(
|
||||
persist_path=PERSIST_PATH,
|
||||
collection_name=collection_name)],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
|
||||
)
|
||||
|
||||
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
|
||||
|
|
@ -55,20 +55,23 @@ class AddNewExperiences(Action):
|
|||
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
|
||||
|
||||
return engine
|
||||
|
||||
async def _single_task_summary(self,trajectory_collection_name: str,experience_collection_name: str):
|
||||
|
||||
async def _single_task_summary(self, trajectory_collection_name: str, experience_collection_name: str):
|
||||
trajectory_engine = self._init_engine(collection_name=trajectory_collection_name)
|
||||
experience_engine = self._init_engine(collection_name=experience_collection_name)
|
||||
|
||||
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
|
||||
collection = db.get_or_create_collection(trajectory_collection_name)
|
||||
|
||||
unused_ids=[id for id in collection.get()["ids"] if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"]==False]
|
||||
|
||||
unused_ids = [
|
||||
id
|
||||
for id in collection.get()["ids"]
|
||||
if json.loads(collection.get([id])["metadatas"][0]["obj_json"])["is_used"] == False
|
||||
]
|
||||
trajectory_dicts = [json.loads(metadata["obj_json"]) for metadata in collection.get(unused_ids)["metadatas"]]
|
||||
trajectories = []
|
||||
experiences = []
|
||||
for trajectory_dict in trajectory_dicts:
|
||||
|
||||
trajectory_dict["is_used"] = True
|
||||
trajectory = Trajectory(**trajectory_dict)
|
||||
trajectories.append(trajectory)
|
||||
|
|
@ -76,12 +79,11 @@ class AddNewExperiences(Action):
|
|||
code_summary = await self.task_code_sumarization(trajectory)
|
||||
experience = Experience(code_summary=code_summary, trajectory=trajectory)
|
||||
experiences.append(experience)
|
||||
|
||||
|
||||
collection.delete(unused_ids)
|
||||
trajectory_engine.add_objs(trajectories)
|
||||
experience_engine.add_objs(experiences)
|
||||
|
||||
|
||||
async def task_code_sumarization(self, trajectory: Trajectory):
|
||||
"""Summarize the task code
|
||||
Args:
|
||||
|
|
@ -90,16 +92,18 @@ class AddNewExperiences(Action):
|
|||
A summary of the task code.
|
||||
"""
|
||||
task = trajectory.task
|
||||
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(code_snippet=task.code, code_result=task.result,
|
||||
code_success="Success" if task.is_success else "Failure")
|
||||
prompt = TASK_CODE_DESCRIPTION_PROMPT.format(
|
||||
code_snippet=task.code, code_result=task.result, code_success="Success" if task.is_success else "Failure"
|
||||
)
|
||||
resp = await self._aask(prompt=prompt)
|
||||
return resp
|
||||
|
||||
async def run(self,
|
||||
trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME,
|
||||
experience_collection_name: str=EXPERIENCE_COLLECTION_NAME,
|
||||
mode :str = "single_task_summary"
|
||||
):
|
||||
async def run(
|
||||
self,
|
||||
trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME,
|
||||
experience_collection_name: str = EXPERIENCE_COLLECTION_NAME,
|
||||
mode: str = "single_task_summary",
|
||||
):
|
||||
"""Initiate a collection and Add a new task experience to the collection
|
||||
|
||||
Args:
|
||||
|
|
@ -109,22 +113,23 @@ class AddNewExperiences(Action):
|
|||
|
||||
"""
|
||||
if mode == "single_task_summary":
|
||||
await self._single_task_summary(trajectory_collection_name=trajectory_collection_name,experience_collection_name=experience_collection_name)
|
||||
await self._single_task_summary(
|
||||
trajectory_collection_name=trajectory_collection_name,
|
||||
experience_collection_name=experience_collection_name,
|
||||
)
|
||||
else:
|
||||
pass # TODO:add other methods to generate experiences from trajectories
|
||||
|
||||
pass # TODO:add other methods to generate experiences from trajectories
|
||||
|
||||
|
||||
class AddNewTrajectories(Action):
|
||||
name: str = "AddNewTrajectories"
|
||||
|
||||
def _init_engine(self,collection_name: str):
|
||||
"""Initialize a collection for storing code experiences
|
||||
"""
|
||||
def _init_engine(self, collection_name: str):
|
||||
"""Initialize a collection for storing code experiences"""
|
||||
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(
|
||||
persist_path=PERSIST_PATH,
|
||||
collection_name=collection_name)],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
|
||||
)
|
||||
|
||||
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
|
||||
|
|
@ -132,8 +137,8 @@ class AddNewTrajectories(Action):
|
|||
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
|
||||
|
||||
return engine
|
||||
|
||||
async def run(self, planner: Planner, trajectory_collection_name: str=TRAJECTORY_COLLECTION_NAME):
|
||||
|
||||
async def run(self, planner: Planner, trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME):
|
||||
"""
|
||||
Initiate a collection and add new trajectories to the collection
|
||||
"""
|
||||
|
|
@ -141,19 +146,21 @@ class AddNewTrajectories(Action):
|
|||
|
||||
if not planner.plan.tasks:
|
||||
return
|
||||
|
||||
|
||||
user_requirement = planner.plan.goal
|
||||
task_map = planner.plan.task_map
|
||||
trajectories = [Trajectory(user_requirement=user_requirement, task_map=task_map, task=task, is_used=False) for task in planner.plan.tasks]
|
||||
|
||||
engine.add_objs(trajectories)
|
||||
trajectories = [
|
||||
Trajectory(user_requirement=user_requirement, task_map=task_map, task=task, is_used=False)
|
||||
for task in planner.plan.tasks
|
||||
]
|
||||
|
||||
engine.add_objs(trajectories)
|
||||
|
||||
|
||||
class RetrieveExperiences(Action):
|
||||
name: str = "RetrieveExperiences"
|
||||
|
||||
def _init_engine(self,collection_name: str,top_k: int):
|
||||
def _init_engine(self, collection_name: str, top_k: int):
|
||||
"""Initialize a SimpleEngine for retrieving experiences
|
||||
|
||||
Args:
|
||||
|
|
@ -163,10 +170,11 @@ class RetrieveExperiences(Action):
|
|||
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(
|
||||
persist_path=PERSIST_PATH,
|
||||
collection_name=collection_name,
|
||||
similarity_top_k=top_k)],
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=PERSIST_PATH, collection_name=collection_name, similarity_top_k=top_k
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
|
||||
|
|
@ -175,7 +183,9 @@ class RetrieveExperiences(Action):
|
|||
|
||||
return engine
|
||||
|
||||
async def run(self, query: str,experience_collection_name: str=EXPERIENCE_COLLECTION_NAME, top_k: int = 5) -> str:
|
||||
async def run(
|
||||
self, query: str, experience_collection_name: str = EXPERIENCE_COLLECTION_NAME, top_k: int = 5
|
||||
) -> str:
|
||||
"""Retrieve past attempted tasks
|
||||
|
||||
Args:
|
||||
|
|
@ -186,7 +196,7 @@ class RetrieveExperiences(Action):
|
|||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
engine = self._init_engine(collection_name=experience_collection_name,top_k=top_k)
|
||||
engine = self._init_engine(collection_name=experience_collection_name, top_k=top_k)
|
||||
|
||||
if len(query) <= 2: # not "" or not '""'
|
||||
return ""
|
||||
|
|
@ -208,7 +218,7 @@ class RetrieveExperiences(Action):
|
|||
"Code summary": code_summary,
|
||||
"Task result": trajectory.task.result,
|
||||
"Task outcome": "Success" if trajectory.task.is_success else "Failure",
|
||||
"Task ownership's requirement": "This task is part of " + trajectory.user_requirement
|
||||
"Task ownership's requirement": "This task is part of " + trajectory.user_requirement,
|
||||
}
|
||||
|
||||
# Replace the placeholder in the keys
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ Code Execution Result:
|
|||
{code_result}
|
||||
Code Success or Failure:
|
||||
{code_success}
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -123,7 +123,9 @@ class DataInterpreter(Role):
|
|||
|
||||
while not success and counter < max_retry:
|
||||
### write code ###
|
||||
code, cause_by = await self._write_code(counter, plan_status, tool_info, experiences = experiences if counter == 0 else "")
|
||||
code, cause_by = await self._write_code(
|
||||
counter, plan_status, tool_info, experiences=experiences if counter == 0 else ""
|
||||
)
|
||||
|
||||
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
|
||||
|
||||
|
|
@ -144,13 +146,7 @@ class DataInterpreter(Role):
|
|||
|
||||
return code, result, success
|
||||
|
||||
async def _write_code(
|
||||
self,
|
||||
counter: int,
|
||||
plan_status: str = "",
|
||||
tool_info: str = "",
|
||||
experiences: str = ""
|
||||
):
|
||||
async def _write_code(self, counter: int, plan_status: str = "", tool_info: str = "", experiences: str = ""):
|
||||
todo = self.rc.todo # todo is WriteAnalysisCode
|
||||
logger.info(f"ready to {todo.name}")
|
||||
use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial
|
||||
|
|
@ -163,7 +159,7 @@ class DataInterpreter(Role):
|
|||
tool_info=tool_info,
|
||||
working_memory=self.working_memory.get(),
|
||||
use_reflection=use_reflection,
|
||||
experiences = experiences
|
||||
experiences=experiences,
|
||||
)
|
||||
|
||||
return code, todo
|
||||
|
|
|
|||
|
|
@ -30,12 +30,12 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
|
|||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.di.use_experience import RetrieveExperiences, AddNewTrajectories
|
||||
from metagpt.actions.di.use_experience import AddNewTrajectories, RetrieveExperiences
|
||||
from metagpt.context_mixin import ContextMixin
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider import HumanProvider
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin, TaskResult, Task
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin, Task, TaskResult
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
|
@ -493,13 +493,13 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
# retrieve past tasks for this task
|
||||
experiences = await RetrieveExperiences().run(query=task.instruction) if self.use_experience else ""
|
||||
|
||||
|
||||
# take on current task
|
||||
task_result = await self._act_on_task(task, experiences)
|
||||
|
||||
# process the result, such as reviewing, confirming, plan updating
|
||||
await self.planner.process_task_result(task_result)
|
||||
|
||||
|
||||
await AddNewTrajectories().run(self.planner)
|
||||
|
||||
rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue