update provider unittests to update coverage rate

This commit is contained in:
better629 2023-12-29 02:39:00 +08:00
parent 5fc8207950
commit 0f047e5693
26 changed files with 509 additions and 76 deletions

View file

@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
TAG = "CONTENT"
@ -275,7 +275,7 @@ class ActionNode:
output_class = self.create_model_class(output_class_name, output_data_mapping)
if schema == "json":
parsed_data = llm_output_postprecess(
parsed_data = llm_output_postprocess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
else: # using markdown parser

View file

@ -100,7 +100,7 @@ def log_info(message, **params):
def log_warn(message, **params):
msg = logfmt(dict(message=message, **params))
print(msg, file=sys.stderr)
logger.warn(msg)
logger.warning(msg)
def logfmt(props):

View file

@ -79,9 +79,6 @@ class GeminiLLM(BaseLLM):
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -31,7 +31,6 @@ class OpenLLMCostManager(CostManager):
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OPEN_LLM)

View file

@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import (
)
class BasePostPrecessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
class BasePostProcessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprocess`
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""

View file

@ -4,17 +4,17 @@
from typing import Union
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
def llm_output_postprecess(
def llm_output_postprocess(
output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None
) -> Union[dict, str]:
"""
default use BasePostPrecessPlugin if there is not matched plugin.
default use BasePostProcessPlugin if there is not matched plugin.
"""
# TODO choose different model's plugin according to the model_name
postprecess_plugin = BasePostPrecessPlugin()
postprocess_plugin = BasePostProcessPlugin()
result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key)
result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key)
return result

View file

@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI):
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
return f"{arr[0]}/api", f"/{arr[1]}"
@classmethod

View file

@ -68,9 +68,6 @@ class ZhiPuAILLM(BaseLLM):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]