This commit is contained in:
seehi 2024-09-10 16:44:22 +08:00
parent 31dbff5474
commit 5d21d255e4
5 changed files with 95 additions and 10 deletions

View file

@ -57,11 +57,11 @@ class RoleZeroLongTermMemory(BaseModel):
if not query:
return []
nodes: list[NodeWithScore] = self.rag_engine.retrieve(query)
nodes = self.rag_engine.retrieve(query)
items = self._get_items_from_nodes(nodes)
memories = []
for node in nodes:
item: LongTermMemoryItem = node.metadata["obj"]
for item in items:
memories.append(item.user_message)
memories.append(item.ai_message)
@ -78,3 +78,11 @@ class RoleZeroLongTermMemory(BaseModel):
return
self.rag_engine.add_objs([item])
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
"""Get items from nodes and arrange them in order of their `created_at`."""
items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes]
items.sort(key=lambda item: item.created_at)
return items

View file

@ -285,12 +285,12 @@ class RoleZero(Role):
self._add_memory(AIMessage(content=self.command_rsp))
if not ok:
error_msg = commands
self._add_memory(UserMessage(content=error_msg))
self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand))
return error_msg
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
logger.info(f"Commands outputs: \n{outputs}")
self._add_memory(UserMessage(content=outputs))
self._add_memory(UserMessage(content=outputs, cause_by=RunCommand))
return AIMessage(
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
@ -605,7 +605,8 @@ class RoleZero(Role):
if not self._should_use_longterm_memory(k=k, k_memories=memories):
return memories
related_memories = self.longterm_memory.fetch(memories[-1].content)
query = self._build_longterm_memory_query(memories)
related_memories = self.longterm_memory.fetch(query)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
# Keep user and AI messages are paired.
@ -666,3 +667,11 @@ class RoleZero(Role):
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
return bool(memories and memories[0].is_ai_message())
def _build_longterm_memory_query(self, memories: list[Message]) -> str:
"""Build the content used to query related long-term memory.
Default is to get the most recent user message, or an empty string if none is found.
"""
return next((m.content for m in reversed(memories) if m.is_real_user_message()), "")

View file

@ -18,6 +18,7 @@ from __future__ import annotations
import asyncio
import json
import os.path
import time
import uuid
from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
@ -408,12 +409,15 @@ class Message(BaseModel):
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
return dynamic_class.model_validate(kvs)
def is_user_message(self):
def is_user_message(self) -> bool:
return self.role == "user"
def is_ai_message(self):
def is_ai_message(self) -> bool:
return self.role == "assistant"
def is_real_user_message(self) -> bool:
return self.is_user_message() and "UserRequirement" in self.cause_by
class UserMessage(Message):
"""便于支持OpenAI的消息
@ -966,6 +970,7 @@ class BaseEnum(Enum):
class LongTermMemoryItem(BaseModel):
user_message: Message
ai_message: Message
created_at: Optional[float] = Field(default_factory=time.time)
def rag_key(self) -> str:
return self.user_message.content