mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
update provider unittests to update coverage rate
This commit is contained in:
parent
5fc8207950
commit
0f047e5693
26 changed files with 509 additions and 76 deletions
|
|
@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
|
@ -275,7 +275,7 @@ class ActionNode:
|
|||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
if schema == "json":
|
||||
parsed_data = llm_output_postprecess(
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
else: # using markdown parser
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def log_info(message, **params):
|
|||
def log_warn(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
print(msg, file=sys.stderr)
|
||||
logger.warn(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def logfmt(props):
|
||||
|
|
|
|||
|
|
@ -79,9 +79,6 @@ class GeminiLLM(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ class OpenLLMCostManager(CostManager):
|
|||
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import (
|
|||
)
|
||||
|
||||
|
||||
class BasePostPrecessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
|
||||
class BasePostProcessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprocess`
|
||||
|
||||
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
|
||||
"""
|
||||
|
|
@ -4,17 +4,17 @@
|
|||
|
||||
from typing import Union
|
||||
|
||||
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
|
||||
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
|
||||
|
||||
|
||||
def llm_output_postprecess(
|
||||
def llm_output_postprocess(
|
||||
output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
default use BasePostPrecessPlugin if there is not matched plugin.
|
||||
default use BasePostProcessPlugin if there is not matched plugin.
|
||||
"""
|
||||
# TODO choose different model's plugin according to the model_name
|
||||
postprecess_plugin = BasePostPrecessPlugin()
|
||||
postprocess_plugin = BasePostProcessPlugin()
|
||||
|
||||
result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
return result
|
||||
|
|
@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -68,9 +68,6 @@ class ZhiPuAILLM(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue