refactor: unified google vein, prefixed module loading, cfg separation
- Unified google vein with OAuth + Sheets API - Prefixed vein module loading (vein_google) to avoid pip package shadowing - Preload pip packages before vein loading - Added common/auth framework - Rebranded sbwrapper from Pawprint to Soleprint - Removed cfg/ from history (now separate repo) - Keep cfg/standalone/ as sample configuration - gitignore cfg/amar/ and cfg/dlt/ (private configs)
This commit is contained in:
3
soleprint/common/__init__.py
Normal file
3
soleprint/common/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Common module - shared abstractions reusable across soleprint systems.
|
||||
"""
|
||||
10
soleprint/common/auth/__init__.py
Normal file
10
soleprint/common/auth/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Generic authentication framework for soleprint.
|
||||
|
||||
Provider-agnostic - delegates to configured provider vein (e.g., google_login).
|
||||
"""
|
||||
|
||||
from .config import AuthConfig, load_auth_config
|
||||
from .session import get_current_user, require_auth
|
||||
|
||||
__all__ = ["AuthConfig", "load_auth_config", "get_current_user", "require_auth"]
|
||||
43
soleprint/common/auth/config.py
Normal file
43
soleprint/common/auth/config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Authentication configuration.
|
||||
|
||||
Generic config that works with any provider vein.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""Authentication configuration for a room."""
|
||||
|
||||
enabled: bool = False
|
||||
provider: str = "google" # Vein name to use for auth
|
||||
allowed_domains: list[str] = [] # Empty = allow any domain
|
||||
session_secret: str = "" # Required if enabled, can be "ENV:VAR_NAME"
|
||||
session_timeout_hours: int = 24
|
||||
login_redirect: str = "/"
|
||||
public_routes: list[str] = [
|
||||
"/health",
|
||||
"/auth/login",
|
||||
"/auth/callback",
|
||||
"/auth/logout",
|
||||
]
|
||||
|
||||
|
||||
def load_auth_config(config: dict) -> Optional[AuthConfig]:
|
||||
"""
|
||||
Load auth config from room config.json.
|
||||
|
||||
Returns None if auth is not enabled.
|
||||
"""
|
||||
auth_data = config.get("auth")
|
||||
if not auth_data:
|
||||
return None
|
||||
|
||||
auth_config = AuthConfig(**auth_data)
|
||||
if not auth_config.enabled:
|
||||
return None
|
||||
|
||||
return auth_config
|
||||
91
soleprint/common/auth/middleware.py
Normal file
91
soleprint/common/auth/middleware.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Authentication middleware for route protection.
|
||||
|
||||
Generic middleware, provider-agnostic.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from .config import AuthConfig
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that protects routes by requiring authentication.
|
||||
|
||||
- Public routes (configurable) are allowed without auth
|
||||
- Unauthenticated browser requests redirect to /auth/login
|
||||
- Unauthenticated API requests get 401 JSON response
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthConfig):
|
||||
super().__init__(app)
|
||||
self.config = auth_config
|
||||
self.public_routes = set(auth_config.public_routes)
|
||||
# Also allow static files and common paths
|
||||
self.public_prefixes = ["/static", "/favicon", "/artery"]
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Check if route is public
|
||||
if self._is_public(path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check session
|
||||
session = request.session
|
||||
user_email = session.get("user_email")
|
||||
expires_at = session.get("expires_at")
|
||||
|
||||
if not user_email:
|
||||
return self._unauthorized(request, "Not authenticated")
|
||||
|
||||
# Check expiry
|
||||
if expires_at:
|
||||
if datetime.fromisoformat(expires_at) < datetime.now():
|
||||
session.clear()
|
||||
return self._unauthorized(request, "Session expired")
|
||||
|
||||
# Check domain restriction
|
||||
user_domain = session.get("domain")
|
||||
if self.config.allowed_domains:
|
||||
if not user_domain or user_domain not in self.config.allowed_domains:
|
||||
session.clear()
|
||||
return self._unauthorized(
|
||||
request,
|
||||
f"Access restricted to: {', '.join(self.config.allowed_domains)}",
|
||||
)
|
||||
|
||||
# Attach user to request state for downstream use
|
||||
request.state.user = {
|
||||
"email": user_email,
|
||||
"name": session.get("user_name"),
|
||||
"domain": user_domain,
|
||||
}
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _is_public(self, path: str) -> bool:
|
||||
"""Check if path is public (no auth required)."""
|
||||
if path in self.public_routes:
|
||||
return True
|
||||
for prefix in self.public_prefixes:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _unauthorized(self, request, message: str):
|
||||
"""Return appropriate unauthorized response."""
|
||||
# API requests get JSON 401
|
||||
accept = request.headers.get("accept", "")
|
||||
if "application/json" in accept:
|
||||
return JSONResponse({"error": message}, status_code=401)
|
||||
|
||||
# Browser requests redirect to login with return URL
|
||||
next_url = str(request.url.path)
|
||||
if request.url.query:
|
||||
next_url += f"?{request.url.query}"
|
||||
return RedirectResponse(url=f"/auth/login?next={next_url}")
|
||||
170
soleprint/common/auth/routes.py
Normal file
170
soleprint/common/auth/routes.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Authentication routes.
|
||||
|
||||
Generic routes that delegate to configured provider vein.
|
||||
|
||||
/auth/login - Start login flow (redirects to provider)
|
||||
/auth/callback - Handle provider callback, create session
|
||||
/auth/logout - Clear session
|
||||
/auth/me - Get current user info
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from .config import AuthConfig
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# Will be initialized by setup_auth() in run.py
|
||||
auth_config: Optional[AuthConfig] = None
|
||||
|
||||
|
||||
def init_auth(config: AuthConfig):
|
||||
"""
|
||||
Initialize auth module with configuration.
|
||||
|
||||
Called by run.py when setting up authentication.
|
||||
"""
|
||||
global auth_config
|
||||
auth_config = config
|
||||
|
||||
|
||||
def _get_provider_base_url() -> str:
|
||||
"""Get base URL for the configured provider vein."""
|
||||
if not auth_config:
|
||||
raise HTTPException(500, "Auth not configured")
|
||||
# Provider is a vein name like "google_login"
|
||||
return f"/artery/{auth_config.provider}"
|
||||
|
||||
|
||||
@router.get("/login")
|
||||
async def login(request: Request, next: str = "/"):
|
||||
"""
|
||||
Start login flow.
|
||||
|
||||
Redirects to the configured provider vein's OAuth start endpoint.
|
||||
"""
|
||||
if not auth_config:
|
||||
raise HTTPException(500, "Auth not configured")
|
||||
|
||||
# Generate CSRF state token
|
||||
state = secrets.token_urlsafe(32)
|
||||
request.session["oauth_state"] = state
|
||||
request.session["oauth_next"] = next
|
||||
|
||||
# Get domain hint from config (first allowed domain)
|
||||
hd = auth_config.allowed_domains[0] if auth_config.allowed_domains else None
|
||||
|
||||
# Build provider OAuth URL
|
||||
provider_url = _get_provider_base_url()
|
||||
params = f"?state={state}"
|
||||
if hd:
|
||||
params += f"&hd={hd}"
|
||||
|
||||
# Redirect includes callback to our /auth/callback
|
||||
return RedirectResponse(url=f"{provider_url}/oauth/start{params}")
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback(
|
||||
request: Request,
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Handle OAuth callback.
|
||||
|
||||
Receives code from provider, exchanges for user info, creates session.
|
||||
"""
|
||||
if not auth_config:
|
||||
raise HTTPException(500, "Auth not configured")
|
||||
|
||||
if error:
|
||||
raise HTTPException(400, f"OAuth error: {error}")
|
||||
|
||||
# Verify state
|
||||
expected_state = request.session.get("oauth_state")
|
||||
if not state or state != expected_state:
|
||||
raise HTTPException(400, "Invalid state parameter")
|
||||
|
||||
# Call provider vein to exchange code for user info
|
||||
provider_url = _get_provider_base_url()
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Get base URL from request
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
response = await client.get(
|
||||
f"{base_url}{provider_url}/oauth/userinfo",
|
||||
params={"code": code},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(400, f"Provider error: {response.text}")
|
||||
user_info = response.json()
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(500, f"Failed to contact provider: {e}")
|
||||
|
||||
# Verify domain if restricted
|
||||
user_domain = user_info.get("hd")
|
||||
if auth_config.allowed_domains:
|
||||
if not user_domain or user_domain not in auth_config.allowed_domains:
|
||||
raise HTTPException(
|
||||
403,
|
||||
f"Access restricted to: {', '.join(auth_config.allowed_domains)}. "
|
||||
f"Your account is from: {user_domain or 'personal Gmail'}",
|
||||
)
|
||||
|
||||
# Create session
|
||||
expires_at = datetime.now() + timedelta(hours=auth_config.session_timeout_hours)
|
||||
request.session.update(
|
||||
{
|
||||
"user_email": user_info["email"],
|
||||
"user_name": user_info.get("name"),
|
||||
"user_picture": user_info.get("picture"),
|
||||
"domain": user_domain,
|
||||
"authenticated_at": datetime.now().isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Clean up oauth state
|
||||
request.session.pop("oauth_state", None)
|
||||
next_url = request.session.pop("oauth_next", "/")
|
||||
|
||||
return RedirectResponse(url=next_url)
|
||||
|
||||
|
||||
@router.get("/logout")
|
||||
async def logout(request: Request):
|
||||
"""Clear session and redirect to login."""
|
||||
request.session.clear()
|
||||
return RedirectResponse(url="/auth/login")
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def me(request: Request):
|
||||
"""
|
||||
Return current user info.
|
||||
|
||||
API endpoint for checking auth status.
|
||||
"""
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user:
|
||||
# Try to get from session directly (in case middleware didn't run)
|
||||
user_email = request.session.get("user_email")
|
||||
if not user_email:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
user = {
|
||||
"email": user_email,
|
||||
"name": request.session.get("user_name"),
|
||||
"picture": request.session.get("user_picture"),
|
||||
"domain": request.session.get("domain"),
|
||||
}
|
||||
return user
|
||||
51
soleprint/common/auth/session.py
Normal file
51
soleprint/common/auth/session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Session helpers for authentication.
|
||||
|
||||
Generic session management, provider-agnostic.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[dict]:
|
||||
"""
|
||||
Get current authenticated user from session.
|
||||
|
||||
Returns:
|
||||
User dict with email, name, domain, etc. or None if not authenticated.
|
||||
"""
|
||||
session = getattr(request, "session", None)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
user_email = session.get("user_email")
|
||||
if not user_email:
|
||||
return None
|
||||
|
||||
# Check expiry
|
||||
expires_at = session.get("expires_at")
|
||||
if expires_at:
|
||||
if datetime.fromisoformat(expires_at) < datetime.now():
|
||||
return None
|
||||
|
||||
return {
|
||||
"email": user_email,
|
||||
"name": session.get("user_name"),
|
||||
"picture": session.get("user_picture"),
|
||||
"domain": session.get("domain"),
|
||||
}
|
||||
|
||||
|
||||
def require_auth(request: Request) -> dict:
|
||||
"""
|
||||
Get current user or raise 401.
|
||||
|
||||
For use as FastAPI dependency.
|
||||
"""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return user
|
||||
Reference in New Issue
Block a user