implement STRole.observe

This commit is contained in:
SereneWalden 2023-09-30 15:45:11 +08:00
parent f30946b557
commit 42c12ab277
5 changed files with 191 additions and 3 deletions

View file

@ -10,3 +10,5 @@ class MazeEnvironment(Environment):
def __init__(self, name: str, maze: Maze) -> None:
self.name = name
self.maze = maze

View file

@ -46,6 +46,9 @@ class BasicMemory(Message):
self.keywords: list = keywords # keywords
self.filling: list = filling # 装的与之相关联的memory_id的列表
def summary(self):
return (self.subject, self.predicate, self.object)
def save_to_dict(self) -> dict:
"""
将MemoryBasic类转化为字典用于存储json文件
@ -316,3 +319,10 @@ class AgentMemory(Memory):
self.embeddings[embedding_pair[0]] = embedding_pair[1]
return memory_node
def get_summarized_latest_events(self, retention):
ret_set = set()
for e_node in self.event_list[:retention]:
ret_set.add(e_node.summary())
return ret_set

View file

@ -104,6 +104,24 @@ class MemoryTree:
return x
def add_tile_info(self, tile_info: dict):
if tile_info["world"]:
if (tile_info["world"] not in self.tree):
self.tree[tile_info["world"]] = {}
if tile_info["sector"]:
if (tile_info["sector"] not in self.tree[tile_info["world"]]):
self.tree[tile_info["world"]][tile_info["sector"]] = {}
if tile_info["arena"]:
if (tile_info["arena"] not in self.tree[tile_info["world"]]
[tile_info["sector"]]):
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = []
if tile_info["game_object"]:
if (tile_info["game_object"] not in self.tree[tile_info["world"]]
[tile_info["sector"]]
[tile_info["arena"]]):
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [
tile_info["game_object"]]

View file

@ -10,9 +10,10 @@ Do the steps following:
- reflect, do the High-level thinking based on memories and re-add into the memory
- execute, move or else in the Maze
"""
import math
from pydantic import Field
from pathlib import Path
from operator import itemgetter
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
@ -24,6 +25,7 @@ from ..actions.user_requirement import UserRequirement
from ..maze_environment import MazeEnvironment
from ..memory.retrieve import agent_retrieve
from ..memory.scratch import Scratch
from ..utils.utils import get_embedding, generate_poig_score
class STRoleContext(RoleContext):
@ -68,14 +70,150 @@ class STRole(Role):
async def observe(self):
# TODO observe info from maze_env
pass
"""
Perceive events around the role and saves it to the memory, both events
and spaces.
We first perceive the events nearby the role, as determined by its
<vision_r>. If there are a lot of events happening within that radius, we
take the <att_bandwidth> of the closest events. Finally, we check whether
any of them are new, as determined by <retention>. If they are new, then we
save those and return the <BasicMemory> instances for those events.
OUTPUT:
ret_events: a list of <BasicMemory> that are perceived and new.
"""
maze = self._rc.env.maze
# PERCEIVE SPACE
# We get the nearby tiles given our current tile and the persona's vision
# radius.
nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile,
self._rc.scratch.vision_r)
# We then store the perceived space. Note that the s_mem of the persona is
# in the form of a tree constructed using dictionaries.
for tile in nearby_tiles:
tile_info = maze.access_tile(tile)
self._rc.spatial_memory.add_tile_info(tile_info)
# PERCEIVE EVENTS.
# We will perceive events that take place in the same arena as the
# persona's current arena.
curr_arena_path = maze.get_tile_path(self._rc.scratch.curr_tile, "arena")
# We do not perceive the same event twice (this can happen if an object is
# extended across multiple tiles).
percept_events_set = set()
# We will order our percept based on the distance, with the closest ones
# getting priorities.
percept_events_list = []
# First, we put all events that are occuring in the nearby tiles into the
# percept_events_list
for tile in nearby_tiles:
tile_details = maze.access_tile(tile)
if tile_details["events"]:
if maze.get_tile_path(tile, "arena") == curr_arena_path:
# This calculates the distance between the persona's current tile,
# and the target tile.
dist = math.dist([tile[0], tile[1]],
[self._rc.scratch.curr_tile[0],
self._rc.scratch.curr_tile[1]])
# Add any relevant events to our temp set/list with the distant info.
for event in tile_details["events"]:
if event not in percept_events_set:
percept_events_list += [[dist, event]]
percept_events_set.add(event)
# We sort, and perceive only self._rc.scratch.att_bandwidth of the closest
# events. If the bandwidth is larger, then it means the persona can perceive
# more elements within a small area.
percept_events_list = sorted(percept_events_list, key=itemgetter(0))
perceived_events = []
for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]:
perceived_events += [event]
# Storing events.
# <ret_events> is a list of <BasicMemory> instances from the persona's
# associative memory.
ret_events = []
for p_event in perceived_events:
s, p, o, desc = p_event
if not p:
# If the object is not present, then we default the event to "idle".
p = "is"
o = "idle"
desc = "idle"
desc = f"{s.split(':')[-1]} is {desc}"
p_event = (s, p, o)
# We retrieve the latest self._rc.scratch.retention events. If there is
# something new that is happening (that is, p_event not in latest_events),
# then we add that event to the a_mem and return it.
latest_events = self._rc.memory.get_summarized_latest_events(
self._rc.scratch.retention)
if p_event not in latest_events:
# We start by managing keywords.
keywords = set()
sub = p_event[0]
obj = p_event[2]
if ":" in p_event[0]:
sub = p_event[0].split(":")[-1]
if ":" in p_event[2]:
obj = p_event[2].split(":")[-1]
keywords.update([sub, obj])
# Get event embedding
desc_embedding_in = desc
if "(" in desc:
desc_embedding_in = (desc_embedding_in.split("(")[1]
.split(")")[0]
.strip())
if desc_embedding_in in self._rc.memory.embeddings:
event_embedding = self._rc.memory.embeddings[desc_embedding_in]
else:
event_embedding = get_embedding(desc_embedding_in)
event_embedding_pair = (desc_embedding_in, event_embedding)
# Get event poignancy.
event_poignancy = generate_poig_score(self,
"action",
desc_embedding_in)
# If we observe the persona's self chat, we include that in the memory
# of the persona here.
chat_node_ids = []
if p_event[0] == f"{self.name}" and p_event[1] == "chat with":
curr_event = self._rc.scratch.act_event
if self._rc.scratch.act_description in self._rc.memory.embeddings:
chat_embedding = self._rc.memory.embeddings[
self._rc.scratch.act_description]
else:
chat_embedding = get_embedding(self._rc.scratch
.act_description)
chat_embedding_pair = (self._rc.scratch.act_description,
chat_embedding)
chat_poignancy = generate_poig_score(self._rc.scratch, "chat",
self._rc.scratch.act_description)
chat_node = self._rc.memory.add_chat(self._rc.scratch.curr_time, None,
curr_event[0], curr_event[1], curr_event[2],
self._rc.scratch.act_description, keywords,
chat_poignancy, chat_embedding_pair,
self._rc.scratch.chat)
chat_node_ids = [chat_node.node_id]
# Finally, we add the current event to the agent's memory.
ret_events += [self._rc.memory.add_event(self._rc.scratch.curr_time, None,
s, p, o, desc, keywords, event_poignancy,
event_embedding_pair, chat_node_ids)]
self._rc.scratch.importance_trigger_curr -= event_poignancy
self._rc.scratch.importance_ele_n += 1
return ret_events
async def retrieve(self, query, n = 30 ,topk = 4):
# TODO retrieve memories from agent_memory
retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk)
return retrieve_memories
async def plan(self):
# TODO make a plan

View file

@ -7,6 +7,7 @@ import json
import openai
from pathlib import Path
import csv
from ..prompts.run_gpt_prompts import get_poignancy_action, get_poignancy_chat
def read_json_file(json_file: str, encoding=None) -> list[Any]:
@ -25,6 +26,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4)
def embedding_tools(query):
embedding_result = openai.Embedding.create(
model="text-embedding-ada-002",
@ -33,6 +35,7 @@ def embedding_tools(query):
embedding_key = embedding_result['data'][0]["embedding"]
return embedding_key
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
"""
Reads in a csv file to a list of list. If header is True, it returns a
@ -61,3 +64,20 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
analysis_list += [row]
return analysis_list[0], analysis_list[1:]
def get_embedding(text, model: str="text-embedding-ada-002"):
text = text.replace("\n", " ")
if not text:
text = "this is blank"
return openai.Embedding.create(
input=[text], model=model)['data'][0]['embedding']
def generate_poig_score(scratch, event_type, description):
if "is idle" in description:
return 1
if event_type == "action":
return get_poignancy_action(scratch, description)[0]
elif event_type == "chat":
return get_poignancy_chat(scratch, description)[0]