mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
implement STRole.observe
This commit is contained in:
parent
f30946b557
commit
42c12ab277
5 changed files with 191 additions and 3 deletions
|
|
@ -10,3 +10,5 @@ class MazeEnvironment(Environment):
|
|||
def __init__(self, name: str, maze: Maze) -> None:
|
||||
self.name = name
|
||||
self.maze = maze
|
||||
|
||||
|
||||
|
|
@ -46,6 +46,9 @@ class BasicMemory(Message):
|
|||
self.keywords: list = keywords # keywords
|
||||
self.filling: list = filling # 装的与之相关联的memory_id的列表
|
||||
|
||||
def summary(self):
|
||||
return (self.subject, self.predicate, self.object)
|
||||
|
||||
def save_to_dict(self) -> dict:
|
||||
"""
|
||||
将MemoryBasic类转化为字典,用于存储json文件
|
||||
|
|
@ -316,3 +319,10 @@ class AgentMemory(Memory):
|
|||
|
||||
self.embeddings[embedding_pair[0]] = embedding_pair[1]
|
||||
return memory_node
|
||||
|
||||
|
||||
def get_summarized_latest_events(self, retention):
|
||||
ret_set = set()
|
||||
for e_node in self.event_list[:retention]:
|
||||
ret_set.add(e_node.summary())
|
||||
return ret_set
|
||||
|
|
@ -104,6 +104,24 @@ class MemoryTree:
|
|||
return x
|
||||
|
||||
|
||||
def add_tile_info(self, tile_info: dict):
|
||||
if tile_info["world"]:
|
||||
if (tile_info["world"] not in self.tree):
|
||||
self.tree[tile_info["world"]] = {}
|
||||
if tile_info["sector"]:
|
||||
if (tile_info["sector"] not in self.tree[tile_info["world"]]):
|
||||
self.tree[tile_info["world"]][tile_info["sector"]] = {}
|
||||
if tile_info["arena"]:
|
||||
if (tile_info["arena"] not in self.tree[tile_info["world"]]
|
||||
[tile_info["sector"]]):
|
||||
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = []
|
||||
if tile_info["game_object"]:
|
||||
if (tile_info["game_object"] not in self.tree[tile_info["world"]]
|
||||
[tile_info["sector"]]
|
||||
[tile_info["arena"]]):
|
||||
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [
|
||||
tile_info["game_object"]]
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ Do the steps following:
|
|||
- reflect, do the High-level thinking based on memories and re-add into the memory
|
||||
- execute, move or else in the Maze
|
||||
"""
|
||||
|
||||
import math
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from operator import itemgetter
|
||||
|
||||
from metagpt.roles.role import Role, RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -24,6 +25,7 @@ from ..actions.user_requirement import UserRequirement
|
|||
from ..maze_environment import MazeEnvironment
|
||||
from ..memory.retrieve import agent_retrieve
|
||||
from ..memory.scratch import Scratch
|
||||
from ..utils.utils import get_embedding, generate_poig_score
|
||||
|
||||
|
||||
class STRoleContext(RoleContext):
|
||||
|
|
@ -68,14 +70,150 @@ class STRole(Role):
|
|||
|
||||
async def observe(self):
|
||||
# TODO observe info from maze_env
|
||||
pass
|
||||
"""
|
||||
Perceive events around the role and saves it to the memory, both events
|
||||
and spaces.
|
||||
|
||||
We first perceive the events nearby the role, as determined by its
|
||||
<vision_r>. If there are a lot of events happening within that radius, we
|
||||
take the <att_bandwidth> of the closest events. Finally, we check whether
|
||||
any of them are new, as determined by <retention>. If they are new, then we
|
||||
save those and return the <BasicMemory> instances for those events.
|
||||
|
||||
OUTPUT:
|
||||
ret_events: a list of <BasicMemory> that are perceived and new.
|
||||
"""
|
||||
maze = self._rc.env.maze
|
||||
# PERCEIVE SPACE
|
||||
# We get the nearby tiles given our current tile and the persona's vision
|
||||
# radius.
|
||||
nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile,
|
||||
self._rc.scratch.vision_r)
|
||||
|
||||
# We then store the perceived space. Note that the s_mem of the persona is
|
||||
# in the form of a tree constructed using dictionaries.
|
||||
for tile in nearby_tiles:
|
||||
tile_info = maze.access_tile(tile)
|
||||
self._rc.spatial_memory.add_tile_info(tile_info)
|
||||
|
||||
# PERCEIVE EVENTS.
|
||||
# We will perceive events that take place in the same arena as the
|
||||
# persona's current arena.
|
||||
curr_arena_path = maze.get_tile_path(self._rc.scratch.curr_tile, "arena")
|
||||
# We do not perceive the same event twice (this can happen if an object is
|
||||
# extended across multiple tiles).
|
||||
percept_events_set = set()
|
||||
# We will order our percept based on the distance, with the closest ones
|
||||
# getting priorities.
|
||||
percept_events_list = []
|
||||
# First, we put all events that are occuring in the nearby tiles into the
|
||||
# percept_events_list
|
||||
for tile in nearby_tiles:
|
||||
tile_details = maze.access_tile(tile)
|
||||
if tile_details["events"]:
|
||||
if maze.get_tile_path(tile, "arena") == curr_arena_path:
|
||||
# This calculates the distance between the persona's current tile,
|
||||
# and the target tile.
|
||||
dist = math.dist([tile[0], tile[1]],
|
||||
[self._rc.scratch.curr_tile[0],
|
||||
self._rc.scratch.curr_tile[1]])
|
||||
# Add any relevant events to our temp set/list with the distant info.
|
||||
for event in tile_details["events"]:
|
||||
if event not in percept_events_set:
|
||||
percept_events_list += [[dist, event]]
|
||||
percept_events_set.add(event)
|
||||
|
||||
# We sort, and perceive only self._rc.scratch.att_bandwidth of the closest
|
||||
# events. If the bandwidth is larger, then it means the persona can perceive
|
||||
# more elements within a small area.
|
||||
percept_events_list = sorted(percept_events_list, key=itemgetter(0))
|
||||
perceived_events = []
|
||||
for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]:
|
||||
perceived_events += [event]
|
||||
|
||||
# Storing events.
|
||||
# <ret_events> is a list of <BasicMemory> instances from the persona's
|
||||
# associative memory.
|
||||
ret_events = []
|
||||
for p_event in perceived_events:
|
||||
s, p, o, desc = p_event
|
||||
if not p:
|
||||
# If the object is not present, then we default the event to "idle".
|
||||
p = "is"
|
||||
o = "idle"
|
||||
desc = "idle"
|
||||
desc = f"{s.split(':')[-1]} is {desc}"
|
||||
p_event = (s, p, o)
|
||||
|
||||
# We retrieve the latest self._rc.scratch.retention events. If there is
|
||||
# something new that is happening (that is, p_event not in latest_events),
|
||||
# then we add that event to the a_mem and return it.
|
||||
latest_events = self._rc.memory.get_summarized_latest_events(
|
||||
self._rc.scratch.retention)
|
||||
if p_event not in latest_events:
|
||||
# We start by managing keywords.
|
||||
keywords = set()
|
||||
sub = p_event[0]
|
||||
obj = p_event[2]
|
||||
if ":" in p_event[0]:
|
||||
sub = p_event[0].split(":")[-1]
|
||||
if ":" in p_event[2]:
|
||||
obj = p_event[2].split(":")[-1]
|
||||
keywords.update([sub, obj])
|
||||
|
||||
# Get event embedding
|
||||
desc_embedding_in = desc
|
||||
if "(" in desc:
|
||||
desc_embedding_in = (desc_embedding_in.split("(")[1]
|
||||
.split(")")[0]
|
||||
.strip())
|
||||
if desc_embedding_in in self._rc.memory.embeddings:
|
||||
event_embedding = self._rc.memory.embeddings[desc_embedding_in]
|
||||
else:
|
||||
event_embedding = get_embedding(desc_embedding_in)
|
||||
event_embedding_pair = (desc_embedding_in, event_embedding)
|
||||
|
||||
# Get event poignancy.
|
||||
event_poignancy = generate_poig_score(self,
|
||||
"action",
|
||||
desc_embedding_in)
|
||||
|
||||
# If we observe the persona's self chat, we include that in the memory
|
||||
# of the persona here.
|
||||
chat_node_ids = []
|
||||
if p_event[0] == f"{self.name}" and p_event[1] == "chat with":
|
||||
curr_event = self._rc.scratch.act_event
|
||||
if self._rc.scratch.act_description in self._rc.memory.embeddings:
|
||||
chat_embedding = self._rc.memory.embeddings[
|
||||
self._rc.scratch.act_description]
|
||||
else:
|
||||
chat_embedding = get_embedding(self._rc.scratch
|
||||
.act_description)
|
||||
chat_embedding_pair = (self._rc.scratch.act_description,
|
||||
chat_embedding)
|
||||
chat_poignancy = generate_poig_score(self._rc.scratch, "chat",
|
||||
self._rc.scratch.act_description)
|
||||
chat_node = self._rc.memory.add_chat(self._rc.scratch.curr_time, None,
|
||||
curr_event[0], curr_event[1], curr_event[2],
|
||||
self._rc.scratch.act_description, keywords,
|
||||
chat_poignancy, chat_embedding_pair,
|
||||
self._rc.scratch.chat)
|
||||
chat_node_ids = [chat_node.node_id]
|
||||
|
||||
# Finally, we add the current event to the agent's memory.
|
||||
ret_events += [self._rc.memory.add_event(self._rc.scratch.curr_time, None,
|
||||
s, p, o, desc, keywords, event_poignancy,
|
||||
event_embedding_pair, chat_node_ids)]
|
||||
self._rc.scratch.importance_trigger_curr -= event_poignancy
|
||||
self._rc.scratch.importance_ele_n += 1
|
||||
|
||||
return ret_events
|
||||
|
||||
async def retrieve(self, query, n = 30 ,topk = 4):
|
||||
# TODO retrieve memories from agent_memory
|
||||
retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk)
|
||||
return retrieve_memories
|
||||
|
||||
|
||||
async def plan(self):
|
||||
# TODO make a plan
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import json
|
|||
import openai
|
||||
from pathlib import Path
|
||||
import csv
|
||||
from ..prompts.run_gpt_prompts import get_poignancy_action, get_poignancy_chat
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding=None) -> list[Any]:
|
||||
|
|
@ -25,6 +26,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
|
|||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def embedding_tools(query):
|
||||
embedding_result = openai.Embedding.create(
|
||||
model="text-embedding-ada-002",
|
||||
|
|
@ -33,6 +35,7 @@ def embedding_tools(query):
|
|||
embedding_key = embedding_result['data'][0]["embedding"]
|
||||
return embedding_key
|
||||
|
||||
|
||||
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
||||
"""
|
||||
Reads in a csv file to a list of list. If header is True, it returns a
|
||||
|
|
@ -61,3 +64,20 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
|||
analysis_list += [row]
|
||||
return analysis_list[0], analysis_list[1:]
|
||||
|
||||
|
||||
def get_embedding(text, model: str="text-embedding-ada-002"):
|
||||
text = text.replace("\n", " ")
|
||||
if not text:
|
||||
text = "this is blank"
|
||||
return openai.Embedding.create(
|
||||
input=[text], model=model)['data'][0]['embedding']
|
||||
|
||||
|
||||
def generate_poig_score(scratch, event_type, description):
|
||||
if "is idle" in description:
|
||||
return 1
|
||||
if event_type == "action":
|
||||
return get_poignancy_action(scratch, description)[0]
|
||||
elif event_type == "chat":
|
||||
return get_poignancy_chat(scratch, description)[0]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue