migrated all pawprint work
This commit is contained in:
179
artery/veins/oauth.py
Normal file
179
artery/veins/oauth.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
OAuth2 utilities and base classes for OAuth-based veins.
|
||||
|
||||
Any vein using OAuth2 (Google, GitHub, GitLab, etc.) can inherit from
|
||||
BaseOAuthVein and use TokenStorage.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import BaseVein, TClient, TCredentials
|
||||
|
||||
|
||||
class TokenStorage:
|
||||
"""
|
||||
File-based token storage for OAuth2 tokens.
|
||||
|
||||
Can be overridden for Redis/database storage in production.
|
||||
Each vein gets its own storage directory.
|
||||
"""
|
||||
|
||||
def __init__(self, vein_name: str, storage_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize token storage.
|
||||
|
||||
Args:
|
||||
vein_name: Name of the vein (e.g., 'google', 'github')
|
||||
storage_dir: Base storage directory (defaults to veins/{vein_name}/storage)
|
||||
"""
|
||||
if storage_dir is None:
|
||||
# Default: veins/{vein_name}/storage/
|
||||
storage_dir = Path(__file__).parent / vein_name / "storage"
|
||||
self.storage_dir = storage_dir
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_path(self, user_id: str) -> Path:
|
||||
"""Get token file path for user."""
|
||||
return self.storage_dir / f"tokens_{user_id}.json"
|
||||
|
||||
def save_tokens(self, user_id: str, tokens: dict) -> None:
|
||||
"""
|
||||
Save OAuth2 tokens for a user.
|
||||
|
||||
tokens should contain:
|
||||
- access_token
|
||||
- refresh_token (optional)
|
||||
- expires_in (seconds)
|
||||
- scope
|
||||
- token_type
|
||||
"""
|
||||
# Add expiry timestamp
|
||||
if "expires_in" in tokens:
|
||||
expires_at = datetime.now() + timedelta(seconds=tokens["expires_in"])
|
||||
tokens["expires_at"] = expires_at.isoformat()
|
||||
|
||||
path = self._get_path(user_id)
|
||||
with open(path, "w") as f:
|
||||
json.dump(tokens, f, indent=2)
|
||||
|
||||
def load_tokens(self, user_id: str) -> Optional[dict]:
|
||||
"""Load OAuth2 tokens for a user. Returns None if not found."""
|
||||
path = self._get_path(user_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def is_expired(self, tokens: dict) -> bool:
|
||||
"""
|
||||
Check if access token is expired.
|
||||
|
||||
Returns True if expired or expiring in less than 5 minutes.
|
||||
"""
|
||||
if "expires_at" not in tokens:
|
||||
return True
|
||||
|
||||
expires_at = datetime.fromisoformat(tokens["expires_at"])
|
||||
# Consider expired if less than 5 minutes remaining
|
||||
return datetime.now() >= expires_at - timedelta(minutes=5)
|
||||
|
||||
def delete_tokens(self, user_id: str) -> None:
|
||||
"""Delete tokens for a user."""
|
||||
path = self._get_path(user_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
class BaseOAuthVein(BaseVein[TCredentials, TClient]):
|
||||
"""
|
||||
Base class for OAuth2-based veins.
|
||||
|
||||
Extends BaseVein with OAuth2 flow management:
|
||||
- Authorization URL generation
|
||||
- Code exchange for tokens
|
||||
- Token refresh
|
||||
- Token storage
|
||||
"""
|
||||
|
||||
def __init__(self, storage: Optional[TokenStorage] = None):
|
||||
"""
|
||||
Initialize OAuth vein.
|
||||
|
||||
Args:
|
||||
storage: Token storage instance (creates default if None)
|
||||
"""
|
||||
if storage is None:
|
||||
storage = TokenStorage(vein_name=self.name)
|
||||
self.storage = storage
|
||||
|
||||
@abstractmethod
|
||||
def get_auth_url(self, state: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate OAuth2 authorization URL.
|
||||
|
||||
Args:
|
||||
state: Optional state parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
URL to redirect user for authorization
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def exchange_code(self, code: str) -> dict:
|
||||
"""
|
||||
Exchange authorization code for tokens.
|
||||
|
||||
Args:
|
||||
code: Authorization code from callback
|
||||
|
||||
Returns:
|
||||
Token dict containing access_token, refresh_token, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def refresh_token(self, refresh_token: str) -> dict:
|
||||
"""
|
||||
Refresh an expired access token.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
|
||||
Returns:
|
||||
New token dict with fresh access_token
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_valid_tokens(self, user_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get valid tokens for user, refreshing if needed.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Valid tokens or None if not authenticated
|
||||
"""
|
||||
tokens = self.storage.load_tokens(user_id)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
if self.storage.is_expired(tokens) and "refresh_token" in tokens:
|
||||
# Try to refresh
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
new_tokens = asyncio.run(self.refresh_token(tokens["refresh_token"]))
|
||||
self.storage.save_tokens(user_id, new_tokens)
|
||||
return new_tokens
|
||||
except Exception:
|
||||
# Refresh failed, user needs to re-authenticate
|
||||
return None
|
||||
|
||||
return tokens
|
||||
Reference in New Issue
Block a user