diff --git a/config/examples/spark_lite.yaml b/config/examples/spark_lite.yaml new file mode 100644 index 000000000..5164e73e5 --- /dev/null +++ b/config/examples/spark_lite.yaml @@ -0,0 +1,10 @@ +# 适用于讯飞星火的spark-lite 参考 https://www.xfyun.cn/doc/spark/Web.html#_2-function-call%E8%AF%B4%E6%98%8E + +llm: + api_type: 'spark' + # 对应模型的url 参考 https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E + base_url: "ws(s)://spark-api.xf-yun.com/v1.1/chat" + app_id: "" + api_key: "" + api_secret: "" + domain: "general" # 取值为 [general,generalv2,generalv3,generalv3.5] 和url一一对应 diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 36b7efedd..93fc533ea 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -10,7 +10,10 @@ from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -# from sparkai.schema import LLMResult, HumanMessage, AIMessage 由于其使用Pydantic V1,导入会报错 +# from sparkai.schema import LLMResult, HumanMessage, AIMessage # 由于其使用Pydantic V1,导入会报错 +from metagpt.utils.common import any_to_str +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import SPARK_TOKENS @register_provider(LLMType.SPARK) @@ -21,6 +24,8 @@ class SparkLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config + self.cost_manager = CostManager(token_costs=SPARK_TOKENS) + self.model = self.config.domain self._init_client() def _init_client(self): @@ -60,7 +65,7 @@ class SparkLLM(BaseLLM): return response async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - return self._achat_completion(messages, timeout) + return await self._achat_completion(messages, timeout) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: response = self.client.astream(messages) @@ -76,3 +81,6 @@ class SparkLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content + + def _extract_assistant_rsp(self, context): + return "\n".join([i.content for i in context if "AIMessage" in any_to_str(i)]) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a8652e607..549776999 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -258,6 +258,14 @@ BEDROCK_TOKEN_COSTS = { "ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188}, } +# https://xinghuo.xfyun.cn/sparkapi?scr=price +SPARK_TOKENS = { + "general": {"prompt": 0.0, "completion": 0.0}, # Spark-Lite + "generalV2": {"prompt": 0.0188, "completion": 0.0188}, # Spark V2.0 + "generalV3": {"prompt": 0.0035, "completion": 0.0035}, # Spark Pro + "generalV3.5": {"prompt": 0.0035, "completion": 0.0035}, # Spark3.5 Max +} + def count_input_tokens(messages, model="gpt-3.5-turbo-0125"): """Return the number of tokens used by a list of messages.""" diff --git a/requirements.txt b/requirements.txt index 0d006d0b9..f8d3ec3b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,3 +71,4 @@ dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 boto3~=1.34.69 +spark_ai_python~=0.3.30 \ No newline at end of file