mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add authentication for OSS (#167)
* feat: add authentication for OSS Fixes #157 and #156 * fix: fix token generation * fix: limit fastapi workers to 1
This commit is contained in:
parent
0791975864
commit
642cc34e8c
48 changed files with 994 additions and 303 deletions
|
|
@ -0,0 +1,34 @@
|
|||
"""add user email and password
|
||||
|
||||
Revision ID: 6fd8fac02883
|
||||
Revises: 6d2f94baf4b7
|
||||
Create Date: 2026-02-20 11:43:47.679075
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6fd8fac02883"
|
||||
down_revision: Union[str, None] = "6d2f94baf4b7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("users", sa.Column("email", sa.String(), nullable=True))
|
||||
op.add_column("users", sa.Column("password_hash", sa.String(), nullable=True))
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_column("users", "password_hash")
|
||||
op.drop_column("users", "email")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -25,6 +25,7 @@ DATABASE_URL = os.environ["DATABASE_URL"]
|
|||
REDIS_URL = os.environ["REDIS_URL"]
|
||||
|
||||
DEPLOYMENT_MODE = os.getenv("DEPLOYMENT_MODE", "oss")
|
||||
AUTH_PROVIDER = os.getenv("AUTH_PROVIDER", "local")
|
||||
DOGRAH_MPS_SECRET_KEY = os.getenv("DOGRAH_MPS_SECRET_KEY", None)
|
||||
MPS_API_URL = os.getenv("MPS_API_URL", "https://services.dograh.com")
|
||||
|
||||
|
|
@ -118,3 +119,7 @@ TURN_HOST = os.getenv("TURN_HOST", "localhost")
|
|||
TURN_PORT = int(os.getenv("TURN_PORT", "3478"))
|
||||
TURN_TLS_PORT = int(os.getenv("TURN_TLS_PORT", "5349"))
|
||||
TURN_CREDENTIAL_TTL = int(os.getenv("TURN_CREDENTIAL_TTL", "86400"))
|
||||
|
||||
# OSS Email/Password Auth
|
||||
OSS_JWT_SECRET = os.getenv("OSS_JWT_SECRET", "change-me-in-production")
|
||||
OSS_JWT_EXPIRY_HOURS = int(os.getenv("OSS_JWT_EXPIRY_HOURS", "720")) # 30 days
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ class UserModel(Base):
|
|||
back_populates="users",
|
||||
)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
email = Column(String, unique=True, index=True, nullable=True)
|
||||
password_hash = Column(String, nullable=True)
|
||||
|
||||
|
||||
class UserConfigurationModel(Base):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -148,3 +149,35 @@ class UserClient(BaseDBClient):
|
|||
raise ValueError(f"User with ID {user_id} not found")
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def update_user_email(self, user_id: int, email: str) -> None:
|
||||
"""Update the user's email address."""
|
||||
async with self.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = update(UserModel).where(UserModel.id == user_id).values(email=email)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserModel | None:
|
||||
"""Fetch a user by their email address."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.email == email)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def create_user_with_email(
|
||||
self, email: str, password_hash: str, name: str | None = None
|
||||
) -> UserModel:
|
||||
"""Create a new user with email and password hash."""
|
||||
async with self.async_session() as session:
|
||||
user = UserModel(
|
||||
provider_id=f"oss_{int(datetime.now(timezone.utc).timestamp())}_{uuid.uuid4()}",
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -15,3 +15,5 @@ sqlalchemy[asyncio]==2.0.43
|
|||
msgpack==1.1.2
|
||||
docling[rapidocr]==2.68.0
|
||||
pgvector==0.4.2
|
||||
bcrypt==5.0.0
|
||||
email-validator==2.3.0
|
||||
|
|
|
|||
97
api/routes/auth.py
Normal file
97
api/routes/auth.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/signup", response_model=AuthResponse)
|
||||
async def signup(request: SignupRequest):
|
||||
# Check if email is already taken
|
||||
existing_user = await db_client.get_user_by_email(request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="Email already registered")
|
||||
|
||||
# Hash password and create user
|
||||
hashed = hash_password(request.password)
|
||||
user = await db_client.create_user_with_email(
|
||||
email=request.email,
|
||||
password_hash=hashed,
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
# Create organization for the user
|
||||
org_provider_id = f"org_{user.provider_id}"
|
||||
organization, _ = await db_client.get_or_create_organization_by_provider_id(
|
||||
org_provider_id=org_provider_id, user_id=user.id
|
||||
)
|
||||
|
||||
# Link user to organization
|
||||
await db_client.add_user_to_organization(user.id, organization.id)
|
||||
await db_client.update_user_selected_organization(user.id, organization.id)
|
||||
|
||||
# Create default service configuration
|
||||
try:
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
user.id, organization.id, user.provider_id
|
||||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
)
|
||||
|
||||
# Create JWT token
|
||||
token = create_jwt_token(user.id, request.email)
|
||||
|
||||
return AuthResponse(
|
||||
token=token,
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=request.name,
|
||||
organization_id=organization.id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthResponse)
|
||||
async def login(request: LoginRequest):
|
||||
# Look up user by email
|
||||
user = await db_client.get_user_by_email(request.email)
|
||||
if not user or not user.password_hash:
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
# Verify password
|
||||
if not verify_password(request.password, user.password_hash):
|
||||
raise HTTPException(status_code=401, detail="Invalid email or password")
|
||||
|
||||
# Create JWT token
|
||||
token = create_jwt_token(user.id, user.email)
|
||||
|
||||
return AuthResponse(
|
||||
token=token,
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
organization_id=user.selected_organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user(user: UserModel = Depends(get_user)):
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
|
@ -2,6 +2,7 @@ from fastapi import APIRouter
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.routes.auth import router as auth_router
|
||||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
|
|
@ -50,17 +51,20 @@ router.include_router(public_agent_router)
|
|||
router.include_router(public_download_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
router.include_router(auth_router)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
version: str
|
||||
backend_api_endpoint: str
|
||||
deployment_mode: str
|
||||
auth_provider: str
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
from api.constants import APP_VERSION
|
||||
from api.constants import APP_VERSION, AUTH_PROVIDER, DEPLOYMENT_MODE
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
logger.debug("Health endpoint called")
|
||||
|
|
@ -69,4 +73,6 @@ async def health() -> HealthResponse:
|
|||
status="ok",
|
||||
version=APP_VERSION,
|
||||
backend_api_endpoint=backend_endpoint,
|
||||
deployment_mode=DEPLOYMENT_MODE,
|
||||
auth_provider=AUTH_PROVIDER,
|
||||
)
|
||||
|
|
|
|||
31
api/schemas/auth.py
Normal file
31
api/schemas/auth.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
name: str | None = None
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_min_length(cls, v: str) -> str:
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters")
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
email: str | None
|
||||
name: str | None = None
|
||||
organization_id: int | None = None
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
token: str
|
||||
user: UserResponse
|
||||
|
|
@ -5,12 +5,13 @@ from fastapi import Header, HTTPException, Query, WebSocket
|
|||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
||||
from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
||||
async def get_user(
|
||||
|
|
@ -24,9 +25,9 @@ async def get_user(
|
|||
return await _handle_api_key_auth(x_api_key)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Check if we're in OSS deployment mode
|
||||
# Check if we're using local (email/password) auth
|
||||
# ------------------------------------------------------------------
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
if AUTH_PROVIDER == "local":
|
||||
return await _handle_oss_auth(authorization)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -54,6 +55,14 @@ async def get_user(
|
|||
|
||||
try:
|
||||
user_model = await db_client.get_or_create_user_by_provider_id(stack_user["id"])
|
||||
|
||||
# Sync email from Stack Auth if available and not already set
|
||||
stack_email = stack_user.get("primary_email_verified") and stack_user.get(
|
||||
"primary_email"
|
||||
)
|
||||
if stack_email and user_model.email != stack_email:
|
||||
await db_client.update_user_email(user_model.id, stack_email)
|
||||
user_model.email = stack_email
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error while creating user from database {e}"
|
||||
|
|
@ -125,7 +134,7 @@ async def get_user_optional(
|
|||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
Uses the authorization token as provider_id and creates user/org if needed.
|
||||
Validates JWT tokens issued by the email/password auth flow.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
|
@ -141,49 +150,15 @@ async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
|||
raise HTTPException(status_code=401, detail="Invalid authorization token")
|
||||
|
||||
try:
|
||||
# Use token as provider_id for OSS mode
|
||||
user_model = await db_client.get_or_create_user_by_provider_id(
|
||||
provider_id=token
|
||||
)
|
||||
|
||||
# Create or get organization for OSS user
|
||||
# Each OSS user gets their own organization using their token as org ID
|
||||
(
|
||||
organization,
|
||||
org_was_created,
|
||||
) = await db_client.get_or_create_organization_by_provider_id(
|
||||
org_provider_id=f"org_{token}", user_id=user_model.id
|
||||
)
|
||||
|
||||
# Ensure user is mapped to their organization
|
||||
if user_model.selected_organization_id != organization.id:
|
||||
# add_user_to_organization now handles race conditions with ON CONFLICT DO NOTHING
|
||||
await db_client.add_user_to_organization(user_model.id, organization.id)
|
||||
await db_client.update_user_selected_organization(
|
||||
user_model.id, organization.id
|
||||
)
|
||||
user_model.selected_organization_id = organization.id
|
||||
|
||||
# Only create default configuration if organization was just created
|
||||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
user_model.id, organization.id, token
|
||||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
|
||||
return user_model
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error while handling OSS authentication: {e}"
|
||||
)
|
||||
payload = decode_jwt_token(token)
|
||||
user = await db_client.get_user_by_id(int(payload["sub"]))
|
||||
if user:
|
||||
return user
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
|
||||
async def _handle_api_key_auth(api_key: str) -> UserModel:
|
||||
|
|
@ -233,8 +208,8 @@ async def create_user_configuration_with_mps_key(
|
|||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Use MPS API URL from constants
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
# For OSS mode, create a temporary service key without authentication
|
||||
if AUTH_PROVIDER == "local":
|
||||
# For local auth mode, create a temporary service key without authentication
|
||||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
|
|
|
|||
28
api/utils/auth.py
Normal file
28
api/utils/auth.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
|
||||
from api.constants import OSS_JWT_EXPIRY_HOURS, OSS_JWT_SECRET
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
|
||||
|
||||
|
||||
def create_jwt_token(user_id: int, email: str) -> str:
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"email": email,
|
||||
"exp": datetime.now(UTC) + timedelta(hours=OSS_JWT_EXPIRY_HOURS),
|
||||
"iat": datetime.now(UTC),
|
||||
}
|
||||
return jwt.encode(payload, OSS_JWT_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_jwt_token(token: str) -> dict:
|
||||
return jwt.decode(token, OSS_JWT_SECRET, algorithms=["HS256"])
|
||||
|
|
@ -119,10 +119,6 @@ async def get_backend_endpoints() -> tuple[str, str]:
|
|||
_validate_url(BACKEND_API_ENDPOINT)
|
||||
|
||||
if BACKEND_API_ENDPOINT:
|
||||
logger.debug(
|
||||
f"Processing BACKEND_API_ENDPOINT from environment: {BACKEND_API_ENDPOINT}"
|
||||
)
|
||||
|
||||
# Handle localhost/127.0.0.1 special case - use tunnel URL if available
|
||||
if "localhost" in BACKEND_API_ENDPOINT or "127.0.0.1" in BACKEND_API_ENDPOINT:
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ class TunnelURLProvider:
|
|||
# Try to get URL from cloudflared metrics
|
||||
urls = await cls._get_cloudflared_urls()
|
||||
if urls:
|
||||
logger.info(f"Retrieved tunnel URLs from cloudflared: {urls}")
|
||||
return urls
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get tunnel URL from cloudflared: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue