add independent openllm and fireworks config fields, add llm output postprecess plugin

This commit is contained in:
better629 2023-11-23 01:46:14 +08:00
parent fc5c01e219
commit 642335317b
10 changed files with 243 additions and 29 deletions

View file

@ -34,6 +34,15 @@ RPM: 10
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
#### if use self-host open llm model with openai-compatible interface
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
#OPEN_LLM_API_MODEL: "llama2-13b"
#
##### if use Fireworks api
#FIREWORKS_API_KEY: "YOUR_API_KEY"
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
#### for Search
## Supported values: serpapi/google/serper/ddg

View file

@ -6,17 +6,29 @@
@File : action.py
"""
import typing
from abc import ABC
from typing import Optional
from tenacity import retry, stop_after_attempt, wait_fixed, after_log
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, _utils
from metagpt.actions.action_output import ActionOutput
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.common import OutputParser
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType,\
retry_parse_json_text, extract_content_from_output
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
def action_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
def log_it(retry_state: "RetryCallState") -> None:
if retry_state.fn is None:
fn_name = "<unknown>"
else:
fn_name = _utils.get_callback_name(retry_state.fn)
logger.error(f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}")
return log_it
class Action(ABC):
@ -53,7 +65,7 @@ class Action(ABC):
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
after=after_log(logger, logger.level("ERROR").name),
after=action_after_log(logger),
)
async def _aask_v1(
self,
@ -70,14 +82,9 @@ class Action(ABC):
content = await self.llm.aask(prompt, system_msgs)
logger.debug(f"llm raw output:\n{content}")
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
output_class_fields = list(output_class.schema()["properties"].keys()) # Custom ActionOutput's fields
if format == "json":
content = repair_llm_raw_output(content, req_keys=output_class_fields + ["[/CONTENT]"])
content = extract_content_from_output(content)
content = repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) # req_keys mocked
logger.info(f"extracted json CONTENT from output:\n{content}")
parsed_data = retry_parse_json_text(output=content) # should use output=content
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]")
else: # using markdown parser
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)

View file

@ -46,10 +46,18 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \
(not self.open_llm_api_base) and \
(not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first "
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE")
self.openai_api_base = self._get("OPENAI_API_BASE")
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
if openai_proxy:
@ -69,6 +77,9 @@ class Config(metaclass=Singleton):
self.domain = self._get("DOMAIN")
self.spark_url = self._get("SPARK_URL")
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
self.claude_api_key = self._get("Anthropic_API_KEY")
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
self.serper_api_key = self._get("SERPER_API_KEY")

View file

@ -6,12 +6,13 @@
@File : llm.py
"""
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2 as Claude
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.human_provider import HumanProvider
@ -26,6 +27,10 @@ def LLM() -> "BaseGPTAPI":
llm = SparkAPI()
elif CONFIG.zhipuai_api_key:
llm = ZhiPuAIGPTAPI()
elif CONFIG.open_llm_api_base:
llm = OpenLLMGPTAPI()
elif CONFIG.fireworks_api_key:
llm = FireWorksGPTAPI()
else:
raise RuntimeError("You should config a LLM configuration first")

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : fireworks.ai's api
import openai
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
class FireWorksGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_fireworks(CONFIG)
self.llm = openai
self.model = CONFIG.fireworks_api_model
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_fireworks(self, config: "Config"):
openai.api_key = config.fireworks_api_key
openai.api_base = config.fireworks_api_base
self.rpm = int(config.get("RPM", 10))

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with openai-compatible interface
import openai
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
class OpenLLMCostManager(CostManager):
""" open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(
f"Max budget: ${CONFIG.max_budget:.3f} | "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_openllm(CONFIG)
self.llm = openai
self.model = CONFIG.open_llm_api_model
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openllm(self, config: "Config"):
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
openai.api_base = config.open_llm_api_base
self.rpm = int(config.get("RPM", 10))

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : base llm postprocess plugin to do the operations like repair the raw llm output
from typing import Union
from metagpt.logs import logger
from metagpt.utils.repair_llm_raw_output import RepairType
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, extract_content_from_output, \
retry_parse_json_text
class BasePostPrecessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""
repair steps
1. repair the case sensitive problem using the schema's fields
2. extract the content from the req_key pair( xx[REQ_KEY]xxx[/REQ_KEY]xx )
3. repair the invalid json text in the content
4. parse the json text and repair it according to the exception with retry loop
"""
output_class_fields = list(schema["properties"].keys()) # Custom ActionOutput's fields
content = self.run_repair_llm_raw_output(output, req_keys=output_class_fields + [req_key])
content = self.run_extract_content_from_output(content, right_key=req_key)
# # req_keys mocked
content = self.run_repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON)
parsed_data = self.run_retry_parse_json_text(content)
return parsed_data
def run_repair_llm_raw_output(self, content: str, req_keys: list[str], repair_type: str = None) -> str:
""" inherited class can re-implement the function"""
return repair_llm_raw_output(content, req_keys=req_keys, repair_type=repair_type)
def run_extract_content_from_output(self, content: str, right_key: str) -> str:
""" inherited class can re-implement the function"""
return extract_content_from_output(content, right_key=right_key)
def run_retry_parse_json_text(self, content: str) -> Union[dict, list]:
""" inherited class can re-implement the function"""
logger.info(f"extracted json CONTENT from output:\n{content}")
parsed_data = retry_parse_json_text(output=content) # should use output=content
return parsed_data
def run(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""
this is used for prompt with a json-format output requirement and outer pair key, like
[REQ_KEY]
{
"Key": "value"
}
[/REQ_KEY]
Args
outer (str): llm raw output
schema: output json schema
req_key: outer pair right key, usually in `[/REQ_KEY]` format
"""
assert len(schema.get("properties")) > 0
assert "/" in req_key
# current, postprocess only deal the repair_llm_raw_output
new_output = self.run_repair_llm_output(
output=output,
schema=schema,
req_key=req_key
)
return new_output

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the entry of choosing which PostProcessPlugin to deal particular LLM model's output
from typing import Union
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
def llm_output_postprecess(output: str, schema: dict, req_key: str = "[/CONTENT]",
model_name: str = None) -> Union[dict, str]:
"""
default use BasePostPrecessPlugin if there is not matched plugin.
"""
# TODO choose different model's plugin according to the model_name
postprecess_plugin = BasePostPrecessPlugin()
result = postprecess_plugin.run(
output=output,
schema=schema,
req_key=req_key
)
return result

View file

@ -91,13 +91,13 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx]
sub_output = sub_output[: idx+1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key):
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
@ -116,7 +116,7 @@ def repair_json_format(output: str) -> str:
elif output.endswith("}]"):
output = output[:-1]
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.startswith("]"):
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
return output
@ -183,9 +183,11 @@ def repair_invalid_json(output: str, error: str) -> str:
if line.endswith("],"):
# problem, redundant char `]`
line = line.replace("]", "")
elif line.endswith("},"):
elif line.endswith("},") and not output.endswith("},"):
# problem, redundant char `}`
line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
line = line[:-1]
elif '",' not in line:
line = f'{line}",'
elif "," not in line:
@ -218,11 +220,10 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
"""
if retry_state.outcome.failed:
if len(retry_state.args) > 0:
# # can't used as args=retry_state.args
# # can't be used as args=retry_state.args
func_param_output = retry_state.args[0]
elif len(retry_state.kwargs) > 0:
func_param_output = retry_state.kwargs.get("output", "")
# import pdb; pdb.set_trace()
exp_str = str(retry_state.outcome.exception())
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
@ -265,6 +266,7 @@ def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
break
return cont.strip()
# TODO construct the extract pattern with the `right_key`
raw_content = copy.deepcopy(content)
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)

View file

@ -77,11 +77,11 @@ def test_required_key_pair_missing():
raw_output = '''[CONTENT]
{
"a": "b"
"key": "value"
]'''
target_output = '''[CONTENT]
{
"a": "b"
"key": "value"
]
[/CONTENT]'''
@ -92,17 +92,15 @@ def test_required_key_pair_missing():
raw_output = '''[CONTENT] tag
[CONTENT]
{
"a": "b"
"key": "value"
}
xxx
'''
target_output = '''[CONTENT] tag
[CONTENT]
target_output = '''[CONTENT]
{
"a": "b"
"key": "value"
}
[/CONTENT]
'''
[/CONTENT]'''
output = repair_llm_raw_output(output=raw_output,
req_keys=["[/CONTENT]"])
assert output == target_output
@ -117,6 +115,22 @@ def test_repair_json_format():
repair_type=RepairType.JSON)
assert output == target_output
raw_output = "[{ xxx }"
target_output = "{ xxx }"
output = repair_llm_raw_output(output=raw_output,
req_keys=[None],
repair_type=RepairType.JSON)
assert output == target_output
raw_output = "{ xxx ]"
target_output = "{ xxx }"
output = repair_llm_raw_output(output=raw_output,
req_keys=[None],
repair_type=RepairType.JSON)
assert output == target_output
def test_retry_parse_json_text():
invalid_json_text = """{
@ -130,7 +144,7 @@ def test_retry_parse_json_text():
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
"Requirement Analysis": "The requirements are clear and well-defined"
}
output = retry_parse_json_text(invalid_json_text)
output = retry_parse_json_text(output=invalid_json_text)
assert output == target_json
invalid_json_text = """{
@ -144,7 +158,7 @@ def test_retry_parse_json_text():
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
"Requirement Analysis": "The requirements are clear and well-defined"
}
output = retry_parse_json_text(invalid_json_text)
output = retry_parse_json_text(output=invalid_json_text)
assert output == target_json