mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
resolve problems
This commit is contained in:
parent
98cb452911
commit
79251cd3cd
9 changed files with 25 additions and 216 deletions
|
|
@ -32,7 +32,7 @@ class LLMType(Enum):
|
|||
MISTRAL = "mistral"
|
||||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
AMAZON_BEDROCK = "amazon_bedrock"
|
||||
BEDROCK = "bedrock"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
|
|||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
|
||||
from metagpt.provider.bedrock_api import AmazonBedrockLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
|
|||
192
metagpt/provider/bedrock/.gitignore
vendored
192
metagpt/provider/bedrock/.gitignore
vendored
|
|
@ -1,192 +0,0 @@
|
|||
### Python template
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
metagpt/tools/schemas/
|
||||
examples/data/search_kb/*.json
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python scripts from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
unittest.txt
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
logs
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# report
|
||||
allure-report
|
||||
allure-results
|
||||
|
||||
# idea / vscode / macos
|
||||
.idea
|
||||
.DS_Store
|
||||
.vscode
|
||||
|
||||
key.yaml
|
||||
/data/
|
||||
data.ms
|
||||
examples/nb/
|
||||
examples/default__vector_store.json
|
||||
examples/docstore.json
|
||||
examples/graph_store.json
|
||||
examples/image__vector_store.json
|
||||
examples/index_store.json
|
||||
.chroma
|
||||
*~$*
|
||||
workspace/*
|
||||
tmp
|
||||
metagpt/roles/idea_agent.py
|
||||
.aider*
|
||||
*.bak
|
||||
*.bk
|
||||
|
||||
# output folder
|
||||
output
|
||||
tmp.png
|
||||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
tests/data/rsp_cache_new.json
|
||||
*.tmp
|
||||
*.png
|
||||
htmlcov
|
||||
htmlcov.*
|
||||
cov.xml
|
||||
*.dot
|
||||
*.pkl
|
||||
*.faiss
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
*.dot
|
||||
.python-version
|
||||
# aws access key
|
||||
config.py
|
||||
|
|
@ -76,7 +76,7 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
max_tokens_field_name = "maxTokens"
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['completions'][0]["data"]["text"]
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
|
|
@ -92,7 +92,7 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['results'][0]['outputText']
|
||||
return rsp_dict["results"][0]["outputText"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ except ImportError:
|
|||
"boto3 not found! please install it by `pip install boto3` ")
|
||||
|
||||
|
||||
@register_provider([LLMType.AMAZON_BEDROCK])
|
||||
@register_provider([LLMType.BEDROCK])
|
||||
class AmazonBedrockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
|
|
@ -36,10 +36,12 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
client = session.client(service_name)
|
||||
return client
|
||||
|
||||
def _get_client(self):
|
||||
@property
|
||||
def client(self):
|
||||
return self.__client
|
||||
|
||||
def _get_provider(self):
|
||||
@property
|
||||
def provider(self):
|
||||
return self.__provider
|
||||
|
||||
def list_models(self):
|
||||
|
|
@ -53,22 +55,21 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
"""
|
||||
client = self.__init_client("bedrock")
|
||||
# only output text-generation models
|
||||
response = client.list_foundation_models(byOutputModality='TEXT')
|
||||
response = client.list_foundation_models(byOutputModality="TEXT")
|
||||
summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
|
||||
for summary in response["modelSummaries"]]
|
||||
logger.info("\n"+"\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body) -> dict:
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body) -> EventStream:
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
modelId=self.config.model, body=request_body)
|
||||
return response
|
||||
|
||||
@property
|
||||
|
|
@ -70,3 +70,4 @@ qianfan==0.3.2
|
|||
dashscope==1.14.1
|
||||
rank-bm25==0.2.2 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
boto3==1.34.92
|
||||
|
|
@ -62,7 +62,7 @@ mock_llm_config_anthropic = LLMConfig(
|
|||
)
|
||||
|
||||
mock_llm_config_bedrock = LLMConfig(
|
||||
api_type="amazon_bedrock",
|
||||
api_type="bedrock",
|
||||
model="gpt-100",
|
||||
region_name="somewhere",
|
||||
access_key="123abc",
|
||||
|
|
|
|||
|
|
@ -160,8 +160,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
|||
),
|
||||
ContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(
|
||||
name=name), type="text_delta"),
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
|
||||
type="content_block_delta",
|
||||
),
|
||||
]
|
||||
|
|
@ -237,5 +236,5 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
|
|||
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
|
||||
},
|
||||
|
||||
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
|
||||
}
|
||||
"amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import json
|
||||
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
|
||||
from metagpt.provider.bedrock_api import AmazonBedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
|
||||
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
|
||||
|
|
@ -34,7 +34,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
|
|||
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {
|
||||
"body": [{'chunk': {'bytes': response_body_bytes}}]}
|
||||
"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
@ -74,13 +74,13 @@ def bedrock_api(request) -> AmazonBedrockLLM:
|
|||
|
||||
class TestAPI:
|
||||
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
|
||||
provider = bedrock_api._get_provider()
|
||||
provider = bedrock_api.provider
|
||||
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
|
||||
bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api._get_provider()
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(
|
||||
messages, bedrock_api._generate_kwargs))
|
||||
|
||||
|
|
@ -88,13 +88,13 @@ class TestAPI:
|
|||
bedrock_api.config.model))
|
||||
|
||||
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mock_bedrock_provider_response)
|
||||
assert bedrock_api.completion(messages) == "Hello World"
|
||||
|
||||
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mock_bedrock_provider_response)
|
||||
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response)
|
||||
assert bedrock_api._chat_completion_stream(messages) == "Hello World"
|
||||
Loading…
Add table
Add a link
Reference in a new issue