mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
fix llm set
This commit is contained in:
parent
f59449d5d2
commit
6e44b2b515
8 changed files with 38 additions and 24 deletions
|
|
@ -13,7 +13,9 @@ from metagpt.roles import Role
|
|||
from metagpt.team import Team
|
||||
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action1.llm.model = "gpt-4-1106-preview"
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2.llm.model = "gpt-3.5-turbo-1106"
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from metagpt.utils.file_repository import FileRepository
|
|||
|
||||
|
||||
class Action(SerializationMixin, ContextMixin, BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str = ""
|
||||
i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ class ContextMixin(BaseModel):
|
|||
def set_config(self, config: Config, override=False):
|
||||
"""Set config"""
|
||||
self.set("private_config", config, override)
|
||||
if config is not None:
|
||||
_ = self.llm # init llm
|
||||
|
||||
def set_llm(self, llm: BaseLLM, override=False):
|
||||
"""Set llm"""
|
||||
|
|
|
|||
|
|
@ -220,10 +220,12 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs()
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
|
|
|
|||
|
|
@ -131,6 +131,13 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
|
||||
# scenarios to set action system_prompt:
|
||||
# 1. `__init__` while using Role(actions=[...])
|
||||
# 2. add action to role while using `role.set_action(action)`
|
||||
# 3. set_todo while using `role.set_todo(action)`
|
||||
# 4. when role.system_prompt is being updated (e.g. by `role.system_prompt = "..."`)
|
||||
# Additional, if llm is not set, we will use role's llm
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
addresses: set[str] = set()
|
||||
|
|
@ -222,6 +229,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_actions(self):
|
||||
"""Check actions and set llm and prefix for each action."""
|
||||
self.set_actions(self.actions)
|
||||
return self
|
||||
|
||||
def _init_action(self, action: Action):
|
||||
action.set_llm(self.llm, override=False)
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
|
@ -306,6 +319,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
if env:
|
||||
env.set_addresses(self, self.addresses)
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self.set_actions(self.actions) # reset actions to update llm and prefix
|
||||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
|
|
@ -318,7 +332,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
|
||||
|
||||
if self.rc.env and self.rc.env.desc:
|
||||
other_role_names = ", ".join(self.rc.env.role_names())
|
||||
all_roles = self.rc.env.role_names()
|
||||
other_role_names = ", ".join([r for r in all_roles if r != self.name])
|
||||
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
|
||||
prefix += env_desc
|
||||
return prefix
|
||||
|
|
@ -478,7 +493,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
if not msg.cause_by:
|
||||
msg.cause_by = UserRequirement
|
||||
self.put_message(msg)
|
||||
|
||||
if not await self._observe():
|
||||
# If there is no new information, suspend and wait
|
||||
logger.debug(f"{self._setting}: no news. waiting.")
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -23,14 +23,12 @@ from metagpt.team import Team
|
|||
async def test_debate_two_roles():
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
alex = Role(
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
team = Team(investment=10.0, env=env, roles=[alex, bob])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
|
@ -39,9 +37,9 @@ async def test_debate_two_roles():
|
|||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role_in_env():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
team = Team(investment=10.0, env=env, roles=[alex])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
|
@ -49,8 +47,8 @@ async def test_debate_one_role_in_env():
|
|||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await alex.run("Topic: climate change. Under 80 words per message.")
|
||||
|
||||
assert len(msg.content) > 10
|
||||
assert msg.sent_from == "metagpt.roles.role.Role"
|
||||
|
|
|
|||
|
|
@ -104,19 +104,12 @@ async def test_debate_two_roles():
|
|||
config.llm.model = "gpt-4-1106-preview"
|
||||
action1 = Action(config=config, name="AlexSay", instruction="Say your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Say your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
alex = Role(
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
print(action1.llm.system_prompt)
|
||||
print(action2.llm.system_prompt)
|
||||
print(biden.llm.system_prompt)
|
||||
print(trump.llm.system_prompt)
|
||||
team = Team(investment=10.0, env=env, roles=[alex, bob])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue