diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 2a04116f5..170005c21 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,3 +100,6 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT + + def get(self, key: str, default = None): + return getattr(self, key, default) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index ecdee4154..92b137a97 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -9,12 +9,15 @@ from metagpt.logs import log_llm_stream, logger from botocore.config import Config import boto3 +from metagpt.provider.bedrock.bedrock_provider import get_provider + @register_provider([LLMType.AMAZON_BEDROCK]) class AmazonBedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") + self.provider = get_provider(self.config.model) def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): # access key from https://us-east-1.console.aws.amazon.com/iam @@ -28,7 +31,6 @@ class AmazonBedrockLLM(BaseLLM): return client def list_models(self): - """see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html""" client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality='TEXT') @@ -36,14 +38,29 @@ class AmazonBedrockLLM(BaseLLM): for summary in response.get("modelSummaries", {})] logger.info("\n"+"\n".join(summaries)) + @property + def _generate_kwargs(self): + return { + "max_token": self.config.get("max_token", 1024), + "temperature": self.config.get("temperature", 0.3), + "top_p": self.config.get("top_p", 0.95), + "top_k": self.config.get("top_k", 1), + } + def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): pass def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): pass - def completion(self, messages): - pass + def completion(self, messages: list[dict]): + request_body = self.provider.get_request_body( + messages, **self._generate_kwargs) + response = self.__client.invoke_model( + modelId=self.config.model, body=request_body + ) + completions = self.provider.get_choice_text(response) + return completions def acompletion(self, messages: list[dict]): pass @@ -54,4 +71,4 @@ if __name__ == '__main__': prompt = "who are you?" messages = [{"role": "user", "content": prompt}] llm = AmazonBedrockLLM(my_config) - llm.list_models() + print(llm.completion(messages)) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 086569736..46f6ea58c 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,19 +1,14 @@ - import json + class BaseBedrockProvider(object): # to handle different generation kwargs - max_length = "max_tokens" - temperature = "temperature" - top_p = "top-p" - top_k = "top-k" - - def get_request_body(self, prompt, generate_kwargs: dict): - return {"prompt": prompt} | generate_kwargs + def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): + return json.dumps({"prompt": self.messages_to_prompt(messages)}) def get_choice_text(self, response) -> str: response_body = json.loads(response["body"].read()) - completions = response_body["content"]["outputs"][0]['text'] + completions = response_body["outputs"][0]['text'] return completions def messages_to_prompt(self, messages: list[dict]): diff --git a/metagpt/provider/bedrock/bedrock_provide.py b/metagpt/provider/bedrock/bedrock_provide.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py new file mode 100644 index 000000000..10ab66b34 --- /dev/null +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -0,0 +1,73 @@ +from metagpt.provider.bedrock.base_provider import BaseBedrockProvider +import json + + +class MistralProvider(BaseBedrockProvider): + + def format_prompt(self, prompt: str) -> str: + # for mixtral and llama + return f"[INST]{prompt}[/INST]" + + def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): + return json.dumps({ + "prompt": self.format_prompt(self.messages_to_prompt(messages)), + "max_tokens": max_token, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, }) + + +PROVIDERS = { + "mistral": MistralProvider() +} + +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct", + "ai21.j2-jumbo-instruct", + "ai21.j2-mid", + "ai21.j2-mid-v1", + "ai21.j2-ultra", + "ai21.j2-ultra-v1", +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large", + "amazon.titan-text-lite-v1:0:4k", + "amazon.titan-text-lite-v1", + "amazon.titan-text-express-v1:0:8k", + "amazon.titan-text-express-v1", + "anthropic.claude-instant-v1:2:100k", + "anthropic.claude-instant-v1", + "anthropic.claude-v2:0:18k", + "anthropic.claude-v2:0:100k", + "anthropic.claude-v2:1:18k", + "anthropic.claude-v2:1:200k", + "anthropic.claude-v2:1", + "anthropic.claude-v2:2:18k", + "anthropic.claude-v2:2:200k", + "anthropic.claude-v2:2", + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0:28k", + "anthropic.claude-3-sonnet-20240229-v1:0:200k", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0:48k", + "anthropic.claude-3-haiku-20240307-v1:0:200k", + "anthropic.claude-3-haiku-20240307-v1:0", + "cohere.command-text-v14:7:4k", + "cohere.command-text-v14", + "cohere.command-light-text-v14:7:4k", + "cohere.command-light-text-v14", + "meta.llama2-70b-v1", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", +} + + +def get_provider(model_id: str): + model_name = model_id.split(".")[0] # meta、mistral…… + if model_name not in PROVIDERS: + raise KeyError(f"{model_name} is not supported!") + return PROVIDERS[model_name]