SurfSense/surfsense_backend/tests/integration/composio/test_oauth_callback.py

114 lines
3.3 KiB
Python
Raw Normal View History

from uuid import UUID
import httpx
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.utils.oauth_security import OAuthStateManager
pytestmark = pytest.mark.integration
def _state_for(space_id: int, user_id: UUID, toolkit_id: str = "googledrive") -> str:
return OAuthStateManager(config.SECRET_KEY).generate_secure_state(
space_id=space_id,
user_id=user_id,
toolkit_id=toolkit_id,
)
async def _drive_connectors(
session: AsyncSession,
*,
user_id: UUID,
search_space_id: int,
) -> list[SearchSourceConnector]:
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
)
)
return list(result.scalars().all())
async def test_callback_with_error_param_redirects_to_denied_page(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
state = _state_for(db_search_space.id, db_user.id)
response = await client.get(
f"/api/v1/auth/composio/connector/callback?state={state}&error=access_denied"
)
assert response.status_code in {302, 303, 307}
location = response.headers["location"]
assert (
f"/dashboard/{db_search_space.id}/connectors/callback?"
"error=composio_oauth_denied"
) in location
connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert connectors == []
async def test_second_oauth_for_same_toolkit_takes_reconnection_branch(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
first_state = _state_for(db_search_space.id, db_user.id)
first_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={first_state}&connectedAccountId=fake-acct-googledrive-first"
)
assert first_response.status_code in {302, 303, 307}
first_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(first_connectors) == 1
first_connector = first_connectors[0]
assert first_connector.config["composio_connected_account_id"] == (
"fake-acct-googledrive-first"
)
second_state = _state_for(db_search_space.id, db_user.id)
second_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={second_state}&connectedAccountId=fake-acct-googledrive-second"
)
assert second_response.status_code in {302, 303, 307}
second_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(second_connectors) == 1
assert second_connectors[0].id == first_connector.id
assert second_connectors[0].config["composio_connected_account_id"] == (
"fake-acct-googledrive-second"
)