diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index ed4bbb144..d350a87f1 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -11,6 +11,7 @@ on: jobs: pre-commit-check: runs-on: ubuntu-latest + environment: pre-commit steps: - name: Checkout Source Code uses: actions/checkout@v2 diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 3249f5ae3..e57ec3ee8 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -18,6 +18,7 @@ embedding: model: "" api_version: "" embed_batch_size: 100 + dimensions: # output dimension of embedding model repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/config/puppeteer-config.json b/config/puppeteer-config.json index 7b2851c29..b74a514e7 100644 --- a/config/puppeteer-config.json +++ b/config/puppeteer-config.json @@ -1,6 +1,4 @@ { - "executablePath": "/usr/bin/chromium", - "args": [ - "--no-sandbox" - ] -} \ No newline at end of file + "executablePath": "/usr/bin/chromium", + "args": ["--no-sandbox"] +} diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 7dbca35a6..5b716ce03 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -18,13 +18,13 @@ from metagpt.rag.schema import ( ) from metagpt.utils.exceptions import handle_exception +LLM_TIP = "If you not sure, just answer I don't know." + DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" -QUESTION = "What are key qualities to be a good writer?" +QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}" TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" -TRAVEL_QUESTION = "What does Bob like?" - -LLM_TIP = "If you not sure, just answer I don't know." +TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}" class Player(BaseModel): @@ -40,21 +40,21 @@ class Player(BaseModel): class RAGExample: - """Show how to use RAG. + """Show how to use RAG.""" - Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`. - """ - - def __init__(self, engine: SimpleEngine = None): + def __init__(self, engine: SimpleEngine = None, use_llm_ranker: bool = True): self._engine = engine + self._use_llm_ranker = use_llm_ranker @property def engine(self): if not self._engine: + ranker_configs = [LLMRankerConfig()] if self._use_llm_ranker else None + self._engine = SimpleEngine.from_docs( input_files=[DOC_PATH], retriever_configs=[FAISSRetrieverConfig()], - ranker_configs=[LLMRankerConfig()], + ranker_configs=ranker_configs, ) return self._engine @@ -105,7 +105,7 @@ class RAGExample: """ self._print_title("Add Docs") - travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}" + travel_question = f"{TRAVEL_QUESTION}" travel_filepath = TRAVEL_DOC_PATH logger.info("[Before add docs]") @@ -240,8 +240,14 @@ class RAGExample: async def main(): - """RAG pipeline.""" - e = RAGExample() + """RAG pipeline. + + Note: + 1. If `use_llm_ranker` is True, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking, + prefer `gpt-4-turbo`, otherwise might encounter `IndexError: list index out of range` or `ValueError: invalid literal for int() with base 10`. + """ + e = RAGExample(use_llm_ranker=False) + await e.run_pipeline() await e.add_docs() await e.add_objects() diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 3d6056aae..dbf04dac6 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -32,6 +32,7 @@ class LLMType(Enum): MISTRAL = "mistral" YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" + BEDROCK = "bedrock" def __missing__(self, key): return self.OPENAI @@ -73,11 +74,15 @@ class LLMConfig(YamlModel): frequency_penalty: float = 0.0 best_of: Optional[int] = None n: Optional[int] = None - stream: bool = True - logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs + stream: bool = True + # https://cookbook.openai.com/examples/using_logprobs + logprobs: Optional[bool] = None top_logprobs: Optional[int] = None timeout: int = 600 + # For Amazon Bedrock + region_name: str = None + # For Network proxy: Optional[str] = None diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 14d5e7682..fcb5fa32a 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM +from metagpt.provider.bedrock_api import BedrockLLM __all__ = [ "GeminiLLM", @@ -30,4 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", + "BedrockLLM", ] diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index aae00f60d..a95e8dbd3 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -65,7 +65,7 @@ class BaseLLM(ABC): # image url or image base64 url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}" # it can with multiple-image inputs - content.append({"type": "image_url", "image_url": url}) + content.append({"type": "image_url", "image_url": {"url": url}}) return {"role": "user", "content": content} def _assistant_msg(self, msg: str) -> dict[str, str]: diff --git a/metagpt/provider/bedrock/__init__.py b/metagpt/provider/bedrock/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py new file mode 100644 index 000000000..0d13ae938 --- /dev/null +++ b/metagpt/provider/bedrock/base_provider.py @@ -0,0 +1,28 @@ +import json +from abc import ABC, abstractmethod + + +class BaseBedrockProvider(ABC): + # to handle different generation kwargs + max_tokens_field_name = "max_tokens" + + @abstractmethod + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + ... + + def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str: + body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs}) + return body + + def get_choice_text(self, response_body: dict) -> str: + completions = self._get_completion_from_dict(response_body) + return completions + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = self._get_completion_from_dict(rsp_dict) + return completions + + def messages_to_prompt(self, messages: list[dict]) -> str: + """[{"role": "user", "content": msg}] to user: etc.""" + return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py new file mode 100644 index 000000000..ff1d88a47 --- /dev/null +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -0,0 +1,121 @@ +import json +from typing import Literal + +from metagpt.provider.bedrock.base_provider import BaseBedrockProvider +from metagpt.provider.bedrock.utils import ( + messages_to_prompt_llama2, + messages_to_prompt_llama3, +) + + +class MistralProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + + def messages_to_prompt(self, messages: list[dict]): + return messages_to_prompt_llama2(messages) + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["outputs"][0]["text"] + + +class AnthropicProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + return body + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["content"][0]["text"] + + def get_choice_text_from_stream(self, event) -> str: + # https://docs.anthropic.com/claude/reference/messages-streaming + rsp_dict = json.loads(event["chunk"]["bytes"]) + if rsp_dict["type"] == "content_block_delta": + completions = rsp_dict["delta"]["text"] + return completions + else: + return "" + + +class CohereProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generations"][0]["text"] + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + body = json.dumps( + {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs} + ) + return body + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict.get("text", "") + return completions + + +class MetaProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + + max_tokens_field_name = "max_gen_len" + + def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None: + self.llama_version = llama_version + + def messages_to_prompt(self, messages: list[dict]): + if self.llama_version == "llama2": + return messages_to_prompt_llama2(messages) + else: + return messages_to_prompt_llama3(messages) + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["generation"] + + +class Ai21Provider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + + max_tokens_field_name = "maxTokens" + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["completions"][0]["data"]["text"] + + +class AmazonProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + + max_tokens_field_name = "maxTokenCount" + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs}) + return body + + def _get_completion_from_dict(self, rsp_dict: dict) -> str: + return rsp_dict["results"][0]["outputText"] + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict["outputText"] + return completions + + +PROVIDERS = { + "mistral": MistralProvider, + "meta": MetaProvider, + "ai21": Ai21Provider, + "cohere": CohereProvider, + "anthropic": AnthropicProvider, + "amazon": AmazonProvider, +} + + +def get_provider(model_id: str): + provider, model_name = model_id.split(".")[0:2] # meta、mistral…… + if provider not in PROVIDERS: + raise KeyError(f"{provider} is not supported!") + if provider == "meta": + # distinguish llama2 and llama3 + return PROVIDERS[provider](model_name[:6]) + return PROVIDERS[provider]() diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py new file mode 100644 index 000000000..ee31da1b9 --- /dev/null +++ b/metagpt/provider/bedrock/utils.py @@ -0,0 +1,112 @@ +from metagpt.logs import logger + +# max_tokens for each model +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct": 8000, + "ai21.j2-jumbo-instruct": 8000, + "ai21.j2-mid": 8000, + "ai21.j2-mid-v1": 8000, + "ai21.j2-ultra": 8000, + "ai21.j2-ultra-v1": 8000, +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large": 8000, + "amazon.titan-text-express-v1": 8000, + "amazon.titan-text-express-v1:0:8k": 8000, + "amazon.titan-text-lite-v1:0:4k": 4000, + "amazon.titan-text-lite-v1": 4000, + "anthropic.claude-instant-v1": 100000, + "anthropic.claude-instant-v1:2:100k": 100000, + "anthropic.claude-v1": 100000, + "anthropic.claude-v2": 100000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude-v2:0:18k": 18000, + "anthropic.claude-v2:1:200k": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, + "anthropic.claude-3-haiku-20240307-v1:0": 200000, + "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, + "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": 200000, + "cohere.command-text-v14": 4000, + "cohere.command-text-v14:7:4k": 4000, + "cohere.command-light-text-v14": 4000, + "cohere.command-light-text-v14:7:4k": 4000, + "meta.llama2-13b-chat-v1:0:4k": 4000, + "meta.llama2-13b-chat-v1": 2000, + "meta.llama2-70b-v1": 4000, + "meta.llama2-70b-v1:0:4k": 4000, + "meta.llama2-70b-chat-v1": 4000, + "meta.llama2-70b-chat-v1:0:4k": 4000, + "meta.llama3-8b-instruct-v1:0": 2000, + "meta.llama3-70b-instruct-v1:0": 2000, + "mistral.mistral-7b-instruct-v0:2": 32000, + "mistral.mixtral-8x7b-instruct-v0:1": 32000, + "mistral.mistral-large-2402-v1:0": 32000, +} + +# TODO:use a more general function for constructing chat templates. + + +def messages_to_prompt_llama2(messages: list[dict]) -> str: + BOS = ("",) + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + prompt = f"{BOS}" + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + if role == "system": + prompt += f"{B_SYS} {content} {E_SYS}" + elif role == "user": + prompt += f"{B_INST} {content} {E_INST}" + elif role == "assistant": + prompt += f"{content}" + else: + logger.warning(f"Unknown role name {role} when formatting messages") + prompt += f"{content}" + + return prompt + + +def messages_to_prompt_llama3(messages: list[dict]) -> str: + BOS = "<|begin_of_text|>" + GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" + + prompt = f"{BOS}" + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + + if role != "assistant": + prompt += "<|start_header_id|>assistant<|end_header_id|>" + + return prompt + + +def messages_to_prompt_claude2(messages: list[dict]) -> str: + GENERAL_TEMPLATE = "\n\n{role}: {content}" + prompt = "" + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + + if role != "assistant": + prompt += "\n\nAssistant:" + + return prompt + + +def get_max_tokens(model_id: str) -> int: + try: + max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + except KeyError: + logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048") + max_tokens = 2048 + return max_tokens diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py new file mode 100644 index 000000000..d192a5478 --- /dev/null +++ b/metagpt/provider/bedrock_api.py @@ -0,0 +1,140 @@ +import json +from typing import Literal + +import boto3 +from botocore.eventstream import EventStream + +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.bedrock.bedrock_provider import get_provider +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS + + +@register_provider([LLMType.BEDROCK]) +class BedrockLLM(BaseLLM): + def __init__(self, config: LLMConfig): + self.config = config + self.__client = self.__init_client("bedrock-runtime") + self.__provider = get_provider(self.config.model) + self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) + logger.warning("Amazon bedrock doesn't support asynchronous now") + if self.config.model in NOT_SUUPORT_STREAM_MODELS: + logger.warning(f"model {self.config.model} doesn't support streaming output!") + + def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): + """initialize boto3 client""" + # access key and secret key from https://us-east-1.console.aws.amazon.com/iam + self.__credentital_kwargs = { + "aws_secret_access_key": self.config.secret_key, + "aws_access_key_id": self.config.access_key, + "region_name": self.config.region_name, + } + session = boto3.Session(**self.__credentital_kwargs) + client = session.client(service_name) + return client + + @property + def client(self): + return self.__client + + @property + def provider(self): + return self.__provider + + def list_models(self): + """list all available text-generation models + + ```shell + ai21.j2-ultra-v1 Support Streaming:False + meta.llama3-70b-instruct-v1:0 Support Streaming:True + …… + ``` + """ + client = self.__init_client("bedrock") + # only output text-generation models + response = client.list_foundation_models(byOutputModality="TEXT") + summaries = [ + f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' + for summary in response["modelSummaries"] + ] + logger.info("\n" + "\n".join(summaries)) + + def invoke_model(self, request_body: str) -> dict: + response = self.__client.invoke_model(modelId=self.config.model, body=request_body) + usage = self._get_usage(response) + self._update_costs(usage, self.config.model) + response_body = self._get_response_body(response) + return response_body + + def invoke_model_with_response_stream(self, request_body: str) -> EventStream: + response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) + usage = self._get_usage(response) + self._update_costs(usage, self.config.model) + return response + + @property + def _const_kwargs(self) -> dict: + model_max_tokens = get_max_tokens(self.config.model) + if self.config.max_token > model_max_tokens: + max_tokens = model_max_tokens + else: + max_tokens = self.config.max_token + + return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature} + + # boto3 don't support support asynchronous calls. + # for asynchronous version of boto3, check out: + # https://aioboto3.readthedocs.io/en/latest/usage.html + # However,aioboto3 doesn't support invoke model + + def get_choice_text(self, rsp: dict) -> str: + return self.__provider.get_choice_text(rsp) + + async def acompletion(self, messages: list[dict]) -> dict: + request_body = self.__provider.get_request_body(messages, self._const_kwargs) + response_body = self.invoke_model(request_body) + return response_body + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: + return await self.acompletion(messages) + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + if self.config.model in NOT_SUUPORT_STREAM_MODELS: + rsp = await self.acompletion(messages) + full_text = self.get_choice_text(rsp) + log_llm_stream(full_text) + return full_text + + request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) + + response = self.invoke_model_with_response_stream(request_body) + collected_content = [] + for event in response["body"]: + chunk_text = self.__provider.get_choice_text_from_stream(event) + collected_content.append(chunk_text) + log_llm_stream(chunk_text) + + log_llm_stream("\n") + full_text = ("".join(collected_content)).lstrip() + return full_text + + def _get_response_body(self, response) -> dict: + response_body = json.loads(response["body"].read()) + return response_body + + def _get_usage(self, response) -> dict[str, int]: + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) + usage = ( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + }, + ) + return usage diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index 87dbd105f..a16a49c20 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -33,6 +33,7 @@ class HumanProvider(BaseLLM): format_msgs: Optional[list[dict[str, str]]] = None, generator: bool = False, timeout=USE_CONFIG_TIMEOUT, + **kwargs ) -> str: return self.ask(msg, timeout=self.get_timeout(timeout)) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 7957f775c..68dc156c2 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -100,7 +100,7 @@ class OpenAILLM(BaseLLM): log_llm_stream(chunk_message) collected_messages.append(chunk_message) if finish_reason: - if hasattr(chunk, "usage"): + if hasattr(chunk, "usage") and chunk.usage is not None: # Some services have usage as an attribute of the chunk, such as Fireworks usage = CompletionUsage(**chunk.usage) elif hasattr(chunk.choices[0], "usage"): diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 241820cf4..dc75d87b0 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever): self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) - self._index.insert_nodes(nodes, **kwargs) + if self._index: + self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" - self._index.storage_context.persist(persist_dir) + if self._index: + self._index.storage_context.persist(persist_dir) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e7b2e5ce9..618880a22 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType +from metagpt.logs import logger from metagpt.rag.interface import RAGObject @@ -44,7 +45,13 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: - self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.embedding.api_type, 1536 + ) + if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.warning( + f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" + ) return self diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index fbe139a99..9db9f7d9e 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -9,7 +9,7 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role +from metagpt.roles.role import Role, RoleReactMode from metagpt.utils.common import any_to_name @@ -35,17 +35,8 @@ class ProductManager(Role): self.set_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) - self.todo_action = any_to_name(PrepareDocuments) - - async def _think(self) -> bool: - """Decide what to do""" - if self.git_repo and not self.config.git_reinit: - self._set_state(1) - else: - self._set_state(0) - self.config.git_reinit = False - self.todo_action = any_to_name(WritePRD) - return bool(self.rc.todo) + self.rc.react_mode = RoleReactMode.BY_ORDER + self.todo_action = any_to_name(WritePRD) async def _observe(self, ignore_memory=False) -> int: return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 142c3a5b9..071f060ea 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -370,6 +370,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel): self.recovered = False # avoid max_react_loop out of work return True + if self.rc.react_mode == RoleReactMode.BY_ORDER: + if self.rc.max_react_loop != len(self.actions): + self.rc.max_react_loop = len(self.actions) + self._set_state(self.rc.state + 1) + return self.rc.state >= 0 and self.rc.state < len(self.actions) + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self.rc.history, @@ -460,8 +466,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act while actions_taken < self.rc.max_react_loop: # think - await self._think() - if self.rc.todo is None: + todo = await self._think() + if not todo: break # act logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}") @@ -469,15 +475,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): actions_taken += 1 return rsp # return output from the last action - async def _act_by_order(self) -> Message: - """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state - rsp = Message(content="No actions taken yet") # return default message if actions=[] - for i in range(start_idx, len(self.states)): - self._set_state(i) - rsp = await self._act() - return rsp # return output from the last action - async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" @@ -518,10 +515,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): async def react(self) -> Message: """Entry to one of three strategies by which Role reacts to the observed Message""" - if self.rc.react_mode == RoleReactMode.REACT: + if self.rc.react_mode == RoleReactMode.REACT or self.rc.react_mode == RoleReactMode.BY_ORDER: rsp = await self._react() - elif self.rc.react_mode == RoleReactMode.BY_ORDER: - rsp = await self._act_by_order() elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT: rsp = await self._plan_and_act() else: diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 724d49afc..de549cc5a 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -35,8 +35,11 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-turbo": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-turbo-2024-04-09": {"prompt": 0.01, "completion": 0.03}, "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4o": {"prompt": 0.005, "completion": 0.015}, + "gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens @@ -56,11 +59,14 @@ TOKEN_COSTS = { "claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075}, "yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003}, "yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017}, + "yi-large": {"prompt": 0.0028, "completion": 0.0028}, "microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start "meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008}, "llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079}, "openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015}, "openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, + "deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, + "deepseek-coder": {"prompt": 0.00014, "completion": 0.00028}, } @@ -155,6 +161,9 @@ FIREWORKS_GRADE_TOKEN_COSTS = { # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo TOKEN_MAX = { + "gpt-4o-2024-05-13": 128000, + "gpt-4o": 128000, + "gpt-4-turbo-2024-04-09": 128000, "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, "gpt-4-1106-preview": 128000, @@ -191,11 +200,61 @@ TOKEN_MAX = { "claude-3-opus-20240229": 200000, "yi-34b-chat-0205": 4000, "yi-34b-chat-200k": 200000, + "yi-large": 16385, "microsoft/wizardlm-2-8x22b": 65536, "meta-llama/llama-3-70b-instruct": 8192, "llama3-70b-8192": 8192, "openai/gpt-3.5-turbo-0125": 16385, "openai/gpt-4-turbo-preview": 128000, + "deepseek-chat": 32768, + "deepseek-coder": 16385, +} + +# For Amazon Bedrock US region +# See https://aws.amazon.com/cn/bedrock/pricing/ + +BEDROCK_TOKEN_COSTS = { + "amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004}, + "amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004}, + "anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125}, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075}, + "cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003}, + "cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003}, + "meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006}, + "meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035}, + "mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002}, + "mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007}, + "mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024}, + "ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188}, } @@ -224,6 +283,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-turbo", "gpt-4-vision-preview", "gpt-4-1106-vision-preview", + "gpt-4o-2024-05-13", + "gpt-4o", }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 diff --git a/requirements.txt b/requirements.txt index 6c219a9dc..82b478902 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 tenacity==8.2.3 -tiktoken==0.6.0 +tiktoken==0.7.0 tqdm==4.66.2 #unstructured[local-inference] # selenium>4 @@ -70,3 +70,4 @@ qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 +boto3==1.34.92 \ No newline at end of file diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 0c56cc8ea..8f2baea10 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model mock_llm_config_anthropic = LLMConfig( api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229" ) + +mock_llm_config_bedrock = LLMConfig( + api_type="bedrock", + model="gpt-100", + region_name="somewhere", + access_key="123abc", + secret_key="123abc", + max_token=10000, +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 7e4c1a49c..111b57f91 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ resp = await llm.acompletion_text(messages, stream=True) assert resp == resp_cont + + +# For Amazon Bedrock +# Check the API documentation of each model +# https://docs.aws.amazon.com/bedrock/latest/userguide +BEDROCK_PROVIDER_REQUEST_BODY = { + "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, + "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, + "ai21": { + "prompt": "", + "temperature": 0.0, + "topP": 0.0, + "maxTokens": 0, + "stopSequences": [], + "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, + "frequencyPenalty": {"scale": 0.0}, + }, + "cohere": { + "prompt": "", + "temperature": 0.0, + "p": 0.0, + "k": 0.0, + "max_tokens": 0, + "stop_sequences": [], + "return_likelihoods": "NONE", + "stream": False, + "num_generations": 0, + "logit_bias": {}, + "truncate": "NONE", + }, + "anthropic": { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 0, + "system": "", + "messages": [{"role": "", "content": ""}], + "temperature": 0.0, + "top_p": 0.0, + "top_k": 0, + "stop_sequences": [], + }, + "amazon": { + "inputText": "", + "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}, + }, +} + +BEDROCK_PROVIDER_RESPONSE_BODY = { + "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, + "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, + "ai21": { + "id": "", + "prompt": {"text": "Hello World", "tokens": []}, + "completions": [ + {"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}} + ], + }, + "cohere": { + "generations": [ + { + "finish_reason": "", + "id": "", + "text": "Hello World", + "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], + "is_finished": True, + "index": 0, + } + ], + "id": "", + "prompt": "", + }, + "anthropic": { + "id": "", + "model": "", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello World"}], + "stop_reason": "", + "stop_sequence": "", + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + "amazon": { + "inputTextTokenCount": 0, + "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}], + }, +} diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py new file mode 100644 index 000000000..4760a2db2 --- /dev/null +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -0,0 +1,109 @@ +import json + +import pytest + +from metagpt.provider.bedrock.utils import ( + NOT_SUUPORT_STREAM_MODELS, + SUPPORT_STREAM_MODELS, +) +from metagpt.provider.bedrock_api import BedrockLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock +from tests.metagpt.provider.req_resp_const import ( + BEDROCK_PROVIDER_REQUEST_BODY, + BEDROCK_PROVIDER_RESPONSE_BODY, +) + +# all available model from bedrock +models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +messages = [{"role": "user", "content": "Hi!"}] +usage = { + "prompt_tokens": 1000000, + "completion_tokens": 1000000, +} + + +def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: + provider = self.config.model.split(".")[0] + self._update_costs(usage, self.config.model) + return BEDROCK_PROVIDER_RESPONSE_BODY[provider] + + +def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: + # use json object to mock EventStream + def dict2bytes(x): + return json.dumps(x).encode("utf-8") + + provider = self.config.model.split(".")[0] + + if provider == "amazon": + response_body_bytes = dict2bytes({"outputText": "Hello World"}) + elif provider == "anthropic": + response_body_bytes = dict2bytes( + {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}} + ) + elif provider == "cohere": + response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"}) + else: + response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + + response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} + self._update_costs(usage, self.config.model) + return response_body_stream + + +def get_bedrock_request_body(model_id) -> dict: + provider = model_id.split(".")[0] + return BEDROCK_PROVIDER_REQUEST_BODY[provider] + + +def is_subset(subset, superset) -> bool: + """Ensure all fields in request body are allowed. + + ```python + subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}} + superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}} + is_subset(subset, superset) + ``` + >>>False + """ + for key, value in subset.items(): + if key not in superset: + return False + if isinstance(value, dict): + if not isinstance(superset[key], dict): + return False + if not is_subset(value, superset[key]): + return False + return True + + +@pytest.fixture(scope="class", params=models) +def bedrock_api(request) -> BedrockLLM: + model_id = request.param + mock_llm_config_bedrock.model = model_id + api = BedrockLLM(mock_llm_config_bedrock) + return api + + +class TestBedrockAPI: + def _patch_invoke_model(self, mocker): + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model) + + def _patch_invoke_model_stream(self, mocker): + mocker.patch( + "metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", + mock_invoke_model_stream, + ) + + def test_get_request_body(self, bedrock_api: BedrockLLM): + """Ensure request body has correct format""" + provider = bedrock_api.provider + request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs)) + assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) + + @pytest.mark.asyncio + async def test_aask(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) + self._patch_invoke_model_stream(mocker) + assert await bedrock_api.aask(messages, stream=False) == "Hello World" + assert await bedrock_api.aask(messages, stream=True) == "Hello World" diff --git a/tests/metagpt/roles/test_product_manager.py b/tests/metagpt/roles/test_product_manager.py index 59b5aa81a..143eef2f2 100644 --- a/tests/metagpt/roles/test_product_manager.py +++ b/tests/metagpt/roles/test_product_manager.py @@ -10,7 +10,6 @@ import json import pytest from metagpt.actions import WritePRD -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.const import REQUIREMENT_FILENAME from metagpt.context import Context from metagpt.logs import logger @@ -30,11 +29,7 @@ async def test_product_manager(new_filename): rsp = await product_manager.run(MockMessages.req) assert context.git_repo assert context.repo - assert rsp.cause_by == any_to_str(PrepareDocuments) assert REQUIREMENT_FILENAME in context.repo.docs.changed_files - - # write prd - rsp = await product_manager.run(rsp) assert rsp.cause_by == any_to_str(WritePRD) logger.info(rsp) assert len(rsp.content) > 0