SurfSense/surfsense_backend/tests/integration/composio/test_oauth_callback.py

113 lines
3.3 KiB
Python

from uuid import UUID
import httpx
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.utils.oauth_security import OAuthStateManager
pytestmark = pytest.mark.integration
def _state_for(space_id: int, user_id: UUID, toolkit_id: str = "googledrive") -> str:
return OAuthStateManager(config.SECRET_KEY).generate_secure_state(
space_id=space_id,
user_id=user_id,
toolkit_id=toolkit_id,
)
async def _drive_connectors(
session: AsyncSession,
*,
user_id: UUID,
search_space_id: int,
) -> list[SearchSourceConnector]:
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
)
)
return list(result.scalars().all())
async def test_callback_with_error_param_redirects_to_denied_page(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
state = _state_for(db_search_space.id, db_user.id)
response = await client.get(
f"/api/v1/auth/composio/connector/callback?state={state}&error=access_denied"
)
assert response.status_code in {302, 303, 307}
location = response.headers["location"]
assert (
f"/dashboard/{db_search_space.id}/connectors/callback?"
"error=composio_oauth_denied"
) in location
connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert connectors == []
async def test_second_oauth_for_same_toolkit_takes_reconnection_branch(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
first_state = _state_for(db_search_space.id, db_user.id)
first_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={first_state}&connectedAccountId=fake-acct-googledrive-first"
)
assert first_response.status_code in {302, 303, 307}
first_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(first_connectors) == 1
first_connector = first_connectors[0]
assert first_connector.config["composio_connected_account_id"] == (
"fake-acct-googledrive-first"
)
second_state = _state_for(db_search_space.id, db_user.id)
second_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={second_state}&connectedAccountId=fake-acct-googledrive-second"
)
assert second_response.status_code in {302, 303, 307}
second_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(second_connectors) == 1
assert second_connectors[0].id == first_connector.id
assert second_connectors[0].config["composio_connected_account_id"] == (
"fake-acct-googledrive-second"
)