Source code for wtforms_alchemy.generator
import inspect
from collections import OrderedDict
from decimal import Decimal
from enum import Enum
import sqlalchemy as sa
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy_utils import types
from wtforms import (
BooleanField,
Field,
FloatField,
PasswordField,
TextAreaField,
)
from wtforms.widgets import CheckboxInput, TextArea
from wtforms_components import (
ColorField,
DateField,
DateIntervalField,
DateTimeField,
DateTimeIntervalField,
DateTimeLocalField,
DecimalField,
DecimalIntervalField,
EmailField,
IntegerField,
IntIntervalField,
SelectField,
StringField,
TimeField,
)
from wtforms_components.widgets import (
ColorInput,
DateInput,
DateTimeInput,
DateTimeLocalInput,
EmailInput,
NumberInput,
TextInput,
TimeInput,
)
from .exc import (
AttributeTypeException,
InvalidAttributeException,
UnknownTypeException,
)
from .fields import CountryField, PhoneNumberField, WeekDaysField
from .utils import (
choice_type_coerce_factory,
ClassMap,
flatten,
is_date_column,
is_number,
is_number_range,
is_scalar,
null_or_unicode,
strip_string,
translated_attributes,
)
[docs]
class FormGenerator:
"""
Base form generator, you can make your own form generators by inheriting
this class.
"""
# When converting SQLAlchemy types to fields this ordered dict is iterated
# in given order. This allows smart type conversion of different inherited
# type objects.
TYPE_MAP = ClassMap(
(
(sa.types.UnicodeText, TextAreaField),
(sa.types.BigInteger, IntegerField),
(sa.types.SmallInteger, IntegerField),
(sa.types.Text, TextAreaField),
(sa.types.Boolean, BooleanField),
(sa.types.Date, DateField),
(sa.types.DateTime, DateTimeField),
(sa.types.Enum, SelectField),
(sa.types.Float, FloatField),
(sa.types.Integer, IntegerField),
(sa.types.Numeric, DecimalField),
(sa.types.Unicode, StringField),
(sa.types.String, StringField),
(sa.types.Time, TimeField),
(sa.types.JSON, TextAreaField),
(types.ArrowType, DateTimeField),
(types.ChoiceType, SelectField),
(types.ColorType, ColorField),
(types.CountryType, CountryField),
(types.DateRangeType, DateIntervalField),
(types.DateTimeRangeType, DateTimeIntervalField),
(types.EmailType, EmailField),
(types.IntRangeType, IntIntervalField),
(types.NumericRangeType, DecimalIntervalField),
(types.PasswordType, PasswordField),
(types.PhoneNumberType, PhoneNumberField),
(types.ScalarListType, StringField),
(types.URLType, StringField),
(types.UUIDType, StringField),
(types.WeekDaysType, WeekDaysField),
)
)
WIDGET_MAP = OrderedDict(
(
(BooleanField, CheckboxInput),
(ColorField, ColorInput),
(DateField, DateInput),
(DateTimeField, DateTimeInput),
(DateTimeLocalField, DateTimeLocalInput),
(DecimalField, NumberInput),
(EmailField, EmailInput),
(FloatField, NumberInput),
(IntegerField, NumberInput),
(TextAreaField, TextArea),
(TimeField, TimeInput),
(StringField, TextInput),
)
)
def __init__(self, form_class):
"""
Initializes the form generator
:param form_class: ModelForm class to be used as the base of generation
process
"""
self.form_class = form_class
self.model_class = self.form_class.Meta.model
self.meta = self.form_class.Meta
self.TYPE_MAP.update(self.form_class.Meta.type_map)
[docs]
def create_form(self, form):
"""
Creates the form.
:param form: ModelForm instance
"""
attrs = OrderedDict()
for key, property_ in sa.inspect(self.model_class).attrs.items():
if not isinstance(property_, ColumnProperty):
continue
if self.skip_column_property(property_):
continue
attrs[key] = property_
for attr in translated_attributes(self.model_class):
attrs[attr.key] = attr.property
return self.create_fields(form, self.filter_attributes(attrs))
[docs]
def filter_attributes(self, attrs):
"""
Filter set of model attributes based on only, exclude and include
meta parameters.
:param attrs: Set of attributes
"""
if self.meta.only:
attrs = OrderedDict(
[
(key, prop)
for key, prop in map(self.validate_attribute, self.meta.only)
if key
]
)
else:
if self.meta.include:
attrs.update(
[
(key, prop)
for key, prop in map(self.validate_attribute, self.meta.include)
if key
]
)
if self.meta.exclude:
for key in self.meta.exclude:
try:
del attrs[key]
except KeyError:
if self.meta.attr_errors:
raise InvalidAttributeException(key)
return attrs
[docs]
def validate_attribute(self, attr_name):
"""
Finds out whether or not given sqlalchemy model attribute name is
valid. Returns attribute property if valid.
:param attr_name: Attribute name
"""
try:
attr = getattr(self.model_class, attr_name)
except AttributeError:
try:
translation_class = self.model_class.__translatable__["class"]
attr = getattr(translation_class, attr_name)
except AttributeError:
if self.meta.attr_errors:
raise InvalidAttributeException(attr_name)
else:
return None, None
try:
if not isinstance(attr.property, ColumnProperty):
if self.meta.attr_errors:
raise InvalidAttributeException(attr_name)
else:
return None, None
except AttributeError:
raise AttributeTypeException(attr_name)
return attr_name, attr.property
[docs]
def create_fields(self, form, properties):
"""
Creates fields for given form based on given model attributes.
:param form: form to attach the generated fields into
:param attributes: model attributes to generate the form fields from
"""
for key, prop in properties.items():
column = prop.columns[0]
try:
field = self.create_field(prop, column)
except UnknownTypeException:
if not self.meta.skip_unknown_types:
raise
else:
continue
if not hasattr(form, key):
setattr(form, key, field)
[docs]
def skip_column_property(self, column_property):
"""
Whether or not to skip column property in the generation process.
:param column_property: SQLAlchemy ColumnProperty object
"""
if column_property._is_polymorphic_discriminator:
return True
return self.skip_column(column_property.columns[0])
[docs]
def skip_column(self, column):
"""
Whether or not to skip column in the generation process.
:param column_property: SQLAlchemy Column object
"""
if not self.meta.include_foreign_keys and column.foreign_keys:
return True
if not self.meta.include_primary_keys and column.primary_key:
return True
if (
not self.meta.include_datetimes_with_default
and isinstance(column.type, sa.types.DateTime)
and column.default
):
return True
if isinstance(column.type, types.TSVectorType):
return True
if self.meta.only_indexed_fields and not self.has_index(column):
return True
# Skip all non columns (this is the case when using column_property
# methods).
if not isinstance(column, sa.Column):
return True
return False
[docs]
def has_index(self, column):
"""
Whether or not given column has an index.
:param column: Column object to inspect the indexes from
"""
if column.primary_key or column.foreign_keys:
return True
table = column.table
for index in table.indexes:
if len(index.columns) == 1 and column.name in index.columns:
return True
return False
[docs]
def create_field(self, prop, column):
"""
Create form field for given column.
:param prop: SQLAlchemy ColumnProperty object.
:param column: SQLAlchemy Column object.
"""
kwargs = {}
field_class = self.get_field_class(column)
kwargs["default"] = self.default(column)
kwargs["validators"] = self.create_validators(prop, column)
kwargs["filters"] = self.filters(column)
kwargs.update(self.type_agnostic_parameters(prop.key, column))
kwargs.update(self.type_specific_parameters(column))
if prop.key in self.meta.field_args:
kwargs.update(self.meta.field_args[prop.key])
if issubclass(field_class, DecimalField):
if hasattr(column.type, "scale"):
kwargs["places"] = column.type.scale
field = field_class(**kwargs)
return field
[docs]
def default(self, column):
"""
Return field default for given column.
:param column: SQLAlchemy Column object
"""
if column.default and is_scalar(column.default.arg):
return column.default.arg
else:
if not column.nullable:
return self.meta.default
[docs]
def filters(self, column):
"""
Return filters for given column.
:param column: SQLAlchemy Column object
"""
should_trim = column.info.get("trim", None)
filters = column.info.get("filters", [])
if (
isinstance(column.type, sa.types.String)
and self.meta.strip_string_fields
and should_trim is None
) or should_trim is True:
filters.append(strip_string)
return filters
[docs]
def date_format(self, column):
"""
Returns date format for given column.
:param column: SQLAlchemy Column object
"""
if isinstance(column.type, sa.types.DateTime) or isinstance(
column.type, types.ArrowType
):
return self.meta.datetime_format
if isinstance(column.type, sa.types.Date):
return self.meta.date_format
[docs]
def type_specific_parameters(self, column):
"""
Returns type specific parameters for given column.
:param column: SQLAlchemy Column object
"""
kwargs = {}
if (
hasattr(column.type, "enums")
or column.info.get("choices")
or isinstance(column.type, types.ChoiceType)
):
kwargs.update(self.select_field_kwargs(column))
date_format = self.date_format(column)
if date_format:
kwargs["format"] = date_format
if hasattr(column.type, "region"):
kwargs["region"] = column.type.region
kwargs["widget"] = self.widget(column)
return kwargs
[docs]
def widget(self, column):
"""
Returns WTForms widget for given column.
:param column: SQLAlchemy Column object
"""
widget = column.info.get("widget", None)
if widget is not None:
return widget
kwargs = {}
step = column.info.get("step", None)
if step is not None:
kwargs["step"] = step
else:
if isinstance(column.type, sa.types.Numeric):
if column.type.scale is not None and not column.info.get("choices"):
kwargs["step"] = self.scale_to_step(column.type.scale)
if kwargs:
widget_class = self.WIDGET_MAP[self.get_field_class(column)]
return widget_class(**kwargs)
[docs]
def scale_to_step(self, scale):
"""
Returns HTML5 compatible step attribute for given decimal scale.
:param scale: an integer that defines a Numeric column's scale
"""
return str(pow(Decimal("0.1"), scale))
[docs]
def type_agnostic_parameters(self, key, column):
"""
Returns all type agnostic form field parameters for given column.
:param column: SQLAlchemy Column object
"""
kwargs = {}
kwargs["description"] = column.info.get("description", "")
kwargs["label"] = column.info.get("label", key)
return kwargs
[docs]
def select_field_kwargs(self, column):
"""
Returns key value args for SelectField based on SQLAlchemy column
definitions.
:param column: SQLAlchemy Column object
"""
kwargs = {}
kwargs["coerce"] = self.coerce(column)
if isinstance(column.type, types.ChoiceType):
choices = column.type.choices
if (
Enum is not None
and isinstance(choices, type)
and issubclass(choices, Enum)
):
kwargs["choices"] = [(choice.value, str(choice)) for choice in choices]
else:
kwargs["choices"] = choices
elif "choices" in column.info and column.info["choices"]:
kwargs["choices"] = column.info["choices"]
else:
kwargs["choices"] = [(enum, enum) for enum in column.type.enums]
return kwargs
[docs]
def coerce(self, column):
"""
Returns coerce callable for given column
:param column: SQLAlchemy Column object
"""
if "coerce" in column.info:
return column.info["coerce"]
if isinstance(column.type, types.ChoiceType):
return choice_type_coerce_factory(column.type)
try:
python_type = column.type.python_type
except NotImplementedError:
return null_or_unicode
if column.nullable and issubclass(python_type, str):
return null_or_unicode
return python_type
[docs]
def create_validators(self, prop, column):
"""
Returns validators for given column
:param column: SQLAlchemy Column object
"""
validators = [
self.required_validator(column),
self.length_validator(column),
self.unique_validator(prop.key, column),
self.range_validator(column),
]
if isinstance(column.type, types.EmailType):
validators.append(self.get_validator("email"))
if isinstance(column.type, types.URLType):
validators.append(self.get_validator("url"))
validators = flatten([v for v in validators if v is not None])
validators.extend(self.additional_validators(prop.key, column))
return validators
[docs]
def required_validator(self, column):
"""
Returns required / optional validator for given column based on column
nullability and form configuration.
:param column: SQLAlchemy Column object
"""
if (
not self.meta.all_fields_optional
and not column.default
and not column.nullable
):
type_map = self.meta.not_null_validator_type_map
try:
return type_map[column.type]
except KeyError:
if isinstance(column.type, sa.types.TypeDecorator):
type_ = column.type.impl
try:
return type_map[type_]
except KeyError:
pass
if self.meta.not_null_validator is not None:
return self.meta.not_null_validator
return self.get_validator("optional")
def get_validator(self, name, **kwargs):
attr_name = f"{name}_validator"
attr = getattr(self.meta, attr_name)
if attr is None:
return attr
return attr(**kwargs)
[docs]
def additional_validators(self, key, column):
"""
Returns additional validators for given column
:param key: String key of the column property
:param column: SQLAlchemy Column object
"""
validators = []
if key in self.meta.validators:
try:
validators.extend(self.meta.validators[key])
except TypeError:
validators.append(self.meta.validators[key])
if "validators" in column.info and column.info["validators"]:
try:
validators.extend(column.info["validators"])
except TypeError:
validators.append(column.info["validators"])
return validators
[docs]
def unique_validator(self, key, column):
"""
Returns unique validator for given column if column has a unique index
:param key: String key of the column property
:param column: SQLAlchemy Column object
"""
if column.unique:
return self.get_validator(
"unique",
column=getattr(self.model_class, key),
get_session=self.form_class.get_session,
)
[docs]
def range_validator(self, column):
"""
Returns range validator based on column type and column info min and
max arguments
:param column: SQLAlchemy Column object
"""
min_ = column.info.get("min")
max_ = column.info.get("max")
if min_ is not None or max_ is not None:
if is_number(column.type) or is_number_range(column.type):
return self.get_validator("number_range", min=min_, max=max_)
elif is_date_column(column):
return self.get_validator("date_range", min=min_, max=max_)
elif isinstance(column.type, sa.types.Time):
return self.get_validator("time_range", min=min_, max=max_)
[docs]
def length_validator(self, column):
"""
Returns length validator for given column
:param column: SQLAlchemy Column object
"""
if (
isinstance(column.type, sa.types.String)
and hasattr(column.type, "length")
and column.type.length
):
return self.get_validator("length", max=column.type.length)
[docs]
def get_field_class(self, column):
"""
Returns WTForms field class. Class is based on a custom field class
attribute or SQLAlchemy column type.
:param column: SQLAlchemy Column object
"""
if "form_field_class" in column.info and column.info["form_field_class"]:
return column.info["form_field_class"]
if "choices" in column.info and column.info["choices"]:
return SelectField
if column.type not in self.TYPE_MAP and isinstance(
column.type, sa.types.TypeDecorator
):
check_type = column.type.impl
else:
check_type = column.type
try:
column_type = self.TYPE_MAP[check_type]
if inspect.isclass(column_type) and issubclass(column_type, Field):
return column_type
else:
return column_type(column)
except KeyError:
raise UnknownTypeException(column)