Merge pull request #642 from iorisa/feature/unittest

feat: +unit test
This commit is contained in:
geekan 2023-12-28 18:15:15 +08:00 committed by GitHub
commit ec67964925
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 263 additions and 111 deletions

View file

@ -130,7 +130,8 @@ class CollectLinks(Action):
if len(remove) == 0:
break
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:

View file

@ -111,11 +111,7 @@ class Config(metaclass=Singleton):
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
model_name = model_mappings.get(provider)
model_name = self.get_model_name(provider=provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
@ -123,6 +119,14 @@ class Config(metaclass=Singleton):
return provider
raise NotConfiguredException("You should config a LLM configuration first")
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
return model_mappings.get(provider, "")
@staticmethod
def _is_valid_llm_key(k: str) -> bool:
return bool(k and k != "YOUR_API_KEY")

View file

@ -55,9 +55,9 @@ class BrainMemory(BaseModel):
return "\n".join(texts)
@staticmethod
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
async def loads(redis_key: str) -> "BrainMemory":
redis = Redis()
if not redis.is_valid or not redis_key:
return BrainMemory()
v = await redis.get(key=redis_key)
logger.debug(f"REDIS GET {redis_key} {v}")
@ -67,11 +67,11 @@ class BrainMemory(BaseModel):
return bm
return BrainMemory()
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
if not self.is_dirty:
return
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
redis = Redis()
if not redis.is_valid or not redis_key:
return False
v = self.model_dump_json()
if self.cacheable:
@ -86,26 +86,27 @@ class BrainMemory(BaseModel):
async def set_history_summary(self, history_summary, redis_key, redis_conf):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
await self.dumps(redis_key=redis_key)
self.is_dirty = False
return
self.historical_summary = history_summary
self.history = []
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
await self.dumps(redis_key=redis_key)
self.is_dirty = False
def add_history(self, msg: Message):
if msg.id:
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
return
self.history.append(msg.model_dump())
self.history.append(msg)
self.last_history_id = str(msg.id)
self.is_dirty = True
def exists(self, text) -> bool:
for m in reversed(self.history):
if m.get("content") == text:
if m.content == text:
return True
return False
@ -163,7 +164,7 @@ class BrainMemory(BaseModel):
msgs.reverse()
self.history = msgs
self.is_dirty = True
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
await self.dumps(redis_key=CONFIG.REDIS_KEY)
self.is_dirty = False
return BrainMemory.to_metagpt_history_format(self.history)
@ -217,7 +218,7 @@ class BrainMemory(BaseModel):
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
@staticmethod
async def _metagpt_rewrite(sentence: str):
async def _metagpt_rewrite(sentence: str, **kwargs):
return sentence
@staticmethod

View file

@ -43,7 +43,8 @@ class SerpAPIWrapper(BaseModel):
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query, max_results), as_string=as_string)
result = await self.results(query, max_results)
return self._process_response(result, as_string=as_string)
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""

View file

@ -63,5 +63,5 @@ class Redis:
self._client = None
@property
def is_valid(self):
return bool(self._client)
def is_valid(self) -> bool:
return self._client is not None