Source code for drowsy.fields
"""
drowsy.fields
~~~~~~~~~~~~~
Marshmallow fields used in resource schemas.
"""
# :copyright: (c) 2016-2025 by Nicholas Repole and contributors.
# See AUTHORS for more details.
# :license: MIT - See LICENSE for more details.
from marshmallow import EXCLUDE
from marshmallow.fields import Field, missing_
from marshmallow.utils import get_value
from marshmallow_sqlalchemy.fields import Related, ensure_list
from sqlalchemy import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import with_parent
from drowsy.base import EmbeddableMixinABC, NestedPermissibleABC
from drowsy.log import Loggable
[docs]
class EmbeddableRelationshipMixin(EmbeddableMixinABC):
"""Defaults to returning a relationship's URL if not embedded."""
[docs]
def get_url(self, obj):
"""Get the URL for this relationship.
Very likely that you'll want to override this.
:param obj: The parent object being serialized.
"""
url = ""
if self.parent and "self" in self.parent.fields:
url += self.parent.fields["self"].serialize("self", obj)
relationship_name = self.data_key or self.name
url += "/" + relationship_name
return url
def _deserialize_unembedded(self, value, *args, **kwargs):
"""Determine how to deserialize when the field isn't embedded.
:param value: The value being deserialized.
:param args: Any positional arguments that were passed to
the deserializer method.
:param kwargs: Any keyword arguments that were passed to
the deserializer method.
:return: The attr of the parent instance unmodified.
"""
return getattr(self.parent.instance, self.name)
def _serialize_unembedded(self, attr, obj, *args, **kwargs):
"""Determine how to serialize when the field isn't embedded.
:param str attr: The attibute or key to get from the object.
:param obj: The parent object to pull the key from.
:param args: Any positional arguments that were passed to
the serializer method.
:param kwargs: Any keyword arguments that were passed to
the serializer method.
:return: The url for this relationship.
"""
return self.get_url(obj)
[docs]
def deserialize(self, value, *args, **kwargs):
"""Return the field's deserialized value.
:param value: The value provided by the user for this field.
If it's the field's URL, the value is essentially ignored.
"""
# This isn't exactly perfect, seeing as someone could
# POST/PATCH/PUT with a string that isn't a valid url,
# and it would simply be ignored rather than raising
# an error.
if self.required and not self.parent.partial:
self.embedded = True
elif isinstance(value, str):
self.embedded = False
return super(EmbeddableRelationshipMixin, self).deserialize(
value, *args, **kwargs
)
[docs]
class NestedRelated(NestedPermissibleABC, Related):
"""A nested relationship field for use in `ModelResourceSchema`."""
[docs]
def __init__(self, nested, load_default=missing_, dump_default=missing_,
exclude=tuple(), only=None, many=False, columns=None,
permissions_cls=None, **kwargs):
"""Initialize a nested related field.
:param nested: The Schema class or class name (string) to nest,
or ``"self"`` to nest a :class:`~marshmallow.schema.Schema`
within itself.
:param default: Default value to use if attribute is missing.
:param exclude: Fields to exclude.
:type exclude: list, tuple, or None
:param only: A tuple or string of the field(s) to marshal. If
``None``, all fields will be marshalled. If a field name
(string) is given, only a single value will be returned as
output instead of a dictionary. This parameter takes
precedence over ``exclude``.
:type only: tuple, str, or None
:param bool many: Whether the field is a collection of objects.
:param list columns: Optional column names on related model.
If not provided, the primary key(s) of the related model
will be used.
:param permissions_cls: The class of permissions to apply to
this relationship. Defaults to allowing all relationship
operation. May be used to limit the operations that can
be done.
:param kwargs: The same keyword arguments that
:class:`~marshmallow.fields.Field` receives.
"""
super(NestedRelated, self).__init__(
nested=nested,
load_default=load_default,
dump_default=dump_default,
exclude=exclude,
only=only,
many=many,
permissions_cls=permissions_cls,
**kwargs)
self.columns = ensure_list(columns or [])
@property
def model(self):
"""The model associated with this relationship."""
schema = self.parent
return schema.opts.model
@property
def related_keys(self):
"""Gets a list of id keys associated with this nested obj.
Note the hierarchy of id keys to return:
1. If the attached schema for this nested field has an id_keys
attr, use those keys.
2. Else, if this field had a columns arg passed when
initialized, use those column names.
3. Else, use the primary key columns.
"""
# schema here is for this nested field, not the parent.
columns = [
self.related_model.__mapper__.columns[key_name]
for key_name in self.schema.id_keys
]
return [
self.related_model.__mapper__.get_property_by_column(column)
for column in columns
]
def _get_resource_kwargs(self):
"""Get kwargs for creating a resource for this instance.
:return: Dictionary of keyword argument to be passed
to a resource initializer.
:rtype: dict
"""
result = super(NestedRelated, self)._get_resource_kwargs()
result["session"] = self.session
return result
def _parent_contains_child(self, parent, instance, relationship_name):
"""Checks if the parent relation contains the given instance.
Only the relationship this field is related to is checked.
:param parent: An object whose relationship for this field may
contain this instance as a child object.
:param instance: A potential child object of the parent.
:param str relationship_name: The name of the relationship
we're checking on `parent`.
:return: ``True`` if the parent attr already contains the
instance, ``False`` otherwise.
:rtype: bool
"""
with_parentable = False
if self.parent.instance is not None:
if inspect(self.parent.instance).persistent:
if relationship_name in inspect(self.parent.instance).unloaded:
if self.many:
with_parentable = True
elif getattr(self.parent.instance, relationship_name):
with_parentable = True
if with_parentable:
filters = []
for column in self.related_keys:
filters.append(
getattr(self.related_model, column.key) ==
getattr(instance, column.key)
)
query = select(self.related_model).where(
with_parent(
self.parent.instance,
getattr(self.parent.opts.model, relationship_name)
)
).where(*filters)
in_relation_instance = self.session.execute(
query).scalars().first()
if in_relation_instance == instance:
return True
return False
else:
if isinstance(getattr(parent, self.name), list):
if instance in getattr(parent, self.name):
return True
else:
return False
elif getattr(parent, self.name) == instance:
return True
return False
def _get_identified_instance(self, obj_data):
"""Get a formed instance using the unique identifier of the obj.
:param obj_data: Likely a dict, but could be any user provided
data.
:return: A SQLAlchemy object instance based on the identifier
contained in the user supplied data, or ``None`` if no
instance could be found.
"""
# TODO - What if get_instance returns None due to permissions
# The object itself would still exist in DB, but our nested
# field would try to create it anyways. This will result in an
# error, but it won't be very informative...
# Maybe that's ok, since such a scenario would almost certainly
# involve a malicious actor.
with self.session.no_autoflush:
# If the parent object hasn't yet been persisted,
# autoflush can cause an error since it is yet
# to be fully formed.
return self.schema.get_instance(data=obj_data)
def _perform_operation(self, operation, parent, instance, errors, index,
in_place, strict=True):
"""Perform an operation on the parent with a supplied instance.
Example:
If this field corresponds to the `tracks` collection of a parent
``album`` object, then the provided instance should be a
``track`` object, and the action might be to ``"add"`` or
``"remove"`` the provided ``track`` instance from the parent
``album``.
:param str operation: ``"add"`` or ``"remove"`` for collections,
or ``"set"`` for one to one relations. May also be any
custom operations manually defined.
:param parent: Object containing the attribute the operation
is being performed on.
:param instance: A potential child object of the parent.
:param index: If the relationship is of the many variety, the
index at which this child was in the input.
:type index: int or None
:param dict errors: Dict of errors for this field. Any issue
that arises while performing the intended operation will
be added to this dict (at the provided index if supplied).
:param bool strict: If ``True``, an exception will be raised for
an encountered error. Otherwise, the error will simply be
included in the provided ``errors`` dict and things will
proceed as normal.
:param bool in_place: Provided to indicate whether the nested
data is being modified in-place (``True``) or completely
overridden (``False``).
:raise ValidationError: If there's an error when in strict mode.
:return: The corresponding attr for this field with the provided
operation performed on it.
"""
is_instance_in_relation = self._parent_contains_child(
parent, instance, self.name)
# Permissions were good, load caused no problems
# Now perform the actual operation.
if operation == "remove":
if is_instance_in_relation:
if in_place:
relation = getattr(parent, self.name)
relation.remove(instance)
# else:
# no need to remove if not in-place change, as the
# list will already be empty.
elif strict:
self._handle_op_failure(
"invalid_remove",
errors=errors,
index=index,
strict=strict)
elif operation is None or operation == "add" or (
operation == "set" and not self.many):
if not is_instance_in_relation:
if self.many:
relation = getattr(parent, self.name)
relation.append(instance)
else:
setattr(parent, self.name, instance)
# elif operation == "add" and strict:
# self._handle_op_failure(
# "invalid_add",
# errors=errors,
# index=index,
# strict=strict)
elif strict:
self._handle_op_failure(
"invalid_operation",
errors=errors,
index=index,
strict=strict)
return getattr(parent, self.name)
def _load_existing_instance(self, obj_data, instance):
"""Deserialize the provided data into an existing instance.
:param obj_data: Likely a dict, but could be any user provided
data.
:param instance: A SQLAlchemy object instance. The data provided
will be loaded into this instance.
:return: Any errors that came up, and the instance.
"""
return self.schema.load(
obj_data,
session=self.session,
instance=instance,
partial=True,
many=False,
unknown=EXCLUDE)
def _load_new_instance(self, obj_data):
"""Deserialize the provided data into a new SQLAlchemy instance.
:param obj_data: Likely a dict, but could be any user provided
data.
:return: Any errors that came up, and the instance.
"""
return self.schema.load(
obj_data,
session=self.session,
instance=self.related_model(),
partial=False,
many=False,
unknown=EXCLUDE)
@property
def context(self):
"""The context dictionary for the parent :class:`Schema`."""
return self.parent.context
[docs]
class Relationship(EmbeddableRelationshipMixin, NestedRelated):
"""Default relationship field.
When serialized, returns the relationship's assumed URL if not
embedded. Otherwise, returns the nested values of the relationship.
"""
pass
[docs]
class APIUrl(Field, Loggable):
"""Text field, displays the url of the resource it's attached to."""
[docs]
def __init__(self, endpoint_name, base_url=None, *args, **kwargs):
"""Initializes an APIUrl field.
:param str endpoint_name: The name of this URL endpoint and
where to access this resource.
:param base_url: A str or callable (with no args) that returns
the base part of the API url: https://example.com/api
:type base_url: str or callable
:param args: Any field arguments to be passed to the super
constructor.
:param kwargs: Any field keyword arguments to be passed to
the super constructor.
"""
super(APIUrl, self).__init__(*args, **kwargs)
self.endpoint_name = endpoint_name
self.base_url = base_url
[docs]
def serialize(self, attr, obj, accessor=None, **kwargs):
"""Serialize an API url.
:param str attr: The attribute name of this field. Unused.
:param obj: The object to pull any needed info from.
:param accessor: Function used to pull values from ``obj``.
Defaults to :func:`~marshmallow.utils.get_value`.
:type accessor: callable or None
:raise ValidationError: In case of formatting problem.
:return: The serialized API url value.
"""
accessor_func = accessor or get_value
id_keys = self.parent.id_keys
result = self.base_url or ""
if result and result[-1] == "/":
result = result[:-1]
result += "/" + self.endpoint_name
for column in id_keys:
if hasattr(obj, column):
val = accessor_func(obj, column, missing_)
result += "/" + str(val)
return result