Merge pull request #33 from iorisa/feature/talk_prompt

This commit is contained in:
seeker-jie 2023-09-05 01:02:24 +08:00 committed by GitHub
commit 2781bd7019
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 13 deletions

View file

@ -26,7 +26,7 @@ class TalkAction(Action):
self._rsp = None
@property
def prompt(self):
def prompt_old(self):
prompt = ""
if CONFIG.agent_description:
prompt = (
@ -46,7 +46,7 @@ class TalkAction(Action):
return prompt
@property
def formation_prompt(self):
def prompt(self):
kvs = {
"{role}": CONFIG.agent_description or "",
"{history}": self._history_summary or "",
@ -57,6 +57,7 @@ class TalkAction(Action):
prompt = TalkAction.__FORMATION_LOOSE__
for k, v in kvs.items():
prompt = prompt.replace(k, v)
logger.info(f"PROMPT: {prompt}")
return prompt
async def run(self, *args, **kwargs) -> ActionOutput:

View file

@ -34,7 +34,7 @@ class BrainMemory(pydantic.BaseModel):
historical_summary: str = ""
last_history_id: str = ""
is_dirty: bool = False
last_talk: str = ""
last_talk: str = None
def add_talk(self, msg: Message):
msg.add_tag(MessageType.Talk.value)
@ -109,7 +109,6 @@ class BrainMemory(pydantic.BaseModel):
if msg.id:
if self.to_int(msg.id, 0) < self.to_int(self.last_history_id, -1):
return
self.last_history_id = str(self.to_int(msg.id, 0))
self.history.append(msg.dict())
self.is_dirty = True
@ -125,3 +124,8 @@ class BrainMemory(pydantic.BaseModel):
return int(v)
except:
return default_value
def pop_last_talk(self):
v = self.last_talk
self.last_talk = None
return v

View file

@ -226,21 +226,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
async def get_summary(self, text: str, max_words=200, keep_language: bool = False):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
while max_count > 0:
if len(text) < max_token_count:
if text_length < max_token_count:
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language)
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
raise openai.error.InvalidRequestError("text too long")

View file

@ -120,12 +120,12 @@ class Assistant(Role):
async def refine_memory(self) -> str:
history_text = self.memory.history_text
last_talk = self.memory.last_talk
last_talk = self.memory.pop_last_talk()
if last_talk is None: # No user feedback, unsure if past conversation is finished.
return None
if history_text == "":
return last_talk
history_summary = await self._llm.get_summary(history_text, max_words=500)
history_summary = await self._llm.get_summary(history_text, max_words=800, keep_language=True)
await self.memory.set_history_summary(
history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
)

View file

@ -4,6 +4,7 @@
# @Desc: { redis client }
# @Date: 2022/11/28 10:12
import json
import traceback
from datetime import timedelta
from enum import Enum
from typing import Awaitable, Callable, Dict, Optional, Union
@ -203,12 +204,19 @@ class Redis:
async def get(self, key: str) -> str:
if not self.is_valid() or not key:
return None
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
return v
try:
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
return v
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None
async def set(self, key: str, data: str, timeout_sec: int):
if not self.is_valid() or not key:
return
await RedisManager.set_with_cache_info(
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
)
try:
await RedisManager.set_with_cache_info(
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")