238 lines
7.9 KiB
Python
238 lines
7.9 KiB
Python
"""
|
|
Django Extractor
|
|
|
|
Extracts model definitions from Django ORM models.
|
|
"""
|
|
|
|
import ast
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, List, Optional
|
|
|
|
from ..schema import EnumDefinition, FieldDefinition, ModelDefinition
|
|
from .base import BaseExtractor
|
|
|
|
# Django field type mappings to Python types
|
|
DJANGO_FIELD_TYPES = {
|
|
"CharField": str,
|
|
"TextField": str,
|
|
"EmailField": str,
|
|
"URLField": str,
|
|
"SlugField": str,
|
|
"UUIDField": "UUID",
|
|
"IntegerField": int,
|
|
"BigIntegerField": "bigint",
|
|
"SmallIntegerField": int,
|
|
"PositiveIntegerField": int,
|
|
"FloatField": float,
|
|
"DecimalField": float,
|
|
"BooleanField": bool,
|
|
"NullBooleanField": bool,
|
|
"DateField": "datetime",
|
|
"DateTimeField": "datetime",
|
|
"TimeField": "datetime",
|
|
"JSONField": "dict",
|
|
"ForeignKey": "FK",
|
|
"OneToOneField": "FK",
|
|
"ManyToManyField": "M2M",
|
|
}
|
|
|
|
|
|
class DjangoExtractor(BaseExtractor):
|
|
"""Extracts models from Django ORM."""
|
|
|
|
def detect(self) -> bool:
|
|
"""Check if this is a Django project."""
|
|
# Look for manage.py or settings.py
|
|
manage_py = self.source_path / "manage.py"
|
|
settings_py = self.source_path / "settings.py"
|
|
|
|
if manage_py.exists():
|
|
return True
|
|
|
|
# Check for Django imports in any models.py
|
|
for models_file in self.source_path.rglob("models.py"):
|
|
content = models_file.read_text()
|
|
if "from django.db import models" in content:
|
|
return True
|
|
|
|
return settings_py.exists()
|
|
|
|
def extract(self) -> tuple[List[ModelDefinition], List[EnumDefinition]]:
|
|
"""Extract Django models using AST parsing."""
|
|
models = []
|
|
enums = []
|
|
|
|
# Find all models.py files
|
|
for models_file in self.source_path.rglob("models.py"):
|
|
file_models, file_enums = self._extract_from_file(models_file)
|
|
models.extend(file_models)
|
|
enums.extend(file_enums)
|
|
|
|
return models, enums
|
|
|
|
def _extract_from_file(
|
|
self, file_path: Path
|
|
) -> tuple[List[ModelDefinition], List[EnumDefinition]]:
|
|
"""Extract models from a single models.py file."""
|
|
models = []
|
|
enums = []
|
|
|
|
content = file_path.read_text()
|
|
tree = ast.parse(content)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ClassDef):
|
|
# Check if it inherits from models.Model
|
|
if self._is_django_model(node):
|
|
model_def = self._parse_model_class(node)
|
|
if model_def:
|
|
models.append(model_def)
|
|
# Check if it's a TextChoices/IntegerChoices enum
|
|
elif self._is_django_choices(node):
|
|
enum_def = self._parse_choices_class(node)
|
|
if enum_def:
|
|
enums.append(enum_def)
|
|
|
|
return models, enums
|
|
|
|
def _is_django_model(self, node: ast.ClassDef) -> bool:
|
|
"""Check if class inherits from models.Model."""
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Attribute):
|
|
if base.attr == "Model":
|
|
return True
|
|
elif isinstance(base, ast.Name):
|
|
if base.id in ("Model", "AbstractUser", "AbstractBaseUser"):
|
|
return True
|
|
return False
|
|
|
|
def _is_django_choices(self, node: ast.ClassDef) -> bool:
|
|
"""Check if class is a Django TextChoices/IntegerChoices."""
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Attribute):
|
|
if base.attr in ("TextChoices", "IntegerChoices"):
|
|
return True
|
|
elif isinstance(base, ast.Name):
|
|
if base.id in ("TextChoices", "IntegerChoices"):
|
|
return True
|
|
return False
|
|
|
|
def _parse_model_class(self, node: ast.ClassDef) -> Optional[ModelDefinition]:
|
|
"""Parse a Django model class into ModelDefinition."""
|
|
fields = []
|
|
|
|
for item in node.body:
|
|
if isinstance(item, ast.Assign):
|
|
field_def = self._parse_field_assignment(item)
|
|
if field_def:
|
|
fields.append(field_def)
|
|
elif isinstance(item, ast.AnnAssign):
|
|
# Handle annotated assignments (Django 4.0+ style)
|
|
field_def = self._parse_annotated_field(item)
|
|
if field_def:
|
|
fields.append(field_def)
|
|
|
|
# Get docstring
|
|
docstring = ast.get_docstring(node)
|
|
|
|
return ModelDefinition(
|
|
name=node.name,
|
|
fields=fields,
|
|
docstring=docstring,
|
|
)
|
|
|
|
def _parse_field_assignment(self, node: ast.Assign) -> Optional[FieldDefinition]:
|
|
"""Parse a field assignment like: name = models.CharField(...)"""
|
|
if not node.targets or not isinstance(node.targets[0], ast.Name):
|
|
return None
|
|
|
|
field_name = node.targets[0].id
|
|
|
|
# Skip private fields and Meta class
|
|
if field_name.startswith("_") or field_name == "Meta":
|
|
return None
|
|
|
|
# Parse the field call
|
|
if isinstance(node.value, ast.Call):
|
|
return self._parse_field_call(field_name, node.value)
|
|
|
|
return None
|
|
|
|
def _parse_annotated_field(self, node: ast.AnnAssign) -> Optional[FieldDefinition]:
|
|
"""Parse an annotated field assignment."""
|
|
if not isinstance(node.target, ast.Name):
|
|
return None
|
|
|
|
field_name = node.target.id
|
|
|
|
if field_name.startswith("_"):
|
|
return None
|
|
|
|
if node.value and isinstance(node.value, ast.Call):
|
|
return self._parse_field_call(field_name, node.value)
|
|
|
|
return None
|
|
|
|
def _parse_field_call(
|
|
self, field_name: str, call: ast.Call
|
|
) -> Optional[FieldDefinition]:
|
|
"""Parse a Django field call like models.CharField(max_length=100)."""
|
|
# Get field type name
|
|
field_type_name = None
|
|
|
|
if isinstance(call.func, ast.Attribute):
|
|
field_type_name = call.func.attr
|
|
elif isinstance(call.func, ast.Name):
|
|
field_type_name = call.func.id
|
|
|
|
if not field_type_name:
|
|
return None
|
|
|
|
# Map to Python type
|
|
python_type = DJANGO_FIELD_TYPES.get(field_type_name, str)
|
|
|
|
# Check for null=True
|
|
optional = False
|
|
default = None
|
|
|
|
for keyword in call.keywords:
|
|
if keyword.arg == "null":
|
|
if isinstance(keyword.value, ast.Constant):
|
|
optional = keyword.value.value is True
|
|
elif keyword.arg == "default":
|
|
if isinstance(keyword.value, ast.Constant):
|
|
default = keyword.value.value
|
|
|
|
return FieldDefinition(
|
|
name=field_name,
|
|
type_hint=python_type,
|
|
default=default if default is not None else None,
|
|
optional=optional,
|
|
)
|
|
|
|
def _parse_choices_class(self, node: ast.ClassDef) -> Optional[EnumDefinition]:
|
|
"""Parse a Django TextChoices/IntegerChoices class."""
|
|
values = []
|
|
|
|
for item in node.body:
|
|
if isinstance(item, ast.Assign):
|
|
if item.targets and isinstance(item.targets[0], ast.Name):
|
|
name = item.targets[0].id
|
|
if name.isupper(): # Enum values are typically uppercase
|
|
# Get the value
|
|
value = name.lower() # Default to lowercase name
|
|
if isinstance(item.value, ast.Constant):
|
|
value = str(item.value.value)
|
|
elif isinstance(item.value, ast.Tuple) and item.value.elts:
|
|
# TextChoices: NAME = "value", "Label"
|
|
if isinstance(item.value.elts[0], ast.Constant):
|
|
value = str(item.value.elts[0].value)
|
|
|
|
values.append((name, value))
|
|
|
|
if not values:
|
|
return None
|
|
|
|
return EnumDefinition(name=node.name, values=values)
|