resolve problems

This commit is contained in:
usamimeri_renko 2024-04-28 19:18:12 +08:00
parent 98cb452911
commit 79251cd3cd
9 changed files with 25 additions and 216 deletions

View file

@ -32,7 +32,7 @@ class LLMType(Enum):
MISTRAL = "mistral"
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
AMAZON_BEDROCK = "amazon_bedrock"
BEDROCK = "bedrock"
def __missing__(self, key):
return self.OPENAI

View file

@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from metagpt.provider.bedrock_api import AmazonBedrockLLM
__all__ = [
"GeminiLLM",

View file

@ -1,192 +0,0 @@
### Python template
# Byte-compiled / optimized / DLL files
__pycache__
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
metagpt/tools/schemas/
examples/data/search_kb/*.json
# PyInstaller
# Usually these files are written by a python scripts from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
unittest.txt
# Translations
*.mo
*.pot
# Django stuff:
*.log
logs
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# report
allure-report
allure-results
# idea / vscode / macos
.idea
.DS_Store
.vscode
key.yaml
/data/
data.ms
examples/nb/
examples/default__vector_store.json
examples/docstore.json
examples/graph_store.json
examples/image__vector_store.json
examples/index_store.json
.chroma
*~$*
workspace/*
tmp
metagpt/roles/idea_agent.py
.aider*
*.bak
*.bk
# output folder
output
tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
tests/data/rsp_cache_new.json
*.tmp
*.png
htmlcov
htmlcov.*
cov.xml
*.dot
*.pkl
*.faiss
*-structure.csv
*-structure.json
*.dot
.python-version
# aws access key
config.py

View file

@ -76,7 +76,7 @@ class Ai21Provider(BaseBedrockProvider):
max_tokens_field_name = "maxTokens"
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['completions'][0]["data"]["text"]
return rsp_dict["completions"][0]["data"]["text"]
class AmazonProvider(BaseBedrockProvider):
@ -92,7 +92,7 @@ class AmazonProvider(BaseBedrockProvider):
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['results'][0]['outputText']
return rsp_dict["results"][0]["outputText"]
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])

View file

@ -15,7 +15,7 @@ except ImportError:
"boto3 not found! please install it by `pip install boto3` ")
@register_provider([LLMType.AMAZON_BEDROCK])
@register_provider([LLMType.BEDROCK])
class AmazonBedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
@ -36,10 +36,12 @@ class AmazonBedrockLLM(BaseLLM):
client = session.client(service_name)
return client
def _get_client(self):
@property
def client(self):
return self.__client
def _get_provider(self):
@property
def provider(self):
return self.__provider
def list_models(self):
@ -53,22 +55,21 @@ class AmazonBedrockLLM(BaseLLM):
"""
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality='TEXT')
response = client.list_foundation_models(byOutputModality="TEXT")
summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
for summary in response["modelSummaries"]]
logger.info("\n"+"\n".join(summaries))
def invoke_model(self, request_body) -> dict:
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body) -> EventStream:
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
modelId=self.config.model, body=request_body)
return response
@property

View file

@ -70,3 +70,4 @@ qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
gymnasium==0.29.1
boto3==1.34.92

View file

@ -62,7 +62,7 @@ mock_llm_config_anthropic = LLMConfig(
)
mock_llm_config_bedrock = LLMConfig(
api_type="amazon_bedrock",
api_type="bedrock",
model="gpt-100",
region_name="somewhere",
access_key="123abc",

View file

@ -160,8 +160,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
),
ContentBlockDeltaEvent(
index=0,
delta=TextDelta(text=resp_cont_tmpl.format(
name=name), type="text_delta"),
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
type="content_block_delta",
),
]
@ -237,5 +236,5 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
},
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
}
"amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]}
}

View file

@ -1,6 +1,6 @@
import pytest
import json
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from metagpt.provider.bedrock_api import AmazonBedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
@ -34,7 +34,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {
"body": [{'chunk': {'bytes': response_body_bytes}}]}
"body": [{"chunk": {"bytes": response_body_bytes}}]}
return response_body_stream
@ -74,13 +74,13 @@ def bedrock_api(request) -> AmazonBedrockLLM:
class TestAPI:
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
provider = bedrock_api._get_provider()
provider = bedrock_api.provider
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api._get_provider()
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(
messages, bedrock_api._generate_kwargs))
@ -88,13 +88,13 @@ class TestAPI:
bedrock_api.config.model))
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
assert bedrock_api.completion(messages) == "Hello World"
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
assert bedrock_api._chat_completion_stream(messages) == "Hello World"