mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
update ollama
This commit is contained in:
parent
3fefc48efa
commit
fdb834674d
3 changed files with 184 additions and 61 deletions
|
|
@ -26,7 +26,9 @@ class LLMType(Enum):
|
|||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
OLLAMA = "ollama" # /chat at ollama api
|
||||
OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
|
||||
OLLAMA_EMBEDDING = "ollama.embeddings" # /embeddings at ollama api
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
MOONSHOT = "moonshot"
|
||||
|
|
@ -104,7 +106,8 @@ class LLMConfig(YamlModel):
|
|||
root_config_path = CONFIG_ROOT / "config2.yaml"
|
||||
if root_config_path.exists():
|
||||
raise ValueError(
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {
|
||||
repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
|
||||
)
|
||||
elif repo_config_path.exists():
|
||||
raise ValueError(f"Please set your API key in {repo_config_path}")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import time
|
|||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
|
|
@ -121,7 +122,7 @@ def logfmt(props):
|
|||
|
||||
|
||||
class OpenAIResponse:
|
||||
def __init__(self, data, headers):
|
||||
def __init__(self, data: Union[bytes, Any], headers: dict):
|
||||
self._headers = headers
|
||||
self.data = data
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator, Tuple
|
||||
from enum import Enum, auto
|
||||
from typing import AsyncGenerator, Optional, Tuple
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
|
|
@ -14,6 +15,114 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
class OllamaMessageAPI(Enum):
|
||||
# default
|
||||
CHAT = auto()
|
||||
GENERATE = auto()
|
||||
EMBED = auto()
|
||||
|
||||
|
||||
class OllamaMessageBase:
|
||||
api_type = OllamaMessageAPI.CHAT
|
||||
|
||||
def __init__(self, model: str, **additional_kwargs) -> None:
|
||||
self.model, self.additional_kwargs = model, additional_kwargs
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, response: OpenAIResponse) -> dict:
|
||||
return json.loads(response.data.decode("utf-8"))
|
||||
|
||||
def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
if "role" in msg:
|
||||
return msg["content"], None
|
||||
elif "type" in msg:
|
||||
tpe = msg["type"]
|
||||
if tpe == "text":
|
||||
return msg["text"], None
|
||||
elif tpe == "image_url":
|
||||
return None, msg["image_url"]["url"]
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class OllamaMessageMeta(type):
|
||||
registed_message = {}
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
super().__init__(name, bases, attrs)
|
||||
for base in bases:
|
||||
if issubclass(base, OllamaMessageBase):
|
||||
api_type = attrs["api_type"]
|
||||
if isinstance(api_type, list):
|
||||
for tpe in api_type:
|
||||
assert tpe not in OllamaMessageMeta.registed_message, "api_type already exist"
|
||||
assert isinstance(tpe, OllamaMessageAPI), "api_type not support"
|
||||
OllamaMessageMeta.registed_message[tpe] = cls
|
||||
else:
|
||||
assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist"
|
||||
assert isinstance(api_type, OllamaMessageAPI), "api_type not support"
|
||||
OllamaMessageMeta.registed_message[api_type] = cls
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]:
|
||||
return cls.registed_message[input_type]
|
||||
|
||||
|
||||
class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.CHAT
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/chat"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
prompts = []
|
||||
images = []
|
||||
for msg in messages:
|
||||
prompt, image = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
if image:
|
||||
images.append(image)
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
|
||||
class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.GENERATE
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/generate"
|
||||
|
||||
|
||||
class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta):
|
||||
api_type = OllamaMessageAPI.EMBED
|
||||
|
||||
@property
|
||||
def api_suffix(self) -> str:
|
||||
return "/embeddings"
|
||||
|
||||
def apply(self, messages: list[dict]) -> dict:
|
||||
prompts = []
|
||||
for msg in messages:
|
||||
prompt, _ = self._parse_input_msg(msg)
|
||||
if prompt:
|
||||
prompts.append(prompt)
|
||||
sends = {"model": self.model, "prompt": "\n".join(prompts)}
|
||||
sends.update(self.additional_kwargs)
|
||||
return sends
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA)
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
|
|
@ -21,43 +130,45 @@ class OllamaLLM(BaseLLM):
|
|||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_ollama(config)
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self.cost_manager = TokenCostManager()
|
||||
self.__init_ollama(config)
|
||||
|
||||
def _get_headers(self):
|
||||
return (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {"Authorization": f"Bearer {self.config.api_key}"}
|
||||
)
|
||||
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.CHAT
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
if assist_msg.get("role", None) == "assistant": # chat
|
||||
return assist_msg.get("content")
|
||||
else: # llava
|
||||
return resp["response"]
|
||||
ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
|
||||
self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
|
||||
self.headers = self._get_headers()
|
||||
|
||||
def get_usage(self, resp: dict) -> dict:
|
||||
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
|
||||
|
||||
def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict:
|
||||
return json.loads(openai_resp.data.decode(encoding))
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
messages, fixed_suffix_url = self._apply_llava(messages)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=fixed_suffix_url,
|
||||
headers=self._get_headers(),
|
||||
params=messages,
|
||||
url=self.ollama_message.api_suffix,
|
||||
headers=self.headers,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
if isinstance(resp, AsyncGenerator):
|
||||
|
|
@ -65,30 +176,29 @@ class OllamaLLM(BaseLLM):
|
|||
elif isinstance(resp, OpenAIResponse):
|
||||
return self._processing_openai_response(resp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise ValueError
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
messages, fixed_suffix_url = self._apply_llava(messages, stream=True)
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=fixed_suffix_url,
|
||||
headers=self._get_headers(),
|
||||
stream=True,
|
||||
params=messages,
|
||||
url=self.ollama_message.api_suffix,
|
||||
headers=self.headers,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(stream_resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(stream_resp)
|
||||
elif isinstance(stream_resp, OpenAIResponse):
|
||||
return self._processing_openai_response(stream_resp)
|
||||
if isinstance(resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(resp)
|
||||
elif isinstance(resp, OpenAIResponse):
|
||||
return self._processing_openai_response(resp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise ValueError
|
||||
|
||||
def _processing_openai_response(self, openai_resp: OpenAIResponse):
|
||||
resp = self._decode_and_load(openai_resp)
|
||||
resp = self.ollama_message.decode(openai_resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
|
@ -97,7 +207,7 @@ class OllamaLLM(BaseLLM):
|
|||
collected_content = []
|
||||
usage = {}
|
||||
async for raw_chunk in ag_openai_resp:
|
||||
chunk = self._decode_and_load(raw_chunk)
|
||||
chunk = self.ollama_message.decode(raw_chunk)
|
||||
|
||||
if not chunk.get("done", False):
|
||||
content = self.get_choice_text(chunk)
|
||||
|
|
@ -112,28 +222,37 @@ class OllamaLLM(BaseLLM):
|
|||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
def _get_headers(self):
|
||||
return (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {"Authorization": f"Bearer {self.config.api_key}"}
|
||||
|
||||
@register_provider(LLMType.OLLAMA_GENERATE)
|
||||
class OllamaGenerate(OllamaLLM):
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.GENERATE
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}, "stream": self.config.stream}
|
||||
|
||||
|
||||
@register_provider(LLMType.OLLAMA_EMBEDDING)
|
||||
class OllamaEmbeddings(OllamaLLM):
|
||||
@property
|
||||
def _llama_api_inuse(self) -> OllamaMessageAPI:
|
||||
return OllamaMessageAPI.EMBED
|
||||
|
||||
@property
|
||||
def _llama_api_kwargs(self) -> dict:
|
||||
return {"options": {"temperature": 0.3}}
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.ollama_message.api_suffix,
|
||||
headers=self.headers,
|
||||
params=self.ollama_message.apply(messages=messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
return self.ollama_message.decode(resp)["embedding"]
|
||||
|
||||
def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]:
|
||||
llava = False
|
||||
if isinstance(messages[0]["content"], str):
|
||||
return self._const_kwargs(messages, stream), self.suffix_url
|
||||
|
||||
if any(len(msg["content"]) >= 2 for msg in messages):
|
||||
assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type"
|
||||
llava = True
|
||||
if not llava:
|
||||
return self._const_kwargs(messages, stream), self.suffix_url
|
||||
|
||||
assert len(messages) <= 1, "not support batch massages in llava calling images"
|
||||
contents = messages[0]["content"]
|
||||
return {
|
||||
"model": self.model,
|
||||
"prompt": contents[0]["text"],
|
||||
"images": [i["image_url"]["url"][23:] for i in contents[1:]],
|
||||
}, "/generate"
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue