fixes and modelgen insert
This commit is contained in:
237
tools/loader/extract/django.py
Normal file
237
tools/loader/extract/django.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Django Extractor
|
||||
|
||||
Extracts model definitions from Django ORM models.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from ..schema import EnumDefinition, FieldDefinition, ModelDefinition
|
||||
from .base import BaseExtractor
|
||||
|
||||
# Django field type mappings to Python types
|
||||
DJANGO_FIELD_TYPES = {
|
||||
"CharField": str,
|
||||
"TextField": str,
|
||||
"EmailField": str,
|
||||
"URLField": str,
|
||||
"SlugField": str,
|
||||
"UUIDField": "UUID",
|
||||
"IntegerField": int,
|
||||
"BigIntegerField": "bigint",
|
||||
"SmallIntegerField": int,
|
||||
"PositiveIntegerField": int,
|
||||
"FloatField": float,
|
||||
"DecimalField": float,
|
||||
"BooleanField": bool,
|
||||
"NullBooleanField": bool,
|
||||
"DateField": "datetime",
|
||||
"DateTimeField": "datetime",
|
||||
"TimeField": "datetime",
|
||||
"JSONField": "dict",
|
||||
"ForeignKey": "FK",
|
||||
"OneToOneField": "FK",
|
||||
"ManyToManyField": "M2M",
|
||||
}
|
||||
|
||||
|
||||
class DjangoExtractor(BaseExtractor):
|
||||
"""Extracts models from Django ORM."""
|
||||
|
||||
def detect(self) -> bool:
|
||||
"""Check if this is a Django project."""
|
||||
# Look for manage.py or settings.py
|
||||
manage_py = self.source_path / "manage.py"
|
||||
settings_py = self.source_path / "settings.py"
|
||||
|
||||
if manage_py.exists():
|
||||
return True
|
||||
|
||||
# Check for Django imports in any models.py
|
||||
for models_file in self.source_path.rglob("models.py"):
|
||||
content = models_file.read_text()
|
||||
if "from django.db import models" in content:
|
||||
return True
|
||||
|
||||
return settings_py.exists()
|
||||
|
||||
def extract(self) -> tuple[List[ModelDefinition], List[EnumDefinition]]:
|
||||
"""Extract Django models using AST parsing."""
|
||||
models = []
|
||||
enums = []
|
||||
|
||||
# Find all models.py files
|
||||
for models_file in self.source_path.rglob("models.py"):
|
||||
file_models, file_enums = self._extract_from_file(models_file)
|
||||
models.extend(file_models)
|
||||
enums.extend(file_enums)
|
||||
|
||||
return models, enums
|
||||
|
||||
def _extract_from_file(
|
||||
self, file_path: Path
|
||||
) -> tuple[List[ModelDefinition], List[EnumDefinition]]:
|
||||
"""Extract models from a single models.py file."""
|
||||
models = []
|
||||
enums = []
|
||||
|
||||
content = file_path.read_text()
|
||||
tree = ast.parse(content)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Check if it inherits from models.Model
|
||||
if self._is_django_model(node):
|
||||
model_def = self._parse_model_class(node)
|
||||
if model_def:
|
||||
models.append(model_def)
|
||||
# Check if it's a TextChoices/IntegerChoices enum
|
||||
elif self._is_django_choices(node):
|
||||
enum_def = self._parse_choices_class(node)
|
||||
if enum_def:
|
||||
enums.append(enum_def)
|
||||
|
||||
return models, enums
|
||||
|
||||
def _is_django_model(self, node: ast.ClassDef) -> bool:
|
||||
"""Check if class inherits from models.Model."""
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Attribute):
|
||||
if base.attr == "Model":
|
||||
return True
|
||||
elif isinstance(base, ast.Name):
|
||||
if base.id in ("Model", "AbstractUser", "AbstractBaseUser"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_django_choices(self, node: ast.ClassDef) -> bool:
|
||||
"""Check if class is a Django TextChoices/IntegerChoices."""
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Attribute):
|
||||
if base.attr in ("TextChoices", "IntegerChoices"):
|
||||
return True
|
||||
elif isinstance(base, ast.Name):
|
||||
if base.id in ("TextChoices", "IntegerChoices"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _parse_model_class(self, node: ast.ClassDef) -> Optional[ModelDefinition]:
|
||||
"""Parse a Django model class into ModelDefinition."""
|
||||
fields = []
|
||||
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.Assign):
|
||||
field_def = self._parse_field_assignment(item)
|
||||
if field_def:
|
||||
fields.append(field_def)
|
||||
elif isinstance(item, ast.AnnAssign):
|
||||
# Handle annotated assignments (Django 4.0+ style)
|
||||
field_def = self._parse_annotated_field(item)
|
||||
if field_def:
|
||||
fields.append(field_def)
|
||||
|
||||
# Get docstring
|
||||
docstring = ast.get_docstring(node)
|
||||
|
||||
return ModelDefinition(
|
||||
name=node.name,
|
||||
fields=fields,
|
||||
docstring=docstring,
|
||||
)
|
||||
|
||||
def _parse_field_assignment(self, node: ast.Assign) -> Optional[FieldDefinition]:
|
||||
"""Parse a field assignment like: name = models.CharField(...)"""
|
||||
if not node.targets or not isinstance(node.targets[0], ast.Name):
|
||||
return None
|
||||
|
||||
field_name = node.targets[0].id
|
||||
|
||||
# Skip private fields and Meta class
|
||||
if field_name.startswith("_") or field_name == "Meta":
|
||||
return None
|
||||
|
||||
# Parse the field call
|
||||
if isinstance(node.value, ast.Call):
|
||||
return self._parse_field_call(field_name, node.value)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_annotated_field(self, node: ast.AnnAssign) -> Optional[FieldDefinition]:
|
||||
"""Parse an annotated field assignment."""
|
||||
if not isinstance(node.target, ast.Name):
|
||||
return None
|
||||
|
||||
field_name = node.target.id
|
||||
|
||||
if field_name.startswith("_"):
|
||||
return None
|
||||
|
||||
if node.value and isinstance(node.value, ast.Call):
|
||||
return self._parse_field_call(field_name, node.value)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_field_call(
|
||||
self, field_name: str, call: ast.Call
|
||||
) -> Optional[FieldDefinition]:
|
||||
"""Parse a Django field call like models.CharField(max_length=100)."""
|
||||
# Get field type name
|
||||
field_type_name = None
|
||||
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
field_type_name = call.func.attr
|
||||
elif isinstance(call.func, ast.Name):
|
||||
field_type_name = call.func.id
|
||||
|
||||
if not field_type_name:
|
||||
return None
|
||||
|
||||
# Map to Python type
|
||||
python_type = DJANGO_FIELD_TYPES.get(field_type_name, str)
|
||||
|
||||
# Check for null=True
|
||||
optional = False
|
||||
default = None
|
||||
|
||||
for keyword in call.keywords:
|
||||
if keyword.arg == "null":
|
||||
if isinstance(keyword.value, ast.Constant):
|
||||
optional = keyword.value.value is True
|
||||
elif keyword.arg == "default":
|
||||
if isinstance(keyword.value, ast.Constant):
|
||||
default = keyword.value.value
|
||||
|
||||
return FieldDefinition(
|
||||
name=field_name,
|
||||
type_hint=python_type,
|
||||
default=default if default is not None else None,
|
||||
optional=optional,
|
||||
)
|
||||
|
||||
def _parse_choices_class(self, node: ast.ClassDef) -> Optional[EnumDefinition]:
|
||||
"""Parse a Django TextChoices/IntegerChoices class."""
|
||||
values = []
|
||||
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.Assign):
|
||||
if item.targets and isinstance(item.targets[0], ast.Name):
|
||||
name = item.targets[0].id
|
||||
if name.isupper(): # Enum values are typically uppercase
|
||||
# Get the value
|
||||
value = name.lower() # Default to lowercase name
|
||||
if isinstance(item.value, ast.Constant):
|
||||
value = str(item.value.value)
|
||||
elif isinstance(item.value, ast.Tuple) and item.value.elts:
|
||||
# TextChoices: NAME = "value", "Label"
|
||||
if isinstance(item.value.elts[0], ast.Constant):
|
||||
value = str(item.value.elts[0].value)
|
||||
|
||||
values.append((name, value))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
return EnumDefinition(name=node.name, values=values)
|
||||
Reference in New Issue
Block a user