mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
support spark
This commit is contained in:
parent
b9b268ad8b
commit
f59449d5d2
5 changed files with 50 additions and 8 deletions
|
|
@ -5,13 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : llm.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
def LLM() -> BaseLLM:
|
||||
def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM:
|
||||
"""get the default llm provider if name is None"""
|
||||
# context.use_llm(name=name, provider=provider)
|
||||
if llm_config is not None:
|
||||
CONTEXT.llm_with_cost_manager_from_llm_config(llm_config)
|
||||
return CONTEXT.llm()
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class Engineer(Role):
|
|||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
|
||||
self._init_action_system_message(action)
|
||||
self._init_action(action)
|
||||
coding_context = await action.run()
|
||||
await src_file_repo.save(
|
||||
coding_context.filename,
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
super().__init__(**data)
|
||||
|
||||
if self.is_human:
|
||||
self.llm = HumanProvider()
|
||||
self.llm = HumanProvider(None)
|
||||
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
|
@ -222,7 +222,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def _init_action_system_message(self, action: Action):
|
||||
def _init_action(self, action: Action):
|
||||
action.set_llm(self.llm, override=False)
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def set_action(self, action: Action):
|
||||
|
|
@ -238,7 +239,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self._reset()
|
||||
for action in actions:
|
||||
if not isinstance(action, Action):
|
||||
i = action(name="", llm=self.llm)
|
||||
i = action()
|
||||
else:
|
||||
if self.is_human and not isinstance(action.llm, HumanProvider):
|
||||
logger.warning(
|
||||
|
|
@ -247,7 +248,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
f"try passing in Action classes instead of initialized instances"
|
||||
)
|
||||
i = action
|
||||
self._init_action_system_message(i)
|
||||
self._init_action(i)
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
|
|
@ -33,6 +35,14 @@ def mock_spark_get_msg_from_web_run(self) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask():
|
||||
llm = SparkLLM(Config.model_validate_yaml(Path.home() / ".metagpt" / "spark.yaml").llm)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
print(resp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_context_mixin.py
|
||||
"""
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.context_mixin import ContextMixin
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import Role
|
||||
from metagpt.team import Team
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
|
|
@ -91,3 +96,27 @@ def test_config_mixin_4_multi_inheritance_override_config():
|
|||
print(obj.__dict__.keys())
|
||||
assert "private_config" in obj.__dict__.keys()
|
||||
assert obj.llm.model == "mock_zhipu_model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_two_roles():
|
||||
config = Config.default()
|
||||
config.llm.model = "gpt-4-1106-preview"
|
||||
action1 = Action(config=config, name="AlexSay", instruction="Say your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Say your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
print(action1.llm.system_prompt)
|
||||
print(action2.llm.system_prompt)
|
||||
print(biden.llm.system_prompt)
|
||||
print(trump.llm.system_prompt)
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue