mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
Merge branch 'feat_werewolf' of github.com:better629/MetaGPT into feat_werewolf
This commit is contained in:
commit
bdfec451de
17 changed files with 436 additions and 40 deletions
|
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -47,6 +48,9 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
|
|
|
|||
50
metagpt/configs/embedding_config.py
Normal file
50
metagpt/configs/embedding_config.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class EmbeddingType(Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class EmbeddingConfig(YamlModel):
|
||||
"""Config for Embedding.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
def check_api_type(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -43,7 +43,15 @@ class ZhiPuAILLM(BaseLLM):
|
|||
self.llm = ZhiPuModelAPI(api_key=self.api_key)
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
max_tokens = self.config.max_token if self.config.max_token > 0 else 1024
|
||||
temperature = self.config.temperature if self.config.temperature > 0.0 else 0.3
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ class GenericFactory:
|
|||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +1,103 @@
|
|||
"""RAG Embedding Factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
EmbeddingType.OPENAI: self._create_openai,
|
||||
EmbeddingType.AZURE: self._create_azure,
|
||||
EmbeddingType.GEMINI: self._create_gemini,
|
||||
EmbeddingType.OLLAMA: self._create_ollama,
|
||||
# For backward compatibility
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
return super().get_instance(key or self._resolve_embedding_type())
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
|
||||
"""Resolves the embedding type.
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
|
||||
Raise TypeError if embedding type not found.
|
||||
"""
|
||||
if config.embedding.api_type:
|
||||
return config.embedding.api_type
|
||||
|
||||
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return config.llm.api_type
|
||||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
api_base=config.embedding.base_url or config.llm.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OpenAIEmbedding(**params)
|
||||
|
||||
def _create_azure(self) -> AzureOpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
azure_endpoint=config.embedding.base_url or config.llm.base_url,
|
||||
api_version=config.embedding.api_version or config.llm.api_version,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return AzureOpenAIEmbedding(**params)
|
||||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key,
|
||||
api_base=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return GeminiEmbedding(**params)
|
||||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
params = dict(
|
||||
base_url=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
|
|||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
NestAsyncio.apply_once()
|
||||
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
|
|
@ -32,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
|
|||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
|
|
|
|||
|
|
@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
|
|||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
||||
|
||||
class NestAsyncio:
|
||||
"""Make asyncio event loop reentrant."""
|
||||
|
||||
is_applied = False
|
||||
|
||||
@classmethod
|
||||
def apply_once(cls):
|
||||
"""Ensures `nest_asyncio.apply()` is called only once."""
|
||||
if not cls.is_applied:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
cls.is_applied = True
|
||||
|
|
|
|||
40
metagpt/utils/stream_pipe.py
Normal file
40
metagpt/utils/stream_pipe.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/3/27 10:00
|
||||
# @Author : leiwu30
|
||||
# @File : stream_pipe.py
|
||||
# @Version : None
|
||||
# @Description : None
|
||||
|
||||
import json
|
||||
import time
|
||||
from multiprocessing import Pipe
|
||||
|
||||
|
||||
class StreamPipe:
|
||||
parent_conn, child_conn = Pipe()
|
||||
finish: bool = False
|
||||
|
||||
format_data = {
|
||||
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1711361191,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_3bc1b5746c",
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"role": "assistant", "content": "content"}, "logprobs": None, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
|
||||
def set_message(self, msg):
|
||||
self.parent_conn.send(msg)
|
||||
|
||||
def get_message(self, timeout: int = 3):
|
||||
if self.child_conn.poll(timeout):
|
||||
return self.child_conn.recv()
|
||||
else:
|
||||
return None
|
||||
|
||||
def msg2stream(self, msg):
|
||||
self.format_data["created"] = int(time.time())
|
||||
self.format_data["choices"][0]["delta"]["content"] = msg
|
||||
return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8")
|
||||
|
|
@ -28,6 +28,7 @@ TOKEN_COSTS = {
|
|||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
|
|
@ -147,6 +148,7 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
TOKEN_MAX = {
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
|
|
@ -202,6 +204,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue