Models from JSON(B)

One of the things we’ll often try to do is reduce the number of database queries. In Django, this is most often done by using the select_related queryset method, which does a join to the related objects, thus returning the data from those, and then automatically creates the instances for those objects.

This works great if you have a foreign key relationship (for instance, you are fetching Employee objects, and you also want to fetch the Company for which they work).

Tt does not work if you are following the reverse of a foreign key, or a many-to-many relation. You can’t select_related to get all of the employee’s locations at which she can work, for instance. In order to get around this, Django also provides a prefetch_related queryset method, that will do a second query and fetch all of the related objects for all of the objects in the initial queryset. This is evaluated at the same time as the initial queryset, so works pretty well during pagination, for example.

But, we don’t always want all of the objects: sometimes we might only want the most recent related object. Perhaps we have a queryset of Employee objects, and we want their most recent EmploymentPeriod. If we just want one field from that object (their start date, for instance), then we can do that using a subquery:

employment_start = EmploymentPeriod.objects.filter(
    employee=OuterRef('pk'),
).order_by('-start').values('start')[:1]

employees = Employee.objects.filter(
    company=company,
).annotate(
    start_date=Subquery(employment_start)
)

We can go a little bit further, and limit to only those currently employed (that is, they have no termination date, or their termination date is in the future).

employment_start = EmploymentPeriod.objects.filter(
    employee=OuterRef('pk'),
).filter(
    models.Q(finish__isnull=True) |
    models.Q(finish__gte=datetime.date.today())
).order_by('-start').values('start')[:1]

employees = Employee.objects.filter(
    company=company,
).filter(
    # New Django 3.0 feature alert!
    Exists(employment_start)
).annotate(
    start_date=Subquery(employment_start)
)

We could rewrite this to remove the OR, perhaps by using a DateRange annotation and using an overlap: but we’d want to ensure there was an index on the table that postgres would be able to use. Alternatively, if we stored our employment period using a date range instead of a pair of date fields, but that makes some of the other queries around this a bit more complicated.

But this does not help us if we want to get the whole related object. We could try to use a Prefetch object to obtain this:

Employee.objects.filter(
    company=company,
).prefetch(models.Prefetch(
    'employment_periods',
    queryset=EmploymentPeriod.objects.order_by('-start')[:1], to_attr='current_employment_periods',
))

But this will not work: because the filtering to the employee’s own employment periods is not applied until after the slice. We could do the ordering, and then just select the first one as the current employment period, but we would then still be returning a bunch of objects when we only need at most one.

It would be excellent if we could force django to join to a subquery (because then we could use select_related), however that’s not currently possible (or likely in the short term). Instead, we can turn our EmploymentPeriod into a JSON object in tha database, and then turn that JSON object into an instance of our target model.

class ToJSONB(Subquery):
    template = '(SELECT to_jsonb("row") FROM (%(subquery)s) "row")'
    output_field = JSONField()


current_employment_period = EmploymentPeriod.objects.filter(
    employee=OuterRef('pk')
).order_by('-start')[:1]

employees = Employee.objects.filter(
    company=company
).annotate(
    current_employment_period=ToJSONB(current_employment_period)
)

This gets us part of the way, as it turns our employment period into a JSON object. Let’s try turning that JSON object into an instance:

class EmployeeQuerySet(models.query.QuerySet):
    def with_current_employment_period(self):
        current_employment_period = EmploymentPeriod.objects.filter(
            employee=OuterRef('pk')
        ).order_by('-start')[:1]

        return self.annotate(
            current_employment_period=ToJSONB(current_employment_period)
        )


class Employee(models.Model):
    # field definitions...

    objects = EmployeeQuerySet.as_manager()

    @property
    def current_employment_period(self):
        return getattr(self, '_current_employment_period', None)

    @current_employment_period.setter
    def current_employment_period(self, value):
        if value:
          self._current_employment_period = EmploymentPeriod(**value)

And now we can query and include the current employment period (as an instance), without having to have an extra query:

>>> Employee.objects.with_current_employment_period()[0].current_employment_period
<EmploymentPeriod: 2019-01-01  >

To improve this further, we could allow assigning an EmploymentPeriod object to the value (instead of just the dict), and if that is the case, we could write that value to the database, but that’s probably going to play havoc with constraints that would prevent overlaps.

Functions as Tables in Django and Postgres

Postgres has some excellent features. One of these is set-returning functions. It’s possible to have a function (written in SQL, or some other language) that returns a set of values. For instance, the in-built function generate_series() returns a set of values:

SELECT day::DATE
  FROM generate_series(now(),
                       now() + INTERVAL '1 month',
                       '1 day') day;

This uses a set returning function as a table source: in this case a single column table.

You can use scalar set-returning functions from within Django relatively easily: I blogged about it last year.

It is possible to create your own set-returning functions. Further, the return type can be a SETOF any table type in your database, or even a “new” table.

CREATE OR REPLACE FUNCTION foo(INTEGER, INTEGER)
RETURNS TABLE(id INTEGER,
              bar_id INTEGER,
              baz JSON[]) AS $$

  SELECT foo.id AS id,
         bar.id AS bar_id,
         ARRAY_AGG(JSON_BUILD_OBJECT(bar.x, foo.y))
    FROM foo
   INNER JOIN bar ON (foo.id = bar.foo_id)
   WHERE foo.y = $1
     AND bar.x > $2
   GROUP BY foo.id, bar.id

$$ LANGUAGE SQL STABLE;

It’s possible to have a Postgres VIEW as the data source for a Django model (you just set the Meta.db_table on the model, and mark it as Meta.managed = False). Using a FUNCTION is a bit trickier.

You can use the QuerySet.raw() method, something like:

qs = Foo.objects.raw('SELECT * FROM foo(%s, %s)', [x, y])

The downside of using raw is you can’t apply annotations, or use .filter() to limit the results.

What would be useful is if you could extract the relevant parameters out of a QuerySet, and inject them as the arguments to the function call.


But why would we want to have one of these set (or table) returning functions? Why not write a view?

I have some complex queries that reference a whole bunch of different tables. In order to be able to write a sane query, I decided to use a CTE. This allows me to write the query in a more structured manner:

WITH foo AS (
  SELECT ...
    FROM foo_bar
   WHERE ...
),
bar AS (
  SELECT ...
    FROM foo
   WHERE ...
   GROUP BY ...
)
SELECT ...
  FROM bar
 WHERE ...

There is a drawback to this approach, specifically how it interacts with Django. We can turn a query like this into a view, but any filtering we want to do using the ORM will only apply after the view has executed. Normally, this is not a problem, because Postgres can “push down” the filters it finds in the query, down into the view.

But older versions of postgres are unable to perform this operation on a CTE. In other words, each clause of a CTE must run (and be fully evaluated at that point in time) before the next one can run. In practice, if a clause of a CTE is not referenced later on, postgres will not execute that clause, but that is the extent of the optimisation.

So, if we had 50 million objects in foo_bar, and we needed to filter them in a dynamic way (ie, from the ORM), we would be unable to do this. The initial clause would execute for all 50 million rows, and then any subsequent clauses would then include all these, and so on. Then, the filtering would happen after the view had returned all it’s rows.

The workaround, using a function, is to use the parameters to do the filtering as early as possible:

CREATE OR REPLACE FUNCTION foo(INTEGER, INTEGER, INTEGER)
RETURNS TABLE(...) AS $$

  WITH foo_1 AS (
    SELECT *
      FROM foo_bar
     WHERE x BETWEEN $1 AND $2
  ),
  bar AS (
    SELECT *
      FROM foo_1
     INNER JOIN baz USING (baz_id)
     WHERE baz.qux > $3
  )

  SELECT ...
    FROM bar
   GROUP BY ...

$$ LANGUAGE SQL STRICT IMMUTABLE;

Notice that we do the filtering on foo_bar as early as we possibly can, and likewise filter the baz the first time we reference it.

Now we have established why we may want to use a function as a model source, how do we go about doing that?


We are going to dig fairly deep into the internals of Django’s ORM now, so tighten up your boots!

When Django comes across a .filter() call, it looks at the arguments, and applies them to a new copy of the QuerySet’s query object: or more specifically, the query.where node. This has a list of children, which Django will turn into SQL and execute later. The QuerySet does some validation at this point: we will only use those fields known to the QuerySet (either through being fields of the Model, or those that added using .annotate()). Any others will result in an exception.

This will require some extension, as it is possible that one or more arguments to a Postgres function may not be fields on the Model class used to define the shape of the results.

Each Node within a QuerySet’s query has a method: .as_sql(). This is the part that turns the python objects into actual SQL code. We are going to leverage the fact that even the python object that refers to the table itself uses .as_sql() to safely allow passing parameters to the function-as-a-table.

Normally, the .as_sql() method of the BaseTable object returns just the name of the table (possibly with the current alias attached), and an empty list as params. We can swap this class out with one that will return an SQL fragment, containing function_name(%s, %s) (with the correct number of %s arguments), and a list containing those parameters.

Every Postgres function has a unique signature: the function name, and the list of parameters; or more specifically, their types. Thus, postgres will deem the functions:

  • foo(INTEGER, INTEGER)
  • foo(INTEGER, INTEGER, BOOLEAN)

as distinct entities. We will ignore for now the fact it is possible to have optional arguments, variadic arguments and polymorphic functions.

We need some way of storing what the signature of a Postgres function is. Initially, I used an analog (perhaps even just a subclass) of Django’s Model class. This enabled me to create (temporary) Exact(Col()) nodes to go in the query.where.children list, to be later removed and used for function arguments. I needed to ignore the primary key field, and it always felt wrong to add to the WhereNode, only to remove later.

I ended up settling on a class like Django’s Form class. It uses a metaclass (but requires a Meta.function_name), and uses Django’s form fields to define the arguments.

class FooFunction(Function):
    x = forms.IntegerField()
    y = forms.IntegerField()
    z = forms.BooleanField(required=False, default=True)

    class Meta:
        function_name = 'foo'

A Function subclass can generate a Manager on a Model class, but also can also create the object that will in turn create the relevant SQL. That part happens automatically, when a queryset created from the manager is filtered using appropriate arguments. The Function subclass uses it’s fields to validate the params that it will be passing are valid for their types. It’s a bit like the clean() method on a Form.

We then also need a model class (it could be a model you have already defined, if your function returns a SETOF <table-name>):

class Foo(models.Model):
    bar = models.ForeignKey('foo.Bar', related_name='+', on_delete=models.DO_NOTHING)
    baz = JSONField()

    objects = FooFunction.as_manager()

    class Meta:
        db_table = 'foo()'
        managed = False

Because this is a new and unmanaged model, then we need to set the on_delete so that the ORM won’t try to cascade deletes, but also mark the model as unmanaged. I also set the Meta.db_table to the function call without arguments, so it looks nicer in repr(queryset). It would be nice to be able to get the actual arguments in there, but I haven’t managed that yet.

If this was just a different way of fetching an existing model, you’d just need to add a new manager on to that model. Keep in mind that Django needs a primary key field, so ensure your function provides one.

Then, we can perform a query:

>>> print(Foo.objects.filter(x=1, y=2).query)
SELECT foo.id, foo.bar_id, foo.baz FROM foo(1, 2)
>>> print(Foo.objects.filter(x=1, y=2, z=True).query)
SELECT foo.id, foo.bar_id, foo.baz FROM foo(1, 2, true)
>>> print(Foo.objects.filter(x=1, y=2, bar__gt=5).query)
SELECT foo.id, foo.bar_id, foo.baz FROM foo(1, 2) WHERE foo.bar_id > 5
>>> print(Foo.objects.filter(x=1, y=2).annotate(qux=models.F('bar__name')).query)
SELECT foo.id, foo.bar_id, foo.baz, bar.name AS qux
FROM foo(1, 2) INNER JOIN bar ON (bar.id = foo.bar_id)

Note that filters that are not function parameters will apply after the function call. You can annotate on, and it will automatically create a join through a foreign key. If you omit a required parameter, then you’ll get an exception.

So, where’s the code for this?

Well, I need to clean it up a bit, and write up some automated tests on it. Expect it soon.

Orders please!

Things are better when they are ordered.

Why might we want things ordered?

From a user experience perspective, if things are not ordered the same way upon repeated viewings, they are liable to get confused, and think that things have changed, when really it’s only the order that has changed.

Some time ago, there were versions of Microsoft Office that used to move menu items closer to the start of the menu that had been used more recently and more frequently. This is not “most recent documents”, or onything like that: this was the menu items like “Format” and other tools. There’s a pretty good rationalé for why they stopped doing that!

Why might we need things ordered?

There are a few reasons why we must have things ordered. One good example is the reason Django will complain about a queryset without ordering when you use pagination: if you select the first 10 items when there is no ordering, and then select the next 10 items, there is no guarantee that those items will not have some of the same items as the first set. Without ordering, pagination is useless. There’s no way to page through items and ensure that you see all of them.

Another reason is where you need to process things in a specific order: either to ensure that the first (or last) matching item is “the winner”. We’ll come across one of those shortly.

Ordering by a specific value or set of values (or function over the same) is usually a sane way to order. This could be alphabetically, by date of insertion, or some numeric value, for instance.

How should we order things?

Sometimes, it is obvious how we should order a list of things. Sometimes the order doesn’t matter, and sometimes we want to allow an ordering that doesn’t rely at all on any attributes of the objects in question.

Let’s look at a list of condition names that might be used for a workplace agreement in Australia.

  • Regular Hours
  • Public Holiday
  • Saturday
  • Sunday
  • After 6pm
  • Regular Overtime (first 2 hours)
  • Extended Overtime (more than 2 hours)

For those outside of Australia: under many work circumstances, employees will get paid at a higher rate because of what are called “Penalty” conditions. This could include being paid, for example, a 15% loading on a Saturday, and a 25% loading on a Sunday. We can rewrite this list with the loading percentages:

Condition Penalty Rule
Regular Hours 100 % When no other rules apply
Public Holiday 200 % Any hours on a public holiday
Saturday 125 % Any hours on a Saturday
Sunday 150 % Any hours on a Sunday
After 6pm 117 % Any hours between 6pm and midnight
Regular Overtime 150 % After 8 hours worked in a day or 38 hours in a week
Extended Overtime 200 % After 2 hours of overtime worked, or any overtime on a Sunday

So now we have two possible ways of ordering our list, by name:

  • After 6pm
  • Extended Overtime (more than 2 hours)
  • Public Holiday
  • Regular Hours
  • Regular Overtime (first 2 hours)
  • Saturday
  • Sunday

Or by penalty rate:

  • 100% Regular Hours
  • 117% After 6pm
  • 125% Saturday
  • 150% Regular Overtime (first 2 hours)
  • 150% Sunday
  • 200% Extended Overtime (more than 2 hours)
  • 200% Public Holiday

However, there’s actually a little more complexity than this. We can have the rules in this order, and apply them sequentially, and we should get the right result for an employee’s wages. But, according to our rules, we should be applying Extended Overtime for any overtime hours on a Sunday. So we need some way of controlling the ordering manually.

Specifically, we want to have our conditions ordered in a way that we can check each rule in turn, and ensure that the last rule that is applied will be the one that should apply, if multiple rules do match.

  • Regular hours
  • After 6pm
  • Saturday
  • Sunday
  • Regular Overtime
  • Extended Overtime
  • Public Holiday

So, how do we actually store this custom ordering?

The naive solution is to store an integer for each item, and ensure that we only have one of each value. Let’s look at that as a Django model:

class Condition(models.Model):
    name = models.TextField()
    penalty = models.DecimalField()
    position = models.IntegerField(unique=True)

Then, we can ensure we get our conditions in the correct order by applying an ordering to our queryset:

queryset = Condition.objects.order_by('position')

However, this solution is not without it’s drawbacks. Primarily, to insert an object between two others, we renumber all of the subsequent items:

condition = Condition(name='New condition', position=3, penalty=Decimal('1.20'))
Condition.objects.filter(position__gte=condition.position).update(position=F('position') + 1)
condition.save()

We must ensure that within our database, we are able to suspend constraints until our transaction is complete, or we must perform actions in a specific order. If we care about not having “holes” between our positions, then we also need to renumber remaining objects when we remove one.

Let’s have a look at how we could handle this, with some client and server code. We can use a formset here, which has ordering permitted, and then we just need the JS to make that happen:

ConditionPositionFormSet = modelformset_factory(Condition, ordering=True, fields=())

class OrderConditionsView(FormView):
    form_class = ConditionPositionFormSet

This will result in rewriting all items one by one, which must happen in a transaction, or in the correct order.


Alternatively, we could use a form that just indicates the new position of a given condition, and have that form handle the UPDATE queries: this will result in just two queries instead.

class ConditionPositionForm(forms.ModelForm):
    class Meta:
        model = Condition
        fields = ('position',)


class OrderConditionsView(UpdateView):
    """
    Handle an AJAX request to move one condition to a new position.

    The condition in that position will be moved to the next position, and so on.
    """
    form_class = ConditionPositionForm

    def form_valid(self, form):
        position = form.cleaned_data('position')
        self.get_queryset.filter(position__gte=position)\
                         .exclude(pk=form.instance.pk)\
                         .update(position=F('position') + 1)
        form.save()
        return HttpResponse(json.dumps({'status': 'ok'}))

    def form_invalid(self, form):
      return HttpResponse(json.dumps({
        'status': 'error',
        'message': form.fields['position'].errors,
      }), status_code=409)

    def get_queryset(self):
        return Condition.objects.order_by('position')

# urls.py: simplified

urlpatterns = [
  include('/condition/', [
    path('/condition/<int:pk>/position/', OrderConditionsView.as_view(), name='position'),
    # ...
  ], namespace='condition'),
  # ...
]

In this case, our HTML and JS can just send back the one that was moved, and it’s new position. We’ll want some mechanism for indicating that the move was successful though, because it may not be permitted to move some objects to specific positions.

  <ul id="conditions">
    {% for condition in object_list %}
      <li data-position-url="{% url 'condition:position' condition.pk %}">
        {{ condition.name }} ({{ condition.penalty }}%)
      </li>
    {% endfor %}
  </ul>

  <script>
  // Uses SortableJS: https://github.com/SortableJS/Sortable

  require(['Sortable'], (Sortable) => {
    function save(evt) {
      const condition = evt.item;

      function revert() {
        if (evt.oldIndex < evt.newIndex) {
          // Move back to the old index.
          condition.before(condition.parentElement.children[evt.oldIndex]);
        } else {
          // The old index is actually in the wrong spot now,
          // move back to the old index + 1
          condition.before(condition.parentElement.children[evt.oldIndex + 1]);
        }
      }

      fetch(condition.dataset.positionUrl, {
        method: 'POST',
        body: {position: evt.newIndex},
        credentials: 'same-origin'
      }).then(
        response => response.json()
      ).then(
        response => {
          if (response.status === 'error') {
            // We got an error: let's put the object back into the previous
            // position and then show that error.
            revert();
            // Maybe a nicer way than this though!
            alert(response.message);
          }
        }
      );
    }

    new Sortable(document.getElementById('conditions'), {onUpdate: evt => save});

  });
  </script>

So, are there any other ways we could store the positions of our conditions, which would allow us to insert a new one, or move one around, without having to renumber a bunch of other items?

Why yes, yes we can.

Instead of storing each position as an integer, we can instead store two numbers, and use the ratio of these two numbers as our ordering:

POSITION = ExpressionWrapper(
  models.F('position_a') / models.F('position_b'),
  output_field=models.DecimalField()
)


class PositionQuerySet(models.query.QuerySet):
    def with_position(self):
        return self.annotate(position=POSITION)

    def order_by_position(self):
        return self.order_by(POSITION)


class Condition(models.Model):
    name = models.TextField()
    penalty = models.DecimalField()
    position_a = models.IntegerField()
    position_b = models.IntegerField()

    class Meta:
      unique_together = (
        ('position_a', 'position_b'),  # In fact, this is not quite strict enough, more below.
      )

The logic behind this is that here is always the ability to find a fraction that is between two other fractions. Our initial “first” item can be given the values (1, 1), which will evaluate to 1.0

To place another item, we have three cases:

  • The new value is before the first item in the list.
  • The new value is between two existing items.
  • The new value is after the last item in the list.

Let’s look at how we can calculate the position for the newly added item:

  • Before the first item in the list:

    Increment the denominator (position_b), and use the same numerator (position_a).

    For instance, adding a new first value where our current first value is (1, 1) would give us (1, 2), which evaluates to 0.5, less than 1.0

  • After the last item in the list:

    Increment the numerator (position_a), and use the same denominator (position_b).

    For instance, adding a new value where the last value is (5, 1) would give us (6, 1), which evaluates to 6.0, greater than 5.0

  • Inserting an item between two existing items.

    Use the sum of the two numerators (position_a) as the numerator, and the sum of the two denominators (position_b) as the new denominator. Optionally, you can then simplify this fraction if there are common factors.

    For instance, if we have two values (1, 4) and (5, 1), we can make a new value (6, 5), which evaluates to 1.2, which is between 0.25 and 5.0.

The same operations can be performed on a move: we don’t actually care what the values are, just that the relative values are still consistent. Whenever we move an object, we just change it’s position_a and/or position_b to whatever makes sense for the new position.

<ul id="conditions">
  {% for condition in object_list %}
    <li class="condition"
        data-position-url="{% url 'condition:position' condition.pk %}"
        data-position-a="{{ condition.position_a }}"
        data-position-b="{{ condition.position_b}}">
      {{ condition.name }}
    </li>
  {% endfor %}
</ul>

<script>
  require(['sortable'], (Sortable) => {
    function save(evt) {
      const item = evt.item;
      const prev = condition.previousElementSibling;
      const next = condition.nextElementSibling;

      if (prev == null && next == null) {
        // This is the only item in the list.
        item.dataset.position_a = 1;
        item.dataset.position_b = 1;
      } else if (prev == null) {
        // This is now the first item in the list.
        item.dataset.position_a = next.dataset.position_a;
        item.dataset.position_b = Number(next.dataset.position_b) + 1;
      } else if (next == null) {
        // This is now the last item in the list.
        item.dataset.position_a = Number(prev.dataset.position_a) + 1;
        item.dataset.position_b = prev.dataset.position_b;
      } else {
        // Neither the first nor the last.
        item.dataset.position_a = Number(prev.dataset.position_a) + Number(next.dataset.position_a);
        item.dataset.position_b = Number(prev.dataset.position_b) + Number(next.dataset.position_b);
      }

      function revert() {
        if (evt.oldIndex < evt.newIndex) {
          // Move back to the old index.
          condition.before(condition.parentElement.children[evt.oldIndex]);
        } else {
          // The old index is actually in the wrong spot now,
          // move back to the old index + 1
          condition.before(condition.parentElement.children[evt.oldIndex + 1]);
        }
      }

      // This could do with some error handling...
      fetch(condition.dataset.positionUrl, {
        method: 'POST',
        body: {
          position_a: item.dataset.position_a,
          position_b: item.dataset.position_b
        },
        credentials: 'same-origin'
      }).then(response => response.json()).then(
        response => {
          if (response.status == 'error') {
            revert();
          }
        }
      );
    }


  })
</script>

Postgres ENUM types in Django

Postgres has the ability to create custom types. There are several kinds of CREATE TYPE statement:

  • composite types
  • domain types
  • range types
  • base types
  • enumerated types

I’ve used a metaclass that is based on Django’s Model classes to do Composite Types in the past, and it’s been working fairly well. The current stuff I have been working on made sense to use an Enumerated Type, because there are four possible values, and having a human readable version of them is going to be nicer than using a lookup table.

In the first iteration, I used just a TEXT column to store the data. However, when I then started to use an enum.Enum class for handling the values in python, I discovered that it was actually storing str(value) in the database, rather than value.value.

So, I thought I would implement something similar to my Composite Type class. Not long after starting, I realised that I could make a cleaner implementation (and easier to declare) using a decorator:

@register_enum(db_type='change_type')
class ChangeType(enum.Enum):
    ADDED = 'added'
    CHANGED = 'changed'
    REMOVED = 'removed'
    CANCELLED = 'cancelled'


ChangeType.choices = [
    (ChangeType.ADDED, _('hours added')),
    (ChangeType.REMOVED, _('hours subtracted')),
    (ChangeType.CHANGED, _('start/finish changed with no loss of hours')),
    (ChangeType.CANCELLED, _('shift cancelled')),
]

Because I’m still on an older version of Python/Django, I could not use the brand new Enumeration types, so in order to make things a bit easier, I then annotate onto the class some extra helpers. It’s important to do this after declaring the class, because otherwise the attributes you define will become “members” of the enumeration. When I move to Django 3.0, I’ll probably try to update this register_enum decorator to work with those classes.

So, let’s get down to business with the decorator. I spent quite some time trying to get it to work using wrapt, before realising that I didn’t actually need to use it. In this case, the decorator is only valid for decorating classes, and we just add things onto the class (and register some things), so it can just return the new class, rather than having to muck around with docstrings and names.

from psycopg2.extensions import (
    new_array_type,
    new_type,
    QuotedString,
    register_adapter,
    register_type,
)
known_types = set()


CREATE_TYPE = 'CREATE TYPE {0} AS ENUM ({1})'
SELECT_OIDS = 'SELECT %s::regtype::oid AS "oid", %s::regtype::oid AS "array_oid"'


class register_enum(object):
    def __init__(self, db_type, managed=True):
        self.db_type = db_type
        self.array_type = '{}[]'.format(db_type)
        self.managed = managed

    def __call__(self, cls):
        # Tell psycopg2 how to turn values of this class into db-ready values.
        register_adapter(cls, lambda value: QuotedString(value.value))

        # Store a reference to this instance's "register" method, which allows
        # us to do the magic to turn database values into this enum type.
        known_types.add(self.register)

        self.values = [
            member.value
            for member in cls.__members__.values()
        ]

        # We need to keep a reference to the new class around, so we can use it later.
        self.cls = cls

        return cls

    def register(self, connection):
        with connection.cursor() as cursor:
            try:
                cursor.execute(SELECT_OIDS, [self.db_type, self.array_type])
                oid, array_oid = cursor.fetchone()
            except ProgrammingError:
                if self.managed:
                    cursor.execute(self.create_enum(connection), self.values)
                else:
                    return

        custom_type = new_type(
            (oid,),
            self.db_type,
            lambda data, cursor: data and self.cls(data) or None
        )
        custom_array = new_array_type(
            (array_oid,),
            self.array_type,
            custom_type
        )
        register_type(custom_type, cursor.connection)
        register_type(custom_array, cursor.connection)

    def create_enum(self, connection):
        qn = connection.ops.quote_name
        return CREATE_TYPE.format(
            qn(self.db_type),
            ', '.join(['%s' for value in self.values])
        )

I’ve extracted out the create_enum method, because it’s then possible to use this in a migration (but I’m not totally happy with the code that generates this migration operation just yet). I also have other code that dynamically creates classes for a ModelField and FormField as attributes on the Enum subclass, but that complicates it a bunch.

Subquery and Subclasses

Being able to use correlated subqueries in the Django ORM arrived in 1.11, and I also backported it to 1.8.

Quite commonly, I am asked questions about how to use these, so here is an attempt to document them further.

There are three classes that are supplied with Django, but it’s easy to write extensions using subclassing.

Let’s first look at an example of how you might want to use the included classes. We’ll consider a set of temperature sensors, each with a name and a code, both of which are unique. These sensors will log their current temperature at some sort of interval: maybe it’s regular, maybe it varies between devices. We want to keep every reading, but want to only allow one reading for a given sensor+timestamp.

class Sensor(models.Model):
    location = models.TextField(unique=True)
    code = models.TextField(unique=True)


class Reading(models.Model):
    sensor = models.ForeignKey(Sensor, related_name='readings')
    timestamp = models.DateTimeField()
    temperature = models.DecimalField(max_digits=6, decimal_places=3)

    class Meta:
        unique_together = (('sensor', 'timestamp'),)

Some of the things we might want to do for a given sensor:

  • Get the most recent temperature
  • Get the average temperature over a given period
  • Get the maximum temperature over a given period
  • Get the minimum temperature over a given period

If we start with a single sensor instance, we can do each of these without having to use Subquery and friends:

from django.db.models import Avg, Min, Max

most_recent_temperature = sensor.readings.order_by('-timestamp').first().temperature
period_readings = sensor.readings.filter(
    timestamp__gte=start,
    timestamp__lte=finish,
).aggregate(
    average=Avg('temperature'),
    minimum=Min('temperature'),
    maximum=Max('temperature'),
)

We could also get the minimum or maximum using ordering, like we did with the most_recent_temperature.

If we want to do the same for a set of sensors, mostly we can still achieve this (note how similar the code is to the block above):

sensor_readings = Reading.objects.filter(
  timestamp__gte=start,
  timestamp__lte=finish
).values('sensor').annotate(
  average=Avg('temperature'),
  minimum=Min('temperature'),
  maximum=Max('temperature'),
)

We might get something like:

[
    {
        'sensor': 1,
        'average': 17.5,
        'minimum': 11.3,
        'maximum': 25.9
    },
    {
        'sensor': 2,
        'average': 19.63,
        'minimum': 13.6,
        'maximum': 24.33
    },
]

However, it’s not obvious how we would get all of the sensors, and their current temperature in a single query.

Subquery to the rescue!

from django.db.models.expressions import Subquery, OuterRef

current_temperature = Reading.objects.filter(sensor=OuterRef('pk'))\
                                     .order_by('-timestamp')\
                                     .values('temperature')[:1]

Sensor.objects.annotate(
    current_temperature=Subquery(current_temperature)
)

What’s going on here as that we are filtering the Reading objects inside our subquery to only those associated with the sensor in the outer query. This uses the special OuterRef class, that will, when the queryset is “resolved”, build the association. It does mean that if we tried to inspect the current_temperature queryset, we would get an error that it is unresolved.

We then order the filtered readings by newest timestamp first; this, coupled with the slice at the end will limit us to a single row. This is required because the database will reject a query that results in multiple rows being returned for a subquery.

Additionally, we may only have a single column in our subquery: that’s achieved by the .values('temperature').

But maybe there is a problem here: we actually want to know when the reading was taken, as well as the temperature.

We can do that a couple of ways. The simplest is to use two Subqueries:

current_temperature = Reading.objects.filter(sensor=OuterRef('pk'))\
                                     .order_by('-timestamp')[:1]

Sensor.objects.annotate(
    current_temperature=Subquery(current_temperature.values('temperature')),
    last_reading_at=Subquery(current_temperature.values('timestamp')),
)

However, this will do two subqueries at the database level. Since these subqueries will be performed seperately for each row, each additional correlated subquery will result in more work for the database, with possible performance implications.

What about if we are using Postgres, and are okay with turning the temperature and timestamp pair into a JSONB object?

from django.db.models.expressions import Func, F, Value, OuterRef, Subquery
from django.contrib.postgres.fields import JSONField


class JsonBuildObject(Func):
    function = 'jsonb_build_object'
    output_field = JSONField()


last_temperature = Reading.objects.filter(sensor=OuterRef('pk'))\
                                  .order_by('-timestamp')\
                                  .annotate(
                                      json=JsonBuildObject(
                                          Value('timestamp'), F('timestamp'),
                                          Value('temperature'), F('temperature'),
                                      )
                                   ).values('json')[:1]

Sensor.objects.annotate(
    last_temperature=JsonBuildObject(last_temperature)
)

Now, your Sensor instances would have an attribute last_temperature, which will be a dict with the timestamp and temperature of the last reading.


There is also a supplied Exists subquery that can be used to force the database to emit an EXISTS statement. This could be used to set a boolean field on our sensors to indicate they have data from within the last day:

recent_readings = Reading.objects.filter(
    sensor=OuterRef('pk'),
    timestamp__gte=datetime.datetime.utcnow() - datetime.timedelta(1)
)
Sensor.objects.annotate(
    has_recent_readings=Exists(recent_readings)
)

Sometimes we’ll have values from multiple rows that we will want to annotate on from the subquery. This can’t be done directly: you will need to aggregate those values in some way. Postgres has a neat feature where you can use an ARRAY() constructor and wrap a subquery in that:

SELECT foo,
       bar,
       ARRAY(SELECT baz
               FROM qux
              WHERE qux.bar = base.bar
              ORDER BY fizz
              LIMIT 5) AS baz
  FROM base

We can build this type of structure using a subclass of Subquery.

from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import FieldError
from django.db.models.expressions import Subquery

class SubqueryArray(Subquery):
    template = 'ARRAY(%(subquery)s)'

    @property
    def output_field(self):
        output_fields = [x.output_field for x in self.get_source_expressions()]

        if len(output_fields) > 1:
            raise FieldError('More than one column detected')

        return ArrayField(base_field=output_fields[0])

And now we can use this where we’ve used a Subquery, but we no longer need to slice to a single row:

json_reading = JsonBuildObject(
    Value('timestamp'), F('timestamp'),
    Value('temperature'), F('temperature'),
)

last_five_readings = Reading.objects.filter(
    sensor=OuterRef('pk')
).order_by('-timestamp').annotate(
    json=json_reading
).values('json')[:5]

Sensor.objects.annotate(last_five_readings=SubqueryArray(last_five_readings))

Each sensor instance would now have up to 5 dicts in a list in it’s attribute last_five_readings.

We could get this data in a slightly different way: let’s say instead of an array, we want a dict keyed by a string representation of the timestamp:

sensor.last_five_readings = {
    '2019-01-01T09:12:35Z': 15.35,
    '2019-01-01T09:13:35Z': 14.33,
    '2019-01-01T09:14:35Z': 14.90,
    ...
}

There is a Postgres aggregate we can use there to do that, too:

class JsonObjectAgg(Subquery):
    template = '(SELECT json_object_agg("_j"."key", "_j"."value") FROM (%(subquery)s) "_j")'
    output_field = JSONField()


last_five_readings = Reading.objects.filter(
    sensor=OuterRef('pk')
).order_by('-timestamp').annotate(
    key=F('timestamp'),
    value=F('temperature'),
).values('key', 'value')[:5]

Sensor.objects.annotate(last_five_readings=JsonObjectAgg(last_five_readings))

Indeed, we can wrap any aggregate in a similar way: to get the number of values of a subquery:

class SubqueryCount(Subquery):
    template = '(SELECT count(*) FROM (%(subquery)s) _count)'
    output_field = models.IntegerField()

Since other aggregates need to operate on a single field, we’ll need something that ensures there is a single value in our .values(), and extract that out and use that in the query.

class SubquerySum(Subquery):
    template = '(SELECT SUM(%(field)s) FROM (%(subquery)s) _sum)'

    def as_sql(self, compiler, connection, template=None, **extra_context):
        if 'field' not in extra_context and 'field' not in self.extra:
            if len(self.queryset._fields) > 1:
                raise FieldError('You must provide the field name, or have a single column')
            extra_context['field'] = self.queryset._fields[0]
        return super(SubquerySum, self).as_sql(
          compiler, connection, template=template, **extra_context
        )

As I mentioned, it’s possible to write a subclass like that for any aggregate function, although it would be far nicer if there was a way to write that purely in the ORM. Maybe one day…

Handling overlapping values

One of the things that I enjoy most about Postgres are the rich types. Using these types can help reduce the amount of validation that the application needs to do.

Take for instance anything which contains a start date and a finish date. If you model this using two fields, then you also need to include validation about start <= finish (or perhaps start < finish, depending upon your requirements).

If you use a date range instead, then the database will do this validation for you. It is not possible to create a range value that is “backwards”. Sure, you’ll also need to do application-level (and probably client-side) validation, but there is something nice about having a reliable database that ensures you cannot possibly have invalid data.

Django is able to make good use of range types, and most of my new code seemingly has at least one range type: often a valid_period. So much so that I have a Mixin and a QuerySet that make dealing with these easier:

class ValidPeriodMixin(models.Model):
    valid_period = DateRangeField()

    class Meta:
        abstract = True

    @property
    def start(self):
        if self.valid_period.lower_inc:
            return self.valid_period.lower
        elif self.valid_period.lower is not None:
            return self.valid_period.lower + datetime.timedelta(1)

    @property
    def finish(self):
        if self.valid_period.upper_inc:
            return self.valid_period.upper
        elif self.valid_period.upper is not None:
            return self.valid_period.upper - datetime.timedelta(1)

    @property
    def forever(self):
        return self.valid_period.lower is None and self.valid_period.upper is None

    def get_valid_period_display(self):
        if self.forever:
            message = _('Always applies')
        elif self.start is None:
            message = _('{start} \u2092 no end date')
        elif self.finish is None:
            message = _('no start date \u2092 {finish}')
        else:
            message = _('{start} \u2092 {finish}')

        return message.format(
            start=self.start,
            finish=self.finish,
        )


def ensure_date_range(period):
    """
    If we have a 2-tuple of dates (or strings that are valid dates),
    ensure we turn that into a DateRange instance. This is because
    otherwise Django may mis-interpret this.
    """
    if not isintance(period, DateRange):
        return DateRange(period[0] or None, period[1] or None, '[]')
    return period


class OverlappingQuerySet(models.query.QuerySet):
    def overlapping(self, period):
        return self.filter(valid_period__overlap=ensure_date_range(period))

    def on_date(self, date):
        return self.filter(valid_period__contains=date)

    def today(self):
        return self.on_date(datetime.date.today())

As you may notice from this, it is possible to do some filtering based on range types: specifically, you can use the && Postgres operator using .filter(field__overlap=value), and the containment operators (<@ and @>) using .filter(field__contains=value) and .filter(field__contained_by=value). There are also other operators we will see a bit later using other lookups.


If you have a legacy table that stores a start and a finish, you would need to have a validator on the model (or forms that write to the model) that ensures start < finish, as mentioned above. Also, there is no way (without extra columns) to tell if the upper and lower values should be inclusive or exclusive of the bounds. In Postgres, we write range values using a notation like a mathematical range: using ‘[’, ‘]’ and ‘(‘, ‘)’ to indicate inclusive and exclusive bounds.

SELECT '[2019-01-01,2020-01-01)'::DATERANGE AS period;

One caveat when dealing with discrete range types (like dates and integers) is that Postgres will, if it is able to, convert the range to a normalised value: it will store (2019-01-01,2019-12-31] as [2019-01-02,2020-01-01). This can become a problem when showing the value back to the user, because depending upon context, it’s likely that you will want to use inclusive bounds when showing and editing the values.

You can manage this by using a form field subclass that detects an exclusive upper bound and subtracts one “unit” accordingly:

import datetime

from django.contrib.postgres.forms.ranges import (
    DateRangeField, IntegerRangeField
)


class InclusiveRangeMixin(object):
    _unit_value = None

    def compress(self, values):
        range_value = super().compress(values)
        if range_value:
          return self.range_type(
              range_value.lower,
              range_value.upper,
              bounds='[]'
          )

    def prepare_value(self, value):
        value = super().prepare_value(value)
        value = [
            field.clean(val)
            for field, val in zip(self.fields, value)

        ]
        if value[1] is not None:
            value[1] = value[1] - self._unit_value
        return value


class InclusiveDateRangeField(
    InclusiveRangeMixin, DateRangeField
):
      _unit_value = datetime.timedelta(1)


class InclusiveIntegerRangeField(
    InclusiveRangeMixin, IntegerRangeField
):
    _unit_value = 1

Back on to the topic of storing two values instead of a range: it’s possible to add an expression index on the table that uses DATERANGE:

CREATE INDEX thing_period_idx
          ON thing_thing (DATERANGE(start, finish));

You would be able to annotate on this value, do some querying, and it should use the index, allowing you to build querysets like:

Thing.objects.annotate(
    period=Func(
      F('start'),
      F('finish'),
      function='DATERANGE',
      output_field=DateRangeField())
).filter(period__overlap=other_period)

Range types show their full power when used with exclusion constraints. These allow you to prevent writing rows that violate the constraint. For instance, consider this model (and some largely irrelevant other models, Team and Player):

class TeamMembership(ValidPeriodMixin):
    ployer = models.ForeignKey(
        Player,
        related_name='team_memberships',
        on_delete=models.CASCADE,
    )
    team = models.ForeignKey(
        Team,
        related_name='player_memberships',
        on_delete=models.CASCADE,
    )

A player may only belong to one team at a time: that is, we may not have any overlapping valid_periods for a player.

You can do this using an exclusion constraint, but it does need the btree_gist extension installed:

CREATE EXTENSION IF NOT EXISTS btree_gist;

ALTER TABLE team_teammembership
        ADD CONSTRAINT prevent_overlapping_team_memberships
    EXCLUDE USING gist(person_id WITH =, valid_period WITH &&)
 DEFERRABLE INITIALLY DEFERRED;

Since this type of constraint is not yet supported in Django, you’ll have to do it in a RunSQL migration.

From here, we can attempt to write conflicting data, but the database will forbid it. You will still need to write code that checks before writing - this enables you to return a ValidationError to the user when you detect this conflict in a form, but having the exclusion constraint means that we can avoid the race condition where:

  • Check for overlapping ranges
  • Other process creates a range that will overlap
  • Save our data

You could possibly also use select_for_update in this context, but I prefer adding database constraints.

Note that the DEFERRABLE INITIALLY DEFERRED clause is important: it allows you, within a transaction, to write conflicting data, and it’s only when the transaction commits that the constraint is checked. This makes rewriting a bunch of values in one transaction much simpler: if you do not have this flag enabled then you will need to ensure you update them in an order that maintained no overlaps at each stage. I’m pretty confident this is always possible, but it’s a bunch of work (and it is possible that you might need to write some rows multiple times to maintain that).


So, now we can store range values (with database validation), and prevent overlapping data (with database validation).

What about a process that enables us to say “this row should replace, trim or split any that overlap with it”? I’m glad you asked.

It turns out given two rows, where one should “supersede” the other, there are five different conditions we need to take into account:

  • The rows do not overlap: no action required
  • The new row completely covers the old row: remove the old row
  • The old row has bounds that exceed the new row in both directions: split the old row into two rows
  • The old row has a lower bound that is smaller than the new row: trim the old row at the upper end
  • The old row has an upper bound that is larger than the new row: trim the old row at the lower end

It turns out we can perform this query with the Django range field lookups:

class OverlappingQuerySet(models.query.QuerySet):
    def with_overlap_type(self, period):
        period = ensure_date_range(period)
        return self.annotate(
            overlap_type=Case(
                # The objects do not overlap.
                When(~Q(valid_period__overlap=period,
                        then=Value(None))),
                # The existing value is covered by the new value
                When(valid_period__contained_by=period,
                     then=Value('replace')),
                # The existing value has no values
                # less than the new value
                When(valid_period__not_lt=period,
                     then=Value('trim:lower')),
                # The existing value has no values
                # greater than the new value
                When(valid_period__not_gt=period,
                     then=Value('trim:upper')),
                # The existing value contains the new value
                When(valid_period__contains=period,
                      then=Value('split')),
                output_field=models.TextField()
            )
        )

This works because a CASE WHEN stops evaluating when it finds a match: technically a trim:lower value could also match on containment (split), so we need to test that one earlier.

We are going to have to (possibly) perform multiple queries when writing back the data. If there are any than need to be “removed”, they will need a DELETE. Any that have a “trim” operation will require an UPDATE.

new_instance = Thing(valid_period=('2019-01-01', '2019-02-09'))
overlapping = Thing.objects.overlapping(
  new_instance.valid_period
).with_overlap_type(new_instance.valid_period)

overlapping.filter(overlap_type='replace').delete()
overlapping.filter(
    overlap_type__in=('trim:upper', 'trim:lower')
).update(
    valid_period=valid_period - new_instance.valid_period
)

But the tricky part is that any that are “split” will require at least two: either a DELETE followed by an INSERT (that inserts two rows), or a single UPDATE and a single INSERT. The tricky part here is that we also need to read the values first, if we are going to manipulate them in python. Instead, we can look at how to do it in raw SQL, with the benefit that we can perform this in a single operation.

WITH new_period AS (
  SELECT %s AS new_period
),
split AS (
  SELECT thing_id,
         valid_period,
         other_field,
         new.new_period
    FROM thing_thing old
    INNER JOIN new_period new ON (
          LOWER(old.valid_period) < LOWER(new.new_period)
      AND UPEER(old.valid_period) > UPEER(new.new_period)
    )
), new_rows AS (
  SELECT other_field,
         DATERANGE(LOWER(valid_period),
                   LOWER(new_period)) AS valid_period
    FROM split

   UNION ALL

  SELECT other_field,
         DATERANGE(UPPER(new_period),
                   UPPER(valid_period)) AS valid_period
),
removed AS (
  DELETE FROM thing_thing
   WHERE thing_id IN (SELECT thing_id FROM split)
)
INSERT INTO thing_thing (other_field, valid_period)
SELECT other_field, valid_period FROM new_rows;

This is less than ideal, because we need to enumerate all of the fields (instead of just other_field), so this code is not especially reusable as-is.

Let’s look at alternatives:

# Fetch the existing items.
splits = list(overlapping.filter(overlap_type='split').values())
to_create = []
to_delete = []
for overlap in splits:
    to_delete.append(overlap.pop('thing_id'))
    valid_period = overlap.pop('valid_period')
    to_create.append(Thing(
        valid_period=(valid_period.lower, new_instance.valid_period.lower),
        **overlap
    ))
    to_create.append(Thing(
        valid_period=(new_instance.valid_period.upper, valid_period.upper),
        **overlap
    ))
overlapping.filter(pk__in=to_delete).delete()
Thing.objects.bulk_create(to_create)

We can stick all of that into a queryset method, to make it easier to manage.

import copy


class OverlappingQuerySet(models.query.QuerySet):
    def trim_overlapping(self, period):
        """
        Trim/split/remove all overlapping objects.

        * Remove objects in the queryset that are
          "covered" by the period.
        * Split objects that completely cover the
          new period with overlap at both sides
        * Trim objects that intersect with the new
          period and extend in one direction or the
          other, but not both.

        This will do a single query to trim object that need
        trimming, another query that fetches those that need
        splitting, a single delete query to remove all
        split/replaced objects, and finally an optional query
        to create replacement objects for those split.

        That means this method _may_ perform 3 or 4 queries.

        This particular algorithm should work without a
        transaction needing to be present, but in practice
        this action and the create of a new one should be
        in the same transaction, so they can all roll-back
        if anything goes wrong.
        """
        period = ensure_date_range(period)

        overlapping = self.overlapping(period)\
                          .with_overlap_type(period)

        # Easy first: update those that we can just update.
        overlapping.filter(
            overlap_type__startswith=('trim')
        ).update(
            valid_period=models.F('valid_period') - period
        )

        # Create the new objects for each of the ones that
        # extend either side of the new value.
        # There will alwasy be two of them: one for the lower
        # section, and one for the upper section.
        to_create = []
        for instance in overlapping.filter(overlap_type='split'):
            # Setting the primary key to None will trigger a new
            # instance.
            instance.pk = None
            # We need to create two instances, each with a different
            # valid_period.
            valid_period = instance.valid_period
            # The one _before_ the new value.
            instance.valid_period = DateRange(
                valid_period.lower, period.lower, bounds='[)'
            )
            to_create.append(instance)
            # And a new copy to go _after_ the new value.
            instance = copy.deepcopy(instance)
            instance.valid_period = DateRange(
                period.upper, valid_period.upper, bounds='(]'
            )
            to_create.append(instance)


        # Now clean up any that we need to get rid of.
        overlapping.filter(
            overlap_type__in=('replace', 'split')
        ).delete()

        # And finally add back in any replacement objects
        # that extended either side of the new value.
        if to_create:
            self.model._default_manager.bulk_create(to_create)

Yeah, I think that will do for now.

Fallback values in Django

It’s not uncommon to have some type of cascading of values in a system. For instance, in our software, we allow a Brand to have some default settings, and then a Location may override some or all of these settings, or just fallback to the brand settings. I’m going to have a look at how this type of thing can be implemented using Django, and a way that this can be handled seamlessly.

We’ll start with our models:

class Brand(models.Model):
    brand_id = models.AutoField(primary_key=True)
    name = models.TextField()


class Location(models.Model):
    location_id = models.AutoField(primary_key=True)
    brand_id = models.ForeignKey(Brand, related_name='locations')
    name = models.TextField()


WEEKDAYS = [
  (1, _('Monday')),
  (2, _('Tuesday')),
  (3, _('Wednesday')),
  (4, _('Thursday')),
  (5, _('Friday')),
  (6, _('Saturday')),
  (7, _('Sunday')),
]


class BrandSettings(models.Model):
    brand = models.OneToOneField(Brand, primary_key=True, related_name='settings')
    opening_time = models.TimeField()
    closing_time = models.TimeField()
    start_day = models.IntegerField(choices=WEEKDAYS)


class LocationSettings(models.Model):
    location = models.OneToOneField(Location, primary_key=True, related_name='_raw_settings')
    opening_time = models.TimeField(null=True, blank=True)
    closing_time = models.TimeField(null=True, blank=True)
    start_day = models.IntegerField(choices=WEEKDAYS, null=True, blank=True)

We can’t use an abstract base model here, because the LocationSettings values are all optional, but the BrandSettings are not. We might have a look later at a way we can have a base model and inherit-and-change-null on the fields. In the place where we have used this, the relationship between Location and Brand is optional, which complicates things even further.

In practice, we’d have a bunch more settings, but this will make it much easier for us to follow what is going on.

To use these, we want to use a value from the LocationSettings object if it is set, else fall-back to the BrandSettings value for that column.

Location.objects.annotate(
    opening_time=Coalesce('settings__opening_time', 'brand__settings__opening_time'),
    closing_time=Coalesce('settings__closing_time', 'brand__settings__closing_time'),
    start_day=Coalesce('settings__start_day', 'brand__settings__start_day'),
)

And this is fine, but we can make it easier to manage: we want to be able to use Location().settings.start_day, and have that fall-back, but also build some niceness so that we can set values in a nice way in the UI.

We can use a postgres view, and then have a model in front of that:

CREATE OR REPLACE VIEW location_actualsettings AS (
  SELECT location_id,
         COALESCE(location.opening_time, brand.opening_time) AS opening_time,
         COALESCE(location.closing_time, brand.closing_time) AS closing_time,
         COALESCE(location.start_day, brand.start_day) AS start_day
    FROM location_location
   INNER JOIN location_brandsettings brand USING (brand_id)
   INNER JOIN location_locationsettings location USING (location_id)
)

Notice that we have used INNER JOIN for both tables: we are making the assumption that there will always be a settings object for each brand and location.

Now, we want a model in front of this:

class ActualSettings(models.Model):
    location = models.OneToOneField(Location, primary_key=True, related_name='settings')
    opening_time = models.TimeField(null=True, blank=True)
    closing_time = models.TimeField(null=True, blank=True)
    start_day = models.IntegerField(choices=WEEKDAYS, null=True, blank=True)

    class Meta:
        managed = False

We want to indicate that it should allow NULL values in the columns, as when we go to update it, None will be taken to mean “use the brand default”.

As for the ability to write to this model, we have a couple of options. The first is to make sure that when we edit instances of the model, we actually use the Location()._raw_settings instance instead of the Location().settings. The other is to make the ActualSettings view have an update trigger:

CREATE OR REPLACE FUNCTION update_location_settings()
RETURNS TRIGGER AS $$

BEGIN

  IF (TG_OP = 'DELETE') THEN
    RAISE NOTICE 'DELETE FROM location_locationsettings WHERE location_id = %', OLD.location_id;
    DELETE FROM location_locationsettings WHERE location_id = OLD.location_id;
    RETURN OLD;
  ELSIF (TG_OP = 'UPDATE') THEN
    UPDATE location_locationsettings
       SET opening_time = NEW.opening_time,
           closing_time = NEW.closing_time,
           start_day = NEW.start_day
     WHERE location_locationsettings.location_id = NEW.location_id;
    RETURN NEW;
  ELSIF (TG_OP = 'INSERT') THEN
    INSERT INTO location_locationsettings (SELECT NEW.*);
    RETURN NEW;
  END IF;
  RETURN NEW;
END;

$$ LANGUAGE plpgsql VOLATILE;

CREATE TRIGGER update_location_settings
       INSTEAD OF INSERT OR UPDATE OR DELETE
       ON location_actualsettings
       FOR EACH ROW EXECUTE PROCEDURE update_location_settings();

And this works as expected: however it is subject to a pretty significant drawback. If you add columns to the table/view, then you’ll need to update the function. Indeed, if you add columns to the tables, you’ll need to update the view too.

In many cases, this will be sufficient: those tables may not change much, and when they do, it’s just a matter of writing new migrations to update the view and function.


In practice, having the writeable view is probably overkill. You can just use a regular view, with a model in front of it, and then use that model when you need to use the coalesced values, but use the raw model when you are setting values.

You can even make it so that as a UI affordance, you show what the brand fallback value is instead of the None value:

class SettingsForm(forms.ModelForm):
    class Meta:
        model = LocationSettings
        fields = (
            'opening_time',
            'closing_time',
            'start_day'
        )

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # We'll probably want to make sure we use a select_related() for this!
        brand = self.instance.location.brand
        brand_settings = brand.settings

        for name, field in self.fields.items():
            # See if the model knows how to display a nice value.
            display = 'get_{}_display'.format(name)
            if hasattr(brand_settings, display):
                brand_value = getattr(brand_settings, display)()
            else:
                brand_value = getattr(brand_settings, name)

            # If we have a time, then we want to format it nicely:
            if isinstance(brand_value, datetime.time):
                brand_value = Template('').render(Context({
                  'value': brand_value
                }))

            blank_label = _('Default for {brand}: {value}').format(
                brand=brand.name,
                value=brand_value,
            )

            # If we have a select that is _not_ a multiple select, then we
            # want to make it obvious that the brand default value can be
            # selected, or an explicit choice made.
            if hasattr(field, 'choices') and field.choices[0][0] == '':
                field.widget.choices = field.choices = [
                    (_('Brand default'), [('', blank_label)]),
                    (_('Choices'), list(field.choices[1:]))
                ]
            else:
                # On all other fields, set the placeholder, so that no value
                # entered will show the brand default label.
                field.widget.attrs['placeholder'] = blank_label

As mentioned in a comment: this uses a couple of lookups to get to the BrandSettings, you’d want to make sure your view used a .select_related():

class LocationSettingsView(UpdateView):
    form_class = SettingsForm

    def get_object(self):
        return LocationSettings.objects.select_related('location__brand__settings').get(
            location=self.kwargs['location']
        )

Again, this is all simplified when we have the requirement that there is always a Brand associated with a Location, and each of these always has a related settings object. It’s the latter part of this that is a little tricky. You can have objects automatically created in a signal handler, but in that case it would have to use default values.


Just from a DRY perspective, it would be great if you could have all three models inherit from the one base class, and have the view and trigger function update automatically.

In order to do that, we’ll need to do a bit of magic.

class SettingsBase(models.Model):
    opening_time = models.TimeField()
    closing_time = models.TimeField()
    start_day = models.IntegerField(choices=WEEKDAYS)

    class Meta:
        abstract = True

    def __init_subclass__(cls):
        if getattr(cls, '_settings_optional', False):
            for field in cls._meta.fields:
                field.null = True
                field.blank = True


class BrandSettings(SettingsBase):
    brand = models.OneToOneField(
        Brand,
        primary_key=True,
        related_name='settings',
        on_delete=models.CASCADE,
    )


class LocationSettings(SettingsBase):
    location = models.OneToOneField(
        Location,
        primary_key=True,
        related_name='raw_settings',
        on_delete=models.CASCADE,
    )
    _settings_optional = True


class ActualSettings(SettingsBase):
    location = models.OneToOneField(
        Location,
        primary_key=True,
        related_name='settings',
        on_delete=models.DO_NOTHING,
    )
    _settings_optional

    class Meta:
        managed = False

The magic is all clustered in the one spot, and Django’s order it does things makes this easy. By the time __init_subclass__ is evaluated, the subclass exists, and has all of the inherited fields, but none of the non-inherited fields. So, we can update those fields to not be required, if we find a class attribute _settings_optional that is true.

Automatically creating or replacing the view is a bit more work.

class ActualSettings(BaseSettings):
    location = models.OneToOneField(
        Location,
        primary_key=True,
        related_name='settings',
        on_delete=models.DO_NOTHING,
    )
    _settings_optional = True

    class Meta:
        managed = False

    @classmethod
    def view_queryset(cls):
        settings = {
            attribute: Coalesce(
              'raw_settings__{}'.format(attribute),
              'brand__settings__{}'.format(attribute)
            ) for attribute in (f.name for f in cls._meta.fields)
            if attribute != 'location'
        }
        return Location.objects.annotate(**settings).values('pk', *settings.keys())

This would then need some extra machinery to put that into a migration, and then, when running makemigrations, we’d want to automatically look at the last rendered version of that view, and see if what we have now differs. However, intercepting makemigrations, and changing the operations it creates is something I have not yet figured out how to achieve.

Instead, for Versioning complex database migrations I wound up creating a new management command.

A nicer syntax might be to have some way of defining a postgres view by using a queryset.

ActualSettings = Location.objects.annotate(
    opening_time=Coalesce('_raw_settings__opening_time', 'brand__settings__opening_time'),
    closing_time=Coalesce('_raw_settings__closing_time', 'brand__settings__closing_time'),
    start_day=Coalesce('_raw_settings__start_day', 'brand__settings__start_day'),
).values('location_id', 'opening_time', 'closing_time', 'start_day').as_view()

The problem with this is that we can’t do that in a model definition, as the other models are not loaded at this point in time.

Another possible syntax could be:

class ActualSettings(View):
    location = models.F('location_id')
    opening_time = Coalesce('_raw_settings__opening_time', 'brand__settings__opening_time')
    closing_time = Coalesce('_raw_settings__closing_time', 'brand__settings__closing_time')
    start_day = Coalesce('_raw_settings__start_day', 'brand__settings__start_day')

    class Meta:
      queryset = Location.objects.all()

… but I’m starting to veer off into a different topic now.


Actually writing a trigger function that handles all columns seamlessly is something that we should be able to do. Be warned though, this one is a bit of a doozy:

CREATE OR REPLACE FUNCTION update_instead()
RETURNS TRIGGER AS $$
DECLARE
  primary_key TEXT;
  target_table TEXT;
  columns TEXT;

BEGIN
  -- You must pass as first parameter the name of the table to which writes should
  -- actually be made.
  target_table = TG_ARGV[0]::TEXT;

  -- We want to get the name of the primary key column for the target table,
  -- if that was not already supplied.
  IF (TG_ARGV[1] IS NULL) THEN
    primary_key = (SELECT column_name
                     FROM information_schema.table_constraints
               INNER JOIN information_schema.constraint_column_usage
                    USING (table_catalog, table_schema, table_name,
                           constraint_name, constraint_schema)
                    WHERE constraint_type = 'PRIMARY KEY'
                      AND table_schema = quote_ident(TG_TABLE_SCHEMA)
                      AND table_name = quote_ident(target_table));
  ELSE
    primary_key = TG_ARGV[1]::TEXT;
  END IF;

  -- We also need the names of all of the columns in the current view.
  columns = (SELECT STRING_AGG(quote_ident(column_name), ', ')
               FROM information_schema.columns
              WHERE table_schema = quote_ident(TG_TABLE_SCHEMA)
                AND table_name = quote_ident(TG_TABLE_NAME));

  IF (TG_OP = 'DELETE') THEN
    EXECUTE format(
      'DELETE FROM %1$I WHERE %2$I = ($1).%2$I',
      target_table, primary_key
    ) USING OLD;
    RETURN OLD;
  ELSIF (TG_OP = 'INSERT') THEN
    -- columns must be treated as a string, because we've already
    -- quoted the columns in the query above.
    EXECUTE format(
      'INSERT INTO %1$I (%2$s) (SELECT ($1).*)',
      target_table, columns
    ) USING NEW;
    RETURN NEW;
  ELSIF (TG_OP = 'UPDATE') THEN
    EXECUTE format(
      'UPDATE %1$I SET (%2$s) = (SELECT ($1).*) WHERE %3$I = ($1).%3$I',
      target_table, columns, primary_key
    ) USING NEW;
    RETURN NEW;
  END IF;

  RAISE EXCEPTION 'Unhandled.';
END;

$$ LANGUAGE plpgsql VOLATILE;

There are some things I learned about postgres when doing this: specifically that you can use the EXECUTE format('SELECT ... ($1).%s', arg) USING NEW syntax: the format() function makes it much neater than using string concatenation, and using the EXECUTE '...($1).%s' USING ... form was the only way I was able to access the values from the NEW and OLD aliases within an execute. There’s also a bunch of stuff you have to do to make sure that the columns line up correctly when updating or inserting into the target table.

We can then apply this to our view:

CREATE TRIGGER update_instead
INSTEAD OF UPDATE OR INSERT OR DELETE
ON location_actualsettings
FOR EACH ROW
EXECUTE PROCEDURE update_instead('location_locationsettings', 'location_id');

Django ComputedField()

A very common pattern, at least in code that I’ve written (and read) is to annotate on a field that uses an expression that is based on one or more other fields. This could then be used to filter the objects, or just in some other way.

The usual method of doing this is:

from django.db import models
from django.db.models.expressions import F, Value
from django.db.models.function import Concat


class PersonQuerySet(models.query.QuerySet):
    def with_name(self):
        return self.annotate(
            name=Concat(F('first_name'), Value(' '), F('last_name'), output_field=models.TextField()),
        )


class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()

    objects = PersonQuerySet.as_manager()

Yes, I’m aware of falsehoods programmers believe about names, but this is an easy-to-follow example.

In order to be able to access the name field, we must use the with_name() queryset method. This is usually okay, but if it is something that we almost always want, it can be a little tiresome. Alternatively, you could override the get_queryset() method of a custom manager, but that makes it somewhat surprising to a reader of the code. There are also some places where a custom manager will not automatically be used, or where it will be cumbersome to include the fields from a custom manager (select_related, for instance).

It would be much nicer if we could write the field declaratively, and have it use the normal django mechanism of defer and only to remove it from the query if required.

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()
    name = ComputedField(Concat(F('first_name'), Value(' '), F('last_name'), output_field=models.TextField()))

I’ve spent some time digging around in the django source code, and have a fairly reasonable understanding of how fields work, and how queries are built up. But I did wonder how close to a working proof of concept of this type of field we could get without having to change any of the django source code. After all, I was able to backport the entire Subquery expression stuff to older versions of django after writing that. It would be nice to repeat the process here.

There are a few things you need to do to get this to work:

  • store the expression
  • prevent the field from creating a migration
  • ensure the field knows how to interpret data from the database
  • ensure the field adds the expression to it’s serialised version
  • prevent the field from writing data back to the database
  • inject the expression into the query instead of the field name
class ComputedField(models.Field):
    def __init__(self, expression, *args, **kwargs):
        self.expression = expression.copy()
        kwargs.update(editable=False)
        super().__init__(*args, **kwargs)

There is already a mechanism for a field to prevent a migration operation from being generated: it can return a db_type of None.

    def db_type(self, connection):
        return None

We can delegate the responsibility of interpreting the data from the database to the output field of the expression - that’s how it works in the normal operation of expressions.

    def from_db_value(self, value, expression, connection):
        return self.expression.output_field.from_db_value(value, expression, connection)

Storing the expression in the serialised version of a field is explained in the documentation on custom fields:

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        return name, path, [self.expression] + args, kwargs

To prevent the field from being included in the data we write back to the database turned out to be fairly tricky. There are a couple of mechanisms that can be used, but ultimately the only one that worked in the way I needed was something that is used by the inheritance mechanism. We have to indicate that it is a “private” field. I’m not 100% sure of what the other implications of this might be, but the outcome of making this field private is that it no longer appears in the list of local fields. There is one drawback to this, which I’ll discuss below.

    def contribute_to_class(self, cls, name, private_only=False):
        return super().contribute_to_class(cls, name, True)

So, we only have one task to complete. How do we inject the expression into the query instead of the column?

When django evaluates a queryset, it look at the annotations, and the expressions that are in these. It will then “resolve” these expressions (which means the expression gets told which “query” is being used to evaluate it, allowing it to do whatever it needs to do to make things work).

When a regular field is encountered, it is not resolved: instead it is turned into a Col. This happens in a few different places, but the problem is that a Col should not need to know which query it belongs to: at most it needs to know what the aliased table name is. So, we don’t have a query object we can pass to the resolve_expression method of our expression.

Instead, we’ll need to use Python’s introspection to look up the stack until we find a place that has a reference to this query.

    def get_col(self, alias, output_field=None):
        import inspect

        query = None

        for frame in inspect.stack():
            if frame.function in ['get_default_columns', 'get_order_by']:
                query = frame.frame.f_locals['self'].query
                break
            if frame.function in ['add_fields', 'build_filter']:
                query = frame.frame.f_locals['self']
                break
        else:
            # Aaargh! We don't handle this one yet!
            import pdb; pdb.set_trace()

        col = self.expression.resolve_expression(query=query)
        col.target = self
        return col

So, how does this code actually work? We go through each frame in the stack, and look for a function (or method, but they are really just functions in python) that matches one of the types we know about that have a reference to the query. Then, we grab that, stop iterating and resolve our expression. We have to set the “target” of our resolved expression to the original field, which is how the Col interface works.

This moves the resolve_expression into the get_col, which is where it needs to be. The (resolved) expression is used as the faked column, and it knows how to generate it’s own SQL, which will be put into the query in the correct location.

And this works, almost.

There is one more situation that needs to be taken into account: when we are referencing the field through a join (the x__y lookup syntax you often see in django filters).

Because F() expressions reference the local query, we need to first turn any of these that we find in our computed field’s expression (at any level) into a Col that refers to the correct model. We need to do this before the resolve_expression takes place.

    def get_col(self, alias, output_field=None):
        query = None

        for frame in inspect.stack():
            if frame.function in ['get_default_columns', 'get_order_by']:
                query = frame.frame.f_locals['self'].query
                break
            if frame.function in ['add_fields', 'build_filter']:
                query = frame.frame.f_locals['self']
                break
        else:
            # Aaargh! We don't handle this one yet!
            import pdb; pdb.set_trace()

        def resolve_f(expression):
            if hasattr(expression, 'get_source_expressions'):
                expression = expression.copy()
                expression.set_source_expressions([
                  resolve_f(expr) for expr in expression.get_source_expressions()
                ])
            if isinstance(expression, models.F):
                field = self.model._meta.get_field(expression.name)
                if hasattr(field, 'expression'):
                    return resolve_f(field.expression)
                return Col(alias, field)
            return expression

        col = resolve_f(self.expression).resolve_expression(query=query)
        col.target = self
        return col

There is a repo containing this, which has a bunch of tests showing how the different query types can use the computed field:

https://github.com/schinckel/django-computed-field


But wait, there is one more thing…

A very common requirement, especially if you are planning on using this column for filtering, would be to stick an index on there.

Unfortunately, that’s not currently possible: the mechanism for preventing the field name from being in the write queries, making it a private field, prevents using this field in an index. Anyway, function/expression indexes are not currently supported in Django.

It’s not all bad news though: Markus has a Pull Request that will enable this feauture; from there we could (if db_index is set) automatically add an expression index to Model._meta.indexes in contribute_to_class, but it would also be great to be able to use index_together.

I suspect to get that, though, we’ll need another mechansim to prevent it being in the write queries, but still be a local field.

(Thanks to FunkyBob for suggestions, including suggesting the field at all).

Django bulk_update without upsert

Postgres 9.5 brings a fantastic feature, that I’ve really been looking forward to. However, I’m not on 9.5 in production yet, and I had a situation that would really have benefitted from being able to use it.

I have to insert lots of objects, but if there is already an object in a given “slot”, then I need to instead update that existing object.

Doing this using the Django ORM can be done one a “one by one” basis, by iterating through the objects, finding which one (if any) matches the criteria, updating that, or creating a new one if there wasn’t a match.

However, this is really slow, as it does two queries for each object.

Instead, it would be great to:

  • fetch all of the instances that could possibly overlap (keyed by the matching criteria)
  • iterate through the new data, looking for a match
    • modify the instance if an existing match is made, and stash into pile “update”
    • create a new instance if no match is found, and stash into the pile “create”
  • bulk_update all of the “update” objects
  • bulk_create all of the “create” objects

Those familiar with Django may recognise that there is only one step here that cannot be done as of “now”.

So, how can we do a bulk update?

There are two ways I can think of doing it (at least with Postgres):

  • create a temporary table (cloning the structure of the table)
  • insert all of the data into this table
  • update the rows in the original table from the temporary table, based on pk column

and:

  • come up with some mechanism of using the UPDATE the_table SET ... FROM () sq WHERE sq.pk = the_table.pk syntax

It’s possible to use some of the really nice features of Postgres to create a temporary table, that clones an existing table, and will automatically be dropped at the end of the transaction:

BEGIN;

CREATE TEMPORARY TABLE upsert_source (LIKE my_table INCLUDING ALL) ON COMMIT DROP;

-- Bulk insert into upsert_source

UPDATE my_table
   SET foo = upsert_source.foo,
       bar = upsert_source.bar
  FROM upsert_source
 WHERE my_table.id = upsert_source.id;

The drawbacks of this are that it does two extra queries, but it is possible to implement fairly simply:

from django.db import transaction, connection

@transaction.atomic
def bulk_update(model, instances, *fields):
    cursor = connection.cursor()
    db_table = model._meta.db_table

    try:
        cursor.execute(
            'CREATE TEMPORARY TABLE update_{0} (LIKE {0} INCLUDING ALL) ON COMMIT DROP'.format(db_table)
        )

        model._meta.db_table = 'update_{}'.format(db_table)
        model.objects.bulk_create(instances)

        query = ' '.join([
            'UPDATE {table} SET ',
            ', '.join(
                ('%(field)s=update_{table}.%(field)s' % {'field': field})
                for field in fields
            ),
            'FROM update_{table}',
            'WHERE {table}.{pk}=update_{table}.{pk}'
        ]).format(
            table=db_table,
            pk=model._meta.pk.get_attname_column()[1]
        )
        cursor.execute(query)
    finally:
        model._meta.db_table = db_table

The avantage of this is that it mostly just uses the ORM. There’s limited scope for SQL injection (although you’d probably want to validate the field names).

It’s also possible to do the update directly from a subquery, but without the nice column names:

UPDATE my_table
   SET foo = upsert_source.column2,
       column2 = upsert_source.column3
  FROM (
    VALUES (...), (...)
  ) AS upsert_source
 WHERE upsert_source.column1 = my_table.id;

Note that you must make sure your values are in the correct order (with the primary key first).

Attempting to prevent some likely SQL injection vectors, we want to build up the fixed parts of the query (and the parts that are controlled by the django model, like the table and field names), and then pass the values in as query parameters.

from django.db import connection

def bulk_update(model, instances, *fields):
    set_fields = ', '.join(
        ('%(field)s=update_{table}.column%(i)s' % {'field': field, 'i': i + 2})
        for i, field in enumerate(fields)
    )
    value_placeholder = '({})'.format(', '.join(['%s'] * (len(fields) + 1)))
    values = ','.join([value_placeholder] * len(instances))
    query = ' '.join([
        'UPDATE {table} SET ',
        set_fields,
        'FROM (VALUES ', values, ') update_{table}',
        'WHERE {table}.{pk} = update_{table}.column1'
    ]).format(table=model._meta.db_table, pk=model._meta.pk.get_attname_column()[1])
    params = []
    for instance in instances:
        data.append(instance.pk)
        for field in fields:
            params.append(getattr(instance, field))

    connection.cursor().execute(query, params)

This feels like a reasonable first draft, however I’d probably want to go look at how the query for bulk_create is created, and modify that. There’s a fair bit going on there that I haven’t followed as yet though. Note that this does not need the @transaction.atomic decorator, as it is only a single statement.

From here, we can build an upsert that assumes all objects with a PK need to be updated, and those without need to be inserted:

from django.utils.functional import partition
from django.db import transaction

@transaction.atomic
def bulk_upsert(model, instances, *fields):
    update, create = partition(lambda obj: obj.pk is None, instances)
    if update:
        bulk_update(model, update, *fields)
    if create:
        model.objects.bulk_create(create)

Django second AutoField

Sometimes, your ORM just seems to be out to get you.

For instance, I’ve been investigating a technique for the most important data structure in a system to be essentially immuatable.

That is, instead of updating an existing instance of the object, we always create a new instance.

This requires a handful of things to be useful (and useful for querying).

  • We probably want to have a self-relation so we can see which object supersedes another. A series of objects that supersede one another is called a lifecycle.
  • We want to have a timestamp on each object, so we can view a snapshot at a given time: that is, which phase of the lifecycle was active at that point.
  • We should have a column that unique per-lifecycle: this makes for querying all objects of a lifecycle much simpler (although we can use a recursive query for that).
  • There must be a facility to prevent multiple heads on a lifecycle: that is, at most one phase of a lifecycle may be non-superseded.
  • The lifecycle phases needn’t be in the same order, or really have any differentiating features (like status). In practice they may, but for the purposes of this, they are just “what it was like at that time”.

I’m not sure these ideas will ever get into a released product, but the work behind them was fun (and all my private work).

The basic model structure might look something like:

class Phase(models.Model):
    phase_id = models.AutoField(primary_key=True)
    lifecycle_id = models.AutoField(primary_key=False, editable=False)

    superseded_by = models.OneToOneField('self',
        related_name='supersedes',
        null=True, blank=True, editable=False
    )
    timestamp = models.DateTimeField(auto_now_add=True)

    # Any other fields you might want...

    objects = PhaseQuerySet.as_manager()

So, that looks nice and simple.

Our second AutoField will have a sequence generated for it, and the database will give us a unique value from a sequence when we try to create a row in the database without providing this column in the query.

However, there is one problem: Django will not let us have a second AutoField in a model. And, even if it did, there would still be some problems. For instance, every time we attempt to create a new instance, every AutoField is not sent to the database. Which breaks our ability to keep the lifecycle_id between phases.

So, we will need a custom field. Luckily, all we really need is the SERIAL database type: that creates the sequence for us automatically.

class SerialField(object):
    def db_type(self, connection):
        return 'serial'

So now, using that field type instead, we can write a bit more of our model:

class Phase(models.Model):
    phase_id = models.AutoField(primary_key=True)
    lifecycle_id = SerialField(editable=False)
    superseded_by = models.OneToOneField('self', ...)
    timestamp = models.DateTimeField(auto_now_add=True)

    def save(self, **kwargs):
        self.pk = None
        super(Phase, self).save(**kwargs)

This now ensures each time we save our object, a new instance is created. The lifecycle_id will stay the same.

Still not totally done though. We currently aren’t handling a newly created lifecycle (which should be handled by the associated postgres sequence), nor are we marking the previous instance as superseded.

It’s possible, using some black magic, to get the default value for a database column, and, in this case, execute a query with that default to get the next value. However, that’s pretty horrid: not to mention it also runs an extra two queries.

Similarly, we want to get the phase_id of the newly created instance, and set that as the superseded_by of the old instance. This would require yet another query, after the INSERT, but also has the sinister side-effect of making us unable to apply the not-superseded-by-per-lifecycle requirement.

As an aside, we can investigate storing the self-relation on the other end - this would enable us to just do:

    def save(self, **kwargs):
        self.supersedes = self.pk
        self.pk = None
        super(Phase, self).save(**kwargs)

However, this turns out to be less useful when querying: we are much more likely to be interested in phases that are not superseded, as they are the “current” phase of each lifecycle. Although we could query, it would be running sub-queries for each row.

Our two issues: setting the lifecycle, and storing the superseding data, can be done with one Postgres BEFORE UPDATE trigger function:

CREATE FUNCTION lifecycle_and_supersedes()
RETURNS TRIGGER AS $$

  BEGIN
    IF NEW.lifecycle_id IS NULL THEN
      NEW.lifecycle_id = nextval('phase_lifecycle_id_seq'::regclass);
    ELSE
      NEW.phase_id = nextval('phase_phase_id_seq'::regclass);
      UPDATE app_phase
        SET superseded_by_id = NEW.phase_id
        WHERE group_id = NEW.group_id
        AND superseded_by_id IS NULL;
    END IF;
  END;

$$ LANGUAGE plpgsql VOLATILE;

CREATE TRIGGER lifecycle_and_supersedes
  BEFORE INSERT ON app_phase
  FOR EACH ROW
  EXECUTE PROCEDURE lifecycle_and_supersedes();

So, now all we need to do is prevent multiple-headed lifecycles. We can do this using a UNIQUE INDEX:

CREATE UNIQUE INDEX prevent_hydra_lifecycles
ON app_phase (lifecycle_id)
WHERE superseded_by_id IS NULL;

Wow, that was simple.

So, we have most of the db-level code written. How do we use our model? We can write some nice queryset methods to make getting the various bits easier:

class PhaseQuerySet(models.query.QuerySet):
    def current(self):
        return self.filter(superseded_by=None)

    def superseded(self):
        return self.exclude(superseded_by=None)

    def initial(self):
        return self.filter(supersedes=None)

    def snapshot_at(self, timestamp):
        return filter(timestamp__lte=timestamp).order_by('lifecycle_id', '-timestamp').distinct('lifecycle_id')

The queries generated by the ORM for these should be pretty good: we could look at sticking an index on the lifecycle_id column.

There is one more thing to say on the lifecycle: we can add a model method to fetch the complete lifecycle for a given phase, too:

    def lifecycle(self):
        return self.model.objects.filter(lifecycle_id=self.lifecycle_id)

(That was why I used the lifecycle_id as the column).


Whilst building this prototype, I came across a couple of things that were also interesting. The first was a mechanism to get the default value for a column:

def database_default(table, column):
    cursor = connection.cursor()
    QUERY = """SELECT d.adsrc AS default_value
               FROM   pg_catalog.pg_attribute a
               LEFT   JOIN pg_catalog.pg_attrdef d ON (a.attrelid, a.attnum)
                                                   = (d.adrelid,  d.adnum)
               WHERE  NOT a.attisdropped   -- no dropped (dead) columns
               AND    a.attnum > 0         -- no system columns
               AND    a.attrelid = %s::regclass
               AND    a.attname = %s"""
    cursor.execute(QUERY, [table, column])
    cursor.execute('SELECT {}'.format(*cursor.fetchone()))
    return cursor.fetchone()[0]

You can probably see why I didn’t want to use this. Other than the aforementioned two extra queries, it’s executing a query with data that comes back from the database. It may be possible to inject a default value into a table that causes it to do Very Bad Things™. We could sanitise it, perhaps ensure it matches a regular expression:

NEXTVAL = re.compile(r"^nextval\('(?P<sequence>[a-zA-Z_0-9]+)'::regclass\)$")

However, the trigger-based approach is nicer in every way.

The other thing I discovered, and this one is really nice, is a way to create an exclusion constraint that only applies if a column is NULL. For instance, ensure that no two classes for a given student overlap, but only if they are not superseded (or deleted).

ALTER TABLE "student_enrolments"
ADD CONSTRAINT "prevent_overlaps"
EXCLUDE USING gist(period WITH &&, student_id WITH =)
WHERE (
  superseded_by_id IS NULL
  AND
  status <> 'deleted'
);