Files
soleprint/soleprint/common/auth/routes.py

174 lines
5.2 KiB
Python

"""
Authentication routes.
Generic routes that delegate to configured provider vein.
/auth/login - Start login flow (redirects to provider)
/auth/callback - Handle provider callback, create session
/auth/logout - Clear session
/auth/me - Get current user info
"""
import os
import secrets
from datetime import datetime, timedelta
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import RedirectResponse
from .config import AuthConfig
router = APIRouter(prefix="/auth", tags=["auth"])
# Will be initialized by setup_auth() in run.py
auth_config: Optional[AuthConfig] = None
def init_auth(config: AuthConfig):
"""
Initialize auth module with configuration.
Called by run.py when setting up authentication.
"""
global auth_config
auth_config = config
def _get_provider_base_url() -> str:
"""Get base URL for the configured provider vein."""
if not auth_config:
raise HTTPException(500, "Auth not configured")
# Provider is a vein name like "google_login"
return f"/artery/{auth_config.provider}"
@router.get("/login")
async def login(request: Request, next: str = "/"):
"""
Start login flow.
Redirects to the configured provider vein's OAuth start endpoint.
"""
if not auth_config:
raise HTTPException(500, "Auth not configured")
# Generate CSRF state token
state = secrets.token_urlsafe(32)
request.session["oauth_state"] = state
request.session["oauth_next"] = next
# Get domain hint from config (first allowed domain)
hd = auth_config.allowed_domains[0] if auth_config.allowed_domains else None
# Build provider OAuth URL
provider_url = _get_provider_base_url()
params = f"?state={state}"
if hd:
params += f"&hd={hd}"
# Redirect includes callback to our /auth/callback
return RedirectResponse(url=f"{provider_url}/oauth/start{params}")
@router.get("/callback")
async def callback(
request: Request,
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
):
"""
Handle OAuth callback.
Receives code from provider, exchanges for user info, creates session.
"""
if not auth_config:
raise HTTPException(500, "Auth not configured")
if error:
raise HTTPException(400, f"OAuth error: {error}")
# Verify state
expected_state = request.session.get("oauth_state")
if not state or state != expected_state:
raise HTTPException(400, "Invalid state parameter")
# Call provider vein to exchange code for user info
provider_url = _get_provider_base_url()
try:
async with httpx.AsyncClient() as client:
# Get base URL from request
base_url = str(request.base_url).rstrip("/")
response = await client.get(
f"{base_url}{provider_url}/oauth/userinfo",
params={"code": code},
)
if response.status_code != 200:
raise HTTPException(400, f"Provider error: {response.text}")
user_info = response.json()
except httpx.RequestError as e:
raise HTTPException(500, f"Failed to contact provider: {e}")
# Verify domain/email restriction
user_email = user_info.get("email")
user_domain = user_info.get("hd")
email_allowed = user_email in auth_config.allowed_emails
domain_allowed = user_domain and user_domain in auth_config.allowed_domains
no_restrictions = not auth_config.allowed_domains and not auth_config.allowed_emails
if not (email_allowed or domain_allowed or no_restrictions):
raise HTTPException(
403,
f"Access restricted. Your account ({user_email}) is not authorized.",
)
# Create session
expires_at = datetime.now() + timedelta(hours=auth_config.session_timeout_hours)
request.session.update(
{
"user_email": user_info["email"],
"user_name": user_info.get("name"),
"user_picture": user_info.get("picture"),
"domain": user_domain,
"authenticated_at": datetime.now().isoformat(),
"expires_at": expires_at.isoformat(),
}
)
# Clean up oauth state
request.session.pop("oauth_state", None)
next_url = request.session.pop("oauth_next", "/")
return RedirectResponse(url=next_url)
@router.get("/logout")
async def logout(request: Request):
"""Clear session and redirect to login."""
request.session.clear()
return RedirectResponse(url="/auth/login")
@router.get("/me")
async def me(request: Request):
"""
Return current user info.
API endpoint for checking auth status.
"""
user = getattr(request.state, "user", None)
if not user:
# Try to get from session directly (in case middleware didn't run)
user_email = request.session.get("user_email")
if not user_email:
raise HTTPException(401, "Not authenticated")
user = {
"email": user_email,
"name": request.session.get("user_name"),
"picture": request.session.get("user_picture"),
"domain": request.session.get("domain"),
}
return user