diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index ae8c57ec5..41e04ab44 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -32,7 +32,7 @@ class LLMType(Enum): MISTRAL = "mistral" YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" - AMAZON_BEDROCK = "amazon_bedrock" + BEDROCK = "bedrock" def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index dd5b4f89d..1311ccf61 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import AmazonBedrockLLM __all__ = [ "GeminiLLM", diff --git a/metagpt/provider/bedrock/.gitignore b/metagpt/provider/bedrock/.gitignore deleted file mode 100644 index 971fcecb7..000000000 --- a/metagpt/provider/bedrock/.gitignore +++ /dev/null @@ -1,192 +0,0 @@ -### Python template - -# Byte-compiled / optimized / DLL files -__pycache__ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST -metagpt/tools/schemas/ -examples/data/search_kb/*.json - -# PyInstaller -# Usually these files are written by a python scripts from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ -unittest.txt - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -logs -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# report -allure-report -allure-results - -# idea / vscode / macos -.idea -.DS_Store -.vscode - -key.yaml -/data/ -data.ms -examples/nb/ -examples/default__vector_store.json -examples/docstore.json -examples/graph_store.json -examples/image__vector_store.json -examples/index_store.json -.chroma -*~$* -workspace/* -tmp -metagpt/roles/idea_agent.py -.aider* -*.bak -*.bk - -# output folder -output -tmp.png -.dependencies.json -tests/metagpt/utils/file_repo_git -tests/data/rsp_cache_new.json -*.tmp -*.png -htmlcov -htmlcov.* -cov.xml -*.dot -*.pkl -*.faiss -*-structure.csv -*-structure.json -*.dot -.python-version -# aws access key -config.py \ No newline at end of file diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 01bcaac53..6378939c9 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -76,7 +76,7 @@ class Ai21Provider(BaseBedrockProvider): max_tokens_field_name = "maxTokens" def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['completions'][0]["data"]["text"] + return rsp_dict["completions"][0]["data"]["text"] class AmazonProvider(BaseBedrockProvider): @@ -92,7 +92,7 @@ class AmazonProvider(BaseBedrockProvider): return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['results'][0]['outputText'] + return rsp_dict["results"][0]["outputText"] def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock_api.py similarity index 94% rename from metagpt/provider/bedrock/amazon_bedrock_api.py rename to metagpt/provider/bedrock_api.py index 6230cd3f2..ac0f2e505 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -15,7 +15,7 @@ except ImportError: "boto3 not found! please install it by `pip install boto3` ") -@register_provider([LLMType.AMAZON_BEDROCK]) +@register_provider([LLMType.BEDROCK]) class AmazonBedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config @@ -36,10 +36,12 @@ class AmazonBedrockLLM(BaseLLM): client = session.client(service_name) return client - def _get_client(self): + @property + def client(self): return self.__client - def _get_provider(self): + @property + def provider(self): return self.__provider def list_models(self): @@ -53,22 +55,21 @@ class AmazonBedrockLLM(BaseLLM): """ client = self.__init_client("bedrock") # only output text-generation models - response = client.list_foundation_models(byOutputModality='TEXT') + response = client.list_foundation_models(byOutputModality="TEXT") summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' for summary in response["modelSummaries"]] logger.info("\n"+"\n".join(summaries)) - def invoke_model(self, request_body) -> dict: + def invoke_model(self, request_body: str) -> dict: response = self.__client.invoke_model( modelId=self.config.model, body=request_body ) response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body) -> EventStream: + def invoke_model_with_response_stream(self, request_body: str) -> EventStream: response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body - ) + modelId=self.config.model, body=request_body) return response @property diff --git a/requirements.txt b/requirements.txt index 6c219a9dc..76c158115 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,3 +70,4 @@ qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 +boto3==1.34.92 \ No newline at end of file diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 8660bc24f..8f2baea10 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -62,7 +62,7 @@ mock_llm_config_anthropic = LLMConfig( ) mock_llm_config_bedrock = LLMConfig( - api_type="amazon_bedrock", + api_type="bedrock", model="gpt-100", region_name="somewhere", access_key="123abc", diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 893c33704..fb754abf7 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -160,8 +160,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: ), ContentBlockDeltaEvent( index=0, - delta=TextDelta(text=resp_cont_tmpl.format( - name=name), type="text_delta"), + delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"), type="content_block_delta", ), ] @@ -237,5 +236,5 @@ BEDROCK_PROVIDER_RESPONSE_BODY = { "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} }, - "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} -} \ No newline at end of file + "amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]} +} diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py similarity index 86% rename from tests/metagpt/provider/test_amazon_bedrock_api.py rename to tests/metagpt/provider/test_bedrock_api.py index cd13c0b24..7e282db78 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,6 +1,6 @@ import pytest import json -from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import AmazonBedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY @@ -34,7 +34,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: BEDROCK_PROVIDER_RESPONSE_BODY[provider]) response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}}]} + "body": [{"chunk": {"bytes": response_body_bytes}}]} return response_body_stream @@ -74,13 +74,13 @@ def bedrock_api(request) -> AmazonBedrockLLM: class TestAPI: def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): - provider = bedrock_api._get_provider() + provider = bedrock_api.provider assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( bedrock_api.config.model) def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): """Ensure request body has correct format""" - provider = bedrock_api._get_provider() + provider = bedrock_api.provider request_body = json.loads(provider.get_request_body( messages, bedrock_api._generate_kwargs)) @@ -88,13 +88,13 @@ class TestAPI: bedrock_api.config.model)) def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) assert bedrock_api.completion(messages) == "Hello World" def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) - mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", + mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_stream_response) assert bedrock_api._chat_completion_stream(messages) == "Hello World"