mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #1849 from GasolSun36/feature/bugfix-config-model
Use model instead of config.model inside an LLM instance
This commit is contained in:
commit
2f0c7fb14f
10 changed files with 40 additions and 33 deletions
|
|
@ -42,7 +42,9 @@ class BaseLLM(ABC):
|
|||
# OpenAI / Azure / Others
|
||||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
model: Optional[str] = None # deprecated
|
||||
# Maintain model name in own instance in case the global config has changed,
|
||||
# Should always use model not config.model within this class
|
||||
model: Optional[str] = None
|
||||
pricing_plan: Optional[str] = None
|
||||
|
||||
_reasoning_content: Optional[str] = None # content from reasoning mode
|
||||
|
|
@ -87,7 +89,7 @@ class BaseLLM(ABC):
|
|||
return {"role": "system", "content": msg}
|
||||
|
||||
def support_image_input(self) -> bool:
|
||||
return any([m in self.config.model for m in MULTI_MODAL_MODELS])
|
||||
return any([m in self.model for m in MULTI_MODAL_MODELS])
|
||||
|
||||
def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
|
|
@ -321,7 +323,7 @@ class BaseLLM(ABC):
|
|||
|
||||
def with_model(self, model: str):
|
||||
"""Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
|
||||
self.config.model = model
|
||||
self.model = model
|
||||
return self
|
||||
|
||||
def get_timeout(self, timeout: int) -> int:
|
||||
|
|
@ -352,7 +354,7 @@ class BaseLLM(ABC):
|
|||
if compress_type == CompressType.NO_COMPRESS:
|
||||
return messages
|
||||
|
||||
max_token = TOKEN_MAX.get(self.config.model, max_token)
|
||||
max_token = TOKEN_MAX.get(self.model, max_token)
|
||||
keep_token = int(max_token * threshold)
|
||||
compressed = []
|
||||
|
||||
|
|
|
|||
|
|
@ -22,13 +22,14 @@ from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
|
|||
class BedrockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.model = config.model
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(
|
||||
self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
|
||||
self.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
|
||||
)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
if self.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.model} doesn't support streaming output!")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
"""initialize boto3 client"""
|
||||
|
|
@ -72,25 +73,25 @@ class BedrockLLM(BaseLLM):
|
|||
async def invoke_model(self, request_body: str) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
|
||||
None, partial(self.client.invoke_model, modelId=self.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
self._update_costs(usage, self.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
|
||||
None, partial(self.client.invoke_model_with_response_stream, modelId=self.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
self._update_costs(usage, self.model)
|
||||
return response
|
||||
|
||||
@property
|
||||
def _const_kwargs(self) -> dict:
|
||||
model_max_tokens = get_max_tokens(self.config.model)
|
||||
model_max_tokens = get_max_tokens(self.model)
|
||||
if self.config.max_token > model_max_tokens:
|
||||
max_tokens = model_max_tokens
|
||||
else:
|
||||
|
|
@ -119,7 +120,7 @@ class BedrockLLM(BaseLLM):
|
|||
return await self.acompletion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
if self.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
rsp = await self.acompletion(messages)
|
||||
full_text = self.get_choice_text(rsp)
|
||||
log_llm_stream(full_text)
|
||||
|
|
@ -132,7 +133,7 @@ class BedrockLLM(BaseLLM):
|
|||
full_text = ("".join(collected_content)).lstrip()
|
||||
if self.__provider.usage:
|
||||
# if provider provide usage, update it
|
||||
self._update_costs(self.__provider.usage, self.config.model)
|
||||
self._update_costs(self.__provider.usage, self.model)
|
||||
return full_text
|
||||
|
||||
def _get_response_body(self, response) -> dict:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class HumanProvider(BaseLLM):
|
|||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.model = config.model
|
||||
|
||||
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
|
|
|
|||
|
|
@ -195,6 +195,7 @@ class OllamaLLM(BaseLLM):
|
|||
def __init__(self, config: LLMConfig):
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url, key=config.api_key)
|
||||
self.config = config
|
||||
self.model = config.model
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self.cost_manager = TokenCostManager()
|
||||
|
|
|
|||
|
|
@ -321,6 +321,6 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def count_tokens(self, messages: list[dict]) -> int:
|
||||
try:
|
||||
return count_message_tokens(messages, self.config.model)
|
||||
return count_message_tokens(messages, self.model)
|
||||
except:
|
||||
return super().count_tokens(messages)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class QianFanLLM(BaseLLM):
|
|||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
self.model = self.config.model
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
|
|
@ -61,14 +62,14 @@ class QianFanLLM(BaseLLM):
|
|||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
if self.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
assert not (self.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
|
||||
|
|
@ -87,8 +88,8 @@ class QianFanLLM(BaseLLM):
|
|||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
elif self.model:
|
||||
kwargs["model"] = self.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
|
|
@ -99,7 +100,7 @@ class QianFanLLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model or self.config.endpoint
|
||||
model_or_endpoint = self.model or self.config.endpoint
|
||||
local_calc_usage = model_or_endpoint in self.token_costs
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ mock_llm_config = LLMConfig(
|
|||
app_id="mock_app_id",
|
||||
api_secret="mock_api_secret",
|
||||
domain="mock_domain",
|
||||
model="mock_model",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ def test_format_msg(mocker):
|
|||
|
||||
def test_format_msg_w_images(mocker):
|
||||
base_llm = MockBaseLLM()
|
||||
base_llm.config.model = "gpt-4o"
|
||||
base_llm.model = "gpt-4o"
|
||||
msg_w_images = UserMessage(content="req1")
|
||||
msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"])
|
||||
msg_w_empty_images = UserMessage(content="req2")
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ def deal_special_provider(provider: str, model: str, stream: bool = False) -> st
|
|||
|
||||
|
||||
async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = get_provider_name(self.config.model)
|
||||
self._update_costs(usage, self.config.model)
|
||||
provider = deal_special_provider(provider, self.config.model)
|
||||
provider = get_provider_name(self.model)
|
||||
self._update_costs(usage, self.model)
|
||||
provider = deal_special_provider(provider, self.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
|||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
||||
provider = get_provider_name(self.config.model)
|
||||
provider = get_provider_name(self.model)
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
|
|
@ -68,11 +68,11 @@ async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
|||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
provider = deal_special_provider(provider, self.config.model, stream=True)
|
||||
provider = deal_special_provider(provider, self.model, stream=True)
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
self._update_costs(usage, self.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ async def test_openai_acompletion(mocker):
|
|||
|
||||
def test_count_tokens():
|
||||
llm = LLM()
|
||||
llm.config.model = "gpt-4o"
|
||||
llm.model = "gpt-4o"
|
||||
messages = [
|
||||
llm._system_msg("some system msg"),
|
||||
llm._system_msg("some system message 2"),
|
||||
|
|
@ -184,7 +184,7 @@ def test_count_tokens():
|
|||
|
||||
def test_count_tokens_long():
|
||||
llm = LLM()
|
||||
llm.config.model = "gpt-4-0613"
|
||||
llm.model = "gpt-4-0613"
|
||||
test_msg_content = " ".join([str(i) for i in range(100000)])
|
||||
messages = [
|
||||
llm._system_msg("You are a helpful assistant"),
|
||||
|
|
@ -193,7 +193,7 @@ def test_count_tokens_long():
|
|||
cnt = llm.count_tokens(messages) # 299023, ~300k
|
||||
assert 290000 <= cnt <= 300000
|
||||
|
||||
llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
|
||||
llm.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
|
||||
cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference
|
||||
assert 290000 <= cnt <= 300000
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ def test_count_tokens_long():
|
|||
@pytest.mark.asyncio
|
||||
async def test_aask_long():
|
||||
llm = LLM()
|
||||
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
llm.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN
|
||||
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
|
||||
messages = [
|
||||
|
|
@ -216,7 +216,7 @@ async def test_aask_long():
|
|||
@pytest.mark.asyncio
|
||||
async def test_aask_long_no_compress():
|
||||
llm = LLM()
|
||||
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
llm.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
# Not specifying llm.config.compress_type will use default "", no compress
|
||||
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
|
||||
messages = [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue