diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 4ff49befe..e4b3a3f17 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart - +import json import os -from typing import Optional, Union +from dataclasses import asdict +from typing import List, Optional, Union import google.generativeai as genai from google.ai import generativelanguage as glm @@ -11,6 +12,7 @@ from google.generativeai.generative_models import GenerativeModel from google.generativeai.types import content_types from google.generativeai.types.generation_types import ( AsyncGenerateContentResponse, + BlockedPromptException, GenerateContentResponse, GenerationConfig, ) @@ -141,7 +143,11 @@ class GeminiLLM(BaseLLM): ) collected_content = [] async for chunk in resp: - content = chunk.text + try: + content = chunk.text + except Exception as e: + logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}") + raise BlockedPromptException(str(chunk)) log_llm_stream(content) collected_content.append(content) log_llm_stream("\n") @@ -150,3 +156,10 @@ class GeminiLLM(BaseLLM): usage = await self.aget_usage(messages, full_content) self._update_costs(usage) return full_content + + def list_models(self) -> List: + models = [] + for model in genai.list_models(page_size=100): + models.append(asdict(model)) + logger.info(json.dumps(models)) + return models diff --git a/requirements.txt b/requirements.txt index a0ce1d1ac..da8aa26b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,7 +60,7 @@ gitignore-parser==0.1.9 # connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py websockets~=11.0 networkx~=3.2.1 -google-generativeai==0.3.2 +google-generativeai==0.4.1 playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py anytree ipywidgets==8.1.1