support mistral

This commit is contained in:
usamimeri_renko 2024-04-25 13:47:31 +08:00
parent 4f14ee7ce1
commit ec7df8acdf
5 changed files with 101 additions and 13 deletions

View file

@ -9,12 +9,15 @@ from metagpt.logs import log_llm_stream, logger
from botocore.config import Config
import boto3
from metagpt.provider.bedrock.bedrock_provider import get_provider
@register_provider([LLMType.AMAZON_BEDROCK])
class AmazonBedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")
self.provider = get_provider(self.config.model)
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
# access key from https://us-east-1.console.aws.amazon.com/iam
@ -28,7 +31,6 @@ class AmazonBedrockLLM(BaseLLM):
return client
def list_models(self):
"""see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html"""
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality='TEXT')
@ -36,14 +38,29 @@ class AmazonBedrockLLM(BaseLLM):
for summary in response.get("modelSummaries", {})]
logger.info("\n"+"\n".join(summaries))
@property
def _generate_kwargs(self):
return {
"max_token": self.config.get("max_token", 1024),
"temperature": self.config.get("temperature", 0.3),
"top_p": self.config.get("top_p", 0.95),
"top_k": self.config.get("top_k", 1),
}
def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
def completion(self, messages):
pass
def completion(self, messages: list[dict]):
request_body = self.provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
completions = self.provider.get_choice_text(response)
return completions
def acompletion(self, messages: list[dict]):
pass
@ -54,4 +71,4 @@ if __name__ == '__main__':
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
llm = AmazonBedrockLLM(my_config)
llm.list_models()
print(llm.completion(messages))

View file

@ -1,19 +1,14 @@
import json
class BaseBedrockProvider(object):
# to handle different generation kwargs
max_length = "max_tokens"
temperature = "temperature"
top_p = "top-p"
top_k = "top-k"
def get_request_body(self, prompt, generate_kwargs: dict):
return {"prompt": prompt} | generate_kwargs
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
return json.dumps({"prompt": self.messages_to_prompt(messages)})
def get_choice_text(self, response) -> str:
response_body = json.loads(response["body"].read())
completions = response_body["content"]["outputs"][0]['text']
completions = response_body["outputs"][0]['text']
return completions
def messages_to_prompt(self, messages: list[dict]):

View file

@ -0,0 +1,73 @@
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
import json
class MistralProvider(BaseBedrockProvider):
def format_prompt(self, prompt: str) -> str:
# for mixtral and llama
return f"<s>[INST]{prompt}[/INST]"
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
return json.dumps({
"prompt": self.format_prompt(self.messages_to_prompt(messages)),
"max_tokens": max_token,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k, })
PROVIDERS = {
"mistral": MistralProvider()
}
NOT_SUUPORT_STREAM_MODELS = {
"ai21.j2-grande-instruct",
"ai21.j2-jumbo-instruct",
"ai21.j2-mid",
"ai21.j2-mid-v1",
"ai21.j2-ultra",
"ai21.j2-ultra-v1",
}
SUPPORT_STREAM_MODELS = {
"amazon.titan-tg1-large",
"amazon.titan-text-lite-v1:0:4k",
"amazon.titan-text-lite-v1",
"amazon.titan-text-express-v1:0:8k",
"amazon.titan-text-express-v1",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-instant-v1",
"anthropic.claude-v2:0:18k",
"anthropic.claude-v2:0:100k",
"anthropic.claude-v2:1:18k",
"anthropic.claude-v2:1:200k",
"anthropic.claude-v2:1",
"anthropic.claude-v2:2:18k",
"anthropic.claude-v2:2:200k",
"anthropic.claude-v2:2",
"anthropic.claude-v2",
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-3-sonnet-20240229-v1:0:200k",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0:48k",
"anthropic.claude-3-haiku-20240307-v1:0:200k",
"anthropic.claude-3-haiku-20240307-v1:0",
"cohere.command-text-v14:7:4k",
"cohere.command-text-v14",
"cohere.command-light-text-v14:7:4k",
"cohere.command-light-text-v14",
"meta.llama2-70b-v1",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
}
def get_provider(model_id: str):
model_name = model_id.split(".")[0] # meta、mistral……
if model_name not in PROVIDERS:
raise KeyError(f"{model_name} is not supported!")
return PROVIDERS[model_name]