resolve problems

This commit is contained in:
usamimeri_renko 2024-04-28 19:18:12 +08:00
parent 98cb452911
commit 79251cd3cd
9 changed files with 25 additions and 216 deletions

View file

@ -32,7 +32,7 @@ class LLMType(Enum):
MISTRAL = "mistral"
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
AMAZON_BEDROCK = "amazon_bedrock"
BEDROCK = "bedrock"
def __missing__(self, key):
return self.OPENAI

View file

@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from metagpt.provider.bedrock_api import AmazonBedrockLLM
__all__ = [
"GeminiLLM",

View file

@ -1,192 +0,0 @@
### Python template
# Byte-compiled / optimized / DLL files
__pycache__
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
metagpt/tools/schemas/
examples/data/search_kb/*.json
# PyInstaller
# Usually these files are written by a python scripts from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
unittest.txt
# Translations
*.mo
*.pot
# Django stuff:
*.log
logs
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# report
allure-report
allure-results
# idea / vscode / macos
.idea
.DS_Store
.vscode
key.yaml
/data/
data.ms
examples/nb/
examples/default__vector_store.json
examples/docstore.json
examples/graph_store.json
examples/image__vector_store.json
examples/index_store.json
.chroma
*~$*
workspace/*
tmp
metagpt/roles/idea_agent.py
.aider*
*.bak
*.bk
# output folder
output
tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
tests/data/rsp_cache_new.json
*.tmp
*.png
htmlcov
htmlcov.*
cov.xml
*.dot
*.pkl
*.faiss
*-structure.csv
*-structure.json
*.dot
.python-version
# aws access key
config.py

View file

@ -76,7 +76,7 @@ class Ai21Provider(BaseBedrockProvider):
max_tokens_field_name = "maxTokens"
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['completions'][0]["data"]["text"]
return rsp_dict["completions"][0]["data"]["text"]
class AmazonProvider(BaseBedrockProvider):
@ -92,7 +92,7 @@ class AmazonProvider(BaseBedrockProvider):
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['results'][0]['outputText']
return rsp_dict["results"][0]["outputText"]
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])

View file

@ -15,7 +15,7 @@ except ImportError:
"boto3 not found! please install it by `pip install boto3` ")
@register_provider([LLMType.AMAZON_BEDROCK])
@register_provider([LLMType.BEDROCK])
class AmazonBedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
@ -36,10 +36,12 @@ class AmazonBedrockLLM(BaseLLM):
client = session.client(service_name)
return client
def _get_client(self):
@property
def client(self):
return self.__client
def _get_provider(self):
@property
def provider(self):
return self.__provider
def list_models(self):
@ -53,22 +55,21 @@ class AmazonBedrockLLM(BaseLLM):
"""
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality='TEXT')
response = client.list_foundation_models(byOutputModality="TEXT")
summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
for summary in response["modelSummaries"]]
logger.info("\n"+"\n".join(summaries))
def invoke_model(self, request_body) -> dict:
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body) -> EventStream:
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
modelId=self.config.model, body=request_body)
return response
@property