diff --git a/config/config.yaml b/config/config.yaml index 72d2c0b19..080de4000 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -34,6 +34,15 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if use self-host open llm model with openai-compatible interface +#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" +#OPEN_LLM_API_MODEL: "llama2-13b" +# +##### if use Fireworks api +#FIREWORKS_API_KEY: "YOUR_API_KEY" +#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1" +#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat + #### for Search ## Supported values: serpapi/google/serper/ddg diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 7433c3857..cb5bd9ce1 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,17 +6,29 @@ @File : action.py """ +import typing from abc import ABC from typing import Optional -from tenacity import retry, stop_after_attempt, wait_fixed, after_log +from tenacity import retry, stop_after_attempt, wait_fixed, after_log, _utils from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.common import OutputParser -from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType,\ - retry_parse_json_text, extract_content_from_output +from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess + + +def action_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: + def log_it(retry_state: "RetryCallState") -> None: + if retry_state.fn is None: + fn_name = "" + else: + fn_name = _utils.get_callback_name(retry_state.fn) + logger.error(f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " + f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " + f"exp: {retry_state.outcome.exception()}") + return log_it class Action(ABC): @@ -53,7 +65,7 @@ class Action(ABC): @retry( stop=stop_after_attempt(3), wait=wait_fixed(1), - after=after_log(logger, logger.level("ERROR").name), + after=action_after_log(logger), ) async def _aask_v1( self, @@ -70,14 +82,9 @@ class Action(ABC): content = await self.llm.aask(prompt, system_msgs) logger.debug(f"llm raw output:\n{content}") output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - output_class_fields = list(output_class.schema()["properties"].keys()) # Custom ActionOutput's fields if format == "json": - content = repair_llm_raw_output(content, req_keys=output_class_fields + ["[/CONTENT]"]) - content = extract_content_from_output(content) - content = repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) # req_keys mocked - logger.info(f"extracted json CONTENT from output:\n{content}") - parsed_data = retry_parse_json_text(output=content) # should use output=content + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) diff --git a/metagpt/config.py b/metagpt/config.py index a4c43c28a..2ce75b013 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -46,10 +46,18 @@ class Config(metaclass=Singleton): self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") + + self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") + self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") + + self.fireworks_api_key = self._get("FIREWORKS_API_KEY") if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ - (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key): - raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first") + (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \ + (not self.open_llm_api_base) and \ + (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key): + raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: @@ -69,6 +77,9 @@ class Config(metaclass=Singleton): self.domain = self._get("DOMAIN") self.spark_url = self._get("SPARK_URL") + self.fireworks_api_base = self._get("FIREWORKS_API_BASE") + self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") + self.claude_api_key = self._get("Anthropic_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") diff --git a/metagpt/llm.py b/metagpt/llm.py index 4edcd7a83..1f7d1b4c9 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,12 +6,13 @@ @File : llm.py """ -from metagpt.logs import logger from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 as Claude from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI from metagpt.provider.spark_api import SparkAPI +from metagpt.provider.open_llm_api import OpenLLMGPTAPI +from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider @@ -26,6 +27,10 @@ def LLM() -> "BaseGPTAPI": llm = SparkAPI() elif CONFIG.zhipuai_api_key: llm = ZhiPuAIGPTAPI() + elif CONFIG.open_llm_api_base: + llm = OpenLLMGPTAPI() + elif CONFIG.fireworks_api_key: + llm = FireWorksGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py new file mode 100644 index 000000000..23126af2d --- /dev/null +++ b/metagpt/provider/fireworks_api.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : fireworks.ai's api + +import openai + +from metagpt.config import CONFIG +from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter + + +class FireWorksGPTAPI(OpenAIGPTAPI): + + def __init__(self): + self.__init_fireworks(CONFIG) + self.llm = openai + self.model = CONFIG.fireworks_api_model + self.auto_max_tokens = False + self._cost_manager = CostManager() + RateLimiter.__init__(self, rpm=self.rpm) + + def __init_fireworks(self, config: "Config"): + openai.api_key = config.fireworks_api_key + openai.api_base = config.fireworks_api_base + self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py new file mode 100644 index 000000000..a6820b42b --- /dev/null +++ b/metagpt/provider/open_llm_api.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : self-host open llm model with openai-compatible interface + +import openai + +from metagpt.logs import logger +from metagpt.config import CONFIG +from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter + + +class OpenLLMCostManager(CostManager): + """ open llm model is self-host, it's free and without cost""" + + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + + logger.info( + f"Max budget: ${CONFIG.max_budget:.3f} | " + f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + CONFIG.total_cost = self.total_cost + + +class OpenLLMGPTAPI(OpenAIGPTAPI): + + def __init__(self): + self.__init_openllm(CONFIG) + self.llm = openai + self.model = CONFIG.open_llm_api_model + self.auto_max_tokens = False + self._cost_manager = OpenLLMCostManager() + RateLimiter.__init__(self, rpm=self.rpm) + + def __init_openllm(self, config: "Config"): + openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value + openai.api_base = config.open_llm_api_base + self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py new file mode 100644 index 000000000..702a03194 --- /dev/null +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base llm postprocess plugin to do the operations like repair the raw llm output + +from typing import Union + +from metagpt.logs import logger +from metagpt.utils.repair_llm_raw_output import RepairType +from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, extract_content_from_output, \ + retry_parse_json_text + + +class BasePostPrecessPlugin(object): + + model = None # the plugin of the `model`, use to judge in `llm_postprecess` + + def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: + """ + repair steps + 1. repair the case sensitive problem using the schema's fields + 2. extract the content from the req_key pair( xx[REQ_KEY]xxx[/REQ_KEY]xx ) + 3. repair the invalid json text in the content + 4. parse the json text and repair it according to the exception with retry loop + """ + output_class_fields = list(schema["properties"].keys()) # Custom ActionOutput's fields + + content = self.run_repair_llm_raw_output(output, req_keys=output_class_fields + [req_key]) + content = self.run_extract_content_from_output(content, right_key=req_key) + # # req_keys mocked + content = self.run_repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) + parsed_data = self.run_retry_parse_json_text(content) + + return parsed_data + + def run_repair_llm_raw_output(self, content: str, req_keys: list[str], repair_type: str = None) -> str: + """ inherited class can re-implement the function""" + return repair_llm_raw_output(content, req_keys=req_keys, repair_type=repair_type) + + def run_extract_content_from_output(self, content: str, right_key: str) -> str: + """ inherited class can re-implement the function""" + return extract_content_from_output(content, right_key=right_key) + + def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: + """ inherited class can re-implement the function""" + logger.info(f"extracted json CONTENT from output:\n{content}") + parsed_data = retry_parse_json_text(output=content) # should use output=content + return parsed_data + + def run(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: + """ + this is used for prompt with a json-format output requirement and outer pair key, like + [REQ_KEY] + { + "Key": "value" + } + [/REQ_KEY] + + Args + outer (str): llm raw output + schema: output json schema + req_key: outer pair right key, usually in `[/REQ_KEY]` format + """ + assert len(schema.get("properties")) > 0 + assert "/" in req_key + + # current, postprocess only deal the repair_llm_raw_output + new_output = self.run_repair_llm_output( + output=output, + schema=schema, + req_key=req_key + ) + return new_output diff --git a/metagpt/provider/postprecess/llm_output_postprecess.py b/metagpt/provider/postprecess/llm_output_postprecess.py new file mode 100644 index 000000000..4b5955061 --- /dev/null +++ b/metagpt/provider/postprecess/llm_output_postprecess.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the entry of choosing which PostProcessPlugin to deal particular LLM model's output + +from typing import Union + +from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin + + +def llm_output_postprecess(output: str, schema: dict, req_key: str = "[/CONTENT]", + model_name: str = None) -> Union[dict, str]: + """ + default use BasePostPrecessPlugin if there is not matched plugin. + """ + # TODO choose different model's plugin according to the model_name + postprecess_plugin = BasePostPrecessPlugin() + + result = postprecess_plugin.run( + output=output, + schema=schema, + req_key=req_key + ) + return result diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index a12a36fcc..4a632b80c 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -91,13 +91,13 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - idx1 = sub_output.rfind("}") idx2 = sub_output.rindex("]") idx = idx1 if idx1 >= idx2 else idx2 - sub_output = sub_output[: idx] + sub_output = sub_output[: idx+1] return sub_output if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): # # avoid [req_key]xx[req_key] case to append [/req_key] output = output + "\n" + right_key - elif judge_potential_json(output, left_key): + elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): sub_content = judge_potential_json(output, left_key) output = sub_content + "\n" + right_key @@ -116,7 +116,7 @@ def repair_json_format(output: str) -> str: elif output.endswith("}]"): output = output[:-1] logger.info(f"repair_json_format: {'}]'}") - elif output.startswith("{") and output.startswith("]"): + elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" return output @@ -183,9 +183,11 @@ def repair_invalid_json(output: str, error: str) -> str: if line.endswith("],"): # problem, redundant char `]` line = line.replace("]", "") - elif line.endswith("},"): + elif line.endswith("},") and not output.endswith("},"): # problem, redundant char `}` line = line.replace("}", "") + elif line.endswith("},") and output.endswith("},"): + line = line[:-1] elif '",' not in line: line = f'{line}",' elif "," not in line: @@ -218,11 +220,10 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R """ if retry_state.outcome.failed: if len(retry_state.args) > 0: - # # can't used as args=retry_state.args + # # can't be used as args=retry_state.args func_param_output = retry_state.args[0] elif len(retry_state.kwargs) > 0: func_param_output = retry_state.kwargs.get("output", "") - # import pdb; pdb.set_trace() exp_str = str(retry_state.outcome.exception()) logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry " f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}") @@ -265,6 +266,7 @@ def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): break return cont.strip() + # TODO construct the extract pattern with the `right_key` raw_content = copy.deepcopy(content) pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" new_content = re_extract_content(raw_content, pattern) diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index dfcf60ad5..8779c965c 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -77,11 +77,11 @@ def test_required_key_pair_missing(): raw_output = '''[CONTENT] { - "a": "b" + "key": "value" ]''' target_output = '''[CONTENT] { - "a": "b" + "key": "value" ] [/CONTENT]''' @@ -92,17 +92,15 @@ def test_required_key_pair_missing(): raw_output = '''[CONTENT] tag [CONTENT] { - "a": "b" + "key": "value" } xxx ''' - target_output = '''[CONTENT] tag -[CONTENT] + target_output = '''[CONTENT] { - "a": "b" + "key": "value" } -[/CONTENT] -''' +[/CONTENT]''' output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"]) assert output == target_output @@ -117,6 +115,22 @@ def test_repair_json_format(): repair_type=RepairType.JSON) assert output == target_output + raw_output = "[{ xxx }" + target_output = "{ xxx }" + + output = repair_llm_raw_output(output=raw_output, + req_keys=[None], + repair_type=RepairType.JSON) + assert output == target_output + + raw_output = "{ xxx ]" + target_output = "{ xxx }" + + output = repair_llm_raw_output(output=raw_output, + req_keys=[None], + repair_type=RepairType.JSON) + assert output == target_output + def test_retry_parse_json_text(): invalid_json_text = """{ @@ -130,7 +144,7 @@ def test_retry_parse_json_text(): "Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis", "Requirement Analysis": "The requirements are clear and well-defined" } - output = retry_parse_json_text(invalid_json_text) + output = retry_parse_json_text(output=invalid_json_text) assert output == target_json invalid_json_text = """{ @@ -144,7 +158,7 @@ def test_retry_parse_json_text(): "Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis", "Requirement Analysis": "The requirements are clear and well-defined" } - output = retry_parse_json_text(invalid_json_text) + output = retry_parse_json_text(output=invalid_json_text) assert output == target_json