From 4f14ee7ce143125e3190370d6ef997982d6bfdbc Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 12:30:20 +0800 Subject: [PATCH] implement base provider --- metagpt/provider/bedrock/base_provider.py | 30 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index eaedfe045..086569736 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -1,3 +1,27 @@ -from abc import ABC -class BaseBedrockProvider(ABC): - pass \ No newline at end of file + +import json + +class BaseBedrockProvider(object): + # to handle different generation kwargs + max_length = "max_tokens" + temperature = "temperature" + top_p = "top-p" + top_k = "top-k" + + def get_request_body(self, prompt, generate_kwargs: dict): + return {"prompt": prompt} | generate_kwargs + + def get_choice_text(self, response) -> str: + response_body = json.loads(response["body"].read()) + completions = response_body["content"]["outputs"][0]['text'] + return completions + + def messages_to_prompt(self, messages: list[dict]): + """[{"role": "user", "content": msg}] to user: etc.""" + return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + + def format_prompt(self, prompt: str) -> str: + return prompt + + def format_messages(self, messages: list[dict]) -> list[dict]: + return messages