From fd4c0fde265a3653503ccfff23df480cd6b67852 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 12:00:27 +0800 Subject: [PATCH 01/32] implement framework --- metagpt/configs/llm_config.py | 7 +- metagpt/provider/bedrock/.gitignore | 192 ++++++++++++++++++ metagpt/provider/bedrock/__init__.py | 0 .../provider/bedrock/amazon_bedrock_api.py | 57 ++++++ metagpt/provider/bedrock/base_provider.py | 3 + metagpt/provider/bedrock/bedrock_provide.py | 0 6 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 metagpt/provider/bedrock/.gitignore create mode 100644 metagpt/provider/bedrock/__init__.py create mode 100644 metagpt/provider/bedrock/amazon_bedrock_api.py create mode 100644 metagpt/provider/bedrock/base_provider.py create mode 100644 metagpt/provider/bedrock/bedrock_provide.py diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 222e116ee..2a04116f5 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -32,6 +32,7 @@ class LLMType(Enum): MISTRAL = "mistral" YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" + AMAZON_BEDROCK = "amazon_bedrock" def __missing__(self, key): return self.OPENAI @@ -74,10 +75,14 @@ class LLMConfig(YamlModel): best_of: Optional[int] = None n: Optional[int] = None stream: bool = False - logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs + # https://cookbook.openai.com/examples/using_logprobs + logprobs: Optional[bool] = None top_logprobs: Optional[int] = None timeout: int = 600 + # For Amazon Bedrock + region_name: str = None + # For Network proxy: Optional[str] = None diff --git a/metagpt/provider/bedrock/.gitignore b/metagpt/provider/bedrock/.gitignore new file mode 100644 index 000000000..971fcecb7 --- /dev/null +++ b/metagpt/provider/bedrock/.gitignore @@ -0,0 +1,192 @@ +### Python template + +# Byte-compiled / optimized / DLL files +__pycache__ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +metagpt/tools/schemas/ +examples/data/search_kb/*.json + +# PyInstaller +# Usually these files are written by a python scripts from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ +unittest.txt + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +logs +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# report +allure-report +allure-results + +# idea / vscode / macos +.idea +.DS_Store +.vscode + +key.yaml +/data/ +data.ms +examples/nb/ +examples/default__vector_store.json +examples/docstore.json +examples/graph_store.json +examples/image__vector_store.json +examples/index_store.json +.chroma +*~$* +workspace/* +tmp +metagpt/roles/idea_agent.py +.aider* +*.bak +*.bk + +# output folder +output +tmp.png +.dependencies.json +tests/metagpt/utils/file_repo_git +tests/data/rsp_cache_new.json +*.tmp +*.png +htmlcov +htmlcov.* +cov.xml +*.dot +*.pkl +*.faiss +*-structure.csv +*-structure.json +*.dot +.python-version +# aws access key +config.py \ No newline at end of file diff --git a/metagpt/provider/bedrock/__init__.py b/metagpt/provider/bedrock/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py new file mode 100644 index 000000000..ecdee4154 --- /dev/null +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -0,0 +1,57 @@ + +import json +from typing import Coroutine, Literal +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.provider.base_llm import BaseLLM +from metagpt.logs import log_llm_stream, logger +from botocore.config import Config +import boto3 + + +@register_provider([LLMType.AMAZON_BEDROCK]) +class AmazonBedrockLLM(BaseLLM): + def __init__(self, config: LLMConfig): + self.config = config + self.__client = self.__init_client("bedrock-runtime") + + def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): + # access key from https://us-east-1.console.aws.amazon.com/iam + self.__credentital_kwards = { + "aws_secret_access_key": self.config.secret_key, + "aws_access_key_id": self.config.access_key, + "region_name": self.config.region_name + } + session = boto3.Session(**self.__credentital_kwards) + client = session.client(service_name) + return client + + def list_models(self): + """see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html""" + client = self.__init_client("bedrock") + # only output text-generation models + response = client.list_foundation_models(byOutputModality='TEXT') + summaries = [f'{summary.get("modelId", ""):50} Support Streaming:{summary.get("responseStreamingSupported","")}' + for summary in response.get("modelSummaries", {})] + logger.info("\n"+"\n".join(summaries)) + + def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + pass + + def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + pass + + def completion(self, messages): + pass + + def acompletion(self, messages: list[dict]): + pass + + +if __name__ == '__main__': + from .config import my_config + prompt = "who are you?" + messages = [{"role": "user", "content": prompt}] + llm = AmazonBedrockLLM(my_config) + llm.list_models() diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py new file mode 100644 index 000000000..eaedfe045 --- /dev/null +++ b/metagpt/provider/bedrock/base_provider.py @@ -0,0 +1,3 @@ +from abc import ABC +class BaseBedrockProvider(ABC): + pass \ No newline at end of file diff --git a/metagpt/provider/bedrock/bedrock_provide.py b/metagpt/provider/bedrock/bedrock_provide.py new file mode 100644 index 000000000..e69de29bb From 4f14ee7ce143125e3190370d6ef997982d6bfdbc Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 12:30:20 +0800 Subject: [PATCH 02/32] implement base provider --- metagpt/provider/bedrock/base_provider.py | 30 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index eaedfe045..086569736 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,3 +1,27 @@ -from abc import ABC -class BaseBedrockProvider(ABC): - pass \ No newline at end of file + +import json + +class BaseBedrockProvider(object): + # to handle different generation kwargs + max_length = "max_tokens" + temperature = "temperature" + top_p = "top-p" + top_k = "top-k" + + def get_request_body(self, prompt, generate_kwargs: dict): + return {"prompt": prompt} | generate_kwargs + + def get_choice_text(self, response) -> str: + response_body = json.loads(response["body"].read()) + completions = response_body["content"]["outputs"][0]['text'] + return completions + + def messages_to_prompt(self, messages: list[dict]): + """[{"role": "user", "content": msg}] to user: etc.""" + return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + + def format_prompt(self, prompt: str) -> str: + return prompt + + def format_messages(self, messages: list[dict]) -> list[dict]: + return messages From ec7df8acdf4ae7ff768c5485e03028a403ee20be Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 13:47:31 +0800 Subject: [PATCH 03/32] support mistral --- metagpt/configs/llm_config.py | 3 + .../provider/bedrock/amazon_bedrock_api.py | 25 ++++++- metagpt/provider/bedrock/base_provider.py | 13 +--- metagpt/provider/bedrock/bedrock_provide.py | 0 metagpt/provider/bedrock/bedrock_provider.py | 73 +++++++++++++++++++ 5 files changed, 101 insertions(+), 13 deletions(-) delete mode 100644 metagpt/provider/bedrock/bedrock_provide.py create mode 100644 metagpt/provider/bedrock/bedrock_provider.py diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 2a04116f5..170005c21 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,3 +100,6 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT + + def get(self, key: str, default = None): + return getattr(self, key, default) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index ecdee4154..92b137a97 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -9,12 +9,15 @@ from metagpt.logs import log_llm_stream, logger from botocore.config import Config import boto3 +from metagpt.provider.bedrock.bedrock_provider import get_provider + @register_provider([LLMType.AMAZON_BEDROCK]) class AmazonBedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") + self.provider = get_provider(self.config.model) def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): # access key from https://us-east-1.console.aws.amazon.com/iam @@ -28,7 +31,6 @@ class AmazonBedrockLLM(BaseLLM): return client def list_models(self): - """see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html""" client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality='TEXT') @@ -36,14 +38,29 @@ class AmazonBedrockLLM(BaseLLM): for summary in response.get("modelSummaries", {})] logger.info("\n"+"\n".join(summaries)) + @property + def _generate_kwargs(self): + return { + "max_token": self.config.get("max_token", 1024), + "temperature": self.config.get("temperature", 0.3), + "top_p": self.config.get("top_p", 0.95), + "top_k": self.config.get("top_k", 1), + } + def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): pass def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): pass - def completion(self, messages): - pass + def completion(self, messages: list[dict]): + request_body = self.provider.get_request_body( + messages, **self._generate_kwargs) + response = self.__client.invoke_model( + modelId=self.config.model, body=request_body + ) + completions = self.provider.get_choice_text(response) + return completions def acompletion(self, messages: list[dict]): pass @@ -54,4 +71,4 @@ if __name__ == '__main__': prompt = "who are you?" messages = [{"role": "user", "content": prompt}] llm = AmazonBedrockLLM(my_config) - llm.list_models() + print(llm.completion(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 086569736..46f6ea58c 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,19 +1,14 @@ - import json + class BaseBedrockProvider(object): # to handle different generation kwargs - max_length = "max_tokens" - temperature = "temperature" - top_p = "top-p" - top_k = "top-k" - - def get_request_body(self, prompt, generate_kwargs: dict): - return {"prompt": prompt} | generate_kwargs + def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): + return json.dumps({"prompt": self.messages_to_prompt(messages)}) def get_choice_text(self, response) -> str: response_body = json.loads(response["body"].read()) - completions = response_body["content"]["outputs"][0]['text'] + completions = response_body["outputs"][0]['text'] return completions def messages_to_prompt(self, messages: list[dict]): diff --git a/metagpt/provider/bedrock/bedrock_provide.py b/metagpt/provider/bedrock/bedrock_provide.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py new file mode 100644 index 000000000..10ab66b34 --- /dev/null +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -0,0 +1,73 @@ +from metagpt.provider.bedrock.base_provider import BaseBedrockProvider +import json + + +class MistralProvider(BaseBedrockProvider): + + def format_prompt(self, prompt: str) -> str: + # for mixtral and llama + return f"[INST]{prompt}[/INST]" + + def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): + return json.dumps({ + "prompt": self.format_prompt(self.messages_to_prompt(messages)), + "max_tokens": max_token, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, }) + + +PROVIDERS = { + "mistral": MistralProvider() +} + +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct", + "ai21.j2-jumbo-instruct", + "ai21.j2-mid", + "ai21.j2-mid-v1", + "ai21.j2-ultra", + "ai21.j2-ultra-v1", +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large", + "amazon.titan-text-lite-v1:0:4k", + "amazon.titan-text-lite-v1", + "amazon.titan-text-express-v1:0:8k", + "amazon.titan-text-express-v1", + "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-instant-v1", + "anthropic.claude-v2:0:18k", + "anthropic.claude-v2:0:100k", + "anthropic.claude-v2:1:18k", + "anthropic.claude-v2:1:200k", + "anthropic.claude-v2:1", + "anthropic.claude-v2:2:18k", + "anthropic.claude-v2:2:200k", + "anthropic.claude-v2:2", + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0:28k", + "anthropic.claude-3-sonnet-20240229-v1:0:200k", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0:48k", + "anthropic.claude-3-haiku-20240307-v1:0:200k", + "anthropic.claude-3-haiku-20240307-v1:0", + "cohere.command-text-v14:7:4k", + "cohere.command-text-v14", + "cohere.command-light-text-v14:7:4k", + "cohere.command-light-text-v14", + "meta.llama2-70b-v1", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", +} + + +def get_provider(model_id: str): + model_name = model_id.split(".")[0] # meta、mistral…… + if model_name not in PROVIDERS: + raise KeyError(f"{model_name} is not supported!") + return PROVIDERS[model_name] From 4d1fb207855b446e8b22691711dab2c52691da5a Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 14:31:22 +0800 Subject: [PATCH 04/32] add generate_kwargs --- .../provider/bedrock/amazon_bedrock_api.py | 1 - metagpt/provider/bedrock/base_provider.py | 4 +-- metagpt/provider/bedrock/bedrock_provider.py | 25 +++++++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 92b137a97..11a2bb3f5 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -41,7 +41,6 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self): return { - "max_token": self.config.get("max_token", 1024), "temperature": self.config.get("temperature", 0.3), "top_p": self.config.get("top_p", 0.95), "top_k": self.config.get("top_k", 1), diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 46f6ea58c..9a6f9659c 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -3,8 +3,8 @@ import json class BaseBedrockProvider(object): # to handle different generation kwargs - def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): - return json.dumps({"prompt": self.messages_to_prompt(messages)}) + def get_request_body(self, messages, **generate_kwargs): + return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) def get_choice_text(self, response) -> str: response_body = json.loads(response["body"].read()) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 10ab66b34..3ae84d8c3 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -8,13 +8,24 @@ class MistralProvider(BaseBedrockProvider): # for mixtral and llama return f"[INST]{prompt}[/INST]" - def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): - return json.dumps({ - "prompt": self.format_prompt(self.messages_to_prompt(messages)), - "max_tokens": max_token, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, }) + def get_request_body(self, messages, **generate_kwargs): + return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs) + + +class AnthropicProvider(BaseBedrockProvider): + pass + + +class CohereProvider(BaseBedrockProvider): + pass + + +class MetaProvider(BaseBedrockProvider): + pass + + +class Ai21Provider(BaseBedrockProvider): + pass PROVIDERS = { From a7414884100f67fbf6e3aed0abaceecb30472873 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 15:16:05 +0800 Subject: [PATCH 05/32] add stream --- .../provider/bedrock/amazon_bedrock_api.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 11a2bb3f5..d8aaed8e9 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -42,16 +42,8 @@ class AmazonBedrockLLM(BaseLLM): def _generate_kwargs(self): return { "temperature": self.config.get("temperature", 0.3), - "top_p": self.config.get("top_p", 0.95), - "top_k": self.config.get("top_k", 1), } - def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - pass - - def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - pass - def completion(self, messages: list[dict]): request_body = self.provider.get_request_body( messages, **self._generate_kwargs) @@ -61,13 +53,38 @@ class AmazonBedrockLLM(BaseLLM): completions = self.provider.get_choice_text(response) return completions - def acompletion(self, messages: list[dict]): - pass + def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + request_body = self.provider.get_request_body( + messages, **self._generate_kwargs) + response = self.__client.invoke_model_with_response_stream( + modelId=self.config.model, body=request_body + ) + collected_content = [] + + for event in response.get("body"): + chunk_text = json.loads(event["chunk"]["bytes"])[ + "outputs"][0]["text"] + collected_content.append(chunk_text) + log_llm_stream(chunk_text) + + log_llm_stream("\n") + full_text = ("".join(collected_content)).lstrip() + return full_text + + async def acompletion(self, messages: list[dict]): + return self._achat_completion(messages) + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + # TODO:make it async + return self.completion(messages) + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + return self._chat_completion_stream(messages) if __name__ == '__main__': from .config import my_config - prompt = "who are you?" + prompt = "write an essay for living on mars in 1000 word" messages = [{"role": "user", "content": prompt}] llm = AmazonBedrockLLM(my_config) - print(llm.completion(messages)) + llm._chat_completion_stream(messages) From 9775a2b1eb77e82898e0c36a9205951ce81402f7 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 20:05:25 +0800 Subject: [PATCH 06/32] implement meta --- .../provider/bedrock/amazon_bedrock_api.py | 16 +++++++---- metagpt/provider/bedrock/base_provider.py | 14 +++++----- metagpt/provider/bedrock/bedrock_provider.py | 28 ++++++++++++------- metagpt/provider/bedrock/utils.py | 25 +++++++++++++++++ 4 files changed, 60 insertions(+), 23 deletions(-) create mode 100644 metagpt/provider/bedrock/utils.py diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index d8aaed8e9..123495da5 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -51,6 +51,7 @@ class AmazonBedrockLLM(BaseLLM): modelId=self.config.model, body=request_body ) completions = self.provider.get_choice_text(response) + log_llm_stream(completions) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): @@ -59,11 +60,10 @@ class AmazonBedrockLLM(BaseLLM): response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) - collected_content = [] - for event in response.get("body"): - chunk_text = json.loads(event["chunk"]["bytes"])[ - "outputs"][0]["text"] + collected_content = [] + for event in response["body"]: + chunk_text = self.provider.get_choice_text_from_stream(event) collected_content.append(chunk_text) log_llm_stream(chunk_text) @@ -84,7 +84,11 @@ class AmazonBedrockLLM(BaseLLM): if __name__ == '__main__': from .config import my_config - prompt = "write an essay for living on mars in 1000 word" - messages = [{"role": "user", "content": prompt}] + messages = [ + {"role": "system", "content": "your name is Bob"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello,my friend"}, + {"role": "user", "content": "What is your name?"}] llm = AmazonBedrockLLM(my_config) + llm.completion(messages) llm._chat_completion_stream(messages) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 9a6f9659c..3ecd5789a 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -7,16 +7,16 @@ class BaseBedrockProvider(object): return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) def get_choice_text(self, response) -> str: - response_body = json.loads(response["body"].read()) + response_body = self._get_response_body_json(response) completions = response_body["outputs"][0]['text'] return completions + def get_choice_text_from_stream(self, event): + return json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + + def _get_response_body_json(self, response): + return json.loads(response["body"].read()) + def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) - - def format_prompt(self, prompt: str) -> str: - return prompt - - def format_messages(self, messages: list[dict]) -> list[dict]: - return messages diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 3ae84d8c3..a50f9abed 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,18 +1,16 @@ -from metagpt.provider.bedrock.base_provider import BaseBedrockProvider import json +from metagpt.provider.bedrock.base_provider import BaseBedrockProvider +from metagpt.provider.bedrock.utils import messages_to_prompt_llama class MistralProvider(BaseBedrockProvider): - - def format_prompt(self, prompt: str) -> str: - # for mixtral and llama - return f"[INST]{prompt}[/INST]" - - def get_request_body(self, messages, **generate_kwargs): - return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs) + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + def messages_to_prompt(self, messages: list[dict]): + return messages_to_prompt_llama(messages) class AnthropicProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html pass @@ -21,7 +19,16 @@ class CohereProvider(BaseBedrockProvider): class MetaProvider(BaseBedrockProvider): - pass + def messages_to_prompt(self, messages: list[dict]): + return messages_to_prompt_llama(messages) + + def get_choice_text(self, response) -> str: + response_body = self._get_response_body_json(response) + completions = response_body['generation'] + return completions + + def get_choice_text_from_stream(self, event): + return json.loads(event["chunk"]["bytes"])["generation"] class Ai21Provider(BaseBedrockProvider): @@ -29,7 +36,8 @@ class Ai21Provider(BaseBedrockProvider): PROVIDERS = { - "mistral": MistralProvider() + "mistral": MistralProvider(), + "meta": MetaProvider(), } NOT_SUUPORT_STREAM_MODELS = { diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py new file mode 100644 index 000000000..7352a15a0 --- /dev/null +++ b/metagpt/provider/bedrock/utils.py @@ -0,0 +1,25 @@ +from metagpt.logs import logger + + +def messages_to_prompt_llama(messages: list[dict]): + BOS, EOS = "", "" + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + prompt = f"{BOS}" + for message in messages: + role = message["role"] + content = message["content"] + if role == "system": + prompt += f"{B_SYS} {content} {E_SYS}" + elif role == "user": + prompt += f"{B_INST} {content} {E_INST}" + elif role == "assistant": + prompt += f"{content}" + else: + logger.warning( + f"Unknown role name {role} when formatting messages") + prompt += f"{content}" + + return prompt + From 187e9ef698cc1820451aed917e19a57ddb61b7e0 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 20:24:39 +0800 Subject: [PATCH 07/32] support anthropic --- metagpt/provider/bedrock/base_provider.py | 10 +++++++--- metagpt/provider/bedrock/bedrock_provider.py | 17 +++++++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 3ecd5789a..2a17c335b 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,7 +4,9 @@ import json class BaseBedrockProvider(object): # to handle different generation kwargs def get_request_body(self, messages, **generate_kwargs): - return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) + body = json.dumps( + {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) + return body def get_choice_text(self, response) -> str: response_body = self._get_response_body_json(response) @@ -12,10 +14,12 @@ class BaseBedrockProvider(object): return completions def get_choice_text_from_stream(self, event): - return json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + completions = json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + return completions def _get_response_body_json(self, response): - return json.loads(response["body"].read()) + response_body = json.loads(response["body"].read()) + return response_body def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index a50f9abed..bbf3de223 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -11,7 +11,19 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - pass + def get_request_body(self, messages, **generate_kwargs): + body = json.dumps( + {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + return body + + def get_choice_text(self, response) -> str: + response_body = self._get_response_body_json(response) + completions = response_body["content"][0]['text'] + return completions + + def get_choice_text_from_stream(self, event): + completions = json.loads(event["chunk"]["bytes"])["content"][0]["text"] + return completions class CohereProvider(BaseBedrockProvider): @@ -28,7 +40,8 @@ class MetaProvider(BaseBedrockProvider): return completions def get_choice_text_from_stream(self, event): - return json.loads(event["chunk"]["bytes"])["generation"] + completions = json.loads(event["chunk"]["bytes"])["generation"] + return completions class Ai21Provider(BaseBedrockProvider): From 6355c5c0ed2db2ad612dbb0e33dc1136a8948567 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 20:55:59 +0800 Subject: [PATCH 08/32] implement all model --- .../provider/bedrock/amazon_bedrock_api.py | 20 +++-- metagpt/provider/bedrock/base_provider.py | 12 ++- metagpt/provider/bedrock/bedrock_provider.py | 84 +++++-------------- metagpt/provider/bedrock/utils.py | 44 ++++++++++ 4 files changed, 87 insertions(+), 73 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 123495da5..b6d12b8d9 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,15 +1,14 @@ import json -from typing import Coroutine, Literal +from typing import Literal from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger -from botocore.config import Config -import boto3 - from metagpt.provider.bedrock.bedrock_provider import get_provider +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS +import boto3 @register_provider([LLMType.AMAZON_BEDROCK]) @@ -40,8 +39,9 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self): + # for now only use temperature due to the difference of request body return { - "temperature": self.config.get("temperature", 0.3), + "temperature": self.config.get("temperature", 0.1), } def completion(self, messages: list[dict]): @@ -51,10 +51,14 @@ class AmazonBedrockLLM(BaseLLM): modelId=self.config.model, body=request_body ) completions = self.provider.get_choice_text(response) - log_llm_stream(completions) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + if self.config.model in NOT_SUUPORT_STREAM_MODELS: + logger.warning( + f"model {self.config.model} doesn't support streaming output!") + return self.completion(messages) + request_body = self.provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model_with_response_stream( @@ -90,5 +94,5 @@ if __name__ == '__main__': {"role": "assistant", "content": "hello,my friend"}, {"role": "user", "content": "What is your name?"}] llm = AmazonBedrockLLM(my_config) - llm.completion(messages) - llm._chat_completion_stream(messages) + print(llm.completion(messages)) + print(llm._chat_completion_stream(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 2a17c335b..4a96192d9 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,7 +1,8 @@ import json +from abc import ABC, abstractmethod -class BaseBedrockProvider(object): +class BaseBedrockProvider(ABC): # to handle different generation kwargs def get_request_body(self, messages, **generate_kwargs): body = json.dumps( @@ -10,17 +11,22 @@ class BaseBedrockProvider(object): def get_choice_text(self, response) -> str: response_body = self._get_response_body_json(response) - completions = response_body["outputs"][0]['text'] + completions = self._get_completion_from_dict(response_body) return completions def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = self._get_completion_from_dict(rsp_dict) return completions def _get_response_body_json(self, response): response_body = json.loads(response["body"].read()) return response_body + @abstractmethod + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + ... + def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index bbf3de223..d0fe42725 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -5,96 +5,56 @@ from metagpt.provider.bedrock.utils import messages_to_prompt_llama class MistralProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["outputs"][0]["text"] + class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + def get_request_body(self, messages, **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) - completions = response_body["content"][0]['text'] - return completions - - def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["content"][0]["text"] - return completions + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["content"][0]["text"] class CohereProvider(BaseBedrockProvider): - pass + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generations"][0]["text"] class MetaProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) - completions = response_body['generation'] - return completions - - def get_choice_text_from_stream(self, event): - completions = json.loads(event["chunk"]["bytes"])["generation"] - return completions + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generation"] class Ai21Provider(BaseBedrockProvider): - pass + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict['completions'][0]["data"]["text"] PROVIDERS = { "mistral": MistralProvider(), "meta": MetaProvider(), -} - -NOT_SUUPORT_STREAM_MODELS = { - "ai21.j2-grande-instruct", - "ai21.j2-jumbo-instruct", - "ai21.j2-mid", - "ai21.j2-mid-v1", - "ai21.j2-ultra", - "ai21.j2-ultra-v1", -} - -SUPPORT_STREAM_MODELS = { - "amazon.titan-tg1-large", - "amazon.titan-text-lite-v1:0:4k", - "amazon.titan-text-lite-v1", - "amazon.titan-text-express-v1:0:8k", - "amazon.titan-text-express-v1", - "anthropic.claude-instant-v1:2:100k", - "anthropic.claude-instant-v1", - "anthropic.claude-v2:0:18k", - "anthropic.claude-v2:0:100k", - "anthropic.claude-v2:1:18k", - "anthropic.claude-v2:1:200k", - "anthropic.claude-v2:1", - "anthropic.claude-v2:2:18k", - "anthropic.claude-v2:2:200k", - "anthropic.claude-v2:2", - "anthropic.claude-v2", - "anthropic.claude-3-sonnet-20240229-v1:0:28k", - "anthropic.claude-3-sonnet-20240229-v1:0:200k", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0:48k", - "anthropic.claude-3-haiku-20240307-v1:0:200k", - "anthropic.claude-3-haiku-20240307-v1:0", - "cohere.command-text-v14:7:4k", - "cohere.command-text-v14", - "cohere.command-light-text-v14:7:4k", - "cohere.command-light-text-v14", - "meta.llama2-70b-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", + "ai21": Ai21Provider(), + "cohere": CohereProvider(), + "anthropic": AnthropicProvider(), } diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 7352a15a0..57b83681c 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -23,3 +23,47 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt + +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct", + "ai21.j2-jumbo-instruct", + "ai21.j2-mid", + "ai21.j2-mid-v1", + "ai21.j2-ultra", + "ai21.j2-ultra-v1", +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large", + "amazon.titan-text-lite-v1:0:4k", + "amazon.titan-text-lite-v1", + "amazon.titan-text-express-v1:0:8k", + "amazon.titan-text-express-v1", + "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-instant-v1", + "anthropic.claude-v2:0:18k", + "anthropic.claude-v2:0:100k", + "anthropic.claude-v2:1:18k", + "anthropic.claude-v2:1:200k", + "anthropic.claude-v2:1", + "anthropic.claude-v2:2:18k", + "anthropic.claude-v2:2:200k", + "anthropic.claude-v2:2", + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0:28k", + "anthropic.claude-3-sonnet-20240229-v1:0:200k", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0:48k", + "anthropic.claude-3-haiku-20240307-v1:0:200k", + "anthropic.claude-3-haiku-20240307-v1:0", + "cohere.command-text-v14:7:4k", + "cohere.command-text-v14", + "cohere.command-light-text-v14:7:4k", + "cohere.command-light-text-v14", + "meta.llama2-70b-v1", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", +} From a6058ca629ccefc5770ff8b91801879dbb4e38ae Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 20:58:50 +0800 Subject: [PATCH 09/32] change provider to private --- metagpt/provider/bedrock/amazon_bedrock_api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index b6d12b8d9..7d615ec5e 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -16,7 +16,7 @@ class AmazonBedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") - self.provider = get_provider(self.config.model) + self.__provider = get_provider(self.config.model) def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): # access key from https://us-east-1.console.aws.amazon.com/iam @@ -33,8 +33,8 @@ class AmazonBedrockLLM(BaseLLM): client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality='TEXT') - summaries = [f'{summary.get("modelId", ""):50} Support Streaming:{summary.get("responseStreamingSupported","")}' - for summary in response.get("modelSummaries", {})] + summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' + for summary in response["modelSummaries"]] logger.info("\n"+"\n".join(summaries)) @property @@ -45,12 +45,12 @@ class AmazonBedrockLLM(BaseLLM): } def completion(self, messages: list[dict]): - request_body = self.provider.get_request_body( + request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model( modelId=self.config.model, body=request_body ) - completions = self.provider.get_choice_text(response) + completions = self.__provider.get_choice_text(response) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): @@ -59,7 +59,7 @@ class AmazonBedrockLLM(BaseLLM): f"model {self.config.model} doesn't support streaming output!") return self.completion(messages) - request_body = self.provider.get_request_body( + request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body @@ -67,7 +67,7 @@ class AmazonBedrockLLM(BaseLLM): collected_content = [] for event in response["body"]: - chunk_text = self.provider.get_choice_text_from_stream(event) + chunk_text = self.__provider.get_choice_text_from_stream(event) collected_content.append(chunk_text) log_llm_stream(chunk_text) From f45a379183fe19eec7fcd2e1288882691bac2a01 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 21:16:22 +0800 Subject: [PATCH 10/32] add titan --- metagpt/provider/bedrock/amazon_bedrock_api.py | 12 +----------- metagpt/provider/bedrock/base_provider.py | 9 +++++---- metagpt/provider/bedrock/bedrock_provider.py | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 7d615ec5e..3b2ea0b81 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -7,7 +7,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS import boto3 @@ -86,13 +86,3 @@ class AmazonBedrockLLM(BaseLLM): return self._chat_completion_stream(messages) -if __name__ == '__main__': - from .config import my_config - messages = [ - {"role": "system", "content": "your name is Bob"}, - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hello,my friend"}, - {"role": "user", "content": "What is your name?"}] - llm = AmazonBedrockLLM(my_config) - print(llm.completion(messages)) - print(llm._chat_completion_stream(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 4a96192d9..c591549ce 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,6 +4,11 @@ from abc import ABC, abstractmethod class BaseBedrockProvider(ABC): # to handle different generation kwargs + + @abstractmethod + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + ... + def get_request_body(self, messages, **generate_kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) @@ -23,10 +28,6 @@ class BaseBedrockProvider(ABC): response_body = json.loads(response["body"].read()) return response_body - @abstractmethod - def _get_completion_from_dict(self, rsp_dict: dict) -> str: - ... - def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index d0fe42725..e2dba9223 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -49,12 +49,30 @@ class Ai21Provider(BaseBedrockProvider): return rsp_dict['completions'][0]["data"]["text"] +class AmazonProvider(BaseBedrockProvider): + def get_request_body(self, messages, **generate_kwargs): + body = json.dumps({ + "inputText": self.messages_to_prompt(messages), + "textGenerationConfig": generate_kwargs + }) + return body + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict['results'][0]['outputText'].strip() + + def get_choice_text_from_stream(self, event): + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict["outputText"] + return completions + + PROVIDERS = { "mistral": MistralProvider(), "meta": MetaProvider(), "ai21": Ai21Provider(), "cohere": CohereProvider(), "anthropic": AnthropicProvider(), + "amazon": AmazonProvider() } From aded1dc2ed3000ae5e0be39344935c01f17d8781 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 21:23:25 +0800 Subject: [PATCH 11/32] add some type hint --- metagpt/provider/bedrock/amazon_bedrock_api.py | 10 +++------- metagpt/provider/bedrock/base_provider.py | 6 +++--- metagpt/provider/bedrock/bedrock_provider.py | 8 +++++--- metagpt/provider/bedrock/utils.py | 1 - 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 3b2ea0b81..2a72de019 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,5 +1,3 @@ - -import json from typing import Literal from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider @@ -38,13 +36,13 @@ class AmazonBedrockLLM(BaseLLM): logger.info("\n"+"\n".join(summaries)) @property - def _generate_kwargs(self): + def _generate_kwargs(self) -> dict: # for now only use temperature due to the difference of request body return { "temperature": self.config.get("temperature", 0.1), } - def completion(self, messages: list[dict]): + def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model( @@ -53,7 +51,7 @@ class AmazonBedrockLLM(BaseLLM): completions = self.__provider.get_choice_text(response) return completions - def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning( f"model {self.config.model} doesn't support streaming output!") @@ -84,5 +82,3 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) - - diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index c591549ce..724cbb669 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -9,7 +9,7 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages, **generate_kwargs): + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body @@ -19,7 +19,7 @@ class BaseBedrockProvider(ABC): completions = self._get_completion_from_dict(response_body) return completions - def get_choice_text_from_stream(self, event): + def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = self._get_completion_from_dict(rsp_dict) return completions @@ -28,6 +28,6 @@ class BaseBedrockProvider(ABC): response_body = json.loads(response["body"].read()) return response_body - def messages_to_prompt(self, messages: list[dict]): + def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index e2dba9223..47348c083 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -16,7 +16,7 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages, **generate_kwargs): + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body @@ -50,7 +50,9 @@ class Ai21Provider(BaseBedrockProvider): class AmazonProvider(BaseBedrockProvider): - def get_request_body(self, messages, **generate_kwargs): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps({ "inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs @@ -60,7 +62,7 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['results'][0]['outputText'].strip() - def get_choice_text_from_stream(self, event): + def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict["outputText"] return completions diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 57b83681c..58236157d 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,6 +1,5 @@ from metagpt.logs import logger - def messages_to_prompt_llama(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" From 6452adf82aa16b10ad19d76377672da486aa17ca Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 21:47:57 +0800 Subject: [PATCH 12/32] update provider package --- metagpt/configs/llm_config.py | 4 +--- metagpt/provider/__init__.py | 2 ++ metagpt/provider/bedrock/amazon_bedrock_api.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 170005c21..ae8c57ec5 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,6 +100,4 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT - - def get(self, key: str, default = None): - return getattr(self, key, default) + diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 14d5e7682..dd5b4f89d 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM +from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM __all__ = [ "GeminiLLM", @@ -30,4 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", + "AmazonBedrockLLM" ] diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 2a72de019..7262a11be 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -39,7 +39,7 @@ class AmazonBedrockLLM(BaseLLM): def _generate_kwargs(self) -> dict: # for now only use temperature due to the difference of request body return { - "temperature": self.config.get("temperature", 0.1), + "temperature": self.config.temperature } def completion(self, messages: list[dict]) -> str: @@ -74,11 +74,12 @@ class AmazonBedrockLLM(BaseLLM): return full_text async def acompletion(self, messages: list[dict]): + # Amazon bedrock doesn't support async now return self._achat_completion(messages) async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - # TODO:make it async return self.completion(messages) async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) + From 784c6265ee89f8e26271559009358bd1ffed6d25 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 23:41:21 +0800 Subject: [PATCH 13/32] update max tokens and support max_tokens field --- .../provider/bedrock/amazon_bedrock_api.py | 18 ++++- metagpt/provider/bedrock/base_provider.py | 1 + metagpt/provider/bedrock/bedrock_provider.py | 3 + metagpt/provider/bedrock/utils.py | 75 ++++++++----------- 4 files changed, 53 insertions(+), 44 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 7262a11be..e4382f7bd 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -5,7 +5,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens import boto3 @@ -15,6 +15,7 @@ class AmazonBedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) + logger.warning("Amazon bedrock doesn't support async now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): # access key from https://us-east-1.console.aws.amazon.com/iam @@ -38,7 +39,13 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self) -> dict: # for now only use temperature due to the difference of request body + model_max_tokens = get_max_tokens(self.config.model) + if self.config.max_token > model_max_tokens: + max_tokens = model_max_tokens + else: + max_tokens = self.config.max_token return { + self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature } @@ -59,6 +66,7 @@ class AmazonBedrockLLM(BaseLLM): request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) + response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) @@ -75,7 +83,13 @@ class AmazonBedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]): # Amazon bedrock doesn't support async now - return self._achat_completion(messages) + return await self._achat_completion(messages) + + async def acompletion_text(self, messages: list[dict], stream: bool = False, + timeout: int = USE_CONFIG_TIMEOUT) -> str: + if stream: + return await self._achat_completion_stream(messages) + return await self._achat_completion(messages) async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self.completion(messages) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 724cbb669..c24556645 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod class BaseBedrockProvider(ABC): # to handle different generation kwargs + max_tokens_field_name = "max_tokens" @abstractmethod def _get_completion_from_dict(self, rsp_dict: dict) -> str: diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 47348c083..2aa90c7ee 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -34,6 +34,7 @@ class CohereProvider(BaseBedrockProvider): class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + max_tokens_field_name = "max_gen_len" def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) @@ -44,6 +45,7 @@ class MetaProvider(BaseBedrockProvider): class Ai21Provider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + max_tokens_field_name = "maxTokens" def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['completions'][0]["data"]["text"] @@ -51,6 +53,7 @@ class Ai21Provider(BaseBedrockProvider): class AmazonProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + max_tokens_field_name = "maxTokenCount" def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps({ diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 58236157d..2df0bf163 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,5 +1,35 @@ from metagpt.logs import logger +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct": 8000, + "ai21.j2-jumbo-instruct": 8000, + "ai21.j2-mid": 8000, + "ai21.j2-mid-v1": 8000, + "ai21.j2-ultra": 8000, + "ai21.j2-ultra-v1": 8000, +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large": 8000, + "amazon.titan-text-express-v1": 8000, + "anthropic.claude-instant-v1": 100000, + "anthropic.claude-v1": 100000, + "anthropic.claude-v2": 100000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-haiku-20240307-v1:0": 200000, + "anthropic.claude-3-opus-20240229-v1:0": 200000, + "cohere.command-text-v14": 4096, + "cohere.command-light-text-v14": 4096, + "meta.llama2-70b-v1": 4096, + "meta.llama3-8b-instruct-v1:0": 2000, + "meta.llama3-70b-instruct-v1:0": 2000, + "mistral.mistral-7b-instruct-v0:2": 32000, + "mistral.mixtral-8x7b-instruct-v0:1": 32000, + "mistral.mistral-large-2402-v1:0": 32000, +} + + def messages_to_prompt_llama(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" @@ -23,46 +53,7 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt -NOT_SUUPORT_STREAM_MODELS = { - "ai21.j2-grande-instruct", - "ai21.j2-jumbo-instruct", - "ai21.j2-mid", - "ai21.j2-mid-v1", - "ai21.j2-ultra", - "ai21.j2-ultra-v1", -} +def get_max_tokens(model_id) -> int: + return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + -SUPPORT_STREAM_MODELS = { - "amazon.titan-tg1-large", - "amazon.titan-text-lite-v1:0:4k", - "amazon.titan-text-lite-v1", - "amazon.titan-text-express-v1:0:8k", - "amazon.titan-text-express-v1", - "anthropic.claude-instant-v1:2:100k", - "anthropic.claude-instant-v1", - "anthropic.claude-v2:0:18k", - "anthropic.claude-v2:0:100k", - "anthropic.claude-v2:1:18k", - "anthropic.claude-v2:1:200k", - "anthropic.claude-v2:1", - "anthropic.claude-v2:2:18k", - "anthropic.claude-v2:2:200k", - "anthropic.claude-v2:2", - "anthropic.claude-v2", - "anthropic.claude-3-sonnet-20240229-v1:0:28k", - "anthropic.claude-3-sonnet-20240229-v1:0:200k", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0:48k", - "anthropic.claude-3-haiku-20240307-v1:0:200k", - "anthropic.claude-3-haiku-20240307-v1:0", - "cohere.command-text-v14:7:4k", - "cohere.command-text-v14", - "cohere.command-light-text-v14:7:4k", - "cohere.command-light-text-v14", - "meta.llama2-70b-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", -} From 0cdca1b642175777a119c1ac8934ce29f85fc572 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 01:05:04 +0800 Subject: [PATCH 14/32] fix llama3 chat template bug --- .../provider/bedrock/amazon_bedrock_api.py | 1 - metagpt/provider/bedrock/bedrock_provider.py | 36 ++++++++++++------- metagpt/provider/bedrock/utils.py | 21 +++++++++-- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index e4382f7bd..a5cacec8c 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -96,4 +96,3 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) - diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 2aa90c7ee..729697ad7 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,13 +1,14 @@ import json +from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 class MistralProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html def messages_to_prompt(self, messages: list[dict]): - return messages_to_prompt_llama(messages) + return messages_to_prompt_llama2(messages) def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["outputs"][0]["text"] @@ -36,8 +37,14 @@ class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html max_tokens_field_name = "max_gen_len" + def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None: + self.llama_version = llama_version + def messages_to_prompt(self, messages: list[dict]): - return messages_to_prompt_llama(messages) + if self.llama_version == "llama2": + return messages_to_prompt_llama2(messages) + else: + return messages_to_prompt_llama3(messages) def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generation"] @@ -72,17 +79,20 @@ class AmazonProvider(BaseBedrockProvider): PROVIDERS = { - "mistral": MistralProvider(), - "meta": MetaProvider(), - "ai21": Ai21Provider(), - "cohere": CohereProvider(), - "anthropic": AnthropicProvider(), - "amazon": AmazonProvider() + "mistral": MistralProvider, + "meta": MetaProvider, + "ai21": Ai21Provider, + "cohere": CohereProvider, + "anthropic": AnthropicProvider, + "amazon": AmazonProvider } def get_provider(model_id: str): - model_name = model_id.split(".")[0] # meta、mistral…… - if model_name not in PROVIDERS: - raise KeyError(f"{model_name} is not supported!") - return PROVIDERS[model_name] + provider, model_name = model_id.split(".")[0:2] # meta、mistral…… + if provider not in PROVIDERS: + raise KeyError(f"{provider} is not supported!") + if provider == "meta": + # distinguish llama2 and llama3 + return PROVIDERS[provider](model_name[:6]) + return PROVIDERS[provider]() diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 2df0bf163..61778e8e8 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -29,8 +29,10 @@ SUPPORT_STREAM_MODELS = { "mistral.mistral-large-2402-v1:0": 32000, } +# TODO:use a general function for constructing chat templates. -def messages_to_prompt_llama(messages: list[dict]): + +def messages_to_prompt_llama2(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -53,7 +55,20 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt +def messages_to_prompt_llama3(messages: list[dict]): + BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" + GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" + + prompt = f"{BOS}" + for message in messages: + role = message["role"] + content = message["content"] + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": + prompt += f"<|start_header_id|>assistant<|end_header_id|>" + + return prompt + + def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] - - From e9723f4955ee1c5c8b3f322e0515dc1a7ab2cac5 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 01:16:54 +0800 Subject: [PATCH 15/32] add claude chat template --- metagpt/provider/bedrock/bedrock_provider.py | 4 +++- metagpt/provider/bedrock/utils.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 729697ad7..850b57c4f 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,7 @@ import json from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude class MistralProvider(BaseBedrockProvider): @@ -16,6 +16,8 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + def messages_to_prompt(self, messages: list[dict]) -> str: + return messages_to_prompt_claude(messages) def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 61778e8e8..47a23caeb 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -70,5 +70,18 @@ def messages_to_prompt_llama3(messages: list[dict]): return prompt +def messages_to_prompt_claude(messages: list[dict]): + GENERAL_TEMPLATE = "\n\n{role}: {content}" + prompt = "" + for message in messages: + role = message["role"] + content = message["content"] + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": + prompt += f"\n\nAssistant:" + return prompt + + def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + From 6561c7aa7e03756b57ed809b8f97e44e32ffb0d9 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 01:29:43 +0800 Subject: [PATCH 16/32] add docs --- .../provider/bedrock/amazon_bedrock_api.py | 27 ++++++++++++++++--- metagpt/provider/bedrock/utils.py | 7 +++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index a5cacec8c..640be3534 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -11,14 +11,21 @@ import boto3 @register_provider([LLMType.AMAZON_BEDROCK]) class AmazonBedrockLLM(BaseLLM): + """ + check out: + https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html + """ + def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) - logger.warning("Amazon bedrock doesn't support async now") + logger.warning( + "Amazon bedrock doesn't support asynchronous calls now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): - # access key from https://us-east-1.console.aws.amazon.com/iam + """initialize boto3 client""" + # access key and secret key from https://us-east-1.console.aws.amazon.com/iam self.__credentital_kwards = { "aws_secret_access_key": self.config.secret_key, "aws_access_key_id": self.config.access_key, @@ -29,6 +36,14 @@ class AmazonBedrockLLM(BaseLLM): return client def list_models(self): + """list all available text-generation models + + ```shell + ai21.j2-ultra-v1 Support Streaming:False + meta.llama3-70b-instruct-v1:0 Support Streaming:True + …… + ``` + """ client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality='TEXT') @@ -38,12 +53,12 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self) -> dict: - # for now only use temperature due to the difference of request body model_max_tokens = get_max_tokens(self.config.model) if self.config.max_token > model_max_tokens: max_tokens = model_max_tokens else: max_tokens = self.config.max_token + return { self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature @@ -81,8 +96,12 @@ class AmazonBedrockLLM(BaseLLM): full_text = ("".join(collected_content)).lstrip() return full_text + # boto3 don't support support asynchronous calls. + # for asynchronous version of boto3,check out: + # https://aioboto3.readthedocs.io/en/latest/usage.html + # However,aioboto3 doesn't support invoke model + async def acompletion(self, messages: list[dict]): - # Amazon bedrock doesn't support async now return await self._achat_completion(messages) async def acompletion_text(self, messages: list[dict], stream: bool = False, diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 47a23caeb..80b7b82bd 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,5 +1,6 @@ from metagpt.logs import logger +# max_tokens for each model NOT_SUUPORT_STREAM_MODELS = { "ai21.j2-grande-instruct": 8000, "ai21.j2-jumbo-instruct": 8000, @@ -29,7 +30,7 @@ SUPPORT_STREAM_MODELS = { "mistral.mistral-large-2402-v1:0": 32000, } -# TODO:use a general function for constructing chat templates. +# TODO:use a more general function for constructing chat templates. def messages_to_prompt_llama2(messages: list[dict]): @@ -64,6 +65,7 @@ def messages_to_prompt_llama3(messages: list[dict]): role = message["role"] content = message["content"] prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": prompt += f"<|start_header_id|>assistant<|end_header_id|>" @@ -77,11 +79,12 @@ def messages_to_prompt_claude(messages: list[dict]): role = message["role"] content = message["content"] prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": prompt += f"\n\nAssistant:" + return prompt def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] - From cafe666bfd74c21f9df14db097904a91520b14d0 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 02:02:38 +0800 Subject: [PATCH 17/32] lazy installation --- metagpt/provider/bedrock/amazon_bedrock_api.py | 8 ++++++-- metagpt/provider/bedrock/bedrock_provider.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 640be3534..184e934fa 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -6,7 +6,11 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens -import boto3 +try: + import boto3 +except ImportError: + raise ImportError( + "boto3 not found! please install it by `pip install boto3` first ") @register_provider([LLMType.AMAZON_BEDROCK]) @@ -97,7 +101,7 @@ class AmazonBedrockLLM(BaseLLM): return full_text # boto3 don't support support asynchronous calls. - # for asynchronous version of boto3,check out: + # for asynchronous version of boto3, check out: # https://aioboto3.readthedocs.io/en/latest/usage.html # However,aioboto3 doesn't support invoke model diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 850b57c4f..375657f12 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -16,6 +16,7 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + def messages_to_prompt(self, messages: list[dict]) -> str: return messages_to_prompt_claude(messages) @@ -37,6 +38,7 @@ class CohereProvider(BaseBedrockProvider): class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + max_tokens_field_name = "max_gen_len" def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None: @@ -54,6 +56,7 @@ class MetaProvider(BaseBedrockProvider): class Ai21Provider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + max_tokens_field_name = "maxTokens" def _get_completion_from_dict(self, rsp_dict: dict) -> str: @@ -62,6 +65,7 @@ class Ai21Provider(BaseBedrockProvider): class AmazonProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + max_tokens_field_name = "maxTokenCount" def get_request_body(self, messages: list[dict], **generate_kwargs): From 4c394a1cac32fca2a526425a6600f3355e209aba Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 16:08:39 +0800 Subject: [PATCH 18/32] add test --- .../provider/bedrock/amazon_bedrock_api.py | 39 ++++++-- metagpt/provider/bedrock/base_provider.py | 7 +- metagpt/provider/bedrock/bedrock_provider.py | 5 +- metagpt/provider/bedrock/utils.py | 6 +- tests/metagpt/provider/mock_llm_config.py | 9 ++ tests/metagpt/provider/req_resp_const.py | 70 ++++++++++++-- .../provider/test_amazon_bedrock_api.py | 91 +++++++++++++++++++ 7 files changed, 198 insertions(+), 29 deletions(-) create mode 100644 tests/metagpt/provider/test_amazon_bedrock_api.py diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 184e934fa..07687c682 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,4 +1,5 @@ from typing import Literal +import json from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider from metagpt.configs.llm_config import LLMConfig, LLMType @@ -8,9 +9,10 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens try: import boto3 + from botocore.response import StreamingBody except ImportError: raise ImportError( - "boto3 not found! please install it by `pip install boto3` first ") + "boto3 not found! please install it by `pip install boto3` ") @register_provider([LLMType.AMAZON_BEDROCK]) @@ -25,7 +27,7 @@ class AmazonBedrockLLM(BaseLLM): self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) logger.warning( - "Amazon bedrock doesn't support asynchronous calls now") + "Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -39,6 +41,12 @@ class AmazonBedrockLLM(BaseLLM): client = session.client(service_name) return client + def _get_client(self): + return self.__client + + def _get_provider(self): + return self.__provider + def list_models(self): """list all available text-generation models @@ -55,6 +63,19 @@ class AmazonBedrockLLM(BaseLLM): for summary in response["modelSummaries"]] logger.info("\n"+"\n".join(summaries)) + def invoke_model(self, request_body) -> dict: + response = self.__client.invoke_model( + modelId=self.config.model, body=request_body + ) + response_body = self._get_response_body(response) + return response_body + + def invoke_model_with_response_stream(self, request_body) -> StreamingBody: + response = self.__client.invoke_model_with_response_stream( + modelId=self.config.model, body=request_body + ) + return response + @property def _generate_kwargs(self) -> dict: model_max_tokens = get_max_tokens(self.config.model) @@ -71,10 +92,8 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) - response = self.__client.invoke_model( - modelId=self.config.model, body=request_body - ) - completions = self.__provider.get_choice_text(response) + response_body = self.invoke_model(request_body) + completions = self.__provider.get_choice_text(response_body) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: @@ -86,9 +105,7 @@ class AmazonBedrockLLM(BaseLLM): request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) - response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body - ) + response = self.invoke_model_with_response_stream(request_body) collected_content = [] for event in response["body"]: @@ -119,3 +136,7 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) + + def _get_response_body(self, response) -> dict: + response_body = json.loads(response["body"].read()) + return response_body diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index c24556645..449e0f5c8 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -15,8 +15,7 @@ class BaseBedrockProvider(ABC): {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) + def get_choice_text(self, response_body: dict) -> str: completions = self._get_completion_from_dict(response_body) return completions @@ -25,10 +24,6 @@ class BaseBedrockProvider(ABC): completions = self._get_completion_from_dict(rsp_dict) return completions - def _get_response_body_json(self, response): - response_body = json.loads(response["body"].read()) - return response_body - def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 375657f12..e01184705 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,7 @@ import json from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 class MistralProvider(BaseBedrockProvider): @@ -17,9 +17,6 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def messages_to_prompt(self, messages: list[dict]) -> str: - return messages_to_prompt_claude(messages) - def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 80b7b82bd..f69d04a7b 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -33,7 +33,7 @@ SUPPORT_STREAM_MODELS = { # TODO:use a more general function for constructing chat templates. -def messages_to_prompt_llama2(messages: list[dict]): +def messages_to_prompt_llama2(messages: list[dict]) -> str: BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -56,7 +56,7 @@ def messages_to_prompt_llama2(messages: list[dict]): return prompt -def messages_to_prompt_llama3(messages: list[dict]): +def messages_to_prompt_llama3(messages: list[dict]) -> str: BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" @@ -72,7 +72,7 @@ def messages_to_prompt_llama3(messages: list[dict]): return prompt -def messages_to_prompt_claude(messages: list[dict]): +def messages_to_prompt_claude2(messages: list[dict]) -> str: GENERAL_TEMPLATE = "\n\n{role}: {content}" prompt = "" for message in messages: diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 0c56cc8ea..8660bc24f 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model mock_llm_config_anthropic = LLMConfig( api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229" ) + +mock_llm_config_bedrock = LLMConfig( + api_type="amazon_bedrock", + model="gpt-100", + region_name="somewhere", + access_key="123abc", + secret_key="123abc", + max_token=10000, +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 7e4c1a49c..6a244cbe4 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -65,17 +65,20 @@ def get_openai_chat_completion(name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), + message=ChatCompletionMessage( + role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], - usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), + usage=CompletionUsage(completion_tokens=110, + prompt_tokens=92, total_tokens=202), ) return openai_chat_completion def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: - usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) + usage = CompletionUsage(completion_tokens=110, + prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( id="cmpl-a6652c1bb181caae8dd19ad8", @@ -84,7 +87,8 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), + delta=ChoiceDelta(role="assistant", + content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -133,7 +137,8 @@ def get_dashscope_response(name: str) -> GenerationResponse: ], } ), - usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), + usage=GenerationUsage( + **{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), ) ) @@ -155,7 +160,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: ), ContentBlockDeltaEvent( index=0, - delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"), + delta=TextDelta(text=resp_cont_tmpl.format( + name=name), type="text_delta"), type="content_block_delta", ), ] @@ -165,7 +171,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: model=name, role="assistant", type="message", - content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")], + content=[ContentBlock( + text=resp_cont_tmpl.format(name=name), type="text")], usage=AnthropicUsage(input_tokens=10, output_tokens=10), ) @@ -183,3 +190,52 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ resp = await llm.acompletion_text(messages, stream=True) assert resp == resp_cont + + +# For Amazon Bedrock +BEDROCK_PROVIDER_REQUEST_BODY = { + "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, + "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, + "ai21": { + "prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0, + "stopSequences": [], "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0} + }, + "cohere": { + "prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [], + "return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE" + }, + "anthropic": { + "anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "", + "messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": [] + }, + "amazon": { + "inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []} + } +} + +BEDROCK_PROVIDER_RESPONSE_BODY = { + "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, + + "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, + + "ai21": { + "id": "", "prompt": {"text": "Hello World", "tokens": []}, + "completions": [{"data": {"text": "Hello World", "tokens": []}, + "finishReason": {"reason": "length", "length": 2}}] + }, + "cohere": { + "generations": [ + {"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": "" + }, + + "anthropic": { + "id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}], + "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} + }, + + "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} +} + +BEDROCK_PROVIDER_STREAM_RESPONSE = {} diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py new file mode 100644 index 000000000..0b8206463 --- /dev/null +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -0,0 +1,91 @@ +import pytest +import json +from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock +from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS +from tests.metagpt.provider.req_resp_const import ( + BEDROCK_PROVIDER_REQUEST_BODY, + BEDROCK_PROVIDER_RESPONSE_BODY, + BEDROCK_PROVIDER_STREAM_RESPONSE) +from botocore.response import StreamingBody + +# all available model from bedrock +models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +messages = [{"role": "user", "content": "Hi!"}] + + +def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: + provider = self.config.model.split(".")[0] + return BEDROCK_PROVIDER_RESPONSE_BODY[provider] + + +def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody: + provider = self.config.model.split(".")[0] + response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider] + return + + +def get_bedrock_request_body(model_id) -> dict: + provider = model_id.split(".")[0] + return BEDROCK_PROVIDER_REQUEST_BODY[provider] + + +def is_subset(subset, superset): + """Ensure all fields in request body are allowed. + + ```python + subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}} + superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}} + is_subset(subset, superset) + ``` + >>>False + """ + for key, value in subset.items(): + if key not in superset: + return False + if isinstance(value, dict): + if not isinstance(superset[key], dict): + return False + if not is_subset(value, superset[key]): + return False + return True + + +@ pytest.fixture(scope="class", params=models) +def bedrock_api(request) -> AmazonBedrockLLM: + model_id = request.param + mock_llm_config_bedrock.model = model_id + api = AmazonBedrockLLM(mock_llm_config_bedrock) + return api + + +class TestAPI: + def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): + provider = bedrock_api._get_provider() + assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( + bedrock_api.config.model) + + def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + provider = bedrock_api._get_provider() + request_body = json.loads(provider.get_request_body( + messages, **bedrock_api._generate_kwargs)) + print(get_bedrock_request_body( + bedrock_api.config.model)) + print(request_body) + + assert is_subset(request_body, get_bedrock_request_body( + bedrock_api.config.model)) + + def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + mocker.patch( + "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) + assert bedrock_api.completion(messages) == "Hello World" + + # def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + # mocker.patch( + # "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response) + # assert bedrock_api._chat_completion_stream(messages) == "Hello World" + + +if __name__ == '__main__': + print(get_bedrock_request_body("amazon.titan-tg1-large")) From 8fafa2eb4ed6c0492f87b3caa7c4347001737158 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 17:03:06 +0800 Subject: [PATCH 19/32] stream test --- .../provider/bedrock/amazon_bedrock_api.py | 5 +-- metagpt/provider/bedrock/bedrock_provider.py | 7 +--- tests/metagpt/provider/req_resp_const.py | 6 +-- .../provider/test_amazon_bedrock_api.py | 41 ++++++++++--------- 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 07687c682..3d1b08f47 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -9,7 +9,7 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens try: import boto3 - from botocore.response import StreamingBody + from botocore.eventstream import EventStream except ImportError: raise ImportError( "boto3 not found! please install it by `pip install boto3` ") @@ -70,7 +70,7 @@ class AmazonBedrockLLM(BaseLLM): response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body) -> StreamingBody: + def invoke_model_with_response_stream(self, request_body) -> EventStream: response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) @@ -106,7 +106,6 @@ class AmazonBedrockLLM(BaseLLM): messages, **self._generate_kwargs) response = self.invoke_model_with_response_stream(request_body) - collected_content = [] for event in response["body"]: chunk_text = self.__provider.get_choice_text_from_stream(event) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index e01184705..a4b90a82f 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -73,12 +73,7 @@ class AmazonProvider(BaseBedrockProvider): return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['results'][0]['outputText'].strip() - - def get_choice_text_from_stream(self, event) -> str: - rsp_dict = json.loads(event["chunk"]["bytes"]) - completions = rsp_dict["outputText"] - return completions + return rsp_dict['results'][0]['outputText'] PROVIDERS = { diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 6a244cbe4..893c33704 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -193,6 +193,8 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ # For Amazon Bedrock +# Check the API documentation of each model +# https://docs.aws.amazon.com/bedrock/latest/userguide BEDROCK_PROVIDER_REQUEST_BODY = { "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, @@ -236,6 +238,4 @@ BEDROCK_PROVIDER_RESPONSE_BODY = { }, "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} -} - -BEDROCK_PROVIDER_STREAM_RESPONSE = {} +} \ No newline at end of file diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index 0b8206463..bf8e467e7 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -3,14 +3,10 @@ import json from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS -from tests.metagpt.provider.req_resp_const import ( - BEDROCK_PROVIDER_REQUEST_BODY, - BEDROCK_PROVIDER_RESPONSE_BODY, - BEDROCK_PROVIDER_STREAM_RESPONSE) -from botocore.response import StreamingBody +from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY # all available model from bedrock -models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS) messages = [{"role": "user", "content": "Hi!"}] @@ -19,10 +15,15 @@ def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody: +def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: + # use json object to mock EventStream provider = self.config.model.split(".")[0] - response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider] - return + response_body_bytes = json.dumps( + BEDROCK_PROVIDER_RESPONSE_BODY[provider]).encode("utf-8") + # decoded bytes share the same format as non-stream response_body + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}},]} + return response_body_stream def get_bedrock_request_body(model_id) -> dict: @@ -51,7 +52,7 @@ def is_subset(subset, superset): return True -@ pytest.fixture(scope="class", params=models) +@pytest.fixture(scope="class", params=models) def bedrock_api(request) -> AmazonBedrockLLM: model_id = request.param mock_llm_config_bedrock.model = model_id @@ -66,6 +67,7 @@ class TestAPI: bedrock_api.config.model) def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + """Ensure request body has correct format""" provider = bedrock_api._get_provider() request_body = json.loads(provider.get_request_body( messages, **bedrock_api._generate_kwargs)) @@ -77,15 +79,14 @@ class TestAPI: bedrock_api.config.model)) def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch( - "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mock_bedrock_provider_response) assert bedrock_api.completion(messages) == "Hello World" - # def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - # mocker.patch( - # "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response) - # assert bedrock_api._chat_completion_stream(messages) == "Hello World" - - -if __name__ == '__main__': - print(get_bedrock_request_body("amazon.titan-tg1-large")) + def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response) + assert bedrock_api._chat_completion_stream( + messages) == "Hello World" From 83c8ccb6b9733f61b799f3104c7d504c3f51fe9e Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 17:15:24 +0800 Subject: [PATCH 20/32] update --- metagpt/provider/bedrock/amazon_bedrock_api.py | 5 ----- tests/metagpt/provider/test_amazon_bedrock_api.py | 8 ++------ 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 3d1b08f47..31862e2d5 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -17,11 +17,6 @@ except ImportError: @register_provider([LLMType.AMAZON_BEDROCK]) class AmazonBedrockLLM(BaseLLM): - """ - check out: - https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html - """ - def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index bf8e467e7..46d44e81b 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -31,7 +31,7 @@ def get_bedrock_request_body(model_id) -> dict: return BEDROCK_PROVIDER_REQUEST_BODY[provider] -def is_subset(subset, superset): +def is_subset(subset, superset) -> bool: """Ensure all fields in request body are allowed. ```python @@ -71,9 +71,6 @@ class TestAPI: provider = bedrock_api._get_provider() request_body = json.loads(provider.get_request_body( messages, **bedrock_api._generate_kwargs)) - print(get_bedrock_request_body( - bedrock_api.config.model)) - print(request_body) assert is_subset(request_body, get_bedrock_request_body( bedrock_api.config.model)) @@ -88,5 +85,4 @@ class TestAPI: mock_bedrock_provider_response) mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_stream_response) - assert bedrock_api._chat_completion_stream( - messages) == "Hello World" + assert bedrock_api._chat_completion_stream(messages) == "Hello World" From b0ed292fa751e89b458e3214bdcdbf9d00d6a730 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 17:30:03 +0800 Subject: [PATCH 21/32] update test --- metagpt/provider/bedrock/bedrock_provider.py | 5 +++++ tests/metagpt/provider/test_amazon_bedrock_api.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index a4b90a82f..29cdf38a9 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -74,6 +74,11 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['results'][0]['outputText'] + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict["outputText"] + return completions PROVIDERS = { diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index 46d44e81b..c10b74d0a 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -17,12 +17,17 @@ def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: # use json object to mock EventStream + def dict2bytes(x): + return json.dumps(x).encode("utf-8") provider = self.config.model.split(".")[0] - response_body_bytes = json.dumps( - BEDROCK_PROVIDER_RESPONSE_BODY[provider]).encode("utf-8") - # decoded bytes share the same format as non-stream response_body - response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}},]} + response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + # decoded bytes share the same format as non-stream response_body except for titan + if provider == "amazon": + response_body_stream = { + "body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]} + else: + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}}]} return response_body_stream From 6776e5cc18ce0ef4fd887cef4d4c949edb52dfe4 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <93753250+usamimeri@users.noreply.github.com> Date: Fri, 26 Apr 2024 23:12:19 +0800 Subject: [PATCH 22/32] Update amazon_bedrock_api.py fix naming --- metagpt/provider/bedrock/amazon_bedrock_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 31862e2d5..aad58f884 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -27,12 +27,12 @@ class AmazonBedrockLLM(BaseLLM): def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" # access key and secret key from https://us-east-1.console.aws.amazon.com/iam - self.__credentital_kwards = { + self.__credentital_kwargs = { "aws_secret_access_key": self.config.secret_key, "aws_access_key_id": self.config.access_key, "region_name": self.config.region_name } - session = boto3.Session(**self.__credentital_kwards) + session = boto3.Session(**self.__credentital_kwargs) client = session.client(service_name) return client From a05d25757cd0ca2e924c44fa8daffb1160cb1412 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Sat, 27 Apr 2024 12:18:35 +0800 Subject: [PATCH 23/32] fix claude streaming bug --- metagpt/provider/bedrock/bedrock_provider.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 29cdf38a9..47d699313 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -25,6 +25,15 @@ class AnthropicProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["content"][0]["text"] + def get_choice_text_from_stream(self, event) -> str: + # https://docs.anthropic.com/claude/reference/messages-streaming + rsp_dict = json.loads(event["chunk"]["bytes"]) + if rsp_dict["type"] == "content_block_delta": + completions = rsp_dict["delta"]["text"] + return completions + else: + return "" + class CohereProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html @@ -74,7 +83,7 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['results'][0]['outputText'] - + def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict["outputText"] From 98cb45291112f2f83d8495c2c1e7746d6c6bef5b Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Sat, 27 Apr 2024 13:09:04 +0800 Subject: [PATCH 24/32] remove opus since unavailable now and fix bug --- .../provider/bedrock/amazon_bedrock_api.py | 4 ++-- metagpt/provider/bedrock/base_provider.py | 2 +- metagpt/provider/bedrock/bedrock_provider.py | 14 +++++++++++-- metagpt/provider/bedrock/utils.py | 17 +++++++++++---- .../provider/test_amazon_bedrock_api.py | 21 ++++++++++++------- 5 files changed, 42 insertions(+), 16 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index aad58f884..6230cd3f2 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -86,7 +86,7 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( - messages, **self._generate_kwargs) + messages, self._generate_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions @@ -98,7 +98,7 @@ class AmazonBedrockLLM(BaseLLM): return self.completion(messages) request_body = self.__provider.get_request_body( - messages, **self._generate_kwargs) + messages, self._generate_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 449e0f5c8..5aa6ae605 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -10,7 +10,7 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 47d699313..01bcaac53 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -17,7 +17,7 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body @@ -41,6 +41,16 @@ class CohereProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generations"][0]["text"] + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + body = json.dumps( + {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}) + return body + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict.get("text", "") + return completions + class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html @@ -74,7 +84,7 @@ class AmazonProvider(BaseBedrockProvider): max_tokens_field_name = "maxTokenCount" - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps({ "inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index f69d04a7b..a87367b22 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -14,15 +14,24 @@ SUPPORT_STREAM_MODELS = { "amazon.titan-tg1-large": 8000, "amazon.titan-text-express-v1": 8000, "anthropic.claude-instant-v1": 100000, + "anthropic.claude-instant-v1:2:100k": 100000, "anthropic.claude-v1": 100000, "anthropic.claude-v2": 100000, "anthropic.claude-v2:1": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000, - "anthropic.claude-3-opus-20240229-v1:0": 200000, - "cohere.command-text-v14": 4096, - "cohere.command-light-text-v14": 4096, - "meta.llama2-70b-v1": 4096, + "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, + "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, + "cohere.command-text-v14": 4000, + "cohere.command-text-v14:7:4k": 4000, + "cohere.command-light-text-v14": 4000, + "cohere.command-light-text-v14:7:4k": 4000, + "meta.llama2-70b-v1": 4000, + "meta.llama2-13b-chat-v1:0:4k": 4000, + "meta.llama2-13b-chat-v1": 2000, + "meta.llama2-70b-v1:0:4k": 4000, "meta.llama3-8b-instruct-v1:0": 2000, "meta.llama3-70b-instruct-v1:0": 2000, "mistral.mistral-7b-instruct-v0:2": 32000, diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index c10b74d0a..cd13c0b24 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -20,14 +20,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: def dict2bytes(x): return json.dumps(x).encode("utf-8") provider = self.config.model.split(".")[0] - response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) - # decoded bytes share the same format as non-stream response_body except for titan + if provider == "amazon": - response_body_stream = { - "body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]} + response_body_bytes = dict2bytes({"outputText": "Hello World"}) + elif provider == "anthropic": + response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0, + "delta": {"type": "text_delta", "text": "Hello World"}}) + elif provider == "cohere": + response_body_bytes = dict2bytes( + {"is_finished": False, "text": "Hello World"}) else: - response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}}]} + response_body_bytes = dict2bytes( + BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}}]} return response_body_stream @@ -75,7 +82,7 @@ class TestAPI: """Ensure request body has correct format""" provider = bedrock_api._get_provider() request_body = json.loads(provider.get_request_body( - messages, **bedrock_api._generate_kwargs)) + messages, bedrock_api._generate_kwargs)) assert is_subset(request_body, get_bedrock_request_body( bedrock_api.config.model)) From 79251cd3cdf2b3dd5a1a8b27a53ae771a4682ef3 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Sun, 28 Apr 2024 19:18:12 +0800 Subject: [PATCH 25/32] resolve problems --- metagpt/configs/llm_config.py | 2 +- metagpt/provider/__init__.py | 2 +- metagpt/provider/bedrock/.gitignore | 192 ------------------ metagpt/provider/bedrock/bedrock_provider.py | 4 +- .../amazon_bedrock_api.py => bedrock_api.py} | 17 +- requirements.txt | 1 + tests/metagpt/provider/mock_llm_config.py | 2 +- tests/metagpt/provider/req_resp_const.py | 7 +- ...zon_bedrock_api.py => test_bedrock_api.py} | 14 +- 9 files changed, 25 insertions(+), 216 deletions(-) delete mode 100644 metagpt/provider/bedrock/.gitignore rename metagpt/provider/{bedrock/amazon_bedrock_api.py => bedrock_api.py} (94%) rename tests/metagpt/provider/{test_amazon_bedrock_api.py => test_bedrock_api.py} (86%) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index ae8c57ec5..41e04ab44 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -32,7 +32,7 @@ class LLMType(Enum): MISTRAL = "mistral" YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" - AMAZON_BEDROCK = "amazon_bedrock" + BEDROCK = "bedrock" def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index dd5b4f89d..1311ccf61 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import AmazonBedrockLLM __all__ = [ "GeminiLLM", diff --git a/metagpt/provider/bedrock/.gitignore b/metagpt/provider/bedrock/.gitignore deleted file mode 100644 index 971fcecb7..000000000 --- a/metagpt/provider/bedrock/.gitignore +++ /dev/null @@ -1,192 +0,0 @@ -### Python template - -# Byte-compiled / optimized / DLL files -__pycache__ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST -metagpt/tools/schemas/ -examples/data/search_kb/*.json - -# PyInstaller -# Usually these files are written by a python scripts from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ -unittest.txt - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -logs -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# report -allure-report -allure-results - -# idea / vscode / macos -.idea -.DS_Store -.vscode - -key.yaml -/data/ -data.ms -examples/nb/ -examples/default__vector_store.json -examples/docstore.json -examples/graph_store.json -examples/image__vector_store.json -examples/index_store.json -.chroma -*~$* -workspace/* -tmp -metagpt/roles/idea_agent.py -.aider* -*.bak -*.bk - -# output folder -output -tmp.png -.dependencies.json -tests/metagpt/utils/file_repo_git -tests/data/rsp_cache_new.json -*.tmp -*.png -htmlcov -htmlcov.* -cov.xml -*.dot -*.pkl -*.faiss -*-structure.csv -*-structure.json -*.dot -.python-version -# aws access key -config.py \ No newline at end of file diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 01bcaac53..6378939c9 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -76,7 +76,7 @@ class Ai21Provider(BaseBedrockProvider): max_tokens_field_name = "maxTokens" def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['completions'][0]["data"]["text"] + return rsp_dict["completions"][0]["data"]["text"] class AmazonProvider(BaseBedrockProvider): @@ -92,7 +92,7 @@ class AmazonProvider(BaseBedrockProvider): return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['results'][0]['outputText'] + return rsp_dict["results"][0]["outputText"] def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock_api.py similarity index 94% rename from metagpt/provider/bedrock/amazon_bedrock_api.py rename to metagpt/provider/bedrock_api.py index 6230cd3f2..ac0f2e505 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -15,7 +15,7 @@ except ImportError: "boto3 not found! please install it by `pip install boto3` ") -@register_provider([LLMType.AMAZON_BEDROCK]) +@register_provider([LLMType.BEDROCK]) class AmazonBedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config @@ -36,10 +36,12 @@ class AmazonBedrockLLM(BaseLLM): client = session.client(service_name) return client - def _get_client(self): + @property + def client(self): return self.__client - def _get_provider(self): + @property + def provider(self): return self.__provider def list_models(self): @@ -53,22 +55,21 @@ class AmazonBedrockLLM(BaseLLM): """ client = self.__init_client("bedrock") # only output text-generation models - response = client.list_foundation_models(byOutputModality='TEXT') + response = client.list_foundation_models(byOutputModality="TEXT") summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' for summary in response["modelSummaries"]] logger.info("\n"+"\n".join(summaries)) - def invoke_model(self, request_body) -> dict: + def invoke_model(self, request_body: str) -> dict: response = self.__client.invoke_model( modelId=self.config.model, body=request_body ) response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body) -> EventStream: + def invoke_model_with_response_stream(self, request_body: str) -> EventStream: response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body - ) + modelId=self.config.model, body=request_body) return response @property diff --git a/requirements.txt b/requirements.txt index 6c219a9dc..76c158115 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,3 +70,4 @@ qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 +boto3==1.34.92 \ No newline at end of file diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 8660bc24f..8f2baea10 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -62,7 +62,7 @@ mock_llm_config_anthropic = LLMConfig( ) mock_llm_config_bedrock = LLMConfig( - api_type="amazon_bedrock", + api_type="bedrock", model="gpt-100", region_name="somewhere", access_key="123abc", diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 893c33704..fb754abf7 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -160,8 +160,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: ), ContentBlockDeltaEvent( index=0, - delta=TextDelta(text=resp_cont_tmpl.format( - name=name), type="text_delta"), + delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"), type="content_block_delta", ), ] @@ -237,5 +236,5 @@ BEDROCK_PROVIDER_RESPONSE_BODY = { "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} }, - "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} -} \ No newline at end of file + "amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]} +} diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py similarity index 86% rename from tests/metagpt/provider/test_amazon_bedrock_api.py rename to tests/metagpt/provider/test_bedrock_api.py index cd13c0b24..7e282db78 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,6 +1,6 @@ import pytest import json -from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import AmazonBedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY @@ -34,7 +34,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: BEDROCK_PROVIDER_RESPONSE_BODY[provider]) response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}}]} + "body": [{"chunk": {"bytes": response_body_bytes}}]} return response_body_stream @@ -74,13 +74,13 @@ def bedrock_api(request) -> AmazonBedrockLLM: class TestAPI: def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): - provider = bedrock_api._get_provider() + provider = bedrock_api.provider assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( bedrock_api.config.model) def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): """Ensure request body has correct format""" - provider = bedrock_api._get_provider() + provider = bedrock_api.provider request_body = json.loads(provider.get_request_body( messages, bedrock_api._generate_kwargs)) @@ -88,13 +88,13 @@ class TestAPI: bedrock_api.config.model)) def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) assert bedrock_api.completion(messages) == "Hello World" def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_stream_response) assert bedrock_api._chat_completion_stream(messages) == "Hello World" From cb8e14dca8448dcfa984cf7ec7f0032e0c535e1c Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Sun, 28 Apr 2024 19:30:09 +0800 Subject: [PATCH 26/32] add token counts --- metagpt/provider/bedrock_api.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index ac0f2e505..16b50d996 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -64,12 +64,16 @@ class AmazonBedrockLLM(BaseLLM): response = self.__client.invoke_model( modelId=self.config.model, body=request_body ) + usage = self._get_usage(response) + self._update_costs(usage) response_body = self._get_response_body(response) return response_body def invoke_model_with_response_stream(self, request_body: str) -> EventStream: response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body) + usage = self._get_usage(response) + self._update_costs(usage) return response @property @@ -135,3 +139,13 @@ class AmazonBedrockLLM(BaseLLM): def _get_response_body(self, response) -> dict: response_body = json.loads(response["body"].read()) return response_body + + def _get_usage(self, response) -> dict[str, int]: + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + }, + return usage From aa8e123a8e152a5a4156c3ce3e892d644e3f983c Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Sun, 28 Apr 2024 19:31:59 +0800 Subject: [PATCH 27/32] delete cache --- config/puppeteer-config.json | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 config/puppeteer-config.json diff --git a/config/puppeteer-config.json b/config/puppeteer-config.json deleted file mode 100644 index 7b2851c29..000000000 --- a/config/puppeteer-config.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "executablePath": "/usr/bin/chromium", - "args": [ - "--no-sandbox" - ] -} \ No newline at end of file From 986e3c827e29895595c857eebfc64a806550aa33 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 09:37:30 +0800 Subject: [PATCH 28/32] fix --- config/puppeteer-config.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 config/puppeteer-config.json diff --git a/config/puppeteer-config.json b/config/puppeteer-config.json new file mode 100644 index 000000000..b74a514e7 --- /dev/null +++ b/config/puppeteer-config.json @@ -0,0 +1,4 @@ +{ + "executablePath": "/usr/bin/chromium", + "args": ["--no-sandbox"] +} From 3f108abd062904c27b25766f37696f2dc3e68f64 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 10:46:50 +0800 Subject: [PATCH 29/32] rename bedrock class and add more tests --- metagpt/provider/__init__.py | 4 +- metagpt/provider/bedrock/base_provider.py | 4 +- metagpt/provider/bedrock_api.py | 22 +++------- tests/metagpt/provider/test_bedrock_api.py | 50 ++++++++++++++-------- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 1311ccf61..fc91c7460 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import BedrockLLM __all__ = [ "GeminiLLM", @@ -31,5 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", - "AmazonBedrockLLM" + "BedrockLLM" ] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 5aa6ae605..fd0508cb0 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -10,9 +10,9 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str: body = json.dumps( - {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) + {"prompt": self.messages_to_prompt(messages), **const_kwargs}) return body def get_choice_text(self, response_body: dict) -> str: diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 16b50d996..b07520fd1 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -7,16 +7,12 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens -try: - import boto3 - from botocore.eventstream import EventStream -except ImportError: - raise ImportError( - "boto3 not found! please install it by `pip install boto3` ") +import boto3 +from botocore.eventstream import EventStream @register_provider([LLMType.BEDROCK]) -class AmazonBedrockLLM(BaseLLM): +class BedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") @@ -77,7 +73,7 @@ class AmazonBedrockLLM(BaseLLM): return response @property - def _generate_kwargs(self) -> dict: + def _const_kwargs(self) -> dict: model_max_tokens = get_max_tokens(self.config.model) if self.config.max_token > model_max_tokens: max_tokens = model_max_tokens @@ -91,7 +87,7 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( - messages, self._generate_kwargs) + messages, self._const_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions @@ -103,7 +99,7 @@ class AmazonBedrockLLM(BaseLLM): return self.completion(messages) request_body = self.__provider.get_request_body( - messages, self._generate_kwargs, stream=True) + messages, self._const_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] @@ -124,12 +120,6 @@ class AmazonBedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]): return await self._achat_completion(messages) - async def acompletion_text(self, messages: list[dict], stream: bool = False, - timeout: int = USE_CONFIG_TIMEOUT) -> str: - if stream: - return await self._achat_completion_stream(messages) - return await self._achat_completion(messages) - async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self.completion(messages) diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 7e282db78..7b2797da0 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,6 +1,6 @@ import pytest import json -from metagpt.provider.bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY @@ -65,36 +65,50 @@ def is_subset(subset, superset) -> bool: @pytest.fixture(scope="class", params=models) -def bedrock_api(request) -> AmazonBedrockLLM: +def bedrock_api(request) -> BedrockLLM: model_id = request.param mock_llm_config_bedrock.model = model_id - api = AmazonBedrockLLM(mock_llm_config_bedrock) + api = BedrockLLM(mock_llm_config_bedrock) return api -class TestAPI: - def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): +class TestBedrockAPI: + def _patch_invoke_model(self, mocker): + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) + + def _patch_invoke_model_stream(self, mocker): + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response) + + def test_const_kwargs(self, bedrock_api: BedrockLLM): provider = bedrock_api.provider - assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( + assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens( bedrock_api.config.model) - def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider request_body = json.loads(provider.get_request_body( - messages, bedrock_api._generate_kwargs)) + messages, bedrock_api._const_kwargs)) - assert is_subset(request_body, get_bedrock_request_body( - bedrock_api.config.model)) + assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) - def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", - mock_bedrock_provider_response) + def test_completion(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) assert bedrock_api.completion(messages) == "Hello World" - def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", - mock_bedrock_provider_response) - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response) + def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) + self._patch_invoke_model_stream(mocker) assert bedrock_api._chat_completion_stream(messages) == "Hello World" + + @pytest.mark.asyncio + async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model_stream(mocker) + self._patch_invoke_model(mocker) + assert await bedrock_api._achat_completion_stream(messages) == "Hello World" + + @pytest.mark.asyncio + async def test_acompletion(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) + assert await bedrock_api.acompletion(messages) == "Hello World" From f14a1f63ef93d035b77cb285c078bc511290c094 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 15:04:33 +0800 Subject: [PATCH 30/32] add pre-commit --- metagpt/configs/llm_config.py | 1 - metagpt/provider/__init__.py | 2 +- metagpt/provider/bedrock/base_provider.py | 3 +- metagpt/provider/bedrock/bedrock_provider.py | 19 ++-- metagpt/provider/bedrock_api.py | 62 ++++++------- tests/metagpt/provider/req_resp_const.py | 98 +++++++++++++------- tests/metagpt/provider/test_bedrock_api.py | 44 +++++---- 7 files changed, 132 insertions(+), 97 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 41e04ab44..e202150f7 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,4 +100,3 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT - diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index fc91c7460..fcb5fa32a 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -31,5 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", - "BedrockLLM" + "BedrockLLM", ] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index fd0508cb0..0d13ae938 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -11,8 +11,7 @@ class BaseBedrockProvider(ABC): ... def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str: - body = json.dumps( - {"prompt": self.messages_to_prompt(messages), **const_kwargs}) + body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs}) return body def get_choice_text(self, response_body: dict) -> str: diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 6378939c9..ff1d88a47 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,11 @@ import json from typing import Literal + from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 +from metagpt.provider.bedrock.utils import ( + messages_to_prompt_llama2, + messages_to_prompt_llama3, +) class MistralProvider(BaseBedrockProvider): @@ -18,8 +22,7 @@ class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps( - {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: @@ -43,7 +46,8 @@ class CohereProvider(BaseBedrockProvider): def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( - {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}) + {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs} + ) return body def get_choice_text_from_stream(self, event) -> str: @@ -85,10 +89,7 @@ class AmazonProvider(BaseBedrockProvider): max_tokens_field_name = "maxTokenCount" def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps({ - "inputText": self.messages_to_prompt(messages), - "textGenerationConfig": generate_kwargs - }) + body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: @@ -106,7 +107,7 @@ PROVIDERS = { "ai21": Ai21Provider, "cohere": CohereProvider, "anthropic": AnthropicProvider, - "amazon": AmazonProvider + "amazon": AmazonProvider, } diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index b07520fd1..de3fbae94 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -1,15 +1,17 @@ -from typing import Literal import json -from metagpt.const import USE_CONFIG_TIMEOUT -from metagpt.provider.llm_provider_registry import register_provider -from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.provider.base_llm import BaseLLM -from metagpt.logs import log_llm_stream, logger -from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from typing import Literal + import boto3 from botocore.eventstream import EventStream +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.bedrock.bedrock_provider import get_provider +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from metagpt.provider.llm_provider_registry import register_provider + @register_provider([LLMType.BEDROCK]) class BedrockLLM(BaseLLM): @@ -17,8 +19,7 @@ class BedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) - logger.warning( - "Amazon bedrock doesn't support asynchronous now") + logger.warning("Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -26,7 +27,7 @@ class BedrockLLM(BaseLLM): self.__credentital_kwargs = { "aws_secret_access_key": self.config.secret_key, "aws_access_key_id": self.config.access_key, - "region_name": self.config.region_name + "region_name": self.config.region_name, } session = boto3.Session(**self.__credentital_kwargs) client = session.client(service_name) @@ -52,22 +53,21 @@ class BedrockLLM(BaseLLM): client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality="TEXT") - summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' - for summary in response["modelSummaries"]] - logger.info("\n"+"\n".join(summaries)) + summaries = [ + f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' + for summary in response["modelSummaries"] + ] + logger.info("\n" + "\n".join(summaries)) def invoke_model(self, request_body: str) -> dict: - response = self.__client.invoke_model( - modelId=self.config.model, body=request_body - ) + response = self.__client.invoke_model(modelId=self.config.model, body=request_body) usage = self._get_usage(response) self._update_costs(usage) response_body = self._get_response_body(response) return response_body def invoke_model_with_response_stream(self, request_body: str) -> EventStream: - response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body) + response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) usage = self._get_usage(response) self._update_costs(usage) return response @@ -80,26 +80,20 @@ class BedrockLLM(BaseLLM): else: max_tokens = self.config.max_token - return { - self.__provider.max_tokens_field_name: max_tokens, - "temperature": self.config.temperature - } + return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature} def completion(self, messages: list[dict]) -> str: - request_body = self.__provider.get_request_body( - messages, self._const_kwargs) + request_body = self.__provider.get_request_body(messages, self._const_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: - logger.warning( - f"model {self.config.model} doesn't support streaming output!") + logger.warning(f"model {self.config.model} doesn't support streaming output!") return self.completion(messages) - request_body = self.__provider.get_request_body( - messages, self._const_kwargs, stream=True) + request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] @@ -134,8 +128,10 @@ class BedrockLLM(BaseLLM): headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) - usage = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - }, + usage = ( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + }, + ) return usage diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index fb754abf7..111b57f91 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -65,20 +65,17 @@ def get_openai_chat_completion(name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - role="assistant", content=resp_cont_tmpl.format(name=name)), + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], - usage=CompletionUsage(completion_tokens=110, - prompt_tokens=92, total_tokens=202), + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) return openai_chat_completion def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: - usage = CompletionUsage(completion_tokens=110, - prompt_tokens=92, total_tokens=202) + usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( id="cmpl-a6652c1bb181caae8dd19ad8", @@ -87,8 +84,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", - content=resp_cont_tmpl.format(name=name)), + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -137,8 +133,7 @@ def get_dashscope_response(name: str) -> GenerationResponse: ], } ), - usage=GenerationUsage( - **{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), + usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), ) ) @@ -170,8 +165,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: model=name, role="assistant", type="message", - content=[ContentBlock( - text=resp_cont_tmpl.format(name=name), type="text")], + content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")], usage=AnthropicUsage(input_tokens=10, output_tokens=10), ) @@ -198,43 +192,81 @@ BEDROCK_PROVIDER_REQUEST_BODY = { "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, "ai21": { - "prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0, - "stopSequences": [], "countPenalty": {"scale": 0.0}, - "presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0} + "prompt": "", + "temperature": 0.0, + "topP": 0.0, + "maxTokens": 0, + "stopSequences": [], + "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, + "frequencyPenalty": {"scale": 0.0}, }, "cohere": { - "prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [], - "return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE" + "prompt": "", + "temperature": 0.0, + "p": 0.0, + "k": 0.0, + "max_tokens": 0, + "stop_sequences": [], + "return_likelihoods": "NONE", + "stream": False, + "num_generations": 0, + "logit_bias": {}, + "truncate": "NONE", }, "anthropic": { - "anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "", - "messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": [] + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 0, + "system": "", + "messages": [{"role": "", "content": ""}], + "temperature": 0.0, + "top_p": 0.0, + "top_k": 0, + "stop_sequences": [], }, "amazon": { - "inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []} - } + "inputText": "", + "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}, + }, } BEDROCK_PROVIDER_RESPONSE_BODY = { "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, - "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, - "ai21": { - "id": "", "prompt": {"text": "Hello World", "tokens": []}, - "completions": [{"data": {"text": "Hello World", "tokens": []}, - "finishReason": {"reason": "length", "length": 2}}] + "id": "", + "prompt": {"text": "Hello World", "tokens": []}, + "completions": [ + {"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}} + ], }, "cohere": { "generations": [ - {"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0, - "token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": "" + { + "finish_reason": "", + "id": "", + "text": "Hello World", + "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], + "is_finished": True, + "index": 0, + } + ], + "id": "", + "prompt": "", }, - "anthropic": { - "id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}], - "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} + "id": "", + "model": "", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello World"}], + "stop_reason": "", + "stop_sequence": "", + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + "amazon": { + "inputTextTokenCount": 0, + "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}], }, - - "amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]} } diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 7b2797da0..54ff5afa4 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,12 +1,21 @@ -import pytest import json + +import pytest + +from metagpt.provider.bedrock.utils import ( + NOT_SUUPORT_STREAM_MODELS, + SUPPORT_STREAM_MODELS, + get_max_tokens, +) from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock -from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS -from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY +from tests.metagpt.provider.req_resp_const import ( + BEDROCK_PROVIDER_REQUEST_BODY, + BEDROCK_PROVIDER_RESPONSE_BODY, +) # all available model from bedrock -models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS) +models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS messages = [{"role": "user", "content": "Hi!"}] @@ -19,22 +28,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") + provider = self.config.model.split(".")[0] if provider == "amazon": response_body_bytes = dict2bytes({"outputText": "Hello World"}) elif provider == "anthropic": - response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0, - "delta": {"type": "text_delta", "text": "Hello World"}}) + response_body_bytes = dict2bytes( + {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}} + ) elif provider == "cohere": - response_body_bytes = dict2bytes( - {"is_finished": False, "text": "Hello World"}) + response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"}) else: - response_body_bytes = dict2bytes( - BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) - response_body_stream = { - "body": [{"chunk": {"bytes": response_body_bytes}}]} + response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} return response_body_stream @@ -77,19 +85,19 @@ class TestBedrockAPI: mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) def _patch_invoke_model_stream(self, mocker): - mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response) + mocker.patch( + "metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response, + ) def test_const_kwargs(self, bedrock_api: BedrockLLM): provider = bedrock_api.provider - assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens( - bedrock_api.config.model) + assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model) def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider - request_body = json.loads(provider.get_request_body( - messages, bedrock_api._const_kwargs)) + request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs)) assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) From 0006b6290170fe91d5be0516a6c9572e94bdf0a4 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 17:29:41 +0800 Subject: [PATCH 31/32] resolve problem and add cost manager --- metagpt/provider/bedrock/utils.py | 43 +++++++++++++------- metagpt/provider/bedrock_api.py | 44 ++++++++++---------- metagpt/utils/token_counter.py | 47 ++++++++++++++++++++++ tests/metagpt/provider/test_bedrock_api.py | 41 +++++++------------ 4 files changed, 112 insertions(+), 63 deletions(-) diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index a87367b22..ee31da1b9 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -13,25 +13,34 @@ NOT_SUUPORT_STREAM_MODELS = { SUPPORT_STREAM_MODELS = { "amazon.titan-tg1-large": 8000, "amazon.titan-text-express-v1": 8000, + "amazon.titan-text-express-v1:0:8k": 8000, + "amazon.titan-text-lite-v1:0:4k": 4000, + "amazon.titan-text-lite-v1": 4000, "anthropic.claude-instant-v1": 100000, "anthropic.claude-instant-v1:2:100k": 100000, "anthropic.claude-v1": 100000, "anthropic.claude-v2": 100000, "anthropic.claude-v2:1": 200000, + "anthropic.claude-v2:0:18k": 18000, + "anthropic.claude-v2:1:200k": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000, "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000, "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": 200000, "cohere.command-text-v14": 4000, "cohere.command-text-v14:7:4k": 4000, "cohere.command-light-text-v14": 4000, "cohere.command-light-text-v14:7:4k": 4000, - "meta.llama2-70b-v1": 4000, "meta.llama2-13b-chat-v1:0:4k": 4000, "meta.llama2-13b-chat-v1": 2000, + "meta.llama2-70b-v1": 4000, "meta.llama2-70b-v1:0:4k": 4000, + "meta.llama2-70b-chat-v1": 4000, + "meta.llama2-70b-chat-v1:0:4k": 4000, "meta.llama3-8b-instruct-v1:0": 2000, "meta.llama3-70b-instruct-v1:0": 2000, "mistral.mistral-7b-instruct-v0:2": 32000, @@ -43,14 +52,14 @@ SUPPORT_STREAM_MODELS = { def messages_to_prompt_llama2(messages: list[dict]) -> str: - BOS, EOS = "", "" + BOS = ("",) B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" prompt = f"{BOS}" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") if role == "system": prompt += f"{B_SYS} {content} {E_SYS}" elif role == "user": @@ -58,25 +67,24 @@ def messages_to_prompt_llama2(messages: list[dict]) -> str: elif role == "assistant": prompt += f"{content}" else: - logger.warning( - f"Unknown role name {role} when formatting messages") + logger.warning(f"Unknown role name {role} when formatting messages") prompt += f"{content}" return prompt def messages_to_prompt_llama3(messages: list[dict]) -> str: - BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" + BOS = "<|begin_of_text|>" GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" prompt = f"{BOS}" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") prompt += GENERAL_TEMPLATE.format(role=role, content=content) if role != "assistant": - prompt += f"<|start_header_id|>assistant<|end_header_id|>" + prompt += "<|start_header_id|>assistant<|end_header_id|>" return prompt @@ -85,15 +93,20 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str: GENERAL_TEMPLATE = "\n\n{role}: {content}" prompt = "" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") prompt += GENERAL_TEMPLATE.format(role=role, content=content) if role != "assistant": - prompt += f"\n\nAssistant:" + prompt += "\n\nAssistant:" return prompt -def get_max_tokens(model_id) -> int: - return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] +def get_max_tokens(model_id: str) -> int: + try: + max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + except KeyError: + logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048") + max_tokens = 2048 + return max_tokens diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index de3fbae94..483b08f29 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -11,6 +11,8 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens from metagpt.provider.llm_provider_registry import register_provider +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS @register_provider([LLMType.BEDROCK]) @@ -19,6 +21,7 @@ class BedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) + self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) logger.warning("Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): @@ -62,14 +65,14 @@ class BedrockLLM(BaseLLM): def invoke_model(self, request_body: str) -> dict: response = self.__client.invoke_model(modelId=self.config.model, body=request_body) usage = self._get_usage(response) - self._update_costs(usage) + self._update_costs(usage, self.config.model) response_body = self._get_response_body(response) return response_body def invoke_model_with_response_stream(self, request_body: str) -> EventStream: response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) usage = self._get_usage(response) - self._update_costs(usage) + self._update_costs(usage, self.config.model) return response @property @@ -82,16 +85,29 @@ class BedrockLLM(BaseLLM): return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature} - def completion(self, messages: list[dict]) -> str: + # boto3 don't support support asynchronous calls. + # for asynchronous version of boto3, check out: + # https://aioboto3.readthedocs.io/en/latest/usage.html + # However,aioboto3 doesn't support invoke model + + def get_choice_text(self, rsp: dict) -> str: + return self.__provider.get_choice_text(rsp) + + async def acompletion(self, messages: list[dict]) -> dict: request_body = self.__provider.get_request_body(messages, self._const_kwargs) response_body = self.invoke_model(request_body) - completions = self.__provider.get_choice_text(response_body) - return completions + return response_body - def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: + return await self.acompletion(messages) + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") - return self.completion(messages) + rsp = await self.acompletion(messages) + full_text = self.get_choice_text(rsp) + log_llm_stream(full_text) + return full_text request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) @@ -106,20 +122,6 @@ class BedrockLLM(BaseLLM): full_text = ("".join(collected_content)).lstrip() return full_text - # boto3 don't support support asynchronous calls. - # for asynchronous version of boto3, check out: - # https://aioboto3.readthedocs.io/en/latest/usage.html - # However,aioboto3 doesn't support invoke model - - async def acompletion(self, messages: list[dict]): - return await self._achat_completion(messages) - - async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - return self.completion(messages) - - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - return self._chat_completion_stream(messages) - def _get_response_body(self, response) -> dict: response_body = json.loads(response["body"].read()) return response_body diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 724d49afc..9249d674e 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -198,6 +198,53 @@ TOKEN_MAX = { "openai/gpt-4-turbo-preview": 128000, } +# For Amazon Bedrock US region +# See https://aws.amazon.com/cn/bedrock/pricing/ + +BEDROCK_TOKEN_COSTS = { + "amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004}, + "amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004}, + "anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125}, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075}, + "cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003}, + "cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003}, + "meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006}, + "meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035}, + "mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002}, + "mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007}, + "mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024}, + "ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188}, +} + def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): """Return the number of tokens used by a list of messages.""" diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 54ff5afa4..4760a2db2 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -5,7 +5,6 @@ import pytest from metagpt.provider.bedrock.utils import ( NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS, - get_max_tokens, ) from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock @@ -17,14 +16,19 @@ from tests.metagpt.provider.req_resp_const import ( # all available model from bedrock models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS messages = [{"role": "user", "content": "Hi!"}] +usage = { + "prompt_tokens": 1000000, + "completion_tokens": 1000000, +} -def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: +def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: provider = self.config.model.split(".")[0] + self._update_costs(usage, self.config.model) return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: +def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") @@ -43,6 +47,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} + self._update_costs(usage, self.config.model) return response_body_stream @@ -82,41 +87,23 @@ def bedrock_api(request) -> BedrockLLM: class TestBedrockAPI: def _patch_invoke_model(self, mocker): - mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model) def _patch_invoke_model_stream(self, mocker): mocker.patch( "metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response, + mock_invoke_model_stream, ) - def test_const_kwargs(self, bedrock_api: BedrockLLM): - provider = bedrock_api.provider - assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model) - def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs)) - assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) - def test_completion(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model(mocker) - assert bedrock_api.completion(messages) == "Hello World" - - def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + @pytest.mark.asyncio + async def test_aask(self, bedrock_api: BedrockLLM, mocker): self._patch_invoke_model(mocker) self._patch_invoke_model_stream(mocker) - assert bedrock_api._chat_completion_stream(messages) == "Hello World" - - @pytest.mark.asyncio - async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model_stream(mocker) - self._patch_invoke_model(mocker) - assert await bedrock_api._achat_completion_stream(messages) == "Hello World" - - @pytest.mark.asyncio - async def test_acompletion(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model(mocker) - assert await bedrock_api.acompletion(messages) == "Hello World" + assert await bedrock_api.aask(messages, stream=False) == "Hello World" + assert await bedrock_api.aask(messages, stream=True) == "Hello World" From f7b29edaf7b7386f747118b74105a489949a235c Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 17:35:58 +0800 Subject: [PATCH 32/32] change log for non-streaming model --- metagpt/provider/bedrock_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 483b08f29..d192a5478 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -23,6 +23,8 @@ class BedrockLLM(BaseLLM): self.__provider = get_provider(self.config.model) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) logger.warning("Amazon bedrock doesn't support asynchronous now") + if self.config.model in NOT_SUUPORT_STREAM_MODELS: + logger.warning(f"model {self.config.model} doesn't support streaming output!") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -103,7 +105,6 @@ class BedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: - logger.warning(f"model {self.config.model} doesn't support streaming output!") rsp = await self.acompletion(messages) full_text = self.get_choice_text(rsp) log_llm_stream(full_text)