Adjacency Lists in Django with Postgres

Today, I’m going to walk through modelling a tree in Django, using an Adjacency List, and a Postgres View that dynamically creates the materialised path of ancestors for each node.

With this, we will be able to query the tree for a range of operations using the Django ORM.

We will start with our model:

class Node(models.Model):
    node_id = models.AutoField(primary_key=True)
    parent = models.ForeignKey('tree.node', related_name='children', null=True, blank=True)

    class Meta:
        app_label = 'tree'

We will also build an unmanaged model that will be backed by our view.

from django.contrib.postgres.fields import ArrayField

class Tree(models.Model):
    root = models.ForeignKey(Node, related_name='+')
    node = models.OneToOneField(Node, related_name='tree_node', primary_key=True)
    ancestors = ArrayField(base_field=models.IntegerField())

    class Meta:
        app_label = 'tree'
        managed = False

You’ll notice I’ve included a root relation. This could be obtained by using ancestors[0] if ancestors else node_id, but that’s a bit cumbersome.

So, on to the View:

CREATE RECURSIVE VIEW tree_tree(root_id, node_id, ancestors) AS

SELECT node_id, node_id, ARRAY[]::INTEGER[]
FROM tree_node WHERE parent_id IS NULL

UNION ALL

SELECT tree.root_id, node.node_id, tree.ancestors || node.parent_id
FROM tree_node node INNER JOIN tree_tree tree ON (node.parent_id = tree.node_id)

I’ve written this view before, so I won’t go into any detail.

We can create a tree. Normally I wouldn’t specify the primary key, but since we want to talk about those values shortly, I will. It also means you can delete them, and recreate with this code, and not worry about the sequence values.

from tree.models import Node

Node.objects.bulk_create([
  Node(pk=1),
  Node(pk=2, parent_id=1),
  Node(pk=3, parent_id=1),
  Node(pk=4, parent_id=2),
  Node(pk=5, parent_id=2),
  Node(pk=6, parent_id=3),
  Node(pk=7, parent_id=3),
  Node(pk=8, parent_id=4),
  Node(pk=9, parent_id=8),
  Node(pk=10),
  Node(pk=11, parent_id=10),
  Node(pk=12, parent_id=11),
  Node(pk=13, parent_id=11),
  Node(pk=14, parent_id=12),
  Node(pk=15, parent_id=12),
  Node(pk=16, parent_id=12),
])

Okay, let’s start looking at how we might perform some operations on it.

We’ve already seen how to create a node, either root or leaf nodes. No worries there.

What about inserting an intermediate node, say between 11 and 12?

node = Node.objects.create(parent_id=11)
node.parent.children.exclude(pk=node.pk).update(parent=node)

I’m not sure if it is possible to do it in a single statement.

Okay, let’s jump to some tree-based statements. We’ll start by finding a sub-tree.

Node.objects.filter(tree_node__ancestors__contains=[2])

Oh, that’s pretty nice. It’s not necessarily sorted, but it will do for now.

We can also query directly for a root:

Node.objects.filter(tree_node__root=10)

We could spell that one as tree_node__ancestors__0=10, but I think this is more explicit. Also, that one will not include the root node itself.

Deletions are also simple: if we can build a queryset, we can delete it. Thus, deleting a full tree could be done by following any queryset by a .delete()

Fetching a node’s ancestors is a little trickier: because we only have an array of node ids; thus it does two queries.

Node.objects.filter(pk__in=Node.objects.get(pk=15).tree_node.ancestors)

The count of ancestors doesn’t require the second query:

len(Node.objects.get(pk=15).tree_node.ancestors)

Getting ancestors to a given depth is also simple, although it still requires two queries:

Node.objects.filter(pk__in=Node.objects.get(pk=15).tree_node.ancestors[-2:])

This is a fairly simple way to enable relatively performance-aware queries of tree data. There are still places where it’s not perfect, and in reality, you’d probably look at building up queryset or model methods for wrapping common operations.

Postgres Tree Shootout part 3: Adjacency List using Views

It’s been a while, but I’ve finally gotten off my arsefound some time to revisit this series. As promised last time, I’m going to rewrite the queries from the Adjacency List “solutions” using a View. Indeed, there will be two versions of the view - one which is a MATERIALIZED VIEW. There will also be discussion of when the two different types of view might be best to use.

One of the reasons this post took so long to write was that I was sidetracked by writing an SVG generator that would allow for graphically seeing what the different operations discussed in this series look like in terms of an actual tree. That didn’t eventuate.

We will start by defining what our tree view will actually look like. You’ll notice is it rather like the CTE that we saw in the previous post.

CREATE TABLE nodes (
  node_id SERIAL PRIMARY KEY,
  parent_id INTEGER REFERENCES nodes(node_id)
);

CREATE RECURSIVE VIEW tree (node_id, ancestors) AS (
  SELECT node_id, ARRAY[]::integer[] AS ancestors
  FROM nodes WHERE parent_id IS NULL

  UNION ALL

  SELECT nodes.node_id, tree.ancestors || nodes.parent_id
  FROM nodes, tree
  WHERE nodes.parent_id = tree.node_id
);

INSERT INTO nodes VALUES
  (1, NULL),
  (2, 1),
  (3, 1),
  (4, 2),
  (5, 2),
  (6, 3),
  (7, 3),
  (8, 4),
  (9, 8),
  (10, NULL),
  (11, 10),
  (12, 11),
  (13, 11),
  (14, 12),
  (15, 12),
  (16, 12);

Insertions

All of the insertions do not require access to the tree view, since the beauty of an Adjacency List model is that you only ever need to operate on the immediate parent-child.

Removals

Similarly, we will skip over the simple operations: those don’t require access to any more of the tree than just the parent-child relationship. It’s not until we need to remove a subtree that it becomes interesting.

DELETE FROM nodes
WHERE node_id IN (
  SELECT node_id FROM tree WHERE 2 = ANY(ancestors)
) OR node_id = 2;

If you are paying attention, you will notice that this is virtually identical to the CTE version, except that we no longer need to redeclare the CTE each time. The full tree deletion is the same, as is removing all decscendants:

DELETE FROM nodes
WHERE node_id IN (
  SELECT node_id FROM tree WHERE 2 = ANY(ancestors)
);

Moves

Again, the operations that don’t require the actual tree are unchanged: this is where the Adjacency List really shines.

Fetches

Since we are starting with the “full” tree, we should be able to use it for all of the queries. It is possible that these queries (unlike those we have seen before) may be slightly slower than the CTE version (specifically, those where the CTE is customised for that operation).

Descendants

Let’s get all of node 10’s descendants:

SELECT node_id FROM tree WHERE 10 = ANY(ancestors);

This query is far less complicated than the CTE version, as expected. However, when dealing with very large datasets, it performs far worse. I have a data set with 221000 nodes, in 1001 different trees. Performing this query takes around 5 seconds, but the customised CTE version takes about 750ms.

Turning this view into a materialised view:

CREATE MATERIALIZED VIEW tree_mat AS
SELECT node_id, ancestors FROM tree;

and then querying that turns this into around 75ms.

To limit the query to nodes to a given depth requires slightly more work.

SELECT node_id, ancestors FROM tree
WHERE ARRAY_POSITION(ancestors, 10) < ARRAY_LENGTH(ancestors, 1) - 2;

Ancestors

Fetching ancestors of a node is again trivial:

SELECT unnest(ancestors) FROM tree WHERE node_id = 15;

And the count of ancestors:

SELECT ARRAY_LENGTH(ancestors, 1) FROM tree WHERE node_id=15;

Getting a set of ancestors to a given depth is actually a little tricky: because we can’t just reverse the end that we add the parent node to the ancestors array, we can’t use that trick. We’ll have to enumerate the rows, and then extract those we care about. You can’t use OFFSET with a variable, otherwise that would be a nice trick.

WITH ancestors AS (
  SELECT unnest(ancestors) AS node_id
  FROM tree
  WHERE node_id = 15
), enumerated AS (
  SELECT
    row_number() OVER () AS row,
    count(*) OVER () AS ancestor_count,
    node_id
  FROM ancestors
)
SELECT node_id
FROM enumerated
WHERE "row" > ancestor_count - 2;

Ugh. That’s way worse than the CTE version.

Special queries

None of the special queries access the tree in any way, so can be omitted for now.

Discussion

So how does using a view stack up to the ad-hoc CTE queries?

Mostly pretty well. In the case where you have only small data sets, then the view that builds up the complete tree each time is not that much of a problem (I ran some tests with tens of thousands of items, and it still performed relatively well). When it moves up to hundreds of thousands, then the ad-hoc CTE versions can greatly outperform the full tree view.

However, using a materialised view changes everything. It now becomes just as fast as querying a table: indeed, that’s just what it is. You could have triggers based on changes to the nodes table causing a REFRESH MATERIALIZED VIEW, but it is worth keeping in mind that this will take some time: in my case, a full refresh of 221000 rows took upwards of 4.5 seconds.

Using a materialised view gets us most of the way to (and leads nicely into the next method), storing a materialised path. The similarity of the names here should be a trigger, but now I’m just making foreshadowing jokes.

Tree data as a nested list

One of the nice things about Adjacency Lists as a method of storing tree structures is that there is not much redundancy: you only store a reference to the parent, and that’s it.

It does mean that getting that data in a nested object is a bit complicated. I’ve written before about getting data out of a database: I’ll revisit that again I’m sure, but for now, I’m going to deal with data that has the following shape: that is, has been built up into a Materialized Path:

[
  {
    "node": 1,
    "ancestors": [],
    "label": "Australia"
  },
  {
    "node": 2,
    "ancestors": [1],
    "label": "South Australia"
  },
  {
    "node": 3,
    "ancestors": [1],
    "label": "Victoria"
  },
  {
    "node": 4,
    "ancestors": [1, 2],
    "label": "South-East"
  },
  {
    "node": 5,
    "ancestors": [1, 3],
    "label": "Western Districts"
  },
  {
    "node": 6,
    "ancestors": [],
    "label": "New Zealand"
  },
  {
    "node": 7,
    "ancestors": [1, 2],
    "label": "Barossa Valley"
  },
  {
    "node": 8,
    "ancestors": [1, 2],
    "label": "Riverland"
  }
]

From here, we want to build up something that looks like:

  • Australia
    • South Australia
      • Barossa Valley
      • Riverland
      • South East
    • Victoria
      • Western Districts
  • New Zealand

Or, a nested python data structure:

[
  ('Australia', [
    ('South Australia', [
      ('Barossa Valley', []),
      ('Riverland', []),
      ('South-East', [])
    ]),
    ('Victoria', [
      ('Western Districts', [])
    ])
  ]),
  ('New Zealand', [])
]

You’ll see that each node is a 2-tuple, and each set of siblings is a list. Even a node with no children still gets an empty list.

We can build up this data structure in two steps: based on the fact that a dict, as key-value pairs, matches a 2-tuple. That is, we will start by creating:

{
  1: {
    2: {
      4: {},
      7: {},
      8: {},
    },
    3: {
      5: {},
    }
  },
  6: {},
}

You might be reaching for python’s defaultdict class at this point, but there is a slightly nicer way:

class Tree(dict):
    def __missing__(self, key):
        value = self[key] = type(self)()
        return value

(Note: This class, and the seed of the idea, came from this answer on StackOverflow).

We can also create a recursive method on this class that creates a node and all of it’s ancestors:

    def insert(self, key, ancestors):
        if ancestors:
            self[ancestors[0]].insert(key, ancestors[1:])
        else:
          self[key]
>>> tree = Tree()
>>> for node in data:
...     tree.insert(node['node'], node['ancestors'])
>>> print tree
{1: {2: {8: {}, 4: {}, 7: {}}, 3: {5: {}}}, 6: {}}

Looking good.

Let’s make another method that allows us to actually insert the labels (and apply a sort, if we want):

    def label(self, label_dict, sort_key=lambda x: x[0]):
        return sorted([
          (label_dict.get(key), value.label(label_dict, sort_key))
          for key, value in self.items()
        ], key=sort_key)

We also need to build up the simple key-value store to pass as label_dict, but that’s pretty easy.

Let’s look at the full code: first the complete class:

class Tree(dict):
    """Simple Tree data structure

    Stores data in the form:

    {
        "a": {
            "b": {},
            "c": {},
        },
        "d": {
            "e": {},
        },
    }

    And can be nested to any depth.
    """

    def __missing__(self, key):
        value = self[key] = type(self)()
        return value

    def insert(self, node, ancestors):
        """Insert the supplied node, creating all ancestors as required.

        This expects a list (possibly empty) containing the ancestors,
        and a value for the node.
        """
        if not ancestors:
            self[node]
        else:
            self[ancestors[0]].insert(node, ancestors[1:])

    def label(self, labels, sort_key=lambda x: x[0]):
        """Return a nested 2-tuple with just the supplied labels.

        Realistically, the labels could be any type of object.
        """
        return sorted([
            (
                labels.get(key),
                value.label(labels, sort_key)
            ) for key, value in self.items()
        ], key=sort_key)

Now, using it:

>>> tree = Tree()
>>> labels = {}
>>>
>>> for node in data:
>>>     tree.insert(node['node'], node['ancestors'])
>>>     labels[node['node']] = node['label']
>>>
>>> from pprint import pprint
>>> pprint(tree.label(labels))

[('Australia',
  [('South Australia',
    [('Barossa Valley', []), ('Riverland', []), ('South-East', [])]),
   ('Victoria', [('Western Districts', [])])]),
 ('New Zealand', [])]

Awesome. Now use your template rendering of choice to turn this into a nicely formatted list.

slugify() for postgres (almost)

A recent discussion in #django suggested “what we need is a PG slugify function”.

The actual algorithm in Django for this is fairly simple, and easy to follow. Shouldn’t be too hard to write it in SQL.

Function slugify(value, allow_unicode=False).

  • Convert to ASCII if allow_unicode is false
  • Remove characters that aren’t alphanum, underscores, hyphens
  • Strip leading/trailing whitespace
  • Convert to lowercase
  • Convert spaces to hyphens
  • Remove repeated hyphens

(As an aside, the comment in the django function is slightly misleading: if you followed the algorithm there, you’d get a different result with respect to leading trailing whitespace. I shall submit a PR).

We can write an SQL function that uses the Postgres unaccent extension to get pretty close:

CREATE OR REPLACE FUNCTION slugify("value" TEXT, "allow_unicode" BOOLEAN)
RETURNS TEXT AS $$

  WITH "normalized" AS (
    SELECT CASE
      WHEN "allow_unicode" THEN "value"
      ELSE unaccent("value")
    END AS "value"
  ),
  "remove_chars" AS (
    SELECT regexp_replace("value", E'[^\\w\\s-]', '', 'gi') AS "value"
    FROM "normalized"
  ),
  "lowercase" AS (
    SELECT lower("value") AS "value"
    FROM "remove_chars"
  ),
  "trimmed" AS (
    SELECT trim("value") AS "value"
    FROM "lowercase"
  ),
  "hyphenated" AS (
    SELECT regexp_replace("value", E'[-\\s]+', '-', 'gi') AS "value"
    FROM "trimmed"
  )
  SELECT "value" FROM "hyphenated";

$$ LANGUAGE SQL STRICT IMMUTABLE;

I’ve used a CTE to get each step as a seperate query: you can do it with just two levels if you don’t mind looking at nested function calls:

CREATE OR REPLACE FUNCTION slugify("value" TEXT, "allow_unicode" BOOLEAN)
RETURNS TEXT AS $$

  WITH "normalized" AS (
    SELECT CASE
      WHEN "allow_unicode" THEN "value"
      ELSE unaccent("value")
    END AS "value"
  )
  SELECT regexp_replace(
    trim(
      lower(
        regexp_replace(
          "value",
          E'[^\\w\\s-]',
          '',
          'gi'
        )
      )
    ),
    E'[-\\s]+', '-', 'gi'
  ) FROM "normalized";

$$ LANGUAGE SQL STRICT IMMUTABLE;

To get the default value for the second argument, we can have an overloaded version with only a single argument:

CREATE OR REPLACE FUNCTION slugify(TEXT)
RETURNS TEXT AS 'SELECT slugify($1, false)' LANGUAGE SQL IMMUTABLE STRICT;

Now for some tests. I’ve been using pgTAP lately, so here’s some tests using that:

BEGIN;

SELECT plan(7);

SELECT is(slugify('Hello, World!', false), 'hello-world');
SELECT is(slugify('Héllø, Wørld!', false), 'hello-world');
SELECT is(slugify('spam & eggs', false), 'spam-eggs');
SELECT is(slugify('spam & ıçüş', true), 'spam-ıçüş');
SELECT is(slugify('foo ıç bar', true), 'foo-ıç-bar');
SELECT is(slugify('    foo ıç bar', true), 'foo-ıç-bar');
SELECT is(slugify('你好', true), '你好');

SELECT * FROM finish();

ROLLBACK;

And we get one failing test:

=# SELECT is(slugify('你好', true), '你好');

          is
──────────────────────
 not ok 7            ↵
 # Failed test 7     ↵
 #         have:     ↵
 #         want: 你好
(1 row)

Time: 2.004 ms

It seems there is no way to get the equivalent to the python re.U flag on a postgres regular expression function, so that is as close as we can get.

Row Level Security in Postgres and Django

Postgres keeps introducing new things that pique my attention. One of the latest ones of these is Row Level Permissions, which essentially hides rows that a given database user cannot view. There’s a bit more to it than that, a good writeup is at Postgres 9.5 feature highlight: Row-Level Security and Policies.

However, it’s worth noting that the way Django connects to a database uses a single database user. Indeed, if your users table is in your database, then you’ll need some way to connect to it to authenticate. I haven’t come up with a nice way to use the Postgres users for authentication within Django just yet.

I did have an idea about a workflow that may just work.

  • Single Postgres User is used for authentication (Login User).
  • Every Django user gets an associated Postgres User (Session User), that may not log in.
  • This Session User is automatically created using a Postgres trigger, whenever the Django users table is updated.
  • After authentication, a SET SESSION ROLE (or SET SESSION AUTHORIZATION) statement is used to change to the correct Session User for the remainder of the session.

Then, we can implement the Postgres Row Level Security policies as required.

Initially, I had thought that perhaps the Session Users would have the same level of access as the Login User, however tonight it occurred to me that perhaps this could replace the whole Django permissions concept.


We do have a few things we need to work out before this is a done deal.

  • The trigger function that performs the CREATE ROLE statement when the django users table is updated.
  • Some mechanism of handling GRANT and REVOKE statements.
  • Similarly, some mechanism for showing current permissions for the given user.
  • A middleware that sets the SESSION USER according to the django user.

The simplest part of this is the last one, so we will start there. We can (in the meantime) manually create the users and their permissions to see how well it all goes. No point doing work that doesn’t work.

from django.db import connection


class SetSessionAuthorization(object):
    def process_view(self, request, *args, **kwargs):
        if request.user.pk:
          connection.cursor().execute(
            'SET SESSION SESSION AUTHORIZATION "django:{}"'.format(request.user.pk)
          )

We need to add this to our project’s middleware.

You’ll see we are using roles of the form django:<id>, which need to be quoted. We use the user id rather than the username, because usernames may be changed.

We’ll want to create a user for each of the existing Django users: I currently have a single user in this database, with id 1. I also have an existing SUPERUSER with the name django. We need to use a superuser if we are using SET SESSION AUTHORIZATION, which seems to be the best. I haven’t found anything which really does a good job of explaining how this and SET SESSION ROLE differ.

CREATE USER "django:1" NOLOGIN;
GRANT "django:1" TO django;
GRANT ALL ON ALL TABLES IN SCHEMA public TO public;
GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO public;

Note we have just for now enabled full access to all tables and sequences. This will remain until we find a good way to handle this.

We can start up a project using this, and see if it works. Unless I’ve missed something, then it should.

Next, we will turn on row-level-security for the auth_user table, and see if it works.

ALTER TABLE auth_user ENABLE ROW LEVEL SECURITY;

Then try to view the list of users (even as a django superuser). You should see an empty list.

We’ll turn on the ability to see our own user object:

CREATE POLICY read_own_data ON auth_user FOR
SELECT USING ('django:' || id = current_user);

Phew, that was close. Now we can view our user.

However, we can’t update it. Let’s fix that:

CREATE POLICY update_own_user_data ON auth_user FOR
UPDATE USING ('django:' || id = current_user)
WITH CHECK ('django:' || id = current_user);

We should be able to do some magic there to prevent a user toggling their own superuser status.

Let’s investigate writing a trigger function that creates a new ROLE when we update the django user.

CREATE OR REPLACE FUNCTION create_shadow_role()
RETURNS TRIGGER AS $$

BEGIN

  EXECUTE 'CREATE USER "django:' || NEW.id || '" NOLOGIN';
  EXECUTE 'GRANT "django:' || NEW.id || '" TO django';

  RETURN NULL;

END;

$$
LANGUAGE plpgsql
SECURITY DEFINER
SET search_path =  public, pg_temp
VOLATILE;


CREATE TRIGGER create_shadow_role
  AFTER INSERT ON auth_user
  FOR EACH ROW
  EXECUTE PROCEDURE create_shadow_role();

Note we still can’t create users from the admin (due to the RLS restrictions thata are there), so we need to resort to ./manage.py createsuperuser again.

Having done that, we should see that our new user gets a ROLE:

# \du
                                        List of roles
 Role name │                         Attributes                         │      Member of
───────────┼────────────────────────────────────────────────────────────┼─────────────────────
 django    │ Superuser                                                  │ {django:1,django:6}
 django:1  │ Cannot login                                               │ {}
 django:6  │ Cannot login                                               │ {}
 matt      │ Superuser, Create role, Create DB                          │ {}
 postgres  │ Superuser, Create role, Create DB, Replication, Bypass RLS │ {}

We should be able to write similar triggers for update. We can, for example, shadow the Django is_superuser attribute to the Postgres SUPERUSER attribute. I’m not sure if that’s a super idea or not.

But we can write a simple function that allows us to see if the current django user is a superuser:

CREATE FUNCTION is_superuser()
RETURNS BOOLEAN AS $$

SELECT is_superuser
FROM auth_user
WHERE 'django:' || id = current_user

$$ LANGUAGE SQL;

We can now use this to allow superuser access to all records:

CREATE POLICY superuser_user_select ON auth_user
FOR SELECT USING (is_superuser);

CREATE POLICY superuser_user_update ON auth_user
FOR UPDATE USING (is_superuser)
WITH CHECK (is_superuser);

That gives us a fair bit of functionality. We still don’t have any mechanism for viewing or setting permissions. Because of the way Django’s permissions work, we can’t quite use the same trick but on the auth_user_user_permissions table, because we’d need to also look at the auth_user_groups table and auth_group_permissions.

I’m still not sure if this is a good idea or not, but it is a fun thought process.

Unicode Flags in Python

I can’t even remember how I got onto this idea.

I’ve been an avid user of django-countries, which is a really nice way to have a country field, without having to maintain your own database of countries.

One neat feature is that Chris includes flag icons for all countries. I have some code in my project that uses these to have the correct flag shown next to the country select box whenever you change the selection.

However, these flags are small, and cannot resize easily (and require a request to fetch the image). Then it occurred to me that new iOS/OS X (and probably other platforms) now support Emoji/Unicode flags.

A bit of research found Unicode’s encoding of national flags is just crazy enough to work, which discusses how the system works: basically the two-letter (ISO 3166-1 alpha-2 code) is used, but instead of just “AU”, it uses the Regional Indicator Symbol. When two valid characters are used one after another, they are combined into a single glyph.

We just need to add an offset to the integer value of each character in the code, and convert this back into a unicode character.

The offset can be calculated from knowing the start of the Regional Indicator Symbol range:

python3
>>> ord('🇦')
127462

To generate a flag glyph in python, you can use the following:

OFFSET = ord('🇦') - ord('A')

def flag(code):
    return chr(ord(code[0]) + OFFSET) + chr(ord(code[1]) + OFFSET)

This only works with Python 3, and we can expand it a bit to make it more robust:

OFFSET = ord('🇦') - ord('A')

def flag(code):
    if not code:
        return u''
    points = map(lambda x: ord(x) + OFFSET, code.upper())
    try:
        return chr(points[0]) + chr(points[1])
    except ValueError:
        return ('\\U%08x\\U%08x' % tuple(points)).decode('unicode-escape')

This relies on the fact that Python 2 will raise a ValueError when attempting to chr() a value greater than 256. Whilst we could use unichr(), this fails on systems that have python compiled without wide unicode support. Luckily my local system was that case here.

Thoughts on Mutation testing in Python

Writing code is fun.

Writing tests can be less fun, but is a great way to have code that is more likely to work as desired.

Using a coverage tool will show you what percentage of your program is executed when you run your tests, but getting 100% code coverage does not mean your code is 100% tested.

For instance, testing a function:

def product(a, b):
  return a * b

If your tests are not carefully written, you may end up with tests that pass when they really should not. For instance:

>>> product(2, 2)
4

This execution shows only that it works for that choice of values. But how many sets of values do we need to test in order to be satisfied that a function does what is expected?

I happened across a document about Mutation Testing some time ago. It was actually new to me (and to a couple of people I respect), but it was actually quite interesting. The idea behind it is that you can change bits of the program, re-run your tests, and if your tests still succeed, then your tests are not sufficient.

Mutation Testing, then, is testing your tests.

For instance, mutating the * in the previous function to + results in a passing test, even though the code has changed (and, to us at least, is clearly different). Thus, we need to improve our test(s).

In theory, we should be able to mutate each statement in our program in all the ways it could be mutated, and if any of these mutants are not killed by our tests, then our tests are insufficient.

In practice, this is untenable. Some mutants will be equivalent (for instance, a * b is the same as b * a), but more so, there will be many, many mutants created for all but the simplest program.


Aside: doctest

Python has a fantastic library called doctest. This enables you to write tests as part of the function definition. I’ll use that here, as it makes writing the tests simple.

def product(a, b):
  """Multiply two numbers together.

  >>> product(2, 3)
  6
  """
  return a * b

They are really useful for documentation that is an accurate representation of what the code does. They don’t replace unit tests (indeed, you’ll want your unit tests to run your doctests), but they are good fun.

This investigation will use doctests, as then the tests for a module are always self-contained.


Selection of mutations

With a suitably complicated program, the number of possible mutations will be very, very large. We will, for now, consider only applying one mutation per mutant, doing otherwise would result in huge numbers of possible mutants.

There is quite a bit of academic research into Mutation Testing (indeed, if I were an academic, I’d probably write a paper or two on it), and one part that is relatively well documented is the discussion of mutation selection. That is, how can we reduce the number of mutant programs (and make our tests run in a satisfactory time), without missing mutations that could show how our tests may be improved.

Offtut et al (1996) discusses how we can reduce the number of mutation operators, and come up with the conclusion that with just five operators, sufficient tests can be determined.

Using the classifications from Mothra, they are:

  • ABS: ABSolute value insertion. Force arithmetic expressions to take on (-, 0, +) values.
  • AOR: Arithmetic Operator Replacement. Replace every arithmetic operator with every syntactically legal other arithmetic operator.
  • LCR: Logical Connectior Replacement. Replace every logical connector (and, or) with several kinds of logical connectors.
  • ROR: Relational Operator Replacement. Replace every relational operator with every other relational operator.
  • UOI: Unary Operator Insertion. Insert unary operators in front of expressions. In the case of python, the unary + operator will not be added, as that usually results in an equivalent mutant. I’m actually also going to include swapping unary operators as part of this operation.

These mutation operations make sense in python. For instance, in our toy example, we’d see the following (showing only the return statement from our product function):

# original
return a * b

# ABS
return -1
return 0
return 1
return a * -1
return a * 0
return a * 1
return -1 * b
return 0 * b
return 1 * b

# AOR
return a + b
return a - b
return a / b
return a // b
return a ** b
return a % b

# UOI
return -a * b
return a * -b
return -(a * b)
return not a * b
return a * not b
return not (a * b)
return ~a * b
return a * ~b
return ~(a * b)

We can see from our toy example that there are no LCR or ROR mutations possible, since there are no logical or relational operations or connectors. However, in the case of python, we could see code like:

# original
return a < b and a > 0

# LCR
return a < b or a > 0
return a < b and not a > 0
return a < b or not a > 0

# ROR
return a > b and a > 0
return a >= b and a > 0
return a <= b and a > 0
return a != b and a > 0
return a == b and a > 0

return a < b and a < 0
return a < b and a <= 0
return a < b and a == 0
return a < b and a >= 0
return a < b and a != 0

Limiting ourself to just these five mutation operations is great: it simplifies our mutation code immensely.

Within the context of python, I actually think there should also be some other mutations. A common mistake in programmer code is to omit a call to super(), or missing the explicit self argument in a method definition. For now, we’ll stick to the five above though.

I’m guessing the extreme age of Mothra means it wasn’t used in the case of object-oriented programming languages, even more so with multiple inheritance!

Okay, we are done with the generic stuff, time to get specific.


Aside: ast

Another great module that is included in python is ast. This allows you to build an Abstract Syntax Tree from (among other things) a string containing python code. Along with astor, which allows you to rewrite it as python code, after performing our mutation operation(s).

import ast
tree = ast.parse(open(filename).read())
ast.fix_missing_locations(tree)

The ast.fix_missing_locations stuff fixes up any line numbers, which we may use for reporting later.


Mutating the AST.

Bingo, now we have an abstract syntax tree containing our module. ast also contains classes for walking this tree, which we can subclass to do interesting things. For instance, to collect all statements that should be mutated, we can do something like:

SWAPS = {
  ast.Add: [ast.Mult, ast.Sub, ast.Div, ast.Mod, ast.Pow, ast.FloorDiv],
  # etc, etc
}

class Collector(ast.NodeVisitor):
  def __init__(self, *args, **kwargs):
    self.mutable_nodes = []

  def __visit_binary_operation(self, node):
    # Visit our left child node.
    self.visit(node.left)
    # Create all possible mutant nodes for this, according to AOR
    self.mutable_nodes.extend([
      (node, node.__class__(
        left=node.left, op=op(), right=node.right
      )) for op in SWAPS.get(node.op.__class__, [])
    ])
    # Visit our right child node.
    self.visit(node.right)

  # These three node types are handled identically.
  visit_BinOp = __visit_binary_operation
  visit_BoolOp = __visit_binary_operation
  visit_Compare = __visit_binary_operation

  def visit_UnaryOp(self, node):
    # Create all possible mutant nodes, according to UOI.
    self.mutable_nodes.extend([
      (node, node.__class__(
        op=op(), operand=node.operand)
      ) for op in SWAPS.get(node.op.__class__)
    ])
    # Also, create a node without the operator.
    self.mutable_nodes.append((node, node.operand))
    self.visit(node.operand)

Our mutator code is actually much simpler. Since we keep a reference to the nodes we want to mutate (and the new node it should be replaced with), we can just swap them each time a given node is visited:

class Mutator(ast.NodeTransformer):
  def __init__(self, original, replacment):
    self.original = original
    self.replacment = replacment

  def __swap(self, node):
    # Handle child nodes.
    self.generic_visit(node)

    # Don't swap out the node if it wasn't our target node.
    if node not in [self.original, self.replacement]:
        return node

    # Swap the node back if we are being visited again.
    if node == self.replacement:
        return self.original

    # Otherwise, swap the node out for the replacement.
    return self.replacement

  # We can just use this same function for a whole stack of visits.
  visit_BinOp = __swap
  visit_BoolOp = __swap
  visit_Compare = __swap
  visit_UnaryOp = __swap

Note that calling mutator.visit(tree) on a mutated tree will revert the mutation.

To use these, assuming we have a special function test that runs the tests we want:

tree = ast.parse(open(filename).read())
ast.fix_missing_locations(tree)

results = test(tree)
if results.failed:
    raise Exception("Unable to run tests without mutations.")

# Collect all of the possible mutants.
collector = Collector()
collector.visit(tree)

survivors = []

for (node, replacement) in collector.mutable_nodes:
    mutator = Mutator(node, replacement)
    # Apply our mutation
    mutator.visit(tree)
    ast.fix_missing_locations(tree)

    try:
        results = test(tree)
    except Exception:
        # Looks like this mutant was DOA.
        continue

    if not results.failed:
        survivors.append((node, replacement, results))

    # Revert our mutation
    mutator.visit(tree)

This is a bit of a simplification, but it’s actually pretty close to working code. We use multiple processes to run it in parallel (and also have a timeout based on the initial test run time, assuming we should not take more than twice as long), and compile the tree into a module to test it.

You may see the current version at pymutant. Keep in mind that this is little more than a proof of concept at this stage.


Testing our testing

So, let’s look at some toy examples, and see what the outcomes are.

def product(a, b):
    """
    >>> product(2, 2)
    4
    >>> product(2, 3)
    6
    """
    return a * b

def addition(a, b):
    """
    >>> addition(1, 0)
    1
    >>> addition(1, 1)
    2
    """
    return a + b

def negate(a):
    """
    >>> negate(1)
    -1
    >>> negate(-1)
    1
    """
    return -a

Running in a debug mode, we can view the mutants, and the results:

TestResults(failed=2, attempted=6) (a * b) -> (not (a * b))
TestResults(failed=1, attempted=6) (a * b) -> (a ** b)
TestResults(failed=2, attempted=6) (a * b) -> (a // b)
TestResults(failed=2, attempted=6) (a * b) -> 1
TestResults(failed=2, attempted=6) (a * b) -> (a % b)
TestResults(failed=2, attempted=6) (a * b) -> 0
TestResults(failed=2, attempted=6) (a * b) -> (- (a * b))
TestResults(failed=2, attempted=6) (a * b) -> (~ (a * b))
TestResults(failed=2, attempted=6) (a * b) -> True
TestResults(failed=2, attempted=6) (a * b) -> (-1)
TestResults(failed=2, attempted=6) (a * b) -> False
TestResults(failed=2, attempted=6) (a * b) -> (a / b)
TestResults(failed=1, attempted=6) (a * b) -> (a + b)
TestResults(failed=2, attempted=6) return (a * b) -> pass
TestResults(failed=2, attempted=6) (a * b) -> None
TestResults(failed=2, attempted=6) (a * b) -> (a - b)
TestResults(failed=2, attempted=6) (a + b) -> (not (a + b))
TestResults(failed=2, attempted=6) (a + b) -> (a % b)
TestResults(failed=1, attempted=6) (a + b) -> (a - b)
TestResults(failed=1, attempted=6) (a + b) -> (a ** b)
TestResults(failed=2, attempted=6) (a + b) -> (~ (a + b))
TestResults(failed=2, attempted=6) (a + b) -> (-1)
TestResults(failed=1, attempted=6) (a + b) -> 1
TestResults(failed=2, attempted=6) (a + b) -> (- (a + b))
TestResults(failed=2, attempted=6) (a + b) -> (a // b)
TestResults(failed=2, attempted=6) (a + b) -> None
TestResults(failed=2, attempted=6) (a + b) -> 0
TestResults(failed=1, attempted=6) (a + b) -> True
TestResults(failed=2, attempted=6) (a + b) -> False
TestResults(failed=2, attempted=6) return (a + b) -> pass
TestResults(failed=2, attempted=6) (a + b) -> (a / b)
TestResults(failed=2, attempted=6) (a + b) -> (a * b)
TestResults(failed=2, attempted=6) (- a) -> 0
TestResults(failed=2, attempted=6) (- a) -> (- (- a))
TestResults(failed=2, attempted=6) return (- a) -> pass
TestResults(failed=2, attempted=6) (- a) -> (not (- a))
TestResults(failed=1, attempted=6) (- a) -> True
TestResults(failed=2, attempted=6) (- a) -> False
TestResults(failed=1, attempted=6) (- a) -> (-1)
TestResults(failed=2, attempted=6) (- a) -> None
TestResults(failed=1, attempted=6) (- a) -> 1
TestResults(failed=2, attempted=6) (- a) -> (~ (- a))

===============
Mutation Report
===============

* Generated 42 mutants, and tested in 0.107773065567 seconds.

* 0 of these mutants were unable to execute correctly.

* 0 of these mutants were killed for taking too long to execute.

* Tests killed of 42 of the remaining mutants, leaving 0 survivors.

* Your innoculation rate is 100%.

Well, that’s all nice, but there’s more we can think about than this. What about tests that are “useless”? Tests that never fail, for instance?

Unfortunately, doctest.testmod() only returns the count of failures and attempts (and looking into that module, I’m not sure that which tests passed/failed is actually stored). It would be really nice to be able to capture this, but perhaps that is a task for a unittest-based approach.

What about a slightly more complex example?

def aggregate(items):
    """
    Aggregate a list of 2-tuples, which refer to start/finish values.

    Returns a list with overlaps merged.

    >>> aggregate([])
    []
    >>> aggregate([(1, 3)])
    [(1, 3)]
    >>> aggregate([(1, 3), (2, 6)])
    [(1, 6)]
    >>> aggregate([(1, 3), (4, 6)])
    [(1, 3), (4, 6)]
    >>> aggregate([(3, 4), (1, 9)])
    [(1, 9)]
    """

    # Sort our items first, by the first value in the tuple. This means we can
    # iterate through them later.
    sorted_items = sorted(items)

    i = 0
    while i < len(sorted_items) - 1:
        current = sorted_items[i]
        next = sorted_items[i + 1]

        if current[1] >= next[1]:
            # Skip over the next item totally.
            sorted_items.remove(next)
            continue

        if current[1] >= next[0]:
            # Merge the two items.
            sorted_items[i:i+2] = ((current[0], next[1]),)
            continue

        i += 1

    return sorted_items

And when we test with mutants (no debug mode this time, as we have 169 mutants):

===============
Mutation Report
===============

* Generated 169 mutants, and tested in 0.905333995819 seconds.

* 0 of these mutants were unable to execute correctly.

* 61 of these mutants were killed for taking too long to execute.

* Tests killed of 99 of the remaining mutants, leaving 9 survivors.

* Your innoculation rate is 94%.

Survivor Report
===============

0 at line 23, changed to False
(i < (len(sorted_items) - 1)) at line 24, changed to (- (i < (len(sorted_items) - 1)))
(i + 1) at line 26, changed to (i - 1)
(i + 1) at line 26, changed to (- (i + 1))
(i + 1) at line 26, changed to 1
(i + 1) at line 26, changed to (-1)
(i + 1) at line 26, changed to True
(current[1] >= next[1]) at line 28, changed to (- (current[1] >= next[1]))
(current[1] >= next[1]) at line 28, changed to (current[1] > next[1])

Timeout Report
==============

0 at line 23, changed to (-1)
...

I’ve omitted the remainder timeout report.

But, this does show us that our tests are incomplete, and perhaps what we should be doing to fix this.

In particular, note the group of mutants at line 26 that all survived: indicating that this particular line of code is not being tested well at all.

Perhaps the biggest takeaway (and an indicator of how Mutation Testing may be really useful) is the last listed mutant. It’s showing that this particular comparison is clearly not being tested for off-by-one errors.

Postgres Domains and Triggers

It occurred to me tonight (after rewriting my Tax File Number generation script in JavaScript, so it works on Text Expander Touch), that it might be nice to do TFN validation in Postgres.

You can write a function that validates a TFN (this is not the first version I wrote, this one is a bit more capable in terms of the range of values it accepts):

CREATE OR REPLACE FUNCTION valid_tfn(TEXT)
RETURNS BOOLEAN AS $$

WITH weights AS (
  SELECT
    row_number() OVER (), weight
  FROM unnest('{1,4,3,7,5,8,6,9,10}'::integer[]) weight
),
digits AS (
  SELECT
    row_number() OVER (),
    digit
  FROM (
    SELECT
      unnest(digit)::integer as digit
    FROM regexp_matches($1, '^(\d)(\d)(\d)[- ]?(\d)(\d)(\d)[- ]?(\d)(\d)(\d)$') AS digit
  ) digits
)

SELECT
  COALESCE(sum(weight * digit) % 11 = 0, FALSE)
FROM weights INNER JOIN digits USING (row_number);

$$ LANGUAGE SQL IMMUTABLE;

Once you have this (which will incidentally limit to 9 digits, and optional spacer items of - or ` `), you may create a DOMAIN, that validates values:

CREATE DOMAIN tax_file_number AS TEXT
CONSTRAINT valid_tfn CHECK (valid_tfn(VALUE));

Now, we can test it:

# SELECT valid_tfn('123-456-789') ; --> FALSE
# SELECT valid_tfn('123 456 782') ; --> TRUE

However, we might want to convert our data into a canonical “format”: in this case, always store it as XXX-XXX-XXX. We can write a function that does this:

CREATE OR REPLACE FUNCTION normalise_tfn(text)
RETURNS tax_file_number AS $$

SELECT string_agg(block, '-'::text)::tax_file_number
FROM (
  SELECT unnest(value) AS block
  FROM regexp_matches($1, '(\d\d\d)', 'g') value
) value;

$$ LANGUAGE SQL IMMUTABLE;

But how can we make our data that we insert always stored in this format?

Postgres triggers to the rescue. We can’t do it as part of a Domain (although we probably could do it as part of a scalar type, but that’s a whole other kettle of fish).

Let’s set up a table to store our tax declarations in:

CREATE TABLE tax_declaration
(
  tax_declaration_id SERIAL PRIMARY KEY,
  tfn tax_file_number,
  lodgement_date date
);

And now create a trigger function and trigger:

CREATE OR REPLACE FUNCTION rewrite_tfn() RETURNS TRIGGER AS $$

BEGIN
  NEW.tfn = normalise_tfn(NEW.tfn);
  RETURN NEW;
END;

$$ LANGUAGE plpgsql;

CREATE TRIGGER rewrite_tfn BEFORE INSERT OR UPDATE ON tax_declaration
FOR EACH ROW EXECUTE PROCEDURE rewrite_tfn();

Now, we can insert some data:

INSERT INTO tax_declaration (tfn, lodgement_date)
VALUES ('123456782', now()::date);

And now we can query it:

SELECT * FROM tax_declaration;
 tax_declaration_id |     tfn     | lodgement_date
--------------------+-------------+----------------
                  1 | 123-456-782 | 2015-10-13
(1 row)

Django second AutoField

Sometimes, your ORM just seems to be out to get you.

For instance, I’ve been investigating a technique for the most important data structure in a system to be essentially immuatable.

That is, instead of updating an existing instance of the object, we always create a new instance.

This requires a handful of things to be useful (and useful for querying).

  • We probably want to have a self-relation so we can see which object supersedes another. A series of objects that supersede one another is called a lifecycle.
  • We want to have a timestamp on each object, so we can view a snapshot at a given time: that is, which phase of the lifecycle was active at that point.
  • We should have a column that unique per-lifecycle: this makes for querying all objects of a lifecycle much simpler (although we can use a recursive query for that).
  • There must be a facility to prevent multiple heads on a lifecycle: that is, at most one phase of a lifecycle may be non-superseded.
  • The lifecycle phases needn’t be in the same order, or really have any differentiating features (like status). In practice they may, but for the purposes of this, they are just “what it was like at that time”.

I’m not sure these ideas will ever get into a released product, but the work behind them was fun (and all my private work).

The basic model structure might look something like:

class Phase(models.Model):
    phase_id = models.AutoField(primary_key=True)
    lifecycle_id = models.AutoField(primary_key=False, editable=False)

    superseded_by = models.OneToOneField('self',
        related_name='supersedes',
        null=True, blank=True, editable=False
    )
    timestamp = models.DateTimeField(auto_now_add=True)

    # Any other fields you might want...

    objects = PhaseQuerySet.as_manager()

So, that looks nice and simple.

Our second AutoField will have a sequence generated for it, and the database will give us a unique value from a sequence when we try to create a row in the database without providing this column in the query.

However, there is one problem: Django will not let us have a second AutoField in a model. And, even if it did, there would still be some problems. For instance, every time we attempt to create a new instance, every AutoField is not sent to the database. Which breaks our ability to keep the lifecycle_id between phases.

So, we will need a custom field. Luckily, all we really need is the SERIAL database type: that creates the sequence for us automatically.

class SerialField(object):
    def db_type(self, connection):
        return 'serial'

So now, using that field type instead, we can write a bit more of our model:

class Phase(models.Model):
    phase_id = models.AutoField(primary_key=True)
    lifecycle_id = SerialField(editable=False)
    superseded_by = models.OneToOneField('self', ...)
    timestamp = models.DateTimeField(auto_now_add=True)

    def save(self, **kwargs):
        self.pk = None
        super(Phase, self).save(**kwargs)

This now ensures each time we save our object, a new instance is created. The lifecycle_id will stay the same.

Still not totally done though. We currently aren’t handling a newly created lifecycle (which should be handled by the associated postgres sequence), nor are we marking the previous instance as superseded.

It’s possible, using some black magic, to get the default value for a database column, and, in this case, execute a query with that default to get the next value. However, that’s pretty horrid: not to mention it also runs an extra two queries.

Similarly, we want to get the phase_id of the newly created instance, and set that as the superseded_by of the old instance. This would require yet another query, after the INSERT, but also has the sinister side-effect of making us unable to apply the not-superseded-by-per-lifecycle requirement.

As an aside, we can investigate storing the self-relation on the other end - this would enable us to just do:

    def save(self, **kwargs):
        self.supersedes = self.pk
        self.pk = None
        super(Phase, self).save(**kwargs)

However, this turns out to be less useful when querying: we are much more likely to be interested in phases that are not superseded, as they are the “current” phase of each lifecycle. Although we could query, it would be running sub-queries for each row.

Our two issues: setting the lifecycle, and storing the superseding data, can be done with one Postgres BEFORE UPDATE trigger function:

CREATE FUNCTION lifecycle_and_supersedes()
RETURNS TRIGGER AS $$

  BEGIN
    IF NEW.lifecycle_id IS NULL THEN
      NEW.lifecycle_id = nextval('phase_lifecycle_id_seq'::regclass);
    ELSE
      NEW.phase_id = nextval('phase_phase_id_seq'::regclass);
      UPDATE app_phase
        SET superseded_by_id = NEW.phase_id
        WHERE group_id = NEW.group_id
        AND superseded_by_id IS NULL;
    END IF;
  END;

$$ LANGUAGE plpgsql VOLATILE;

CREATE TRIGGER lifecycle_and_supersedes
  BEFORE INSERT ON app_phase
  FOR EACH ROW
  EXECUTE PROCEDURE lifecycle_and_supersedes();

So, now all we need to do is prevent multiple-headed lifecycles. We can do this using a UNIQUE INDEX:

CREATE UNIQUE INDEX prevent_hydra_lifecycles
ON app_phase (lifecycle_id)
WHERE superseded_by_id IS NULL;

Wow, that was simple.

So, we have most of the db-level code written. How do we use our model? We can write some nice queryset methods to make getting the various bits easier:

class PhaseQuerySet(models.query.QuerySet):
    def current(self):
        return self.filter(superseded_by=None)

    def superseded(self):
        return self.exclude(superseded_by=None)

    def initial(self):
        return self.filter(supersedes=None)

    def snapshot_at(self, timestamp):
        return filter(timestamp__lte=timestamp).order_by('lifecycle_id', '-timestamp').distinct('lifecycle_id')

The queries generated by the ORM for these should be pretty good: we could look at sticking an index on the lifecycle_id column.

There is one more thing to say on the lifecycle: we can add a model method to fetch the complete lifecycle for a given phase, too:

    def lifecycle(self):
        return self.model.objects.filter(lifecycle_id=self.lifecycle_id)

(That was why I used the lifecycle_id as the column).


Whilst building this prototype, I came across a couple of things that were also interesting. The first was a mechanism to get the default value for a column:

def database_default(table, column):
    cursor = connection.cursor()
    QUERY = """SELECT d.adsrc AS default_value
               FROM   pg_catalog.pg_attribute a
               LEFT   JOIN pg_catalog.pg_attrdef d ON (a.attrelid, a.attnum)
                                                   = (d.adrelid,  d.adnum)
               WHERE  NOT a.attisdropped   -- no dropped (dead) columns
               AND    a.attnum > 0         -- no system columns
               AND    a.attrelid = %s::regclass
               AND    a.attname = %s"""
    cursor.execute(QUERY, [table, column])
    cursor.execute('SELECT {}'.format(*cursor.fetchone()))
    return cursor.fetchone()[0]

You can probably see why I didn’t want to use this. Other than the aforementioned two extra queries, it’s executing a query with data that comes back from the database. It may be possible to inject a default value into a table that causes it to do Very Bad Things™. We could sanitise it, perhaps ensure it matches a regular expression:

NEXTVAL = re.compile(r"^nextval\('(?P<sequence>[a-zA-Z_0-9]+)'::regclass\)$")

However, the trigger-based approach is nicer in every way.

The other thing I discovered, and this one is really nice, is a way to create an exclusion constraint that only applies if a column is NULL. For instance, ensure that no two classes for a given student overlap, but only if they are not superseded (or deleted).

ALTER TABLE "student_enrolments"
ADD CONSTRAINT "prevent_overlaps"
EXCLUDE USING gist(period WITH &&, student_id WITH =)
WHERE (
  superseded_by_id IS NULL
  AND
  status <> 'deleted'
);

Django Proxy Model Relations

I’ve got lots of code I’d do a different way if I were to start over, but often, we have to live with what we have.

One situation I would seriously reconsider is the structure I use for storing data related to how I interact with external systems. I have an Application object, and I create instances of this for each external system I interact with. Each new Application gets a UUID, and is created as part of a migration. Code in the system uses this UUID to determine if something is for that system.

But that’s not the worst of it.

I also have an AppConfig object, and other related objects that store a relation to an Application. This was fine initially, but as my code got more complex, I hit upon the idea of using Django’s Proxy models, and using the related Application to determine the subclass. So, I have AppConfig subclasses for a range of systems. This is nice: we can even ensure that we only get the right instances (using a lookup to the application to get the discriminator, which I’d probably do a different way next time).

However, we also have other bits of information that we need to store, that has a relation to this AppConfig object.

And here is where we run into problems. Eventually, I had the need to subclass these other objects, and deal with them. That gives a similar benefit to above for fetching filtered lists of objects, however when we try to follow the relations between these, something annoying happens.

Instead of getting the subclass of AppConfig, that we probably want to use because the business logic hangs off that, we instead get the actual AppConfig instances. So, in order to get the subclass, we have to fetch the object again, or swizzle the __class__. And, going back the other way would have the same problem.

Python is a dynamic language, so we should be able to do better.

In theory, all we have to do is replace the attributes on the two classes with ones that will do what we want them to do. In practice, we need to muck around a little bit more to make sure it all works out right.

It would be nice to be able to decorate the declaration of the overridden field, but that’s not valid python syntax:

>>> class Foo(object):
...   @override
...   bar = object()

So, we’ll have to do one of two things: alter the class after it has been defined, or leverage the metaclass magic Django already does.

class Foo(models.Model):
    bar = models.ForeignKey('bar.Bar')


class FooProxy(models.Model):
    bar = ProxyForeignKey('bar.BarProxy')  # Note the proxy class

    class Meta:
      proxy = True

However, we can’t just use the contribute_to_class(cls, name) method as-is, as the Proxy model attributes get dealt with before the parent model. So, we need to register a signal, and get the framework to run our code after the class has been prepared:

class ProxyField(object):
    def __init__(self, field):
        self.field = field

    def contribute_to_class(self, model, name):
        @receiver(models.signals.class_prepared, sender=model, weak=False)
        def late_bind(sender, *args, **kwargs):
          override_model_field(model, name, self.field)


class ProxyForeignKey(ProxyField):
    def __init__(self, *args, **kwargs):
        super(ProxyForeignKey, self).__init__(ForeignKey(*args, **kwargs))

Then, it’s a matter of working out what needs to happen to override_model_field.

It turns out: not much. Until we start thinking about edge cases, anyway:

def override_model_field(model, name, field):
    original_field = model._meta.get_field(name)

    model.add_to_class(name, field)
    if field.rel:
      field.rel.to.add_to_class(
        field.related_name,
        ForeignRelatedObjectsDescriptor(field.related)
      )

There is a little more to it than that:

  • We need to use the passed-in related_name if one was provided in the new field definition, else we want to use what the original field’s related name was. However, if not explicitly set, then neither field will actually have a related_name attribute.
  • We cannot allow an override of a foreign key to a non-proxy model: that would hijack the original model’s related queryset.
  • Similarly, we can only allow a single proxy to override-relate to another proxy: any subsequent override-relations would likewise hijack the related object queryset.
  • For non-related fields, we can only allow an override if the field is compatible. What that means I’m not completely sure just yet, but for now, we will only allow the same field class (or a subclass). Things that would require a db change would be verboten.

So, we need to guard our override_model_field somewhat:

def override_model_field(model, name, field):
    original_field = model._meta.get_field(name)

    if not isinstance(field, original_field.__class__):
        raise TypeError('...')

    # Must do these checks before the `add_to_class`, otherwise it breaks tests.
    if field.rel:
        if not field.rel.to._meta.proxy:
            raise TypeError('...')

        related_name = getattr(field, 'related_name', original_field.related.get_accessor_name())
        related_model = getattr(field.rel.to, related_name).related.model

        # Do we already have an overridden relation to this model?
        if related_model._meta.proxy:
            raise TypeError('...')

    model.add_to_class(name, field)

    if field.rel:
        field.rel.to.add_to_class(
          related_name,
          ForeignRelatedObjectsDescriptor(field.related)
        )

There is an installable app that includes tests: django-proxy-overrides.