mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add independent openllm and fireworks config fields, add llm output postprecess plugin
This commit is contained in:
parent
fc5c01e219
commit
642335317b
10 changed files with 243 additions and 29 deletions
|
|
@ -34,6 +34,15 @@ RPM: 10
|
|||
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
|
||||
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if use self-host open llm model with openai-compatible interface
|
||||
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
|
||||
#OPEN_LLM_API_MODEL: "llama2-13b"
|
||||
#
|
||||
##### if use Fireworks api
|
||||
#FIREWORKS_API_KEY: "YOUR_API_KEY"
|
||||
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
|
||||
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
|
||||
|
||||
#### for Search
|
||||
|
||||
## Supported values: serpapi/google/serper/ddg
|
||||
|
|
|
|||
|
|
@ -6,17 +6,29 @@
|
|||
@File : action.py
|
||||
"""
|
||||
|
||||
import typing
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, after_log
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, _utils
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType,\
|
||||
retry_parse_json_text, extract_content_from_output
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
|
||||
|
||||
def action_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
def log_it(retry_state: "RetryCallState") -> None:
|
||||
if retry_state.fn is None:
|
||||
fn_name = "<unknown>"
|
||||
else:
|
||||
fn_name = _utils.get_callback_name(retry_state.fn)
|
||||
logger.error(f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
|
||||
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
|
||||
f"exp: {retry_state.outcome.exception()}")
|
||||
return log_it
|
||||
|
||||
|
||||
class Action(ABC):
|
||||
|
|
@ -53,7 +65,7 @@ class Action(ABC):
|
|||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(1),
|
||||
after=after_log(logger, logger.level("ERROR").name),
|
||||
after=action_after_log(logger),
|
||||
)
|
||||
async def _aask_v1(
|
||||
self,
|
||||
|
|
@ -70,14 +82,9 @@ class Action(ABC):
|
|||
content = await self.llm.aask(prompt, system_msgs)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
|
||||
output_class_fields = list(output_class.schema()["properties"].keys()) # Custom ActionOutput's fields
|
||||
|
||||
if format == "json":
|
||||
content = repair_llm_raw_output(content, req_keys=output_class_fields + ["[/CONTENT]"])
|
||||
content = extract_content_from_output(content)
|
||||
content = repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) # req_keys mocked
|
||||
logger.info(f"extracted json CONTENT from output:\n{content}")
|
||||
parsed_data = retry_parse_json_text(output=content) # should use output=content
|
||||
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]")
|
||||
|
||||
else: # using markdown parser
|
||||
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
|
||||
|
|
|
|||
|
|
@ -46,10 +46,18 @@ class Config(metaclass=Singleton):
|
|||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("Anthropic_API_KEY")
|
||||
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
|
||||
|
||||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
|
||||
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
|
||||
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
|
||||
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \
|
||||
(not self.open_llm_api_base) and \
|
||||
(not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first "
|
||||
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE")
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
if openai_proxy:
|
||||
|
|
@ -69,6 +77,9 @@ class Config(metaclass=Singleton):
|
|||
self.domain = self._get("DOMAIN")
|
||||
self.spark_url = self._get("SPARK_URL")
|
||||
|
||||
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
|
||||
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
|
||||
|
||||
self.claude_api_key = self._get("Anthropic_API_KEY")
|
||||
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
|
||||
self.serper_api_key = self._get("SERPER_API_KEY")
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@File : llm.py
|
||||
"""
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.anthropic_api import Claude2 as Claude
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
|
||||
|
|
@ -26,6 +27,10 @@ def LLM() -> "BaseGPTAPI":
|
|||
llm = SparkAPI()
|
||||
elif CONFIG.zhipuai_api_key:
|
||||
llm = ZhiPuAIGPTAPI()
|
||||
elif CONFIG.open_llm_api_base:
|
||||
llm = OpenLLMGPTAPI()
|
||||
elif CONFIG.fireworks_api_key:
|
||||
llm = FireWorksGPTAPI()
|
||||
else:
|
||||
raise RuntimeError("You should config a LLM configuration first")
|
||||
|
||||
|
|
|
|||
24
metagpt/provider/fireworks_api.py
Normal file
24
metagpt/provider/fireworks_api.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
|
||||
|
||||
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
|
||||
def __init__(self):
|
||||
self.__init_fireworks(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.fireworks_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_fireworks(self, config: "Config"):
|
||||
openai.api_key = config.fireworks_api_key
|
||||
openai.api_base = config.fireworks_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
47
metagpt/provider/open_llm_api.py
Normal file
47
metagpt/provider/open_llm_api.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
import openai
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
|
||||
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
""" open llm model is self-host, it's free and without cost"""
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openllm(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.open_llm_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self, config: "Config"):
|
||||
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
|
||||
openai.api_base = config.open_llm_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
72
metagpt/provider/postprecess/base_postprecess_plugin.py
Normal file
72
metagpt/provider/postprecess/base_postprecess_plugin.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : base llm postprocess plugin to do the operations like repair the raw llm output
|
||||
|
||||
from typing import Union
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.repair_llm_raw_output import RepairType
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, extract_content_from_output, \
|
||||
retry_parse_json_text
|
||||
|
||||
|
||||
class BasePostPrecessPlugin(object):
|
||||
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
|
||||
|
||||
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
|
||||
"""
|
||||
repair steps
|
||||
1. repair the case sensitive problem using the schema's fields
|
||||
2. extract the content from the req_key pair( xx[REQ_KEY]xxx[/REQ_KEY]xx )
|
||||
3. repair the invalid json text in the content
|
||||
4. parse the json text and repair it according to the exception with retry loop
|
||||
"""
|
||||
output_class_fields = list(schema["properties"].keys()) # Custom ActionOutput's fields
|
||||
|
||||
content = self.run_repair_llm_raw_output(output, req_keys=output_class_fields + [req_key])
|
||||
content = self.run_extract_content_from_output(content, right_key=req_key)
|
||||
# # req_keys mocked
|
||||
content = self.run_repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON)
|
||||
parsed_data = self.run_retry_parse_json_text(content)
|
||||
|
||||
return parsed_data
|
||||
|
||||
def run_repair_llm_raw_output(self, content: str, req_keys: list[str], repair_type: str = None) -> str:
|
||||
""" inherited class can re-implement the function"""
|
||||
return repair_llm_raw_output(content, req_keys=req_keys, repair_type=repair_type)
|
||||
|
||||
def run_extract_content_from_output(self, content: str, right_key: str) -> str:
|
||||
""" inherited class can re-implement the function"""
|
||||
return extract_content_from_output(content, right_key=right_key)
|
||||
|
||||
def run_retry_parse_json_text(self, content: str) -> Union[dict, list]:
|
||||
""" inherited class can re-implement the function"""
|
||||
logger.info(f"extracted json CONTENT from output:\n{content}")
|
||||
parsed_data = retry_parse_json_text(output=content) # should use output=content
|
||||
return parsed_data
|
||||
|
||||
def run(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
|
||||
"""
|
||||
this is used for prompt with a json-format output requirement and outer pair key, like
|
||||
[REQ_KEY]
|
||||
{
|
||||
"Key": "value"
|
||||
}
|
||||
[/REQ_KEY]
|
||||
|
||||
Args
|
||||
outer (str): llm raw output
|
||||
schema: output json schema
|
||||
req_key: outer pair right key, usually in `[/REQ_KEY]` format
|
||||
"""
|
||||
assert len(schema.get("properties")) > 0
|
||||
assert "/" in req_key
|
||||
|
||||
# current, postprocess only deal the repair_llm_raw_output
|
||||
new_output = self.run_repair_llm_output(
|
||||
output=output,
|
||||
schema=schema,
|
||||
req_key=req_key
|
||||
)
|
||||
return new_output
|
||||
23
metagpt/provider/postprecess/llm_output_postprecess.py
Normal file
23
metagpt/provider/postprecess/llm_output_postprecess.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the entry of choosing which PostProcessPlugin to deal particular LLM model's output
|
||||
|
||||
from typing import Union
|
||||
|
||||
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
|
||||
|
||||
|
||||
def llm_output_postprecess(output: str, schema: dict, req_key: str = "[/CONTENT]",
|
||||
model_name: str = None) -> Union[dict, str]:
|
||||
"""
|
||||
default use BasePostPrecessPlugin if there is not matched plugin.
|
||||
"""
|
||||
# TODO choose different model's plugin according to the model_name
|
||||
postprecess_plugin = BasePostPrecessPlugin()
|
||||
|
||||
result = postprecess_plugin.run(
|
||||
output=output,
|
||||
schema=schema,
|
||||
req_key=req_key
|
||||
)
|
||||
return result
|
||||
|
|
@ -91,13 +91,13 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -
|
|||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx]
|
||||
sub_output = sub_output[: idx+1]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key):
|
||||
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ def repair_json_format(output: str) -> str:
|
|||
elif output.endswith("}]"):
|
||||
output = output[:-1]
|
||||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.startswith("]"):
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
return output
|
||||
|
|
@ -183,9 +183,11 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
if line.endswith("],"):
|
||||
# problem, redundant char `]`
|
||||
line = line.replace("]", "")
|
||||
elif line.endswith("},"):
|
||||
elif line.endswith("},") and not output.endswith("},"):
|
||||
# problem, redundant char `}`
|
||||
line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
line = line[:-1]
|
||||
elif '",' not in line:
|
||||
line = f'{line}",'
|
||||
elif "," not in line:
|
||||
|
|
@ -218,11 +220,10 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if len(retry_state.args) > 0:
|
||||
# # can't used as args=retry_state.args
|
||||
# # can't be used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif len(retry_state.kwargs) > 0:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
# import pdb; pdb.set_trace()
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
|
||||
|
|
@ -265,6 +266,7 @@ def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
|
|||
break
|
||||
return cont.strip()
|
||||
|
||||
# TODO construct the extract pattern with the `right_key`
|
||||
raw_content = copy.deepcopy(content)
|
||||
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
|
|
|
|||
|
|
@ -77,11 +77,11 @@ def test_required_key_pair_missing():
|
|||
|
||||
raw_output = '''[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
"key": "value"
|
||||
]'''
|
||||
target_output = '''[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
"key": "value"
|
||||
]
|
||||
[/CONTENT]'''
|
||||
|
||||
|
|
@ -92,17 +92,15 @@ def test_required_key_pair_missing():
|
|||
raw_output = '''[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
"key": "value"
|
||||
}
|
||||
xxx
|
||||
'''
|
||||
target_output = '''[CONTENT] tag
|
||||
[CONTENT]
|
||||
target_output = '''[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
"key": "value"
|
||||
}
|
||||
[/CONTENT]
|
||||
'''
|
||||
[/CONTENT]'''
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
|
@ -117,6 +115,22 @@ def test_repair_json_format():
|
|||
repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = "[{ xxx }"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=[None],
|
||||
repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = "{ xxx ]"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=[None],
|
||||
repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
invalid_json_text = """{
|
||||
|
|
@ -130,7 +144,7 @@ def test_retry_parse_json_text():
|
|||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
|
||||
"Requirement Analysis": "The requirements are clear and well-defined"
|
||||
}
|
||||
output = retry_parse_json_text(invalid_json_text)
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
invalid_json_text = """{
|
||||
|
|
@ -144,7 +158,7 @@ def test_retry_parse_json_text():
|
|||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
|
||||
"Requirement Analysis": "The requirements are clear and well-defined"
|
||||
}
|
||||
output = retry_parse_json_text(invalid_json_text)
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue