mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
Merge pull request #33 from iorisa/feature/talk_prompt
This commit is contained in:
commit
2781bd7019
5 changed files with 29 additions and 13 deletions
|
|
@ -26,7 +26,7 @@ class TalkAction(Action):
|
|||
self._rsp = None
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
def prompt_old(self):
|
||||
prompt = ""
|
||||
if CONFIG.agent_description:
|
||||
prompt = (
|
||||
|
|
@ -46,7 +46,7 @@ class TalkAction(Action):
|
|||
return prompt
|
||||
|
||||
@property
|
||||
def formation_prompt(self):
|
||||
def prompt(self):
|
||||
kvs = {
|
||||
"{role}": CONFIG.agent_description or "",
|
||||
"{history}": self._history_summary or "",
|
||||
|
|
@ -57,6 +57,7 @@ class TalkAction(Action):
|
|||
prompt = TalkAction.__FORMATION_LOOSE__
|
||||
for k, v in kvs.items():
|
||||
prompt = prompt.replace(k, v)
|
||||
logger.info(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
async def run(self, *args, **kwargs) -> ActionOutput:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class BrainMemory(pydantic.BaseModel):
|
|||
historical_summary: str = ""
|
||||
last_history_id: str = ""
|
||||
is_dirty: bool = False
|
||||
last_talk: str = ""
|
||||
last_talk: str = None
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
msg.add_tag(MessageType.Talk.value)
|
||||
|
|
@ -109,7 +109,6 @@ class BrainMemory(pydantic.BaseModel):
|
|||
if msg.id:
|
||||
if self.to_int(msg.id, 0) < self.to_int(self.last_history_id, -1):
|
||||
return
|
||||
self.last_history_id = str(self.to_int(msg.id, 0))
|
||||
self.history.append(msg.dict())
|
||||
self.is_dirty = True
|
||||
|
||||
|
|
@ -125,3 +124,8 @@ class BrainMemory(pydantic.BaseModel):
|
|||
return int(v)
|
||||
except:
|
||||
return default_value
|
||||
|
||||
def pop_last_talk(self):
|
||||
v = self.last_talk
|
||||
self.last_talk = None
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -226,21 +226,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
async def get_summary(self, text: str, max_words=200, keep_language: bool = False):
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
while max_count > 0:
|
||||
if len(text) < max_token_count:
|
||||
if text_length < max_token_count:
|
||||
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language)
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
return summaries[0]
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
raise openai.error.InvalidRequestError("text too long")
|
||||
|
|
|
|||
|
|
@ -120,12 +120,12 @@ class Assistant(Role):
|
|||
|
||||
async def refine_memory(self) -> str:
|
||||
history_text = self.memory.history_text
|
||||
last_talk = self.memory.last_talk
|
||||
last_talk = self.memory.pop_last_talk()
|
||||
if last_talk is None: # No user feedback, unsure if past conversation is finished.
|
||||
return None
|
||||
if history_text == "":
|
||||
return last_talk
|
||||
history_summary = await self._llm.get_summary(history_text, max_words=500)
|
||||
history_summary = await self._llm.get_summary(history_text, max_words=800, keep_language=True)
|
||||
await self.memory.set_history_summary(
|
||||
history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# @Desc: { redis client }
|
||||
# @Date: 2022/11/28 10:12
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Dict, Optional, Union
|
||||
|
|
@ -203,12 +204,19 @@ class Redis:
|
|||
async def get(self, key: str) -> str:
|
||||
if not self.is_valid() or not key:
|
||||
return None
|
||||
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
|
||||
return v
|
||||
try:
|
||||
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
|
||||
return v
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, data: str, timeout_sec: int):
|
||||
if not self.is_valid() or not key:
|
||||
return
|
||||
await RedisManager.set_with_cache_info(
|
||||
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
|
||||
)
|
||||
try:
|
||||
await RedisManager.set_with_cache_info(
|
||||
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue