From 2bca5c9d064c4b45bcfa23d1d6961788cc38bdcf Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 09:55:13 +0800 Subject: [PATCH] fix embedding output --- metagpt/provider/ollama_api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 4537a8a2c..3f7d20d0a 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -49,7 +49,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms:] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: @@ -300,6 +300,10 @@ class OllamaEmbeddings(OllamaLLM): def _llama_api_kwargs(self) -> dict: return {"options": {"temperature": 0.3}} + @property + def _llama_embedding_key(self) -> str: + return "embedding" + async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: resp, _, _ = await self.client.arequest( method=self.http_method, @@ -307,7 +311,7 @@ class OllamaEmbeddings(OllamaLLM): params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) - return self.ollama_message.decode(resp)["embedding"] + return self.ollama_message.decode(resp)[self._llama_embedding_key] async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) @@ -321,3 +325,7 @@ class OllamaEmbed(OllamaEmbeddings): @property def _llama_api_inuse(self) -> OllamaMessageAPI: return OllamaMessageAPI.EMBED + + @property + def _llama_embedding_key(self) -> str: + return "embeddings"