174 lines
5.2 KiB
Python
174 lines
5.2 KiB
Python
"""
|
|
Authentication routes.
|
|
|
|
Generic routes that delegate to configured provider vein.
|
|
|
|
/auth/login - Start login flow (redirects to provider)
|
|
/auth/callback - Handle provider callback, create session
|
|
/auth/logout - Clear session
|
|
/auth/me - Get current user info
|
|
"""
|
|
|
|
import os
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from .config import AuthConfig
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
# Will be initialized by setup_auth() in run.py
|
|
auth_config: Optional[AuthConfig] = None
|
|
|
|
|
|
def init_auth(config: AuthConfig):
|
|
"""
|
|
Initialize auth module with configuration.
|
|
|
|
Called by run.py when setting up authentication.
|
|
"""
|
|
global auth_config
|
|
auth_config = config
|
|
|
|
|
|
def _get_provider_base_url() -> str:
|
|
"""Get base URL for the configured provider vein."""
|
|
if not auth_config:
|
|
raise HTTPException(500, "Auth not configured")
|
|
# Provider is a vein name like "google_login"
|
|
return f"/artery/{auth_config.provider}"
|
|
|
|
|
|
@router.get("/login")
|
|
async def login(request: Request, next: str = "/"):
|
|
"""
|
|
Start login flow.
|
|
|
|
Redirects to the configured provider vein's OAuth start endpoint.
|
|
"""
|
|
if not auth_config:
|
|
raise HTTPException(500, "Auth not configured")
|
|
|
|
# Generate CSRF state token
|
|
state = secrets.token_urlsafe(32)
|
|
request.session["oauth_state"] = state
|
|
request.session["oauth_next"] = next
|
|
|
|
# Get domain hint from config (first allowed domain)
|
|
hd = auth_config.allowed_domains[0] if auth_config.allowed_domains else None
|
|
|
|
# Build provider OAuth URL
|
|
provider_url = _get_provider_base_url()
|
|
params = f"?state={state}"
|
|
if hd:
|
|
params += f"&hd={hd}"
|
|
|
|
# Redirect includes callback to our /auth/callback
|
|
return RedirectResponse(url=f"{provider_url}/oauth/start{params}")
|
|
|
|
|
|
@router.get("/callback")
|
|
async def callback(
|
|
request: Request,
|
|
code: Optional[str] = None,
|
|
state: Optional[str] = None,
|
|
error: Optional[str] = None,
|
|
):
|
|
"""
|
|
Handle OAuth callback.
|
|
|
|
Receives code from provider, exchanges for user info, creates session.
|
|
"""
|
|
if not auth_config:
|
|
raise HTTPException(500, "Auth not configured")
|
|
|
|
if error:
|
|
raise HTTPException(400, f"OAuth error: {error}")
|
|
|
|
# Verify state
|
|
expected_state = request.session.get("oauth_state")
|
|
if not state or state != expected_state:
|
|
raise HTTPException(400, "Invalid state parameter")
|
|
|
|
# Call provider vein to exchange code for user info
|
|
provider_url = _get_provider_base_url()
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
# Get base URL from request
|
|
base_url = str(request.base_url).rstrip("/")
|
|
response = await client.get(
|
|
f"{base_url}{provider_url}/oauth/userinfo",
|
|
params={"code": code},
|
|
)
|
|
if response.status_code != 200:
|
|
raise HTTPException(400, f"Provider error: {response.text}")
|
|
user_info = response.json()
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(500, f"Failed to contact provider: {e}")
|
|
|
|
# Verify domain/email restriction
|
|
user_email = user_info.get("email")
|
|
user_domain = user_info.get("hd")
|
|
email_allowed = user_email in auth_config.allowed_emails
|
|
domain_allowed = user_domain and user_domain in auth_config.allowed_domains
|
|
no_restrictions = not auth_config.allowed_domains and not auth_config.allowed_emails
|
|
|
|
if not (email_allowed or domain_allowed or no_restrictions):
|
|
raise HTTPException(
|
|
403,
|
|
f"Access restricted. Your account ({user_email}) is not authorized.",
|
|
)
|
|
|
|
# Create session
|
|
expires_at = datetime.now() + timedelta(hours=auth_config.session_timeout_hours)
|
|
request.session.update(
|
|
{
|
|
"user_email": user_info["email"],
|
|
"user_name": user_info.get("name"),
|
|
"user_picture": user_info.get("picture"),
|
|
"domain": user_domain,
|
|
"authenticated_at": datetime.now().isoformat(),
|
|
"expires_at": expires_at.isoformat(),
|
|
}
|
|
)
|
|
|
|
# Clean up oauth state
|
|
request.session.pop("oauth_state", None)
|
|
next_url = request.session.pop("oauth_next", "/")
|
|
|
|
return RedirectResponse(url=next_url)
|
|
|
|
|
|
@router.get("/logout")
|
|
async def logout(request: Request):
|
|
"""Clear session and redirect to login."""
|
|
request.session.clear()
|
|
return RedirectResponse(url="/auth/login")
|
|
|
|
|
|
@router.get("/me")
|
|
async def me(request: Request):
|
|
"""
|
|
Return current user info.
|
|
|
|
API endpoint for checking auth status.
|
|
"""
|
|
user = getattr(request.state, "user", None)
|
|
if not user:
|
|
# Try to get from session directly (in case middleware didn't run)
|
|
user_email = request.session.get("user_email")
|
|
if not user_email:
|
|
raise HTTPException(401, "Not authenticated")
|
|
user = {
|
|
"email": user_email,
|
|
"name": request.session.get("user_name"),
|
|
"picture": request.session.get("user_picture"),
|
|
"domain": request.session.get("domain"),
|
|
}
|
|
return user
|