diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index df89d36aa..d598b678e 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -21,11 +21,15 @@ class LLMProviderRegistry: return self.providers[enum] -def register_provider(key): +def register_provider(keys): """register provider to registry""" def decorator(cls): - LLM_REGISTRY.register(key, cls) + if isinstance(keys,list): + for key in keys: + LLM_REGISTRY.register(key, cls) + else: + LLM_REGISTRY.register(keys, cls) return cls return decorator diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 36d6f6d77..5ed7168e3 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -50,7 +50,7 @@ See FAQ 5.8 raise retry_state.outcome.exception() -@register_provider(LLMType.OPENAI) +@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT]) class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" @@ -84,14 +84,33 @@ class OpenAILLM(BaseLLM): return params - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]: + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=timeout), stream=True ) - + usage = None + collected_messages = [] async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message - yield chunk_message + finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None + log_llm_stream(chunk_message) + collected_messages.append(chunk_message) + if finish_reason: + if hasattr(chunk, "usage"): + # Some services have usage as an attribute of the chunk, such as Fireworks + usage = CompletionUsage(**chunk.usage) + elif hasattr(chunk.choices[0], "usage"): + # The usage of some services is an attribute of chunk.choices[0], such as Moonshot + usage = CompletionUsage(**chunk.choices[0].usage) + + log_llm_stream("\n") + full_reply_content = "".join(collected_messages) + if not usage: + # Some services do not provide the usage attribute, such as OpenAI or OpenLLM + usage = self._calc_usage(messages, full_reply_content) + + self._update_costs(usage) + return full_reply_content def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: kwargs = { @@ -99,7 +118,7 @@ class OpenAILLM(BaseLLM): "max_tokens": self._get_max_tokens(messages), "n": 1, # "stop": None, # default it's None and gpt4-v can't have this one - "temperature": self.config.temperature, + "temperature": 0.3, "model": self.model, "timeout": max(self.config.timeout, timeout), } @@ -126,18 +145,7 @@ class OpenAILLM(BaseLLM): async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: - resp = self._achat_completion_stream(messages, timeout=timeout) - - collected_messages = [] - async for i in resp: - log_llm_stream(i) - collected_messages.append(i) - log_llm_stream("\n") - - full_reply_content = "".join(collected_messages) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content + await self._achat_completion_stream(messages, timeout=timeout) rsp = await self._achat_completion(messages, timeout=timeout) return self.get_choice_text(rsp) @@ -261,11 +269,12 @@ class OpenAILLM(BaseLLM): if not self.config.calc_usage: return usage + model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model" try: - usage.prompt_tokens = count_message_tokens(messages, self.model) - usage.completion_tokens = count_string_tokens(rsp, self.model) + usage.prompt_tokens = count_message_tokens(messages, model) + usage.completion_tokens = count_string_tokens(rsp, model) except Exception as e: - logger.warning(f"usage calculation failed: {e}") + logger.error(f"usage calculation failed: {e}") return usage