mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
add type hint
This commit is contained in:
parent
bb8ea2eaf9
commit
f304fb3871
3 changed files with 54 additions and 14 deletions
|
|
@ -2,6 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from sparkai.core.messages import _convert_to_message, convert_to_messages
|
||||
from sparkai.core.messages.ai import AIMessage
|
||||
from sparkai.core.messages.base import BaseMessage
|
||||
from sparkai.core.messages.human import HumanMessage
|
||||
from sparkai.core.messages.system import SystemMessage
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
from sparkai.llm.llm import ChatSparkLLM
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -9,8 +14,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT
|
|||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import SPARK_TOKENS
|
||||
|
|
@ -38,19 +41,19 @@ class SparkLLM(BaseLLM):
|
|||
streaming=True,
|
||||
)
|
||||
|
||||
def _system_msg(self, msg: str):
|
||||
def _system_msg(self, msg: str) -> SystemMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _user_msg(self, msg: str, **kwargs):
|
||||
def _user_msg(self, msg: str, **kwargs) -> HumanMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def _assistant_msg(self, msg: str):
|
||||
def _assistant_msg(self, msg: str) -> AIMessage:
|
||||
return _convert_to_message(msg)
|
||||
|
||||
def get_choice_text(self, rsp) -> str:
|
||||
def get_choice_text(self, rsp: LLMResult) -> str:
|
||||
return rsp.generations[0][0].text
|
||||
|
||||
def get_usage(self, response):
|
||||
def get_usage(self, response: LLMResult):
|
||||
message = response.generations[0][0].message
|
||||
if hasattr(message, "additional_kwargs"):
|
||||
return message.additional_kwargs.get("token_usage", {})
|
||||
|
|
@ -59,7 +62,7 @@ class SparkLLM(BaseLLM):
|
|||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
messages = convert_to_messages(messages)
|
||||
response = await self.client.agenerate([messages])
|
||||
response = await self.acreate(messages, stream=False)
|
||||
usage = self.get_usage(response)
|
||||
self._update_costs(usage)
|
||||
return response
|
||||
|
|
@ -68,7 +71,7 @@ class SparkLLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
response = self.client.astream(messages)
|
||||
response = await self.acreate(messages, stream=True)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in response:
|
||||
|
|
@ -82,5 +85,11 @@ class SparkLLM(BaseLLM):
|
|||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
def _extract_assistant_rsp(self, context: list[BaseMessage]) -> str:
|
||||
return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)])
|
||||
|
||||
async def acreate(self, messages: list[dict], stream: bool = True):
|
||||
if stream:
|
||||
return self.client.astream(messages)
|
||||
else:
|
||||
return await self.client.agenerate([messages])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue