2024-01-04 21:16:23 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
@Time : 2024/1/4 16:32
|
|
|
|
|
@Author : alexanderwu
|
|
|
|
|
@File : context.py
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
2024-03-25 21:04:16 +08:00
|
|
|
from typing import Any, Dict, Optional
|
2024-01-04 21:16:23 +08:00
|
|
|
|
2024-01-09 17:39:09 +08:00
|
|
|
from pydantic import BaseModel, ConfigDict
|
2024-01-09 14:16:32 +08:00
|
|
|
|
2024-01-04 21:16:23 +08:00
|
|
|
from metagpt.config2 import Config
|
2024-02-29 10:14:15 +08:00
|
|
|
from metagpt.configs.llm_config import LLMConfig, LLMType
|
2024-01-04 21:16:23 +08:00
|
|
|
from metagpt.provider.base_llm import BaseLLM
|
2024-01-09 15:56:40 +08:00
|
|
|
from metagpt.provider.llm_provider_registry import create_llm_instance
|
2024-02-29 10:14:15 +08:00
|
|
|
from metagpt.utils.cost_manager import (
|
|
|
|
|
CostManager,
|
|
|
|
|
FireworksCostManager,
|
|
|
|
|
TokenCostManager,
|
|
|
|
|
)
|
2024-01-04 21:16:23 +08:00
|
|
|
from metagpt.utils.git_repository import GitRepository
|
2024-01-15 16:37:42 +08:00
|
|
|
from metagpt.utils.project_repo import ProjectRepo
|
2024-01-04 21:16:23 +08:00
|
|
|
|
|
|
|
|
|
2024-01-09 14:16:32 +08:00
|
|
|
class AttrDict(BaseModel):
|
|
|
|
|
"""A dict-like object that allows access to keys as attributes, compatible with Pydantic."""
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(extra="allow")
|
2024-01-08 18:30:04 +08:00
|
|
|
|
2024-01-09 14:16:32 +08:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.__dict__.update(kwargs)
|
2024-01-08 16:26:52 +08:00
|
|
|
|
|
|
|
|
def __getattr__(self, key):
|
2024-01-09 14:16:32 +08:00
|
|
|
return self.__dict__.get(key, None)
|
2024-01-08 16:26:52 +08:00
|
|
|
|
|
|
|
|
def __setattr__(self, key, value):
|
2024-01-09 14:16:32 +08:00
|
|
|
self.__dict__[key] = value
|
2024-01-08 16:26:52 +08:00
|
|
|
|
|
|
|
|
def __delattr__(self, key):
|
2024-01-09 14:16:32 +08:00
|
|
|
if key in self.__dict__:
|
|
|
|
|
del self.__dict__[key]
|
2024-01-08 16:26:52 +08:00
|
|
|
else:
|
|
|
|
|
raise AttributeError(f"No such attribute: {key}")
|
|
|
|
|
|
2024-01-12 15:27:07 +08:00
|
|
|
def set(self, key, val: Any):
|
|
|
|
|
self.__dict__[key] = val
|
|
|
|
|
|
|
|
|
|
def get(self, key, default: Any = None):
|
|
|
|
|
return self.__dict__.get(key, default)
|
|
|
|
|
|
|
|
|
|
def remove(self, key):
|
|
|
|
|
if key in self.__dict__:
|
|
|
|
|
self.__delattr__(key)
|
|
|
|
|
|
2024-01-08 16:26:52 +08:00
|
|
|
|
2024-01-09 17:13:22 +08:00
|
|
|
class Context(BaseModel):
|
2024-01-09 14:16:32 +08:00
|
|
|
"""Env context for MetaGPT"""
|
|
|
|
|
|
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
|
|
2024-01-08 18:30:04 +08:00
|
|
|
kwargs: AttrDict = AttrDict()
|
2024-01-04 21:16:23 +08:00
|
|
|
config: Config = Config.default()
|
2024-01-15 16:37:42 +08:00
|
|
|
|
|
|
|
|
repo: Optional[ProjectRepo] = None
|
2024-01-04 21:16:23 +08:00
|
|
|
git_repo: Optional[GitRepository] = None
|
|
|
|
|
src_workspace: Optional[Path] = None
|
|
|
|
|
cost_manager: CostManager = CostManager()
|
2024-01-10 15:34:49 +08:00
|
|
|
|
|
|
|
|
_llm: Optional[BaseLLM] = None
|
2024-01-04 21:16:23 +08:00
|
|
|
|
|
|
|
|
def new_environ(self):
|
|
|
|
|
"""Return a new os.environ object"""
|
|
|
|
|
env = os.environ.copy()
|
2024-01-15 16:58:01 +08:00
|
|
|
# i = self.options
|
|
|
|
|
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
2024-01-04 21:16:23 +08:00
|
|
|
return env
|
|
|
|
|
|
2024-02-29 10:14:15 +08:00
|
|
|
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
|
|
|
|
"""Return a CostManager instance"""
|
|
|
|
|
if llm_config.api_type == LLMType.FIREWORKS:
|
|
|
|
|
return FireworksCostManager()
|
|
|
|
|
elif llm_config.api_type == LLMType.OPEN_LLM:
|
|
|
|
|
return TokenCostManager()
|
|
|
|
|
else:
|
|
|
|
|
return self.cost_manager
|
|
|
|
|
|
2024-01-11 15:10:07 +08:00
|
|
|
def llm(self) -> BaseLLM:
|
2024-01-10 16:02:05 +08:00
|
|
|
"""Return a LLM instance, fixme: support cache"""
|
|
|
|
|
# if self._llm is None:
|
2024-01-11 15:10:07 +08:00
|
|
|
self._llm = create_llm_instance(self.config.llm)
|
2024-01-10 16:02:05 +08:00
|
|
|
if self._llm.cost_manager is None:
|
2024-02-29 10:14:15 +08:00
|
|
|
self._llm.cost_manager = self._select_costmanager(self.config.llm)
|
2024-01-10 15:34:49 +08:00
|
|
|
return self._llm
|
2024-01-04 21:16:23 +08:00
|
|
|
|
2024-01-10 18:32:03 +08:00
|
|
|
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
|
|
|
|
|
"""Return a LLM instance, fixme: support cache"""
|
|
|
|
|
# if self._llm is None:
|
|
|
|
|
llm = create_llm_instance(llm_config)
|
|
|
|
|
if llm.cost_manager is None:
|
2024-02-29 10:14:15 +08:00
|
|
|
llm.cost_manager = self._select_costmanager(llm_config)
|
2024-01-10 18:32:03 +08:00
|
|
|
return llm
|
2024-03-25 21:04:16 +08:00
|
|
|
|
|
|
|
|
def serialize(self) -> Dict[str, Any]:
|
|
|
|
|
"""Serialize the object's attributes into a dictionary.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict[str, Any]: A dictionary containing serialized data.
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
"workdir": str(self.repo.workdir) if self.repo else "",
|
|
|
|
|
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
2024-03-25 23:23:43 +08:00
|
|
|
"cost_manager": self.cost_manager.model_dump_json(),
|
2024-03-25 21:04:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def deserialize(self, serialized_data: Dict[str, Any]):
|
|
|
|
|
"""Deserialize the given serialized data and update the object's attributes accordingly.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
|
|
|
|
|
"""
|
|
|
|
|
if not serialized_data:
|
|
|
|
|
return
|
|
|
|
|
workdir = serialized_data.get("workdir")
|
|
|
|
|
if workdir:
|
|
|
|
|
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
|
|
|
|
self.repo = ProjectRepo(self.git_repo)
|
|
|
|
|
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
|
|
|
|
if src_workspace.exists():
|
|
|
|
|
self.src_workspace = src_workspace
|
|
|
|
|
kwargs = serialized_data.get("kwargs")
|
|
|
|
|
if kwargs:
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
|
self.kwargs.set(k, v)
|
2024-03-25 23:23:43 +08:00
|
|
|
cost_manager = serialized_data.get("cost_manager")
|
|
|
|
|
if cost_manager:
|
|
|
|
|
self.cost_manager.model_validate_json(cost_manager)
|