Source code for disbi.join

Contains the class Relations that stores relations between models
and joins them together appropriately.
# standard library
from collections import deque
from itertools import product

# Django
from django.apps import apps
from django.conf import settings
from django.db import connection, models

from disbi.db_utils import (exec_query, get_field_query_name,
                            get_fk_query_name, get_m2m_field,
from disbi.models import MetaModel
from disbi.option_utils import get_models_of_superclass

[docs]class Relations(): """ Stores relations of models and has the ability to join them. """ # pylint: disable=too-many-instance-attributes # Eight is reasonable in this case. def __init__(self, app_label, model_superclass=None): """ Initialize Relations. Args: app_label (str): The label of the app the models live in. model_superclass (iterable of type): The classes from which all models of interest derive. Defaults to None. """ self.app_label = app_label self.relation_map = self._construct_relation_map(model_superclass) self.sql = '' self.m2m_count = 0 def _construct_relation_map(self, model_superclass): """ Store the relations bewteen models in a 2-D hashmap. Args: model_superclass (iterable of type): The classes from which all models of interest derive. Returns: dict: A 2-D hashmap where either the ForeignKey field, the intermediary model or False is stored for each two models. Depending on whether their relation is 1:N, N:M or they have no relation. """ if model_superclass is not None: self.models = get_models_of_superclass(self.app_label, model_superclass) else: self.models = apps.get_app_config(self.app_label).get_models() # Use dict as hashmap for adjacency matrix. relation_map = dict((t, None) for t in product(self.models, self.models)) # Set model pairs that have no realtion to False. for pair in relation_map.keys(): model_a, model_b = pair # If the pair is a mapping to itself, set entry to False. if model_a == model_b: relation_map[pair] = False continue # Introspect relation between the two models. model_a, model_b = pair # Get concrete relationship fields. relation_fields = [f for f in model_a._meta.get_fields() if f.is_relation and f.concrete] for field in relation_fields: if field.related_model == model_b: # Set the relation in both directions, so it does not # matter how it is specified in if field.many_to_many: relation_map[pair] = field.remote_field.through elif field.many_to_one: relation_map[pair] = field relation_map[tuple(reversed(pair))] = relation_map[pair] # If nothing has been found set relation to False. if relation_map[pair] is None: relation_map[pair] = False return relation_map #TODO: could be refactored with later depth first search. @property def is_tree(self): """ Determine whether the relation_map is a tree. Perform some setup and then call the depth first search, starting from an arbitrary node. Returns: bool: True if the graph is a tree. If it contains at least on cycle or is not connected, return False. """ self.nodes = self.models self.visited = set() start = self.nodes[0] cyclic = self._df_search(start, None) connected = (self.visited == set(self.nodes)) return (not cyclic and connected) def _df_search(self, current, parent): """ Depth first search for cycles in a relation map. Args: current: The node that is searched in this function call. parent: The node from which the current node was called as child. Returns: bool: True if the graph contains at least on cycle, else False. """ # Find all childs of the node. childs = set() for node in self.nodes: if (node is not parent and node is not current and self.relation_map[current, node]): childs.add(node) # Check if one of the childs was already visited, which would # denote a cycle. for child in childs: if child in self.visited: return True # Now the node has been checked and can be added to the # visited set. self.visited.add(current) # Search each child node recursively. for child in childs: if self._df_search(child, current): return True return False def _get_children(self, parent, group, visited): """ Get all children of a parent model in a specific group of models. Args: parent (Model): The model for which the childs are searched. Returns: set: All child models of the parent model. """ children = [] for model in group: if (model is not parent and model not in visited and self.relation_map[parent, model]): children.append(model) return children
[docs] def start_join(self): """ Start the join process between all models. """ # Confirm that DB backend is Postgres. if not connection.vendor == 'postgresql': raise ValueError('Database system {} not supported. Use PostgreSQL.'.format(connection.vendor)) # First check whether the graph formed by the relation_map # is actually a tree. if not self.is_tree: raise ValueError('The graph formed by the relationships of the ' 'models is no tree.') # list that stores all models in a linear order, # determining the order the columns are displayed in the end. self.linearized = [] # Get the model that was set as the root of the relation tree. for model in self.models: if getattr(model, 'di_first', False): first = model break self.linearized.append(first) # Construct the from statement with the root model. self._construct_SQL(first) # Start the JOIN. self._join(first) # Add the select clause. self.sql = self._construct_SELECT() + self.sql
def _join(self, model): """ Depth first, ForeignKey first traversal trough the relation tree. Construct the appropriate SQL JOIN statement for each child of the model. Then call the function directly for each child. Args: model (Model): The current model in the traversal. """ childs = self._get_prioritized_childs(model) for child in childs: self._construct_SQL(model, child) self.linearized.append(child) self._join(child) def _get_prioritized_childs(self, parent): """ Get all childs of a model. Prioritize ForeignKeys over ManyToMany relations. Args: parent (Model): The model for which the childs are searched and prioritized. Returns: deque: All child models of the parent model. N:1 related models first, N:M related models last. """ childs = deque() for model in self.models: if (model not in self.linearized and self.relation_map[(parent, model)]): if isinstance(self.relation_map[(parent, model)], models.ForeignKey): childs.appendleft(model) else: childs.append(model) return childs def _construct_SQL(self, model_a, model_b=None): """ Construct SQL statement for JOINing two models depending on their relation. Args: model_a (Model): The model that comes first in the JOIN statement. model_a (Model): The model that comes second in the JOIN statement. Defaults to None. """ # If only one model is given, it is the start of the SQL statement. # Insert the model in a FROM clause. if model_b is None: sql = ''' FROM %s ''' % model_a._meta.db_table # Two models are given, a JOIN clause is required. else: sql = self._construct_PostgreSQLJOIN(model_a, model_b) # Add the newly constructed SQL to the models attribute. self.sql += sql def _construct_PostgreSQLJOIN(self, model_a, model_b): """ Construct SQL statement for JOINing two N:M related models in PostgreSQL dialect. Args: model_a (Model): The model that comes first in the JOIN statement. model_a (Model): The model that comes second in the JOIN statement. intermediary_model (Model): The model mediating the N:M relation between model_a and model_b. Returns: str: The constructed SQL JOIN statement. """ # When dealing with a N:1 relation a FULL JOIN with the other model # is sufficient. if isinstance(self.relation_map[(model_a, model_b)], models.ForeignKey): sql = ''' FULL JOIN %s ON %s = %s ''' % (model_b._meta.db_table, get_fk_query_name(model_a, model_b), get_pk_query_name(model_b), ) # When dealing with a N:M relation a FULL JOIN spanning # the intermediary model is required. elif issubclass(self.relation_map[(model_a, model_b)], models.Model): intermediary_model = self.relation_map[(model_a, model_b)] # When dealing with a N:M relation a FULL JOIN spanning # the intermediary model is required. intermediary_model = self.relation_map[(model_a, model_b)] sql = ''' FULL JOIN %s ON %s = %s FULL JOIN %s ON %s = %s ''' % (intermediary_model._meta.db_table, get_pk_query_name(model_a), get_field_query_name(intermediary_model, get_m2m_field(intermediary_model, model_a)), model_b._meta.db_table, get_field_query_name(intermediary_model, get_m2m_field(intermediary_model, model_b)), get_pk_query_name(model_b), ) else: raise ValueError('Unsupported relation. The relation between {} and {} is neither ForeignKey nor ManyToMany.'. format(model_a, model_b)) return sql def _construct_directions(self, num): """ Return a list of lists of directions needed for a FULL OUTER JOIN of `num` N:M related tables. Moving from LEFT to RIGHT. Args: num (int): The number of tables to be joined together. Returns: list of lists of str: The directions by which the tables have to be joined together. """ if num < 2: err_msg = 'Tried to join {number} tables. Less than 2 tables cannot be joined.' raise ValueError(err_msg.format(number=num)) directions = [] num += 1 for i in range(1, num+1): left = num - i right = num - 1 - left directions.append(((['RIGHT'] * right) + ['LEFT'] * left)) return directions def _format_direction(self, direction): """ Return a tuple were each entry of a list is duplicated. This is done because the direction of a join spans two tables, the intermediary and the the related table, in N:M relations. Args: direction (list): A list of strings of the direction needed for one UNION statement. Returns: tuple: Each entry of directions list duplicated. """ duplicated_direction = tuple((lr, lr) for lr in direction) return sum(duplicated_direction, ()) def _construct_SELECT(self): """ Construct the SELECT clause for the SQL JOIN. The order of the SELECTed models is based on the order in that they were linerized. """ column_names = (self._get_column_names(model) for model in self.linearized) # Flatten column_names = sum(column_names, []) select_stmt = 'SELECT DISTINCT %s' % ', '.join(column_names) return select_stmt def _get_column_names(self, model): """ Get the DB column names for the fields of a MeasurementModel of an experiment. Args: model (Model): The model. Returns: list: A list of strings containing the display names. """ # Construct the id column name as 'modelname_id' column_names = [ '%s.%s AS %s_%s' % (model._meta.db_table,, model.__name__.lower(), ] # Add all columns where `show` is True. column_names += [ self._as_display_name(model, field) for field in model._meta.get_fields() if getattr(field, 'di_show', False) ] return column_names def _as_display_name(self, model, field): """ Get the name by which a model field should be fetched from the DB. Either as the plain name as an AS statement with the display_name. Args: model (Model): The model class. field (Field): The field class of the model. Returns: str: The DB name of the field or an AS statement with its display name. """ display_name = getattr(field, 'di_display_name', False) if display_name: return '%s.%s AS %s' % (model._meta.db_table, field.column, display_name) else: return field.column
[docs] def create_joined_table(self): """ Execute the the SQL JOIN and create a table thereof. """ table_name = '%s_%s' % (self.app_label, settings.DISBI['JOINED_TABLENAME']) exec_query('DROP TABLE IF EXISTS %s;' % table_name) sql = ''' CREATE TABLE %s AS %s ''' % (table_name, self.sql) exec_query(sql)