""" OAuth2 utilities and base classes for OAuth-based veins. Any vein using OAuth2 (Google, GitHub, GitLab, etc.) can inherit from BaseOAuthVein and use TokenStorage. """ import json from abc import abstractmethod from datetime import datetime, timedelta from pathlib import Path from typing import Optional from .base import BaseVein, TClient, TCredentials class TokenStorage: """ File-based token storage for OAuth2 tokens. Can be overridden for Redis/database storage in production. Each vein gets its own storage directory. """ def __init__(self, vein_name: str, storage_dir: Optional[Path] = None): """ Initialize token storage. Args: vein_name: Name of the vein (e.g., 'google', 'github') storage_dir: Base storage directory (defaults to veins/{vein_name}/storage) """ if storage_dir is None: # Default: veins/{vein_name}/storage/ storage_dir = Path(__file__).parent / vein_name / "storage" self.storage_dir = storage_dir self.storage_dir.mkdir(parents=True, exist_ok=True) def _get_path(self, user_id: str) -> Path: """Get token file path for user.""" return self.storage_dir / f"tokens_{user_id}.json" def save_tokens(self, user_id: str, tokens: dict) -> None: """ Save OAuth2 tokens for a user. tokens should contain: - access_token - refresh_token (optional) - expires_in (seconds) - scope - token_type """ # Add expiry timestamp if "expires_in" in tokens: expires_at = datetime.now() + timedelta(seconds=tokens["expires_in"]) tokens["expires_at"] = expires_at.isoformat() path = self._get_path(user_id) with open(path, "w") as f: json.dump(tokens, f, indent=2) def load_tokens(self, user_id: str) -> Optional[dict]: """Load OAuth2 tokens for a user. Returns None if not found.""" path = self._get_path(user_id) if not path.exists(): return None with open(path, "r") as f: return json.load(f) def is_expired(self, tokens: dict) -> bool: """ Check if access token is expired. Returns True if expired or expiring in less than 5 minutes. """ if "expires_at" not in tokens: return True expires_at = datetime.fromisoformat(tokens["expires_at"]) # Consider expired if less than 5 minutes remaining return datetime.now() >= expires_at - timedelta(minutes=5) def delete_tokens(self, user_id: str) -> None: """Delete tokens for a user.""" path = self._get_path(user_id) if path.exists(): path.unlink() class BaseOAuthVein(BaseVein[TCredentials, TClient]): """ Base class for OAuth2-based veins. Extends BaseVein with OAuth2 flow management: - Authorization URL generation - Code exchange for tokens - Token refresh - Token storage """ def __init__(self, storage: Optional[TokenStorage] = None): """ Initialize OAuth vein. Args: storage: Token storage instance (creates default if None) """ if storage is None: storage = TokenStorage(vein_name=self.name) self.storage = storage @abstractmethod def get_auth_url(self, state: Optional[str] = None) -> str: """ Generate OAuth2 authorization URL. Args: state: Optional state parameter for CSRF protection Returns: URL to redirect user for authorization """ pass @abstractmethod async def exchange_code(self, code: str) -> dict: """ Exchange authorization code for tokens. Args: code: Authorization code from callback Returns: Token dict containing access_token, refresh_token, etc. """ pass @abstractmethod async def refresh_token(self, refresh_token: str) -> dict: """ Refresh an expired access token. Args: refresh_token: The refresh token Returns: New token dict with fresh access_token """ pass def get_valid_tokens(self, user_id: str) -> Optional[dict]: """ Get valid tokens for user, refreshing if needed. Args: user_id: User identifier Returns: Valid tokens or None if not authenticated """ tokens = self.storage.load_tokens(user_id) if not tokens: return None if self.storage.is_expired(tokens) and "refresh_token" in tokens: # Try to refresh try: import asyncio new_tokens = asyncio.run(self.refresh_token(tokens["refresh_token"])) self.storage.save_tokens(user_id, new_tokens) return new_tokens except Exception: # Refresh failed, user needs to re-authenticate return None return tokens