mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge remote-tracking branch 'origin/main' into fix/cstream_not_work
This commit is contained in:
commit
cda6451b59
25 changed files with 733 additions and 57 deletions
1
.github/workflows/pre-commit.yaml
vendored
1
.github/workflows/pre-commit.yaml
vendored
|
|
@ -11,6 +11,7 @@ on:
|
|||
jobs:
|
||||
pre-commit-check:
|
||||
runs-on: ubuntu-latest
|
||||
environment: pre-commit
|
||||
steps:
|
||||
- name: Checkout Source Code
|
||||
uses: actions/checkout@v2
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ embedding:
|
|||
model: ""
|
||||
api_version: ""
|
||||
embed_batch_size: 100
|
||||
dimensions: # output dimension of embedding model
|
||||
|
||||
repair_llm_output: true # when the output is not a valid json, try to repair it
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
{
|
||||
"executablePath": "/usr/bin/chromium",
|
||||
"args": [
|
||||
"--no-sandbox"
|
||||
]
|
||||
}
|
||||
"executablePath": "/usr/bin/chromium",
|
||||
"args": ["--no-sandbox"]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ from metagpt.rag.schema import (
|
|||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
|
||||
TRAVEL_QUESTION = "What does Bob like?"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}"
|
||||
|
||||
|
||||
class Player(BaseModel):
|
||||
|
|
@ -40,21 +40,21 @@ class Player(BaseModel):
|
|||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG.
|
||||
"""Show how to use RAG."""
|
||||
|
||||
Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: SimpleEngine = None):
|
||||
def __init__(self, engine: SimpleEngine = None, use_llm_ranker: bool = True):
|
||||
self._engine = engine
|
||||
self._use_llm_ranker = use_llm_ranker
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if not self._engine:
|
||||
ranker_configs = [LLMRankerConfig()] if self._use_llm_ranker else None
|
||||
|
||||
self._engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return self._engine
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class RAGExample:
|
|||
"""
|
||||
self._print_title("Add Docs")
|
||||
|
||||
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
|
||||
travel_question = f"{TRAVEL_QUESTION}"
|
||||
travel_filepath = TRAVEL_DOC_PATH
|
||||
|
||||
logger.info("[Before add docs]")
|
||||
|
|
@ -240,8 +240,14 @@ class RAGExample:
|
|||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline."""
|
||||
e = RAGExample()
|
||||
"""RAG pipeline.
|
||||
|
||||
Note:
|
||||
1. If `use_llm_ranker` is True, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking,
|
||||
prefer `gpt-4-turbo`, otherwise might encounter `IndexError: list index out of range` or `ValueError: invalid literal for int() with base 10`.
|
||||
"""
|
||||
e = RAGExample(use_llm_ranker=False)
|
||||
|
||||
await e.run_pipeline()
|
||||
await e.add_docs()
|
||||
await e.add_objects()
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel):
|
|||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
|
@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel):
|
|||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
|
|
@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel):
|
|||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
dimensions: Optional[int] = None # output dimension of embedding model
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class LLMType(Enum):
|
|||
MISTRAL = "mistral"
|
||||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
BEDROCK = "bedrock"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -73,11 +74,15 @@ class LLMConfig(YamlModel):
|
|||
frequency_penalty: float = 0.0
|
||||
best_of: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
stream: bool = True
|
||||
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
|
||||
stream: bool = True
|
||||
# https://cookbook.openai.com/examples/using_logprobs
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
timeout: int = 600
|
||||
|
||||
# For Amazon Bedrock
|
||||
region_name: str = None
|
||||
|
||||
# For Network
|
||||
proxy: Optional[str] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
|
|||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -30,4 +31,5 @@ __all__ = [
|
|||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
"BedrockLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class BaseLLM(ABC):
|
|||
# image url or image base64
|
||||
url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}"
|
||||
# it can with multiple-image inputs
|
||||
content.append({"type": "image_url", "image_url": url})
|
||||
content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
|
|||
0
metagpt/provider/bedrock/__init__.py
Normal file
0
metagpt/provider/bedrock/__init__.py
Normal file
28
metagpt/provider/bedrock/base_provider.py
Normal file
28
metagpt/provider/bedrock/base_provider.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseBedrockProvider(ABC):
|
||||
# to handle different generation kwargs
|
||||
max_tokens_field_name = "max_tokens"
|
||||
|
||||
@abstractmethod
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str:
|
||||
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
121
metagpt/provider/bedrock/bedrock_provider.py
Normal file
121
metagpt/provider/bedrock/bedrock_provider.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
messages_to_prompt_llama2,
|
||||
messages_to_prompt_llama3,
|
||||
)
|
||||
|
||||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
return messages_to_prompt_llama2(messages)
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["outputs"][0]["text"]
|
||||
|
||||
|
||||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["content"][0]["text"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
# https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
if rsp_dict["type"] == "content_block_delta":
|
||||
completions = rsp_dict["delta"]["text"]
|
||||
return completions
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generations"][0]["text"]
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
|
||||
)
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("text", "")
|
||||
return completions
|
||||
|
||||
|
||||
class MetaProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
|
||||
max_tokens_field_name = "max_gen_len"
|
||||
|
||||
def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
|
||||
self.llama_version = llama_version
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
if self.llama_version == "llama2":
|
||||
return messages_to_prompt_llama2(messages)
|
||||
else:
|
||||
return messages_to_prompt_llama3(messages)
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generation"]
|
||||
|
||||
|
||||
class Ai21Provider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
|
||||
|
||||
max_tokens_field_name = "maxTokens"
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
|
||||
max_tokens_field_name = "maxTokenCount"
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["results"][0]["outputText"]
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict["outputText"]
|
||||
return completions
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
"mistral": MistralProvider,
|
||||
"meta": MetaProvider,
|
||||
"ai21": Ai21Provider,
|
||||
"cohere": CohereProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"amazon": AmazonProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str):
|
||||
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
|
||||
if provider not in PROVIDERS:
|
||||
raise KeyError(f"{provider} is not supported!")
|
||||
if provider == "meta":
|
||||
# distinguish llama2 and llama3
|
||||
return PROVIDERS[provider](model_name[:6])
|
||||
return PROVIDERS[provider]()
|
||||
112
metagpt/provider/bedrock/utils.py
Normal file
112
metagpt/provider/bedrock/utils.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
from metagpt.logs import logger
|
||||
|
||||
# max_tokens for each model
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct": 8000,
|
||||
"ai21.j2-jumbo-instruct": 8000,
|
||||
"ai21.j2-mid": 8000,
|
||||
"ai21.j2-mid-v1": 8000,
|
||||
"ai21.j2-ultra": 8000,
|
||||
"ai21.j2-ultra-v1": 8000,
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large": 8000,
|
||||
"amazon.titan-text-express-v1": 8000,
|
||||
"amazon.titan-text-express-v1:0:8k": 8000,
|
||||
"amazon.titan-text-lite-v1:0:4k": 4000,
|
||||
"amazon.titan-text-lite-v1": 4000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"anthropic.claude-instant-v1:2:100k": 100000,
|
||||
"anthropic.claude-v1": 100000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-v2:0:18k": 18000,
|
||||
"anthropic.claude-v2:1:200k": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"cohere.command-text-v14": 4000,
|
||||
"cohere.command-text-v14:7:4k": 4000,
|
||||
"cohere.command-light-text-v14": 4000,
|
||||
"cohere.command-light-text-v14:7:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1:0:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1": 2000,
|
||||
"meta.llama2-70b-v1": 4000,
|
||||
"meta.llama2-70b-v1:0:4k": 4000,
|
||||
"meta.llama2-70b-chat-v1": 4000,
|
||||
"meta.llama2-70b-chat-v1:0:4k": 4000,
|
||||
"meta.llama3-8b-instruct-v1:0": 2000,
|
||||
"meta.llama3-70b-instruct-v1:0": 2000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
|
||||
"mistral.mistral-large-2402-v1:0": 32000,
|
||||
}
|
||||
|
||||
# TODO:use a more general function for constructing chat templates.
|
||||
|
||||
|
||||
def messages_to_prompt_llama2(messages: list[dict]) -> str:
|
||||
BOS = ("<s>",)
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
prompt = f"{BOS}"
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
if role == "system":
|
||||
prompt += f"{B_SYS} {content} {E_SYS}"
|
||||
elif role == "user":
|
||||
prompt += f"{B_INST} {content} {E_INST}"
|
||||
elif role == "assistant":
|
||||
prompt += f"{content}"
|
||||
else:
|
||||
logger.warning(f"Unknown role name {role} when formatting messages")
|
||||
prompt += f"{content}"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_llama3(messages: list[dict]) -> str:
|
||||
BOS = "<|begin_of_text|>"
|
||||
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
|
||||
|
||||
prompt = f"{BOS}"
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += "<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
||||
GENERAL_TEMPLATE = "\n\n{role}: {content}"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
|
||||
|
||||
if role != "assistant":
|
||||
prompt += "\n\nAssistant:"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_max_tokens(model_id: str) -> int:
|
||||
try:
|
||||
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
except KeyError:
|
||||
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
|
||||
max_tokens = 2048
|
||||
return max_tokens
|
||||
140
metagpt/provider/bedrock_api.py
Normal file
140
metagpt/provider/bedrock_api.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
|
||||
|
||||
|
||||
@register_provider([LLMType.BEDROCK])
|
||||
class BedrockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
"""initialize boto3 client"""
|
||||
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
|
||||
self.__credentital_kwargs = {
|
||||
"aws_secret_access_key": self.config.secret_key,
|
||||
"aws_access_key_id": self.config.access_key,
|
||||
"region_name": self.config.region_name,
|
||||
}
|
||||
session = boto3.Session(**self.__credentital_kwargs)
|
||||
client = session.client(service_name)
|
||||
return client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self.__client
|
||||
|
||||
@property
|
||||
def provider(self):
|
||||
return self.__provider
|
||||
|
||||
def list_models(self):
|
||||
"""list all available text-generation models
|
||||
|
||||
```shell
|
||||
ai21.j2-ultra-v1 Support Streaming:False
|
||||
meta.llama3-70b-instruct-v1:0 Support Streaming:True
|
||||
……
|
||||
```
|
||||
"""
|
||||
client = self.__init_client("bedrock")
|
||||
# only output text-generation models
|
||||
response = client.list_foundation_models(byOutputModality="TEXT")
|
||||
summaries = [
|
||||
f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
|
||||
for summary in response["modelSummaries"]
|
||||
]
|
||||
logger.info("\n" + "\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response
|
||||
|
||||
@property
|
||||
def _const_kwargs(self) -> dict:
|
||||
model_max_tokens = get_max_tokens(self.config.model)
|
||||
if self.config.max_token > model_max_tokens:
|
||||
max_tokens = model_max_tokens
|
||||
else:
|
||||
max_tokens = self.config.max_token
|
||||
|
||||
return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature}
|
||||
|
||||
# boto3 don't support support asynchronous calls.
|
||||
# for asynchronous version of boto3, check out:
|
||||
# https://aioboto3.readthedocs.io/en/latest/usage.html
|
||||
# However,aioboto3 doesn't support invoke model
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return self.__provider.get_choice_text(rsp)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
return response_body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self.acompletion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
rsp = await self.acompletion(messages)
|
||||
full_text = self.get_choice_text(rsp)
|
||||
log_llm_stream(full_text)
|
||||
return full_text
|
||||
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
for event in response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_text = ("".join(collected_content)).lstrip()
|
||||
return full_text
|
||||
|
||||
def _get_response_body(self, response) -> dict:
|
||||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
def _get_usage(self, response) -> dict[str, int]:
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
usage = (
|
||||
{
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
},
|
||||
)
|
||||
return usage
|
||||
|
|
@ -33,6 +33,7 @@ class HumanProvider(BaseLLM):
|
|||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
**kwargs
|
||||
) -> str:
|
||||
return self.ask(msg, timeout=self.get_timeout(timeout))
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class OpenAILLM(BaseLLM):
|
|||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage"):
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
if self._index:
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
|||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
|
|
@ -44,7 +45,13 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
|
|||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
|
||||
config.embedding.api_type, 1536
|
||||
)
|
||||
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
|
||||
logger.warning(
|
||||
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
from metagpt.utils.common import any_to_name
|
||||
|
||||
|
||||
|
|
@ -35,17 +35,8 @@ class ProductManager(Role):
|
|||
|
||||
self.set_actions([PrepareDocuments, WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
self.todo_action = any_to_name(PrepareDocuments)
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if self.git_repo and not self.config.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
self.config.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self.rc.todo)
|
||||
self.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
return await super()._observe(ignore_memory=True)
|
||||
|
|
|
|||
|
|
@ -370,6 +370,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self.recovered = False # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
if self.rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
if self.rc.max_react_loop != len(self.actions):
|
||||
self.rc.max_react_loop = len(self.actions)
|
||||
self._set_state(self.rc.state + 1)
|
||||
return self.rc.state >= 0 and self.rc.state < len(self.actions)
|
||||
|
||||
prompt = self._get_prefix()
|
||||
prompt += STATE_TEMPLATE.format(
|
||||
history=self.rc.history,
|
||||
|
|
@ -460,8 +466,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act
|
||||
while actions_taken < self.rc.max_react_loop:
|
||||
# think
|
||||
await self._think()
|
||||
if self.rc.todo is None:
|
||||
todo = await self._think()
|
||||
if not todo:
|
||||
break
|
||||
# act
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
|
|
@ -469,15 +475,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
actions_taken += 1
|
||||
return rsp # return output from the last action
|
||||
|
||||
async def _act_by_order(self) -> Message:
|
||||
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
|
||||
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
|
||||
rsp = Message(content="No actions taken yet") # return default message if actions=[]
|
||||
for i in range(start_idx, len(self.states)):
|
||||
self._set_state(i)
|
||||
rsp = await self._act()
|
||||
return rsp # return output from the last action
|
||||
|
||||
async def _plan_and_act(self) -> Message:
|
||||
"""first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""
|
||||
|
||||
|
|
@ -518,10 +515,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
async def react(self) -> Message:
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message"""
|
||||
if self.rc.react_mode == RoleReactMode.REACT:
|
||||
if self.rc.react_mode == RoleReactMode.REACT or self.rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
rsp = await self._react()
|
||||
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
rsp = await self._act_by_order()
|
||||
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
rsp = await self._plan_and_act()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -35,8 +35,11 @@ TOKEN_COSTS = {
|
|||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo-2024-04-09": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
|
||||
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
|
|
@ -56,11 +59,14 @@ TOKEN_COSTS = {
|
|||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"yi-large": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start
|
||||
"meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008},
|
||||
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
|
||||
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
|
||||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -155,6 +161,9 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
|
||||
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
|
||||
TOKEN_MAX = {
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
|
|
@ -191,11 +200,61 @@ TOKEN_MAX = {
|
|||
"claude-3-opus-20240229": 200000,
|
||||
"yi-34b-chat-0205": 4000,
|
||||
"yi-34b-chat-200k": 200000,
|
||||
"yi-large": 16385,
|
||||
"microsoft/wizardlm-2-8x22b": 65536,
|
||||
"meta-llama/llama-3-70b-instruct": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"openai/gpt-3.5-turbo-0125": 16385,
|
||||
"openai/gpt-4-turbo-preview": 128000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
}
|
||||
|
||||
# For Amazon Bedrock US region
|
||||
# See https://aws.amazon.com/cn/bedrock/pricing/
|
||||
|
||||
BEDROCK_TOKEN_COSTS = {
|
||||
"amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008},
|
||||
"amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004},
|
||||
"anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024},
|
||||
"anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075},
|
||||
"cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015},
|
||||
"cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001},
|
||||
"meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
|
||||
"meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006},
|
||||
"meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035},
|
||||
"mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002},
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007},
|
||||
"mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024},
|
||||
"ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125},
|
||||
"ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188},
|
||||
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -224,6 +283,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-1106-vision-preview",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4o",
|
||||
}:
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
tokens_per_name = 1
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ PyYAML==6.0.1
|
|||
# sentence_transformers==2.2.2
|
||||
setuptools==65.6.3
|
||||
tenacity==8.2.3
|
||||
tiktoken==0.6.0
|
||||
tiktoken==0.7.0
|
||||
tqdm==4.66.2
|
||||
#unstructured[local-inference]
|
||||
# selenium>4
|
||||
|
|
@ -70,3 +70,4 @@ qianfan==0.3.2
|
|||
dashscope==1.14.1
|
||||
rank-bm25==0.2.2 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
boto3==1.34.92
|
||||
|
|
@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
|
|||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
||||
mock_llm_config_bedrock = LLMConfig(
|
||||
api_type="bedrock",
|
||||
model="gpt-100",
|
||||
region_name="somewhere",
|
||||
access_key="123abc",
|
||||
secret_key="123abc",
|
||||
max_token=10000,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
|
|||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
# For Amazon Bedrock
|
||||
# Check the API documentation of each model
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide
|
||||
BEDROCK_PROVIDER_REQUEST_BODY = {
|
||||
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
|
||||
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
|
||||
"ai21": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"topP": 0.0,
|
||||
"maxTokens": 0,
|
||||
"stopSequences": [],
|
||||
"countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0},
|
||||
"frequencyPenalty": {"scale": 0.0},
|
||||
},
|
||||
"cohere": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"p": 0.0,
|
||||
"k": 0.0,
|
||||
"max_tokens": 0,
|
||||
"stop_sequences": [],
|
||||
"return_likelihoods": "NONE",
|
||||
"stream": False,
|
||||
"num_generations": 0,
|
||||
"logit_bias": {},
|
||||
"truncate": "NONE",
|
||||
},
|
||||
"anthropic": {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 0,
|
||||
"system": "",
|
||||
"messages": [{"role": "", "content": ""}],
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.0,
|
||||
"top_k": 0,
|
||||
"stop_sequences": [],
|
||||
},
|
||||
"amazon": {
|
||||
"inputText": "",
|
||||
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
|
||||
},
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY = {
|
||||
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
|
||||
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
|
||||
"ai21": {
|
||||
"id": "",
|
||||
"prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [
|
||||
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
|
||||
],
|
||||
},
|
||||
"cohere": {
|
||||
"generations": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"id": "",
|
||||
"text": "Hello World",
|
||||
"likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}],
|
||||
"is_finished": True,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"id": "",
|
||||
"prompt": "",
|
||||
},
|
||||
"anthropic": {
|
||||
"id": "",
|
||||
"model": "",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "",
|
||||
"stop_sequence": "",
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
"amazon": {
|
||||
"inputTextTokenCount": 0,
|
||||
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
|
||||
},
|
||||
}
|
||||
|
|
|
|||
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
BEDROCK_PROVIDER_REQUEST_BODY,
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY,
|
||||
)
|
||||
|
||||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
"completion_tokens": 1000000,
|
||||
}
|
||||
|
||||
|
||||
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
self._update_costs(usage, self.config.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
||||
provider = self.config.model.split(".")[0]
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
elif provider == "anthropic":
|
||||
response_body_bytes = dict2bytes(
|
||||
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
|
||||
)
|
||||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
def get_bedrock_request_body(model_id) -> dict:
|
||||
provider = model_id.split(".")[0]
|
||||
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
|
||||
|
||||
|
||||
def is_subset(subset, superset) -> bool:
|
||||
"""Ensure all fields in request body are allowed.
|
||||
|
||||
```python
|
||||
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
|
||||
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
|
||||
is_subset(subset, superset)
|
||||
```
|
||||
>>>False
|
||||
"""
|
||||
for key, value in subset.items():
|
||||
if key not in superset:
|
||||
return False
|
||||
if isinstance(value, dict):
|
||||
if not isinstance(superset[key], dict):
|
||||
return False
|
||||
if not is_subset(value, superset[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="class", params=models)
|
||||
def bedrock_api(request) -> BedrockLLM:
|
||||
model_id = request.param
|
||||
mock_llm_config_bedrock.model = model_id
|
||||
api = BedrockLLM(mock_llm_config_bedrock)
|
||||
return api
|
||||
|
||||
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_invoke_model_stream,
|
||||
)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=True) == "Hello World"
|
||||
|
|
@ -10,7 +10,6 @@ import json
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -30,11 +29,7 @@ async def test_product_manager(new_filename):
|
|||
rsp = await product_manager.run(MockMessages.req)
|
||||
assert context.git_repo
|
||||
assert context.repo
|
||||
assert rsp.cause_by == any_to_str(PrepareDocuments)
|
||||
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
|
||||
|
||||
# write prd
|
||||
rsp = await product_manager.run(rsp)
|
||||
assert rsp.cause_by == any_to_str(WritePRD)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue