From 4d1fb207855b446e8b22691711dab2c52691da5a Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 14:31:22 +0800 Subject: [PATCH] add generate_kwargs --- .../provider/bedrock/amazon_bedrock_api.py | 1 - metagpt/provider/bedrock/base_provider.py | 4 +-- metagpt/provider/bedrock/bedrock_provider.py | 25 +++++++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 92b137a97..11a2bb3f5 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -41,7 +41,6 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self): return { - "max_token": self.config.get("max_token", 1024), "temperature": self.config.get("temperature", 0.3), "top_p": self.config.get("top_p", 0.95), "top_k": self.config.get("top_k", 1), diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 46f6ea58c..9a6f9659c 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -3,8 +3,8 @@ import json class BaseBedrockProvider(object): # to handle different generation kwargs - def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): - return json.dumps({"prompt": self.messages_to_prompt(messages)}) + def get_request_body(self, messages, **generate_kwargs): + return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) def get_choice_text(self, response) -> str: response_body = json.loads(response["body"].read()) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 10ab66b34..3ae84d8c3 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -8,13 +8,24 @@ class MistralProvider(BaseBedrockProvider): # for mixtral and llama return f"[INST]{prompt}[/INST]" - def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs): - return json.dumps({ - "prompt": self.format_prompt(self.messages_to_prompt(messages)), - "max_tokens": max_token, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, }) + def get_request_body(self, messages, **generate_kwargs): + return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs) + + +class AnthropicProvider(BaseBedrockProvider): + pass + + +class CohereProvider(BaseBedrockProvider): + pass + + +class MetaProvider(BaseBedrockProvider): + pass + + +class Ai21Provider(BaseBedrockProvider): + pass PROVIDERS = {