"""
drowsy.query_builder
~~~~~~~~~~~~~~~~~~~~
Tools for building SQLAlchemy queries.
"""
# :copyright: (c) 2016-2025 by Nicholas Repole and contributors.
# See AUTHORS for more details.
# :license: MIT - See LICENSE for more details.
from collections import defaultdict
import sqlite3
from drowsy.fields import NestedRelated
from drowsy.log import Loggable
from drowsy.parser import SortInfo, SubfilterInfo
from drowsy.utils import get_field_by_data_key
from sqlalchemy import and_, func, or_, select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import aliased, contains_eager, subqueryload
from sqlalchemy.orm.interfaces import MANYTOMANY, ONETOMANY
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
from sqlalchemy.sql.selectable import Alias, Subquery
from mqlalchemy import (
InvalidMqlException, MqlBuilder, MqlFieldError,
MqlFieldPermissionError, MqlTooComplex)
[docs]
def manipulate_filters_to_list(filters):
"""Return filters as a list of filters, regardless of input format.
:param filters: SQLAlchemy filters to be applied to a query
:type filters: SQLAlchemy expression, list, tuple, or None
"""
if isinstance(filters, tuple):
return list(filters)
if filters is None:
return []
if not isinstance(filters, list):
return [filters]
return filters
[docs]
class QueryBuilder(Loggable):
"""Utility class for building a SQLAlchemy query."""
[docs]
def row_number_supported(self, dialect, dialect_override=None):
"""Given a SQL dialect, figure out if row_number is supported.
:param str dialect: SQL dialect being used, e.g. ``"mssql"``.
:param bool dialect_override: Can be used to force this method
to return ``True`` or ``False``.
:return: ``True`` if the dialect supports ``row_number`` or if
``dialect_override`` is ``True``. Otherwise ``False``.
:rtype: bool
"""
if dialect_override is not None:
return dialect_override
supported_dialects = ["mssql", "postgresql", "oracle", "mysql",
"mariadb"]
sqlite_ver = sqlite3.sqlite_version_info
if sqlite_ver[0] > 3 or (sqlite_ver[0] == 3 and sqlite_ver[1] >= 25):
supported_dialects.append("sqlite") # pragma: no cover
return dialect.lower() in supported_dialects
def _get_order_bys(self, record_class, sorts, convert_key_names_func):
"""Helper method for applying sorts.
:param record_class: The type of model being sorted.
:type record_class: :class:`~sqlalchemy.orm.util.AliasedClass`
or SQLAlchemy model class.
:param sorts: A list of sorts.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param convert_key_names_func: Used to convert key names.
See :func:`~drowsy.parser.parse_filters`.
:type convert_key_names_func: callable
:raise AttributeError: If a sort with an invalid attr name is
provided.
:return: A list of order_by parameters to be applied to a query.
:rtype: list
"""
result = list()
for sort in sorts:
attr_name = convert_key_names_func(sort.attr)
if attr_name is not None and hasattr(record_class, attr_name):
if sort.direction == "ASC":
result.append(getattr(record_class, attr_name).asc())
else:
result.append(getattr(record_class, attr_name).desc())
else:
raise AttributeError("Invalid attribute.")
return result
[docs]
def apply_sorts(self, query, sorts, convert_key_names_func=str):
"""Apply sorts to a provided query.
:param query: A SQLAlchemy query; filters must already have been
applied.
:type query: :class:`~sqlalchemy.orm.query.Query`
:param sorts: A list of sorts to apply to this query.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param convert_key_names_func: Used to convert key names.
See :func:`~drowsy.parser.parse_filters`.
:type convert_key_names_func: callable
:raise AttributeError: If a sort with an invalid attr name is
provided.
:raise ValueError: If a sort not of type
:class:`~drowsy.parser.SortInfo` is provided, or if
`query` isn't of a single model type.
:return: A modified version of the provided query object.
:rtype: :class:`~sqlalchemy.orm.query.Query`
"""
entities = [c["entity"] for c in query.column_descriptions]
if len(entities) == 1:
record_class = entities[0]
order_bys = self._get_order_bys(
record_class, sorts, convert_key_names_func)
for order_by in order_bys:
query = query.order_by(order_by)
else:
raise ValueError
return query
[docs]
def apply_offset(self, query, offset):
"""Applies offset and limit to the query if appropriate.
:param query: Any desired filters must already have been
applied.
:type query: :class:`~sqlalchemy.orm.query.Query`
:param offset: Integer used to offset the query result.
:type offset: int or None
:raise ValueError: If a non ``None`` offset is provided
that is converted to a negative integer.
:raise TypeError: If a non ``None`` offset is provided
of a non integer, or integer convertible, type.
:return: A modified query object with an offset applied.
:rtype: :class:`~sqlalchemy.orm.query.Query`
"""
if offset is not None:
offset = int(offset)
if offset < 0:
raise ValueError("offset can not be a negative integer.")
query = query.offset(offset)
return query
[docs]
def apply_limit(self, query, limit):
"""Applies limit to the query if appropriate.
:param query: Any desired filters must already have been
applied.
:type query: :class:`~sqlalchemy.orm.query.Query`
:param limit: Integer used to limit the number of results
returned.
:type limit: int or None
:raise ValueError: If a non ``None`` limit is provided
that is converted to a negative integer.
:raise TypeError: If a non ``None`` offset is provided
of a non integer, or integer convertible, type.
:return: A modified query object with an limit applied.
:rtype: :class:`~sqlalchemy.orm.query.Query`
"""
if limit is not None:
limit = int(limit)
if limit < 0:
raise ValueError("limit can not be a negative integer.")
query = query.limit(limit)
return query
def _generate_filters(self, model_class, filters, whitelist=None,
nested_conditions=None, stack_size_limit=100,
convert_key_names_func=str, gettext=None):
"""Apply filters to a query using MQLAlchemy.
:param model_class: The model having filters applied to it.
:param filters: The MQLAlchemy style filters to apply.
:type filters: dict or None
:param whitelist: Used to determine what attributes are
acceptable to be queried.
:type whitelist: callable, list, set, or None
:param nested_conditions: Callable accepting one param, or
dict, where the key/param is a dot separated relationship
name, and the return value is any required SQL expressions
for filtering that relationship.
:type nested_conditions: callable, dict, or None
:param stack_size_limit: Used to limit the allowable complexity
of the applied filters.
:type stack_size_limit: int or None
:param callable convert_key_names_func: Used to convert the
attr names from user input (perhaps in camelCase) to the
model format (likely in under_score format).
:param gettext: Used to translate any errors.
:type gettext: callable or None
:raise InvalidMQLException: Raised in cases where invalid
filters were supplied.
:return: The query with filters applied.
"""
return MqlBuilder().parse_mql_filters(
model_class,
filters=filters,
nested_conditions=nested_conditions,
whitelist=whitelist,
stack_size_limit=stack_size_limit,
convert_key_names_func=convert_key_names_func,
gettext=gettext
)
[docs]
def apply_filters(self, model_class, query, filters, whitelist=None,
nested_conditions=None, stack_size_limit=100,
convert_key_names_func=str, gettext=None):
"""Apply filters to a query using MQLAlchemy.
:param query: A SQLAlchemy session or query.
:param model_class: The model having filters applied to it.
:param filters: The MQLAlchemy style filters to apply.
:type filters: dict or None
:param whitelist: Used to determine what attributes are
acceptable to be queried.
:type whitelist: callable, list, set, or None
:param nested_conditions: Callable accepting one param, or
dict, where the key/param is a dot separated relationship
name, and the return value is any required SQL expressions
for filtering that relationship.
:type nested_conditions: callable, dict, or None
:param stack_size_limit: Used to limit the allowable complexity
of the applied filters.
:type stack_size_limit: int or None
:param callable convert_key_names_func: Used to convert the
attr names from user input (perhaps in camelCase) to the
model format (likely in under_score format).
:param gettext: Used to translate any errors.
:type gettext: callable or None
:raise InvalidMQLException: Raised in cases where invalid
filters were supplied.
:return: The query with filters applied.
"""
return MqlBuilder().apply_mql_filters(
model_class,
query,
filters=filters,
nested_conditions=nested_conditions,
whitelist=whitelist,
stack_size_limit=stack_size_limit,
convert_key_names_func=convert_key_names_func,
gettext=gettext
)
[docs]
class ModelResourceQueryBuilder(QueryBuilder):
"""Class for building a SQLAlchemy query by using resources."""
[docs]
class SubqueryNode(Loggable):
"""Used for building a query with subresource filters included.
Not intended for external access.
"""
[docs]
def __init__(self, name, alias=None, parent=None, children=None,
subquery=None, limit=None, limit_source=None, offset=None,
sorts=None, convert_key_name=None, whitelist=None,
relationship_direction=None, joined=False, option=None,
join=None, filters=None, id_keys=None, strategy=None,
unaliased_filters=None):
"""
:param str name: Name of the subresource.
:param alias:
:type alias:
:param parent: The parent of this SubqueryNode. Always has
a value, unless this is the root node.
:type parent: SubqueryNode or None
:param children: Holds any child subresource nodes.
:type children: list of SubqueryNode or None
:param subquery: An aliased SQLAlchemy select statement.
:param limit: Limit the number of results for this
subresource.
:type limit: int or None
:param str limit_source: Whether the limit was provided by
the `"user"` or was a "`default`".
:param offset: Offset to apply to the query for this node.
:type offset: int or None
:param sorts: Any sorts to apply to the query for this node.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param convert_key_name: Function for converting from
user facing field names to internal field names.
:type convert_key_name: callable or None
:param whitelist: What fields are acceptable to be queried
for this model.
:type whitelist: callable or None
:param bool joined: ``True`` if the subquery has been
joined into our final query.
:param relationship_direction:
:type relationship_direction:
:param option:
:type option:
"""
self.parent = parent
self.name = name
self.alias = alias
self.children = children or []
self.subquery = subquery
self.limit = limit
self.limit_source = limit_source
self.offset = offset
self.sorts = sorts
self.convert_key_name = convert_key_name
self.whitelist = whitelist
self.relationship_direction = relationship_direction
self.joined = joined
self.option = option
self.join = join or []
self.id_keys = id_keys or []
self.filters = filters or []
self.unaliased_filters = unaliased_filters or []
self.strategy = strategy
[docs]
def build(self, query, resource, filters, subfilters, embeds=None,
offset=None, limit=None, sorts=None, strict=True,
stack_size_limit=100, dialect_override=None):
"""Build a complex query using user supplied parameters.
:param query: A SQLAlchemy query.
:param resource: Base resource containing the sub resources
that are to be filtered.
:type resource: :class:`~drowsy.resource.BaseModelResource`
:param filters: The MQLAlchemy style filters to apply.
:type filters: dict or None
:param subfilters: Dictionary of filters to apply to embedded
subresources, with the subresource dot separated name as the
key.
:type subfilters: dict or None
:param embeds: List of subresources and fields to embed.
:type embeds: list or None
:param offset: Integer used to offset the query result.
:type offset: int or None
:param limit: Integer used to limit the number of results
returned.
:type limit: int or None
:param sorts: A list of sorts to apply to this query.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param bool strict: If ``True``, will raise an exception when
bad parameters are passed. If ``False``, will quietly ignore
any bad input and treat it as if none was provided.
:param stack_size_limit: Used to limit the allowable complexity
of the applied filters.
:type stack_size_limit: int or None
:param bool|None dialect_override: ``True`` will build query
with row_number support regardless of any db limitations.
If ``False``, will avoid using row_number even if the
database supports it. Mainly used for testing.
:return: query with joins, load options, and subresource
filters applied as appropriate.
:raise BadRequestError: Uses the provided resource to raise
an error when subfilters or embeds are unable to be
successfully applied.
:raise ValueError: Due to programmer error. Generally
Only raised if one of the above parameters is
of the wrong type.
"""
# apply filters
try:
query = self.apply_filters(
resource.model,
query,
filters=filters,
nested_conditions=resource.get_required_nested_filters,
whitelist=resource.whitelist,
stack_size_limit=stack_size_limit,
convert_key_names_func=resource.convert_key_name,
gettext=resource.context.get("gettext", None))
except InvalidMqlException as exc:
self._handle_filter_errors(
resource=resource,
exc=exc)
query = resource.apply_required_filters(query)
if subfilters or embeds:
# more complex process.
# don't apply offset/limit/sorts here
# will need to be taken care of by apply_subquery_loads
query = self.apply_subquery_loads(
query=query,
resource=resource,
subfilters=subfilters,
embeds=embeds,
offset=offset,
limit=limit,
sorts=sorts,
strict=strict,
dialect_override=dialect_override
)
else:
# simple query, apply offset/limit/sorts now
if not sorts and offset is not None:
sorts = []
for key in resource.schema.id_keys:
attr = resource.schema.fields.get(key).data_key or key
sorts.append(SortInfo(attr=attr))
if sorts:
for sort in sorts:
if not isinstance(sort, SortInfo):
raise TypeError("Each sort must be of type SortInfo.")
try:
query = self.apply_sorts(
query, [sort], resource.convert_key_name)
except AttributeError:
if strict:
raise resource.make_error(
"invalid_sort_field", field=sort.attr)
try:
query = self.apply_offset(query, offset)
except ValueError:
if strict:
raise resource.make_error(
"invalid_offset_value", offset=offset)
try:
query = self.apply_limit(query, limit)
except ValueError:
if strict:
raise resource.make_error(
"invalid_limit_value", limit=limit)
return query
def _get_many_to_many_join(self, child, child_model, parent, parent_model,
relationship, assoc_queryable):
parent_expressions = []
join = []
if isinstance(relationship.prop.primaryjoin, BinaryExpression):
parent_expressions.append(relationship.prop.primaryjoin)
elif isinstance(relationship.prop.primaryjoin, BooleanClauseList):
parent_expressions = relationship.prop.primaryjoin.clauses
for expression in parent_expressions:
if assoc_queryable == expression.right.table:
parent_expr = expression.left
parent_col_name = expression.left.name
child_expr = expression.right
child_col_name = expression.right.name
else:
parent_expr = expression.right
parent_col_name = expression.right.name
child_expr = expression.left
child_col_name = expression.left.name
if isinstance(parent, AliasedClass) or isinstance(parent, Subquery):
if not hasattr(parent, 'c'):
parent_selectable = inspect(parent).selectable
else:
parent_selectable = parent
parent_expr = getattr(
parent_selectable.c,
parent_col_name)
if isinstance(child, AliasedClass) or isinstance(child, Subquery):
# TODO - maybe loop through columns and identify one(s)
# not originally part of the child (e.g. come from assoc)
child_selectable = child if hasattr(child, "c") else inspect(
child).selectable
# check if child obj and assoc table both have the
# same column name.
child_model_cols = select(child_model).selected_columns
if hasattr(child_model_cols, child_col_name) and hasattr(
assoc_queryable.c, child_col_name):
child_col_name = child_col_name + "_1"
child_expr = getattr(
child.c,
child_col_name)
join.append(child_expr == parent_expr)
return join
def _get_partition_by_info(self, child, parent, relationship_name):
"""Get the partition_by needed for row_number in a subquery.
Also returns the queryable and join condition to use for a
MANYTOMANY relationship using an association table.
:param child: The child entity being subqueried.
:type child: :class:`~sqlalchemy.orm.util.AliasedClass`
:param parent: The parent of the child entity.
:type parent: :class:`~sqlalchemy.orm.util.AliasedClass` or
SQLAlchemy model class.
:param str relationship_name: Field name of the relationship.
:raises ValueError: If the ``child``, ``parent``, and
``relationship_name`` can not be used to produce
a valid result.
:return: The partition_by, queryable, and join to use
for a subquery join and row_number to limit
subquery results.
:rtype: tuple
"""
relationship = getattr(parent, relationship_name)
partition_by, queryable, join = (None, None, None)
if relationship.prop.direction == MANYTOMANY:
# For relationship Node.children:
# Given assoc table to child join:
# t_NodeToNode.ChildNodeId = Node.NodeId and
# t_NodeToNode.ChildCompositeId = Node.CompositeId
# and parent to assoc table join:
# t_NodeToNode.NodeId = Node.NodeId and
# t_NodeToNode.CompositeId = Node.CompositeId
# We want to extract:
# queryable: t_NodeToNode
# join: child.node_id == t_NodeToNode.ChildNodeId and
# child.composite_id == t_NodeToNode.ChildCompositeId
# partition_by: t_NodeToNode.NodeId,
# t_NodeToNode.CompositeId
# This ultimately allows us to use row_number to limit
# a subresource.
primary_expressions = []
secondary_expressions = []
joins = []
# Break the primaryjoin and secondaryjoin down
# into lists of expressions.
if isinstance(relationship.prop.primaryjoin, BinaryExpression):
primary_expressions.append(relationship.prop.primaryjoin)
elif isinstance(relationship.prop.primaryjoin, BooleanClauseList):
primary_expressions = relationship.prop.primaryjoin.clauses
if isinstance(relationship.prop.secondaryjoin, BinaryExpression):
secondary_expressions.append(relationship.prop.secondaryjoin)
elif isinstance(relationship.prop.secondaryjoin, BooleanClauseList):
secondary_expressions = relationship.prop.secondaryjoin.clauses
# default until proven otherwise:
child_expressions = secondary_expressions
parent_expressions = primary_expressions
for expression in child_expressions:
# Find the association table
# Then figure out the join condition
left_table = expression.left.table
right_table = expression.right.table
child_table = inspect(child).mapper.local_table
if left_table == child_table:
child_side = expression.left
assoc_side = expression.right
queryable = right_table
elif right_table == child_table:
child_side = expression.right
assoc_side = expression.left
queryable = left_table
else:
# Shouldn't ever get here...
raise ValueError # pragma: no cover
child_insp = inspect(inspect(child).class_)
for column_key in child_insp.columns.keys():
if child_insp.columns[column_key].key == child_side.key:
child_condition = getattr(child, column_key)
joins.append(assoc_side == child_condition)
# Now get the partition by...
partition_by = []
# First, figure out the join conditions
for expression in parent_expressions:
assoc_side = None
if isinstance(expression, BinaryExpression):
# find if left or right is the assoc table
left_table = expression.left.table
right_table = expression.right.table
parent_table = inspect(parent).mapper.local_table
if left_table == parent_table:
# parent_side = expression.left
assoc_side = expression.right
elif right_table == parent_table:
# parent_side = expression.right
assoc_side = expression.left
if assoc_side is None:
# Either no assoc_side was found, or
# this wasn't a BinaryExpression
raise ValueError # pragma: no cover
partition_by.append(assoc_side)
if not joins or queryable is None or not partition_by:
# To reach this, one of the following conditions
# must be met:
#
# 1. partition_by is empty because parent_expressions
# was empty
#
# 2. queryable is None because child_expressions was
# empty.
#
# 3. joins is None because something unpredictable
# happened with the child_expressions.
raise ValueError # pragma: no cover
if len(joins) == 1:
join = joins[0]
else:
join = and_(*joins)
elif relationship.prop.direction == ONETOMANY:
# For relationship Album.tracks:
# Given primary (assoc table to child):
# Album.AlbumId = Track.AlbumId
# We want to extract:
# queryable: None
# join: None
# partition_by: track.album_id,
# This ultimately allows us to use row_number to limit
# a subresource.
primary_expressions = []
partition_by = []
join = None
queryable = None
# First, figure out the join conditions
if isinstance(relationship.prop.primaryjoin, BinaryExpression):
primary_expressions.append(relationship.prop.primaryjoin)
elif isinstance(relationship.prop.primaryjoin, BooleanClauseList):
primary_expressions = relationship.prop.primaryjoin.clauses
for expression in primary_expressions:
# find if left or right is the parent side
remote_side = relationship.prop.remote_side
left_table = expression.left.table
right_table = expression.right.table
child_table = inspect(child).mapper.local_table
if left_table == child_table and right_table == child_table:
# Self referential one to many...
if expression.right in remote_side:
child_side = expression.right
else:
child_side = expression.left # pragma: no cover
elif left_table == child_table:
child_side = expression.left
elif right_table == child_table:
child_side = expression.right
else:
# Shouldn't ever get here...
raise ValueError # pragma: no cover
child_insp = inspect(inspect(child).class_)
for column_key in child_insp.columns.keys():
if child_insp.columns[column_key].key == child_side.key:
partition_by.append(getattr(child, column_key))
if not partition_by:
# Also shouldn't ever get here...
raise ValueError # pragma: no cover
return partition_by, queryable, join
def _handle_filter_errors(self, resource, exc, subfilter_key=None):
"""Helper method to handle raising MQL related errors.
:param resource: Root resource filters are being applied to.
:type resource: :class:`~drowsy.resource.BaseModelResource`
:param exc: The exception raised by MQLAlchemy.
:type exc: :class:`~mqlalchemy.exc.InvalidMqlException`
:param subfilter_key: If the filters were being applied to a
subresource, provide the unconverted dot separated name
of that subresource.
:type subfilter_key: str or None
:raise DrowsyError: In all cases.
"""
try:
raise exc
except MqlFieldPermissionError as exc:
raise resource.make_error(
"filters_permission_error",
exc=exc,
subresource_key=subfilter_key)
except MqlFieldError as exc:
if exc.op:
raise resource.make_error(
"filters_field_op_error",
exc=exc,
subresource_key=subfilter_key)
else:
raise resource.make_error(
"filters_field_error",
exc=exc,
subresource_key=subfilter_key)
except MqlTooComplex as exc:
raise resource.make_error(
"filters_too_complex",
exc=exc,
subresource_key=subfilter_key)
except InvalidMqlException as exc: # pragma: no cover
raise resource.make_error(
"invalid_filters",
exc=exc,
subresource_key=subfilter_key)
def _initiate_subquery(self, query, resource, offset, limit, sorts,
supported, strict=True):
"""Handles query limit/offset/sorts for different dialects.
To apply a limit or offset to a query that intends to load
nested results, we must either use row_number, or a multi query
process to first grab results matching the parent table,
and then a separate query to join to those results.
:param query: A SQLAlchemy query.
:param resource: Base resource containing the sub resources
that are to be filtered.
:type resource: :class:`~drowsy.resource.BaseModelResource`
:param offset: Integer used to offset the query result.
:type offset: int or None
:param limit: Integer used to limit the number of results
returned.
:type limit: int or None
:param sorts: A list of sorts to apply to this query.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param bool supported: ``True`` if ``row_number`` is supported
by the SQL engine being used.
:param bool strict: If ``True``, will raise an exception when
bad parameters are passed. If ``False``, will quietly ignore
any bad input and treat it as if none was provided.
:return: A query that has had the proper pagination logic
applies before any subresources are embedded or filtered.
:rtype: :class:`~sqlalchemy.orm.query.Query`
"""
record_class = resource.model
schema = resource.make_schema()
id_keys = schema.id_keys
order_bys = []
if sorts:
for sort in sorts:
try:
order_bys = self._get_order_bys(
record_class, [sort], resource.convert_key_name)
except AttributeError:
if strict:
raise resource.make_error(
"invalid_sort_field", field=sort.attr)
# always sort by id_keys to ensure deterministic order
for attr_name in id_keys:
already_sorted = False
if sorts:
for sort in sorts:
if sort.attr == attr_name:
already_sorted = True
break
if not already_sorted:
order_bys.append(
getattr(
record_class,
attr_name).asc()
)
if limit is not None and limit < 0:
if strict:
raise resource.make_error("invalid_limit_value", limit=limit)
limit = None
if offset is not None and offset < 0:
if strict:
raise resource.make_error("invalid_offset_value",
offset=offset)
offset = None
if limit or offset:
if supported:
# Use row_number to figure out which rows to pull
row_number = func.row_number().over(
order_by=order_bys
).label("row_number")
query = query.add_columns(row_number)
# limit and offset handling
start = 1
if offset is not None:
start = offset + 1
end = None
if limit is not None:
end = start + limit - 1
# Remove row_number from select list
# NOTE - Not sure if there's a better way to do this...
entities = []
for col in query.column_descriptions:
if col["name"] != "row_number":
if col["entity"] not in entities:
entities.append(col["entity"])
# Query from self, allowing us to filter by row_number
# Only include non row_number expressions in SELECT
sq = query.subquery()
query = select(aliased(entities[0], sq))
if start:
query = query.where(sq.c.row_number >= start)
if end:
query = query.where(sq.c.row_number <= end)
order_bys = [sq.c.row_number]
else:
# Unable to use row_number, so unfortunately we have to
# run an actual query with limit/offset/order applied,
# and use those results to build our new query.
# Super inefficient.
temp_query = query
for order_by in order_bys:
temp_query = temp_query.order_by(order_by)
if limit:
temp_query = self.apply_limit(temp_query, limit)
if offset:
temp_query = self.apply_offset(temp_query, offset)
# TODO - maybe deprecate this whole section...
session = resource.session
results = session.execute(temp_query).scalars().all()
if len(id_keys) > 1:
filters = []
for result in results:
conditions = []
for id_key in id_keys:
conditions.append(
getattr(record_class, id_key) ==
getattr(result, id_key)
)
filters.append(
and_(*conditions)
)
if filters:
query = query.where(or_(*filters))
else:
# in condition
id_key = id_keys[0]
values = [getattr(r, id_keys[0]) for r in results]
if values:
query = query.where(
getattr(record_class, id_key).in_(values))
# query = query.from_self()
for order_by in order_bys:
query = query.order_by(order_by)
return query
[docs]
def apply_subquery_loads(self, query, resource, subfilters, embeds=None,
offset=None, limit=None, sorts=None,
strict=True, stack_size_limit=100,
dialect_override=None):
"""Apply joins, load options, and subfilters to a query.
:param query: A SQLAlchemy query.
:param resource: Base resource containing the sub resources
that are to be filtered.
:type resource: :class:`~drowsy.resource.BaseModelResource`
:param subfilters: Dictionary of filters to apply to embedded
subresources, with the subresource dot separated name as the
key.
:param embeds: List of subresources and fields to embed.
:type embeds: list or None
:param offset: Integer used to offset the query result.
:type offset: int or None
:param limit: Integer used to limit the number of results
returned.
:type limit: int or None
:param sorts: A list of sorts to apply to this query.
:type sorts: list of :class:`~drowsy.parser.SortInfo`
:param bool strict: If ``True``, will raise an exception when
bad parameters are passed. If ``False``, will quietly ignore
any bad input and treat it as if none was provided.
:param stack_size_limit: Used to limit the allowable complexity
of the applied filters.
:type stack_size_limit: int or None
:param bool|None dialect_override: ``True`` will build query
with row_number support regardless of any db limitations.
If ``False``, will avoid using row_number even if the
database supports it. Mainly used for testing.
:return: query with joins, load options, and subresource
filters applied as appropriate.
:rtype: :class:`~sqlalchemy.orm.query.Query`
:raise BadRequestError: Uses the provided resource to raise
an error when subfilters or embeds are unable to be
successfully applied.
:raise ValueError: Due to programmer error. Generally
Only raised if one of the above parameters is
of the wrong type.
"""
# NOTE: This is heavily dependent on ModelResource, which isn't
# a great indication of good design. Should figure out how to
# better separate these concerns.
embeds = embeds or []
subfilters = subfilters or {}
dialect_supported = self.row_number_supported(
dialect=resource.session.bind.dialect.name,
dialect_override=dialect_override)
query = self._initiate_subquery(
query=query,
resource=resource,
offset=offset,
limit=limit,
sorts=sorts,
supported=dialect_supported,
strict=strict)
root = self.SubqueryNode(
alias=query.column_descriptions[0]["entity"],
name="$root"
)
model_count = defaultdict(int)
root_resource = resource
embeds = [e for e in embeds if e not in subfilters.keys()]
subfilter_keys = embeds + list(subfilters.keys())
subquery_tracker = {}
for subfilter_key in subfilter_keys:
resource = root_resource
schema = resource.make_schema()
split_subfilter_keys = subfilter_key.split(".")
last_node = root
subfilter_info = subfilters.get(subfilter_key)
is_embed = subfilter_key not in subfilters.keys()
if is_embed:
user_supplied_offset = None
user_supplied_limit = None
user_supplied_filters = None
user_supplied_sorts = None
elif isinstance(subfilter_info, SubfilterInfo):
user_supplied_offset = subfilter_info.offset
user_supplied_limit = subfilter_info.limit
user_supplied_filters = subfilter_info.filters
user_supplied_sorts = subfilter_info.sorts
else:
raise ValueError(
"Each subfilter in subfilters must be a SubfilterInfo.")
failed = False
while split_subfilter_keys and not failed:
failed = False
split_key = split_subfilter_keys.pop(0)
field = get_field_by_data_key(
schema=schema,
data_key=split_key)
if isinstance(field, NestedRelated):
resource = resource.make_subresource(name=split_key)
schema = resource.make_schema()
for node in last_node.children:
if node.name == field.name:
# node already exists in last_node children
last_node = node
# no need to update it, skip the below else
break
else:
# Need to add this node to last_node children
resource_model = schema.opts.model
model_count[resource_model] += 1
# NOTE - embedded_page_max_size for resources?
# No real clean way to limit embedded count
if dialect_supported:
default_limit = resource.page_max_size
else:
default_limit = None
unaliased_relationship = getattr(
inspect(last_node.alias).class_,
field.name)
if unaliased_relationship in subquery_tracker:
raise root_resource.make_error(
"invalid_subresource_multi_embed",
subresource_key=subfilter_key)
subquery_tracker[unaliased_relationship] = True
relationship = getattr(last_node.alias, field.name)
relationship_direction = relationship.prop.direction
new_node = self.SubqueryNode(
parent=last_node,
name=field.name,
alias=aliased(
resource_model,
name=(
resource_model.__name__ +
str(model_count[resource_model])
)
),
limit=default_limit,
limit_source="default",
offset=user_supplied_offset,
sorts=user_supplied_sorts,
children=[],
convert_key_name=resource.convert_key_name,
whitelist=resource.whitelist,
relationship_direction=relationship_direction,
id_keys=resource.schema.id_keys
)
# This takes care of embedding when is_embed
new_node.subquery = resource.apply_required_filters(
query=select(new_node.alias),
alias=new_node.alias
).subquery(inspect(new_node.alias).name)
if default_limit is not None and (
user_supplied_limit is not None
and
user_supplied_limit > default_limit):
if strict:
raise root_resource.make_error(
"invalid_subresource_limit",
supplied_limit=user_supplied_limit,
max_limit=resource.page_max_size,
subresource_key=subfilter_key)
user_supplied_limit = default_limit
elif user_supplied_limit is not None:
new_node.limit = user_supplied_limit
new_node.limit_source = "user"
last_node.children.append(new_node)
last_node = new_node
if not split_subfilter_keys and not is_embed:
# Default subquery likely to be overridden below
# This will get used in situations where
# no limit or offset is provided, since
# row_number won't be needed for pagination.
subquery = select(last_node.alias)
subquery = resource.apply_required_filters(
subquery,
alias=last_node.alias)
# Sort should only be provided with offset/limit
if strict and last_node.sorts is not None and (
last_node.offset is None and
last_node.limit is None):
raise root_resource.make_error(
"invalid_subresource_sorts",
subresource_key=subfilter_key)
# Start figuring out if we need row_number
try:
partition_by, queryable, join_condition = (
self._get_partition_by_info(
last_node.alias,
last_node.parent.alias,
last_node.name
)
)
except ValueError: # pragma: no cover
# Currently have no way to test this.
# Open to suggestions...
if not strict:
failed = True
continue
raise root_resource.make_error(
"invalid_subresource",
subresource_key=subfilter_key
)
# Check the status of the above partition_by
if partition_by is None and (
last_node.limit is not None or
last_node.offset is not None):
# This is a MANYTOONE and user supplied
# limit or offset. Fail if strict.
if strict:
raise root_resource.make_error(
"invalid_subresource_options",
subresource_key=subfilter_key
)
# else will continue without any row_number
# Valid partition by generated,
# time to use row_number
if partition_by is not None and (
last_node.limit is not None or
# 1 is not None or # testing
last_node.offset is not None):
if not dialect_supported and strict:
# dialect doesn't support limit/offset
# fail accordingly
raise root_resource.make_error(
"invalid_subresource_options",
subresource_key=subfilter_key)
# if not strict, the below is skipped
# default subquery gets used
elif dialect_supported:
# NOTE - We're building up to using
# row_number for limiting/offsetting
# the subresource.
# Order by to be used in row_number
order_by = []
if last_node.sorts:
# Use sorts from user if provided
order_by = self._get_order_bys(
last_node.alias,
last_node.sorts,
resource.convert_key_name
)
# always sort by id_keys for determinism
attr_names = schema.id_keys
for attr_name in attr_names:
already_sorted = False
if last_node.sorts:
for sort in last_node.sorts:
if sort.attr == attr_name:
already_sorted = True
break
if not already_sorted:
order_by.append(
getattr(
last_node.alias,
attr_name).asc()
)
row_number_expr = func.row_number().over(
partition_by=partition_by,
order_by=order_by
).label("row_number")
# Build the complete query in one go
if queryable is None:
q1 = select(
last_node.alias,
row_number_expr
)
else:
q1 = select(
last_node.alias,
queryable,
row_number_expr
).join(queryable, join_condition)
nested_conditions = (
resource.get_required_nested_filters)
q1 = resource.apply_required_filters(
q1, alias=last_node.alias)
try:
q1 = self.apply_filters(
query=q1,
model_class=last_node.alias,
whitelist=resource.whitelist,
nested_conditions=nested_conditions,
filters=user_supplied_filters,
convert_key_names_func=(
resource.convert_key_name),
stack_size_limit=stack_size_limit
)
except InvalidMqlException as exc:
if strict:
# Otherwise bad filters ignored
self._handle_filter_errors(
resource=resource,
exc=exc,
subfilter_key=subfilter_key)
q1 = q1.subquery("q1")
# limit and offset handling
start = 1
if last_node.offset is not None:
start = last_node.offset + 1
end = None
if last_node.limit is not None:
end = start + last_node.limit - 1
subquery = select(q1).where(
q1.c.row_number >= start)
if end is not None:
subquery = subquery.where(
q1.c.row_number <= end
)
last_node.subquery = subquery.subquery(
inspect(last_node.alias).name)
child = last_node.subquery
parent = last_node.parent.subquery
if parent is None:
# root node won't have a subquery
parent = last_node.parent.alias
if queryable is not None:
relationship = getattr(
last_node.parent.alias,
last_node.name)
last_node.join = (
self._get_many_to_many_join(
child=child,
child_model=last_node.alias,
parent=parent,
parent_model=last_node.parent.alias,
relationship=relationship,
assoc_queryable=queryable
))
else:
try:
req_filters = resource.get_required_filters(
alias=last_node.alias)
new_filters = self._generate_filters(
model_class=last_node.alias,
filters=user_supplied_filters,
nested_conditions=(
resource.get_required_nested_filters),
whitelist=resource.whitelist,
stack_size_limit=stack_size_limit,
convert_key_names_func=(
resource.convert_key_name)
)
# hacky....need a smarter way to do this
clean_req_filters = (
resource.get_required_filters())
clean_new_filters = self._generate_filters(
model_class=inspect(
last_node.alias).class_,
filters=user_supplied_filters,
nested_conditions=(
resource.get_required_nested_filters),
whitelist=resource.whitelist,
stack_size_limit=stack_size_limit,
convert_key_names_func=(
resource.convert_key_name)
)
last_node.filters = manipulate_filters_to_list(
req_filters) + manipulate_filters_to_list(
new_filters)
last_node.unaliased_filters = (
manipulate_filters_to_list(
clean_req_filters) +
manipulate_filters_to_list(
clean_new_filters))
if last_node.filters:
subquery = subquery.where(
*last_node.filters)
last_node.subquery = subquery.subquery(
inspect(last_node.alias).name)
except InvalidMqlException as exc:
if not strict:
failed = True
continue
self._handle_filter_errors(
resource=resource,
exc=exc,
subfilter_key=subfilter_key)
else:
if is_embed:
if not split_subfilter_keys:
# this is an embed ending in an attribute
# we're fine to continue
continue
else:
if not strict:
failed = True
continue
# strict mode
# invalid embed, fail accordingly.
raise root_resource.make_error(
"invalid_embed",
embed=subfilter_key
)
else:
if not strict:
failed = True
continue
# subresource isn't valid, fail
raise root_resource.make_error(
"invalid_subresource",
subresource_key=subfilter_key)
# build options and joins
subfilter_options = []
nodes = list()
node_queue = [root]
while node_queue:
node = node_queue.pop(0)
nodes.append(node)
for child in node.children:
node_queue.append(child)
for node in nodes:
if node.name == "$root":
continue
strategy = self._decide_strategy(node, model_count)
relationship = getattr(inspect(node.parent.alias).class_,
node.name)
if strategy == "join":
left = node.parent.subquery
if left is None:
left = node.parent.alias
join_info = relationship.prop._create_joins(
source_selectable=inspect(left).selectable,
dest_selectable=inspect(node.subquery).selectable,
source_polymorphic=True,
of_type_entity=inspect(node.alias))
primaryjoin = join_info[0]
secondaryjoin = join_info[1]
secondary = join_info[4]
if secondary is not None:
if node.join:
query = query.outerjoin(
node.subquery, and_(*node.join))
else:
# Default joins used when no limit/offset used for
# subquery
query = query.outerjoin(secondary, primaryjoin)
query = query.outerjoin(node.subquery, secondaryjoin)
else:
query = query.outerjoin(node.subquery, primaryjoin)
entity_relation = getattr(node.parent.alias, node.name)
if node.alias is not None:
entity_relation = entity_relation.of_type(node.alias)
if node.parent and node.parent.option:
node.option = node.parent.option.contains_eager(
entity_relation,
alias=node.subquery)
else:
node.option = contains_eager(
entity_relation,
alias=node.subquery)
elif strategy == "subqueryload":
# Subquery loads are never applied when there's a
# duplicate model type in the load chain, so aliasing
# isn't necessary.
if node.parent and node.parent.name == "$root":
relationship = getattr(node.parent.alias,
node.name)
if node.parent and node.parent.option:
node.option = node.parent.option.subqueryload(
relationship.and_(*node.unaliased_filters))
else:
node.option = subqueryload(
relationship.and_(*node.unaliased_filters))
if not node.children:
subfilter_options.append(node.option)
if subfilter_options:
query = query.options(*subfilter_options)
return query
def _decide_strategy(self, node, model_count):
"""For a given node, decide the load strategy.
Relationships with uselist=False will always use joined eager
loading unless a parent (or grandparent, etc.) used selectinload
or subqueryload.
Relationships with uselist=True will use joined eager loading
if the subresource (as represented by the node) uses a limit
and/or offset, or if any of their descendants need limit/offset.
Otherwise it'll use either selectinload or subqueryload.
The goal here is to avoid exponential row increase on joins
when possible, which joins+contains_eager can cause. There are
probably better ways to optimize this, and ideally we should let
the user define which relationships they want loaded in what way
ahead of time, but for now this should serve as relatively sane
defaults.
:param SubqueryNode node: Corresponds to a nested resource
relationship.
:return: The load strategy to use. Does also modify the node
object by setting its ``strategy`` attr.
:param dict model_count: A dictionary where the key is a
model class, and the value is the number of times that
model is referenced in our data load.
"""
# Find the root model type
# this is needed to prevent unaliased self referential relations
# from using subqueryload/selectinload...only an issue when the
# root model is the same as a child model.
parent = node.parent
root = node
while parent:
root = parent
parent = parent.parent
unaliased_root = inspect(root.alias).class_
unaliased_node = inspect(node.alias).class_
if node.limit or node.offset or (
node.parent and node.parent.strategy == "join") or (
unaliased_node == unaliased_root):
node.strategy = "join"
return node.strategy
# TODO - strategy = "selectinload"
strategy = "subqueryload"
children_limit_offset = False
children_composite_key = False
children_self_ref_root = False
duplicate_model = False
node_queue = [node]
while node_queue:
# loop through nodes
current_node = node_queue.pop(0)
for child in current_node.children:
unaliased_child = inspect(child.alias).class_
duplicate_model = model_count[unaliased_child] > 1
children_self_ref_root = unaliased_root == unaliased_child
children_limit_offset = bool(child.limit or child.offset)
children_composite_key = bool(len(child.id_keys) > 1 and
child.children)
if children_limit_offset and children_composite_key:
# NOTE - Not sure why we do this, but seems to be ok
# for now. Need to further investigate long term.
# Result of hitting this code is falling back to a
# safe join, so skipping code coverage here.
node_queue = [] # pragma: no cover
break # pragma: no cover
else:
node_queue.append(child)
if children_limit_offset or children_self_ref_root or duplicate_model:
# stuck using contains eager to be safe...
# needs refinement later, but this will get the job done
node.strategy = "join"
return node.strategy
elif children_composite_key: # pragma: no cover
# we're going to have to use subqueryload at some point
# since you can't selectinload a composite key
# SQLAlchemy doesn't allow a subqueryload after a
# selectinload, so all non join parents have to be
# subqueryload
# NOTE - has no effect now, may be useful if we build in
# selectinload support as the default.
node.strategy = "subqueryload"
return node.strategy
node.strategy = strategy
return node.strategy