mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-11 16:52:37 +02:00
support mistral
This commit is contained in:
parent
4f14ee7ce1
commit
ec7df8acdf
5 changed files with 101 additions and 13 deletions
|
|
@ -9,12 +9,15 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from botocore.config import Config
|
||||
import boto3
|
||||
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
|
||||
|
||||
@register_provider([LLMType.AMAZON_BEDROCK])
|
||||
class AmazonBedrockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.provider = get_provider(self.config.model)
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
# access key from https://us-east-1.console.aws.amazon.com/iam
|
||||
|
|
@ -28,7 +31,6 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
return client
|
||||
|
||||
def list_models(self):
|
||||
"""see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html"""
|
||||
client = self.__init_client("bedrock")
|
||||
# only output text-generation models
|
||||
response = client.list_foundation_models(byOutputModality='TEXT')
|
||||
|
|
@ -36,14 +38,29 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
for summary in response.get("modelSummaries", {})]
|
||||
logger.info("\n"+"\n".join(summaries))
|
||||
|
||||
@property
|
||||
def _generate_kwargs(self):
|
||||
return {
|
||||
"max_token": self.config.get("max_token", 1024),
|
||||
"temperature": self.config.get("temperature", 0.3),
|
||||
"top_p": self.config.get("top_p", 0.95),
|
||||
"top_k": self.config.get("top_k", 1),
|
||||
}
|
||||
|
||||
def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
|
||||
def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
|
||||
def completion(self, messages):
|
||||
pass
|
||||
def completion(self, messages: list[dict]):
|
||||
request_body = self.provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
response = self.__client.invoke_model(
|
||||
modelId=self.config.model, body=request_body
|
||||
)
|
||||
completions = self.provider.get_choice_text(response)
|
||||
return completions
|
||||
|
||||
def acompletion(self, messages: list[dict]):
|
||||
pass
|
||||
|
|
@ -54,4 +71,4 @@ if __name__ == '__main__':
|
|||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
llm = AmazonBedrockLLM(my_config)
|
||||
llm.list_models()
|
||||
print(llm.completion(messages))
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
|
||||
import json
|
||||
|
||||
|
||||
class BaseBedrockProvider(object):
|
||||
# to handle different generation kwargs
|
||||
max_length = "max_tokens"
|
||||
temperature = "temperature"
|
||||
top_p = "top-p"
|
||||
top_k = "top-k"
|
||||
|
||||
def get_request_body(self, prompt, generate_kwargs: dict):
|
||||
return {"prompt": prompt} | generate_kwargs
|
||||
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
|
||||
return json.dumps({"prompt": self.messages_to_prompt(messages)})
|
||||
|
||||
def get_choice_text(self, response) -> str:
|
||||
response_body = json.loads(response["body"].read())
|
||||
completions = response_body["content"]["outputs"][0]['text']
|
||||
completions = response_body["outputs"][0]['text']
|
||||
return completions
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
|
|
|
|||
73
metagpt/provider/bedrock/bedrock_provider.py
Normal file
73
metagpt/provider/bedrock/bedrock_provider.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
import json
|
||||
|
||||
|
||||
class MistralProvider(BaseBedrockProvider):
|
||||
|
||||
def format_prompt(self, prompt: str) -> str:
|
||||
# for mixtral and llama
|
||||
return f"<s>[INST]{prompt}[/INST]"
|
||||
|
||||
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
|
||||
return json.dumps({
|
||||
"prompt": self.format_prompt(self.messages_to_prompt(messages)),
|
||||
"max_tokens": max_token,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k, })
|
||||
|
||||
|
||||
PROVIDERS = {
|
||||
"mistral": MistralProvider()
|
||||
}
|
||||
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct",
|
||||
"ai21.j2-jumbo-instruct",
|
||||
"ai21.j2-mid",
|
||||
"ai21.j2-mid-v1",
|
||||
"ai21.j2-ultra",
|
||||
"ai21.j2-ultra-v1",
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large",
|
||||
"amazon.titan-text-lite-v1:0:4k",
|
||||
"amazon.titan-text-lite-v1",
|
||||
"amazon.titan-text-express-v1:0:8k",
|
||||
"amazon.titan-text-express-v1",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:0:18k",
|
||||
"anthropic.claude-v2:0:100k",
|
||||
"anthropic.claude-v2:1:18k",
|
||||
"anthropic.claude-v2:1:200k",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v2:2:18k",
|
||||
"anthropic.claude-v2:2:200k",
|
||||
"anthropic.claude-v2:2",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"cohere.command-text-v14:7:4k",
|
||||
"cohere.command-text-v14",
|
||||
"cohere.command-light-text-v14:7:4k",
|
||||
"cohere.command-light-text-v14",
|
||||
"meta.llama2-70b-v1",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
}
|
||||
|
||||
|
||||
def get_provider(model_id: str):
|
||||
model_name = model_id.split(".")[0] # meta、mistral……
|
||||
if model_name not in PROVIDERS:
|
||||
raise KeyError(f"{model_name} is not supported!")
|
||||
return PROVIDERS[model_name]
|
||||
Loading…
Add table
Add a link
Reference in a new issue