refine code

This commit is contained in:
geekan 2024-01-08 17:14:12 +08:00
parent b16315f6c7
commit 244fa81ffe
7 changed files with 22 additions and 12 deletions

View file

@ -78,7 +78,7 @@ class FireworksLLM(OpenAILLM):
self.cost_manager = FireworksCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):

View file

@ -17,6 +17,8 @@ class YamlModel(BaseModel):
@classmethod
def read_yaml(cls, file_path: Path) -> Dict:
if not file_path.exists():
return {}
with open(file_path, "r") as file:
return yaml.safe_load(file)

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/8 17:03
@Author : alexanderwu
@File : mock_llm_config.py
"""
from metagpt.configs.llm_config import LLMConfig
mock_llm_config = LLMConfig(
llm_type="mock",
api_key="mock_api_key",
)

View file

@ -6,10 +6,8 @@
import pytest
from anthropic.resources.completions import Completion
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2
CONFIG.anthropic_api_key = "xxx"
from tests.metagpt.provider.mock_llm_config import mock_llm_config
prompt = "who are you"
resp = "I'am Claude2"
@ -25,10 +23,10 @@ async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2().ask(prompt)
assert resp == Claude2(mock_llm_config).ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2().aask(prompt)
assert resp == await Claude2(mock_llm_config).aask(prompt)

View file

@ -13,17 +13,13 @@ from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireworksLLM,
)
from metagpt.utils.cost_manager import Costs
CONFIG.fireworks_api_key = "xxx"
CONFIG.max_budget = 10
CONFIG.calc_usage = True
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
@ -92,7 +88,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM()
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(