Why CustomUser subclasses are not such a good idea

Background

The system I work on has People who may or may not be Users, and very infrequently Users who may not be a Person. In fact, an extension to the system has meant that there will be more of these: a User who needs to be able to generate reports (say, a Franchisor who needs to only be able to access aggregate data from franchises, that might belong to multiple companies) who is never rostered on for shifts, which is what the Person class is all about.

Anyway, the long and the short of this was that I thought it might be a good idea to look at sub-classing User for ManagementUser.

I guess I should have listened to those smarter than me who shouted that sub-classing User is not cool. Although they never gave any concrete reasons, but now I have one.

You cannot easily convert a superclass object to a specialised sub-class. Once a user is a User, it’s hard to make them into a ManagementUser.

It can be done: the following code will take a User (or any parent class) object, a User (or whatever) subclass, and any other keyword arguments that should be passed into the constructor. It saves the newly upgraded object, and returns it.

1 def create_subclass(SubClass, old_instance, **kwargs):
2     new_instance = SubClass()
3     for field in old_instance._meta.local_fields:
4         setattr(new_instance, field.name, getattr(old_instance, field.name))
5     new_instance.save()
6     return new_instance()

However, it really should check that there isn’t an existing instance, and maybe some other checks.

What advantages does sub-classing have?

The biggest advantage, or so I thought, was to have it so you can automatically downcast your models on user login, and then get access to the extended user details. For instance, if your authentication backend automatically converts User to Person, then you can get access to the Person’s attributes (like the company they work for, their shifts, etc) without an extra level of attribute access:

1 # request.user is always an auth.User instance:
2 request.user.person.company
3 # request.user might be a person, etc.
4 request.user.company

But it turns out that even this is bad. Now, in guard decorators on view functions, you cannot just test the value of an attribute, as not all users will have that attribute. Instead, you need to test to see if the attribute exists, and then test the attribute itself.

So, what do you do instead?

The preferred method in django for extending User is to use a UserProfile class. This is just a model that has a OneToOneField linked back to User. I would look at doing a very small amount of duck-punching just to make getting a hold of the profile class:

 1 import logging
 2 from django.contrib.auth.models import User
 3 from django.db import models
 4 
 5 class Person(models.Model):
 6     user = models.OneToOneField(User, related_name="_person")
 7     date_of_birth = models.DateField(null=True, blank=True)
 8 
 9 def get_person(user):
10     try:
11         return user._person
12     except Person.DoesNotExist:
13         pass
14 
15 def set_person(user, person):
16     user._person = person
17 
18 if hasattr(User, 'person'):
19     logging.error('Model User already has an attribute "person".')
20 else:
21     User.person = property(get_person, set_person)

By having the person’s related name attribute as _person, we can wrap read access to it in an exception handler, and then use a view decorator like:

1 @user_passes_test(lambda u:u.person)
2 def person_only_view(request, **kwargs):
3     pass

We know this view will only be available to logged in users who have a related Person object.

I will point out that I am duck-punching/monkey-patching here. However, I feel that this particular method of doing it is relatively safe. I check before adding the property, and in reality I probably would raise an exception rather than just log an error.

Postgres and Django

Frank Wiles gave a great talk Secrets of PostgreSQL Performance

Don’t do dumb things

  • Dedicate a single server to your database
  • Only fetch what you need

Do smart things

  • cache everything
  • limit number of queries

Tuning

  • shared_buffers : 25% of available RAM
  • effective_cache_size : OS disk cache size
  • work_mem : in-memory sort size

Less important

  • wal_buffers : set to 16MB
  • checkpoint_segments : at least 10
  • maintenance_work_mem : 50MB for every GB of RAM

Can also transactionally turn on grouping of transactions.

Hardware

  • As much RAM as you can afford - fit whole db if you can.
  • Faster disks.
    • Disk speed is important
    • RAID5 is bad
    • RAID-1+0 is good
    • WAL on own disk → 4x write performance
  • CPU speed - unlikely to be the limiting factor.

Other

  • use pg_bouncer to pool connections
  • use tablespaces to move tables/indexes onto other disks
    • ie, indexes on fastest disk
    • stuff that might run in background and hit only specific tables that are not used by other bits

Keyset Pagination in Django

Pagination is great. Nothing worse than having an HTML page that renders 25000 rows in a table.

Django Pagination is also great. It makes it super easy to declare that a view (that inherits from MultipleObjectMixin) should paginate its results:

class List(ListView):
    queryset = Foo.objects.order_by('bar', '-baz')
    paginate_by = 10
    template_name = 'foo.html'

Django pagination uses the LIMIT/OFFSET method. This is fine for smaller offsets, but once you start getting beyond a few pages, it can perform really badly. This is because the database needs to fetch all of the previous rows, even though it discards them.

Using Keyset Pagination allows for better performing “next page” fetches, at the cost of not being able to randomly fetch a page. That is, if you know the last element from page N-1, then you may fetch page N, but otherwise you really can’t.

Keyset Pagination, sometimes called the Seek Method, has been documented by Markus Winand and Joe Nelson. If you are not familiar with the concept, I strongly suggest you read the articles above.

Django’s pagination is somewhat pluggable: you may switch out the paginator in a Django ListView, for instance, allowing you to do things like switch the Page class, or how the various parts are computed. I used it recently to allow for a different query to be used when calculating the total number of objects in a queryset, to vastly improve performance of a particular paginated queryset.

However, there are limits. Both the view and the paginator expect (nay, demand) an integer page number, which, as we shall see shortly, will not work in this case. I also feel like the view is over-reaching it’s remit by casting the page number to an integer, as I’ll discuss below.

In order to get consistent results with any type of pagination, you must ensure that the ordering on the queryset is stable: that is, there are no rows that will be ‘tied’. Doing otherwise will mean that the database will “break the tie”, and not always in the same order. I’ve seen a bug that was extremely hard to track down that was caused by exactly this problem (and that was just with OFFSET pagination).

Often, to ensure stable ordering, the primary key is used as the “last” sort column. This is perfectly valid, but is not always necessary.

Because in many cases we will need to sort by multiple columns, we’ll need some mechanism for passing through to the paginator the “last value” in a given page for each of these columns. For instance, if we are ordering by timestamp and then group, we would need to pass through both the timestamp and the group of the last object. Because I like to use GET forms to allow me to paginate filtered results, I’ll want to have all of the values combined into one query parameter. If you were constructing links instead, you could look at having these as different parameters. However, you’d still need to be careful, because you aren’t filtering all results on these parameters. Having them serialised into a single parameter (using JSON) means that they are all in the one place, and you can just use that for the filtering to get the page results.

I’ve built a working implementation of keyset pagination, at least for forwards traversal, at django-keyset-pagination.

We can see from this that there really is not that much that we needed to do. We use a different Page object, which enables us to change what the next_page_number will generate. When I figure out how, it will also allow us to work out the previous_page_number

Likewise, we needed to change how we validate a page number, and how we fetch results for a page. That method, _get_page(number) is the one that does most of the work.

Ultimately, we wind up with a filter that looks like:

  WHERE (A < ?) OR (A = ? AND B > ?) OR (A = ? AND B = ? AND C < ?)

The direction of the test (< vs >) depends upon the sorting of that column, but hopefully you get the idea.

In order to enable the query planner to be able to use an index effectively (if one exists), we need to adjust this to (thanks Markus):

WHERE A <= ? AND (
  (A < ?) OR (A = ? AND B > ?) ...
)

It’s also possible, in Postgres at least, to use a ROW() constructor comparison to order rows. However, this only works if the direction of each column ordering is the same, which in my use case it was not. I have a proof of concept of using ROW() constructors, but I need to figure out how to detect if they are available to the database in use.

In order to use the new paginator, we need to work around some issues in the Django class based views: namely that they force an integer value (or use the special string last, neither of which are acceptable in this case):

class PaginateMixin(object):
    "Make pagination work for non integer page numbers"
    def paginate_queryset(self, queryset, page_size):
        # This is very similar to how django currently (2.1) does it: I may submit a PR to use this
        # mechanism instead, as it is more flexible.
        paginator = self.get_paginator(
            queryset, page_size, orphans=self.get_paginate_orphans(),
            allow_empty_first_page=self.get_allow_empty()
        )
        page_kwarg = self.page_kwarg
        page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1

        try:
            page_number = paginator.validate_number(page_number)
        except ValueError:
            raise Http404(_('Page could not be parsed.'))

        try:
            page = paginator.page(page_number)
            return (paginator, page, page.object_list, page.has_other_pages())
        except InvalidPage as e:
            raise Http404(
                _('Invalid page (%(page_number)s): %(message)s') % {
                    'page_number': page_number,
                    'message': str(e)
                }
            )

There’s really only one change: instead of just casting the page number to an integer, we let the paginator handle that.

Okay, once all that is done, we can use our paginator:

class List(PaginateMixin, ListView):
    paginator_class = KeysetPaginator
    paginate_by = 10

    def get_queryset(self):
        return Foo.objects.order_by('timestamp', 'bar', 'baz')

We’ll need to change our template rendering to only render a next page link or button, rather than trying to render them for each page. We also don’t have any way to return to the previous page: I’m still working through a mechanism for that.


This post was originally written using the ROW() constructor, and this part of the post discussed the shortcomings. Now that has been resolved, the main shortcoming is that it is not yet possible to traverse to the previous page of results. In many cases that may not be necessary (we could use a browser’s back button, or rely on the fact that if it’s infinite scrolling the data is already in the document), however I would like to investigate how hard it is to actually get the previous page.

Set-returning and row-accepting functions in Django and Postgres

Postgres set-returning functions are an awesome thing. With them, you can do fun things like unnesting and array, and will end up with a new row for each item in the array. For example:

class Post(models.Model):
    author = models.ForeignKey(AUTH_USER_MODEL, related_name='posts')
    tags = ArrayField(base_field=TextField(), null=True, blank=True)
    created_at = models.DateTimeField()
    content = models.TextField()

The equivalent SQL might be something like:

CREATE TABLE blog_post (
  id SERIAL NOT NULL PRIMARY KEY,
  author_id INTEGER NOT NULL REFERENCES auth_user (id),
  tags TEXT[],
  created_at TIMESTAMPTZ NOT NULL,
  content TEXT NOT NULL
);

We can “explode” the table so that we have one tag per row:

SELECT author_id, UNNEST(tags) AS tag, created_at, content
FROM blog_post;

To do the same sort of thing in Django, we can use a Func:

from django.db.models import F, Func

Post.objects.annotate(tag=Func(F('tags'), function='UNNEST'))

In practice, just like in the Django docs, I’ll create a convenience function:

class Unnest(Func):
    function = 'UNNEST'

    @property
    def output_field(self):
        output_fields = [x.output_field for x in self.get_source_expressions()]
        if len(output_fields) == 1:
          return output_fields[0].base_field

        return super(Unnest, self).output_field

The opposite of this is aggregation: in the case of UNNEST, it’s almost ARRAY_AGG, although because of the handling of nested arrays, this doesn’t quite round-trip. We already know how to do aggregation in Django, so I will not discuss that here.

Hovewer, there is another related operation: what if you want to turn a row into something else. In my case, this was turning a row from a result into a JSON object.

SELECT id,
       to_jsonb(myapp_mymodel) - 'id' AS "json"
  FROM myapp_mymodel

This will get all of the columns except ‘id’, and put them into a new column called “json”.

But how do we get Django to output SQL that will enable us to use a Model as the argument to a function? Ultimately, we want to get to the following:

class ToJSONB(Func):
    function = 'TO_JSONB'
    output_field = JSONField()


MyModel.objects.annotate(
  json=ToJSONB(MyModel) - Value('id')
).values('id', 'json')

Our first attempt could be to use RawSQL. However, this has a couple of problems. The first is that we are writing lots of raw SQL, the second is that it won’t work so well if the table is aliased by the ORM. That is, if you use this in a join or subquery, where Django automatically assigns an alias to this table, then referring directly to the table name will not work.

MyModel.objects.annotate(json=Raw("to_jsonb(myapp_mymodel) - 'id'", [], output_field=JSONField()))

Instead, we need to dynamically find out what the current alias for the model is in this query, and use that. We’ll also want to figure out how to “subtract” the id key from the JSON object.

class Table(django.db.models.Expression):
    def __init__(self, model, *args, **kwargs):
        self.model = model
        self.query = None
        super(Table, self).__init__(*args, **kwargs)

    def resolve_expression(self, query, *args, **kwargs):
        clone = super(Table, self).resolve_expression(query, *args, **kwargs)
        clone.query = query
        return clone

    def as_sql(self, compiler, connection, **kwargs):
        if not self.query:
            raise ValueError('Unresolved Table expression')
        alias = self.query.table_map.get(self.model._meta.db_table, [self.model._meta.db_table])[0]
        return compiler.quote_name_unless_alias(alias), []

Okay, there’s a fair bit going on there. Let’s look through it. We’ll start with how we use it:

MyModel.objects.annotate(json=ToJSONB(Table(MyModel)))

We create a Table instance, which stores a reference to the model. Technically, all we need later on is the database table name that will be used, but we’ll keep the model for now.

When the ORM “resolves” the queryset, we grab the query object, and store a reference to that.

When the ORM asks us to generate some SQL, we look at the query object we have a reference to, and see if our model’s table name has an entry in the table_map dict: if so, we get the first entry from that, otherwise we just use the table name.

Okay, what about being able to remove the entry in the JSONB object for ‘id’?

We can’t just use the subtraction operator, because Postgres will try to convert the RHS value into JSONB first, and fail. So, we need to ensure it renders it as TEXT. We also need to wrap it in an ExpressionWrapper, so we can indicate what the output field type will be:

id_value = models.Func(models.Value('id'), template='%(expressions)s::TEXT')
MyModel.objects.annotate(
    json=ExpressionWrapper(
        ToJSONB(Table(MyModel)) - id_value, output_field=JSONField()
    )
)

I also often use a convenience Cast function, that automatically does this based on the supplied output_field, but this is a little easier to use here. Note there is a possible use for ToJSONB in a different context, where it doesn’t take a table, but some other primitive.


There’s one more way we can use this construct: the geo_unique_indexer function from a previous post needs a table name, but also the name of a field to omit from the index. So, we can wrap this up nicely:

class GeoMatch(models.Func):
    function = 'geo_unique_indexer'
    output_field = JSONField()

    def __init__(self, model, *args, **kwargs):
        table = Table(model)
        pk = models.Value(model._meta.pk.db_column or model._meta.pk.name)
        return super(GeoMatch, self).__init__(table, pk, *args, **kwargs)

This is really tidy: it takes the model class (or maybe an instance, I didn’t try), and builds a Table, and gets the primary key. These are just used as the arguments for the function, and then it all works.

Django ComputedField()

A very common pattern, at least in code that I’ve written (and read) is to annotate on a field that uses an expression that is based on one or more other fields. This could then be used to filter the objects, or just in some other way.

The usual method of doing this is:

from django.db import models
from django.db.models.expressions import F, Value
from django.db.models.function import Concat


class PersonQuerySet(models.query.QuerySet):
    def with_name(self):
        return self.annotate(
            name=Concat(F('first_name'), Value(' '), F('last_name'), output_field=models.TextField()),
        )


class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()

    objects = PersonQuerySet.as_manager()

Yes, I’m aware of falsehoods programmers believe about names, but this is an easy-to-follow example.

In order to be able to access the name field, we must use the with_name() queryset method. This is usually okay, but if it is something that we almost always want, it can be a little tiresome. Alternatively, you could override the get_queryset() method of a custom manager, but that makes it somewhat surprising to a reader of the code. There are also some places where a custom manager will not automatically be used, or where it will be cumbersome to include the fields from a custom manager (select_related, for instance).

It would be much nicer if we could write the field declaratively, and have it use the normal django mechanism of defer and only to remove it from the query if required.

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()
    name = ComputedField(Concat(F('first_name'), Value(' '), F('last_name'), output_field=models.TextField()))

I’ve spent some time digging around in the django source code, and have a fairly reasonable understanding of how fields work, and how queries are built up. But I did wonder how close to a working proof of concept of this type of field we could get without having to change any of the django source code. After all, I was able to backport the entire Subquery expression stuff to older versions of django after writing that. It would be nice to repeat the process here.

There are a few things you need to do to get this to work:

  • store the expression
  • prevent the field from creating a migration
  • ensure the field knows how to interpret data from the database
  • ensure the field adds the expression to it’s serialised version
  • prevent the field from writing data back to the database
  • inject the expression into the query instead of the field name
class ComputedField(models.Field):
    def __init__(self, expression, *args, **kwargs):
        self.expression = expression.copy()
        kwargs.update(editable=False)
        super().__init__(*args, **kwargs)

There is already a mechanism for a field to prevent a migration operation from being generated: it can return a db_type of None.

    def db_type(self, connection):
        return None

We can delegate the responsibility of interpreting the data from the database to the output field of the expression - that’s how it works in the normal operation of expressions.

    def from_db_value(self, value, expression, connection):
        return self.expression.output_field.from_db_value(value, expression, connection)

Storing the expression in the serialised version of a field is explained in the documentation on custom fields:

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        return name, path, [self.expression] + args, kwargs

To prevent the field from being included in the data we write back to the database turned out to be fairly tricky. There are a couple of mechanisms that can be used, but ultimately the only one that worked in the way I needed was something that is used by the inheritance mechanism. We have to indicate that it is a “private” field. I’m not 100% sure of what the other implications of this might be, but the outcome of making this field private is that it no longer appears in the list of local fields. There is one drawback to this, which I’ll discuss below.

    def contribute_to_class(self, cls, name, private_only=False):
        return super().contribute_to_class(cls, name, True)

So, we only have one task to complete. How do we inject the expression into the query instead of the column?

When django evaluates a queryset, it look at the annotations, and the expressions that are in these. It will then “resolve” these expressions (which means the expression gets told which “query” is being used to evaluate it, allowing it to do whatever it needs to do to make things work).

When a regular field is encountered, it is not resolved: instead it is turned into a Col. This happens in a few different places, but the problem is that a Col should not need to know which query it belongs to: at most it needs to know what the aliased table name is. So, we don’t have a query object we can pass to the resolve_expression method of our expression.

Instead, we’ll need to use Python’s introspection to look up the stack until we find a place that has a reference to this query.

    def get_col(self, alias, output_field=None):
        import inspect

        query = None

        for frame in inspect.stack():
            if frame.function in ['get_default_columns', 'get_order_by']:
                query = frame.frame.f_locals['self'].query
                break
            if frame.function in ['add_fields', 'build_filter']:
                query = frame.frame.f_locals['self']
                break
        else:
            # Aaargh! We don't handle this one yet!
            import pdb; pdb.set_trace()

        col = self.expression.resolve_expression(query=query)
        col.target = self
        return col

So, how does this code actually work? We go through each frame in the stack, and look for a function (or method, but they are really just functions in python) that matches one of the types we know about that have a reference to the query. Then, we grab that, stop iterating and resolve our expression. We have to set the “target” of our resolved expression to the original field, which is how the Col interface works.

This moves the resolve_expression into the get_col, which is where it needs to be. The (resolved) expression is used as the faked column, and it knows how to generate it’s own SQL, which will be put into the query in the correct location.

And this works, almost.

There is one more situation that needs to be taken into account: when we are referencing the field through a join (the x__y lookup syntax you often see in django filters).

Because F() expressions reference the local query, we need to first turn any of these that we find in our computed field’s expression (at any level) into a Col that refers to the correct model. We need to do this before the resolve_expression takes place.

    def get_col(self, alias, output_field=None):
        query = None

        for frame in inspect.stack():
            if frame.function in ['get_default_columns', 'get_order_by']:
                query = frame.frame.f_locals['self'].query
                break
            if frame.function in ['add_fields', 'build_filter']:
                query = frame.frame.f_locals['self']
                break
        else:
            # Aaargh! We don't handle this one yet!
            import pdb; pdb.set_trace()

        def resolve_f(expression):
            if hasattr(expression, 'get_source_expressions'):
                expression = expression.copy()
                expression.set_source_expressions([
                  resolve_f(expr) for expr in expression.get_source_expressions()
                ])
            if isinstance(expression, models.F):
                field = self.model._meta.get_field(expression.name)
                if hasattr(field, 'expression'):
                    return resolve_f(field.expression)
                return Col(alias, field)
            return expression

        col = resolve_f(self.expression).resolve_expression(query=query)
        col.target = self
        return col

There is a repo containing this, which has a bunch of tests showing how the different query types can use the computed field:

https://github.com/schinckel/django-computed-field


But wait, there is one more thing…

A very common requirement, especially if you are planning on using this column for filtering, would be to stick an index on there.

Unfortunately, that’s not currently possible: the mechanism for preventing the field name from being in the write queries, making it a private field, prevents using this field in an index. Anyway, function/expression indexes are not currently supported in Django.

It’s not all bad news though: Markus has a Pull Request that will enable this feauture; from there we could (if db_index is set) automatically add an expression index to Model._meta.indexes in contribute_to_class, but it would also be great to be able to use index_together.

I suspect to get that, though, we’ll need another mechansim to prevent it being in the write queries, but still be a local field.

(Thanks to FunkyBob for suggestions, including suggesting the field at all).

Extracting values from environment variables in tox

Tox is a great tool for automated testing. We use it, not only to run matrix testing, but to run different types of tests in different environments, enabling us to parallelise our test runs, and get better reporting about what types of tests failed.

Recently, we started using Robot Framework for some automated UI testing. This needs to run a django server, and almost certainly wants to run against a different database. This will require our tox -e robot to drop the database if it exists, and then create it.

Because we use dj-database-url to provide our database settings, our Codeship configuration contains an environment variable set to DATABASE_URL. This contains the host, port and database name, as well as the username/password if applicable. However, we don’t have the database name (or port) directly available in their own environment variables.

Instead, I wanted to extract these out of the postgres://user:password@host:port/dbname string.

My tox environment also needed to ensure that a distinct database was used for robot:

[testenv:robot]
setenv=
  CELERY_ALWAYS_EAGER=True
  DATABASE_URL={env:DATABASE_URL}_robot
  PORT=55002
  BROWSER=headlesschrome
whitelist_externals=
  /bin/sh
commands=
  sh -c 'dropdb --if-exists $(echo {env:DATABASE_URL} | cut -d "/" -f 4)'
  sh -c 'createdb $(echo {env:DATABASE_URL} | cut -d "/" -f 4)'
  coverage run --parallel-mode --branch manage.py robot --runserver={env:PORT}

And this was working great. I’m also using the $PG_USER environment variable, which is supplied by Codeship, but that just clutters things up.

However, when merged to our main repo, which has it’s own codeship environment, tests were failing. It would complain about the database not being present when attempting to run the robot tests.

It seems that we were using a different version of postgres, and thus were using a different port.

So, how can we extract the port from the $DATABASE_URL?

commands=
  sh -c 'dropdb --if-exists \
                -p $(echo {env:DATABASE_URL} | cut -d "/" -f 3 | cut -d ":" -f 3) \
                $(echo {env:DATABASE_URL} | cut -d "/" -f 4)'

Which is all well and good, until you have a $DATABASE_URL that omits the port…

dropdb: error: missing required argument database name

Ah, that would mean the command being executed was:

$ dropdb --if-exists -p  <database-name>

Eventually, I came up with the following:

sh -c 'export PG_PORT=$(echo {env:DATABASE_URL} | cut -d "/" -f 3 | cut -d ":" -f 3); \
              dropdb --if-exists \
                     -p $\{PG_PORT:-5432} \
                     $(echo {env:DATABASE_URL} | cut -d "/" -f 4)'

Whew, that is a mouthful!

We store the extracted value in a variable PG_PORT, and then use bash variable substitution (rather than tox variable substitution) to put it in, with a default value. But because of tox variable substitution, we need to escape the curly brace to allow it to be passed through to bash: $\{PG_PORT:-5432}. Also note that you’ll need a space after this before a line continuation, because bash seems to strip leading spaces from the continued line.

Postgres VIEW from Django QuerySet

It’s already possible, given an existing Postgres (or other database) VIEW, to stick a Django Model in front of it, and have it fetch data from that instead of a table.

Creating the views can currently be done using raw SQL (and a RunSQL migration operation), or using some helpers to store the SQL in files for easy versioning.

It would be excellent if it was possible to use Django’s ORM to actually generate the VIEW, and even better if you could make the migration autodetector generate migrations.

But why would this be necessary? Surely, if you were able to create a QuerySet instance that contains the items in your view, that should be good enough?

Not quite, because currently using the ORM it is not possible to perform the following type of query:

SELECT foo.a,
       foo.b,
       bar.d
  FROM foo
  INNER JOIN (
    SELECT baz.a,
           ARRAY_AGG(baz.c) AS d
      FROM baz
     GROUP BY baz.a) bar ON (foo.a = bar.a)

That is, generating a join to a subquery is not possible in the ORM. In this case, you could probably get away with a correlated Subquery, however that would probably not perform as well as using a join in this case. This is because a subquery in a SELECT is evaluated once for each row, whereas a subquery join will be evaluated once.

So, we could use a VIEW for the subquery component:

CREATE OR REPLACE VIEW bar AS

SELECT baz.a,
       ARRAY_AGG(baz.c) AS d
  FROM baz
 GROUP BY baz.a;

And then stick a model in front of that, and join accordingly:

SELECT foo.a,
       foo.b,
       bar.d
  FROM foo
 INNER JOIN bar ON (foo.a = bar.a)

The Django model for the view would look something like:

class Bar(models.Model):
    a = models.OneToOneField(
        'foo.Foo',
        on_delete=models.DO_NOTHING,
        primary_key=True,
        related_name='bar'
    )
    d = django.contrib.postgres.fields.ArrayField(
        base_field=models.TextField()
    )

    class Meta:
        managed = False

The on_delete=models.DO_NOTHING is important: without it, a delete of a Foo instance would trigger an attempted delete of a Bar instance - which would cause a database error, because it’s coming from a VIEW instead of a TABLE.

Then, we’d be able to use:

queryset = Foo.objects.select_related('bar')

So, that’s the logic behind needing to be able to do a subquery, and it becomes even more compelling if you need that subquery/view to filter the objects, or perform some other expression/operation. So, how can we make Django emit code that will enable us to handle that?

There are two problems:

  • Turn a queryset into a VIEW.
  • Get the migration autodetector to trigger VIEW creation.

The other day I came across Create Table As Select in Django, and it made me realise that we can use basically the same logic for creating a view. So, we can create a migration operation that will perform this for us:

class CreateOrReplaceView(Operation):
    def __init__(self, view_name, queryset):
        self.view_name = view_name
        self.queryset = queryset

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        queryset = self.queryset
        compiler = queryset.query.get_compiler(using=schema_editor.connection.alias)
        sql, params = compiler.as_sql()
        sql = 'CREATE OR REPLACE VIEW {view} AS {sql}'.format(
            view=schema_editor.connection.ops.quote_name(self.view_name),
            sql=sql
        )
        schema_editor.execute(sql, params)

    def state_forwards(self, app_label, state):
        pass

We can then have this operation (which needs to be passed a queryset).

This doesn’t really solve how to define the queryset for the view, and have some mechanism for resolving changes made to that queryset (so we can generate a new migration if necessary). It also means we have a queryset written in our migration operation. We won’t be able to leave it like that: due to loading issues, you won’t be able to import model classes during the migration setup - and even if you could, you shouldn’t be accessing them during a migration anyway - you should use models from the ProjectState which is tied to where in the migration graph you currently are.

What would be excellent is if we could write something like:

class Bar(models.Model):
    a = models.OneToOneField(
        'foo.Foo',
        on_delete=models.DO_NOTHING,
        primary_key=True,
        related_name='bar',
    )
    d = django.contrib.postgres.fields.ArrayField(
        base_field=models.TextField()
    )

    class Meta:
        managed = False

    @property
    def view_queryset(self):
        return Baz.objects.values('a').annotate(d=ArrayAgg('c'))

And then, if we change our view definition:

@property
def view_queryset(self):
  return Baz.objects.values('a').filter(
      c__startswith='qux',
  ).annotate(
      d=ArrayAgg('c')
  )

… we would want a migration operation generated that includes the new queryset, or at least be able to know that it has changed. Ideally, we’d want to have a queryset attribute inside the Meta class in our model, which could itself be a property. However, that’s not possible without making changes to django itself.

In the meantime, we can borrow the pattern used by RunPython to have a callable that is passed some parameters during application of the migration, which returns the queryset. We can then have a migration file that looks somewhat like:

def view_queryset(apps, schema_editor):
    Baz = apps.get_model('foo', 'Baz')

    return Baz.objects.values('a').filter(
        c__startswith='qux'
    ).annotate(
        d=ArrayAgg('c')
    )


class Migration(migrations.Migration):
    dependencies = [
        ('foo', '0001_initial'),
    ]

    operations = [
        migrations.CreateModel(
            name='Bar',
            fields=[
                ('a', models.OneToOneField(...)),
                ('d', ArrayField(base_field=models.TextField(), ...)),
            ],
            options={
                'managed': False,
            }
        ),
        CreateOrReplaceView('Bar', view_queryset),
    ]

We still need to have the CreateModel statement so Django knows about our model, but the important bit in this file is the CreateOrReplaceView, which references the callable.

Now for the actual migration operation.

class CreateOrReplaceView(migrations.Operation):
    def __init__(self, model, queryset_factory):
        self.model = model
        self.queryset_factory = queryset_factory

    def database_forwards(self, app_label, schema_editor, from_state, to_state):
        model = from_state.apps.get_model(app_label, self.model)
        queryset = self.queryset_factory(from_state.apps, schema_editor)
        compiler = queryset.query.get_compiler(using=schema_editor.connection.alias)
        sql, params = compiler.as_sql()
        sql = 'CREATE OR REPLACE VIEW {view_name} AS {query}'.format(
            view_name=model._meta.db_table,
            query=sql,
        )
        schema_editor.execute(sql, params)

The backwards migration is not quite a solved problem yet: I do have a working solution that steps up the stack to determine what the current migration name is, and then finds the previous migration that contains one of these operations for this model, but that’s a bit nasty.


There’s no (clean) way to inject ourself into the migration autodetector and “notice” when we need to generate a new version of the view, however we can leverage the checks framework to notify the user when our view queryset is out of date compared to the latest migration.

from django.apps import apps
from django.core.checks import register

@register()
def check_view_definitions(app_configs, **kwargs):
    errors = []

    if app_configs is None:
        app_configs = apps.app_configs.values()

    for app_config in app_configs:
        errors.extend(_check_view_definitions(app_config))

    return errors

And then we need to implement _check_view_definitions:

def get_out_of_date_views(app_config):
    app_name = app_config.name

    view_models = [
        model
        # We need the real app_config, not the migration one.
        for model in apps.get_app_config(app_name.split('.')[-1]).get_models()
        if not model._meta.managed and hasattr(model, 'get_view_queryset')
    ]

    for model in view_models:
        latest = get_latest_queryset(model)
        current = model.get_view_queryset()

        if latest is None or current.query.sql_with_params() != latest.query.sql_with_params():
            yield MissingViewMigration(
                model,
                current,
                latest,
                Warning(W003.format(app_name=app_name, model_name=model._meta.model_name), id='sql_helpers.W003'),
            )


def _check_view_definitions(app_config):
    return [x.warning for x in get_out_of_date_views(app_config)]

The last puzzle piece there is get_latest_queryset, which is a bit more complicated:

def get_latest_queryset(model, before=None):
    from django.db.migrations.loader import MigrationLoader
    from django.db import connection

    migration_loader = MigrationLoader(None)
    migrations = dict(migration_loader.disk_migrations)
    migration_loader.build_graph()
    state = migration_loader.project_state()
    app_label = model._meta.app_label
    root_node = dict(migration_loader.graph.root_nodes())[app_label]
    # We want to skip any migrations in our reverse list until we have
    # hit a specific node: however, if that is not supplied, it means
    # we don't skip any.
    if not before:
        seen_before = True
    for node in migration_loader.graph.backwards_plan((app_label, root_node)):
        if node == before:
            seen_before = True
            continue
        if not seen_before:
            continue
        migration = migrations[node]
        for operation in migration.operations:
            if (
                isinstance(operation, CreateOrReplaceView) and
                operation.model.lower() == model._meta.model_name.lower()
            ):
                return operation.queryset_factory(state.apps, connection.schema_editor())

This also has code to allow us to pass in a node (before), which limits the search to migrations that occur before that node in the forwards migration plan.

Since we already have the bits in place, we could also have a management command that creates a stub migration (without the queryset factory, that’s a problem I haven’t yet solved). I’ve built this into my related “load SQL from files” app.


This is still a bit of a work in progress, but writing it down helped me clarify some concepts.

Django properties from expressions, or ComputedField part 2

I’ve discussed the concept of a ComputedField in the past. On the weekend, a friend pointed me towards SQL Alchemy’s Hybrid Attributes. The main difference here is that in a ComputedField, the calculation is always done in the database. Thus, if a change is made to the model instance (and it is not yet saved), then the ComputedField will not change it’s value. Let’s look at an example from that original post:

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()
    display_name = ComputedField(
        Concat(F('first_name'), Value(' '), F('last_name')),
        output_field=models.TextField()
    )

We can use this to query, or as an attribute:

Person.objects.filter(display_name__startswith='foo')
Person.objects.first().display_name

But, if we make changes, we don’t see them until we re-query:

person = Person(first_name='Fred', last_name='Jones')
person.display_name  # This is not set

So, it got me thinking. Is it possible to turn a django ORM expression into python code that can execute and have the same output?

And, perhaps the syntax SQL Alchemy uses is nicer?

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()

    @shared_property
    def display_name(self):
        return Concat(
            F('first_name'),
            Value(' '),
            F('last_name'),
            output_field=models.TextField(),
        )

The advantage to using the decorator approach is that you could have a more complex expression - but perhaps that is actually a disadvantage. It might be nice to ensure that the code can be turned into a python function, after all.


The first step is to get the expression we need to convert to a python function. Writing a python decorator will give us access to the “function” object - we can just call this, as long as it does not refer to self at all, this can be done without an instance:

class shared_property(object):
    def __init__(self, function):
        expression = function(None)

This gives us the expression object. Because this is a python object, we can just look at it directly, and turn that into an AST. Having a class for parsing this makes things a bit simpler. Let’s look at a parser that can handle this expression.

import ast


class Parser:
    def __init__(self, function):
        # Make a copy, in case this expression is used elsewhere, and we change it.
        expression = function(None).copy()
        tree = self.build_expression(expression)
        # Need to turn this into code...
        self.code = compile(tree, mode='eval', filename=function.func_code.co_filename)

    def build_expression(self, expression):
        # Dynamically find the method we need to call to handle this expression.
        return getattr(self, 'handle_{}'.format(expression.__class__.__name__.lower()))(expression)

    def handle_concat(self, concat):
        # A Concat() contains only one source expression: ConcatPair().
        return self.build_expression(*concat.get_source_expressions())

    def handle_concatpair(self, pair):
        left, right = pair.get_source_expressions()
        return ast.BinOp(
            left=self.build_expression(left),
            op=ast.Add(),
            right=self.build_expression(right),
        )

    def handle_f(self, f):
        # Probably some more work here around transforms/lookups...
        # Set this, because without it we get errors. Will have to
        # figure out a better way to handle this later...
        f.contains_aggregate = False
        return ast.Attribute(
            value=ast.Name(id='self'),
            attr=f.name,
        )

    def handle_value(self, value):
        if value.value is None:
            return ast.Name(id='None')

        if isinstance(value.value, (str, unicode)):
            return ast.Str(s=value.value)

        if isinstance(value.value, (int, float)):
            return ast.Num(n=value.value)

        if isinstance(value.value, bool):
            return ast.Name(id=str(value.value))

        # ... others?
        raise ValueError('Unable to handle {}'.format(value))

There’s a bit more “noise” required in there (every node must have a ctx, and a filename, lineno and col_offset), but they make it a bit harder to follow.

So, we have our expression, and we have turned that into an equivalent python expression, and compiled it…except it won’t compile. We need to wrap it in an ast.Expression(), and then we can compile it (and call it).

Roughly, we’ll end up with a code object that does:

self.first_name + (' ' + self.last_name)

We can call this with our context set:

eval(code, {'self': instance})

But, before we head down that route (I did, but you don’t need to), it’s worth noticing that not all ORM expressions can be mapped directly onto a single python expression. For instance, if we added an optional preferred_name field to our model, our display_name expression may look like:

@shared_property
def display_name(self):
    return Case(
        When(preferred_name__isnull=True, then=Concat(F('first_name'), Value(' '), F('last_name'))),
        When(preferred_name__exact=Value(''), then=Concat(F('first_name'), Value(' '), F('last_name'))),
        default=Concat(F('first_name'), Value(' ('), F('preferred_name'), Value(') ') F('last_name')),
        output_field=models.TextField()
    )

Since this will roughly translate to:

@property
  def display_name(self):
      if all([self.preferred_name is None]):
          return self.first_name + ' ' + self.last_name
      elif all([self.preferred_name == '']):
          return self.first_name + ' ' + self.last_name
      else:
          return self.first_name + ' (' + self.preferred_name + ') ' + self.last_name

Whilst this is still a single ast node, it is not an expression (and cannot easily be turned into an expression - although in this case we could use a dict lookup based on self.preferred_name, but that’s not always going to work). Instead, we’ll need to change our code to generate a statement that contains a function definition, and then evaluate that to get the function object in the context. Then, we’ll have a callable that we can call with our model instance to get our result.

There are a few hitches along the way though. The first is turning our method into both a private field and a property. That is the relatively straightforward part:

class shared_property:
    def __init__(self, function):
        self.parsed = Parser(function)
        context = {}
        eval(self.parsed.code, context)
        self.callable = context[function.func_code.co_name]

    def __get__(self, instance, cls=None):
        # Magic Descriptor method: this method will be called when this property
        # is accessed on the instance.
        if instance is None:
            return self
        return self.callable(instance)

    def contribute_to_class(self, cls, name, private_only=False):
        # Magic Django method: this is called by django on class instantiaton, and allows
        # us to add our field (and ourself) to the model. Mostly this is the same as
        # a normal Django Field class would do, with the exception of setting concrete
        # to false, and using the output_field instead of ourself.
        field = self.parsed.expression.output_field
        field.set_attributes_from_name(name)
        field.model = cls
        field.concrete = False
        # This next line is important - it's the key to having everything work when querying.
        field.cached_col = ExpressionCol(self.parsed.expression)
        cls._meta.add_field(field, private=True)
        if not getattr(cls, field.attname, None):
            setattr(cls, field.attname, self)

There are a few things to note in that last method.

  • We use the output_field from the expression as the added field.
  • We mark this field as a private, non-concrete field. This prevents django from writing it back to the database, but it also means it will not appear in a .values() unless we explicitly ask for it. That’s actually fine, because we want the python property to execute instead of just using the value the database gave us.
  • The cached_col attribute is used when generating queries - we’ll look more at that now.

When I previously wrote the ComputedField implementation, the place I was not happy was with the get_col() method/the cached_col attribute. Indeed, to get that to work, I needed to use inspect to sniff up the stack to find a query instance to resolve the expression.

This time around though, I took a different approach. I was not able to use the regular resolve_expression path, because fields are assumed not to require access to the query to resolve to a Col expression. Instead, we can delay the resolve until we have something that gives us the query object.

class ExpressionCol:
    contains_aggregate = False
    def __init__(self, expression):
        self.expression = expression
        self.output_field = expression.output_field

    def get_lookup(self, name):
        return self.output_field.get_lookup(name)

    def get_transform(self, name):
        return self.output_field.get_transform(name)

    def as_sql(self, compiler, connection):
        resolved = self.expression.resolve_expression(compiler.query)
        return resolve_expression.as_sql(compiler, connection)

    def get_db_converters(self, connection):
      return self.output_field.get_db_converters(connection) + \
             self.expression.get_db_converters(connection)

This doesn’t need to be a full Expression subclass, because it mostly delegates things to the output field, but when it is turned into SQL, it can resolve the expression before then using that resolved expression to build the SQL.

So, let’s see how this works now (without showing the new Nodes that are handled by the Parser):

Person.objects.filter(display_name__startswith='Bob')

Yeah, that correctly limits the queryset. How about the ability to re-evaluate without a db round trip?

person = Person(first_name='Fred', last_name='Jones')
person.display_name  # -> 'Fred Jones'
person.preferred_name = 'Jonesy'
person.display_name  # -> 'Fred (Jonesy) Jones'

Success!


This project is not done yet: I have improved the Parser (as implied) to support more expressions, but there is still a bit more to go. It did occur to me (but not until I was writing this post) that the ComputedField(expression) version may actually be nicer. As hinted, that requires the value to be an expression, rather than a function call. It would be possible to create a function that references self, for instance, and breaks in all sorts of ways.

Preventing Model Overwrites in Django and Postgres

I had an idea tonight while helping someone in #django. It revolved around using a postgres trigger to prevent overwrites with stale data.

Consider the following model:

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()

If we had two users attempting to update a given instance at around the same time, Django would fetch whatever it had in the database when they did the GET request to fetch the form, and display that to them. It would also use whatever they sent back to save the object. In that case, the last update wins. Sometimes, this is what is required, but it does mean that one user’s changes would be completely overwritten, even if they had only changed something that the subsequent user did not change.

There are a couple of solutions to this problem. One is to use something like django-model-utils FieldTracker to record which fields have been changed, and only write those back using instance.save(update_fields=...). If you are using a django Form (and you probably should be), then you can also inspect form.changed_data to see what fields have changed.

However, that may not always be the best behaviour. Another solution would be to refuse to save something that had changed since they initially fetched the object, and instead show them the changes, allow them to update to whatever it should be now, and then resubmit. After which time, someone else may have made changes, but then the process repeats.

But how can we know that the object has changed?

One solution could be to use a trigger (and an extra column).

class Person(models.Model):
    first_name = models.TextField()
    last_name = models.TextField()
    _last_state = models.UUIDField()

And in our database trigger:

CREATE EXTENSION "uuid-ossp";

CREATE OR REPLACE FUNCTION prevent_clobbering()
RETURNS TRIGGER AS $prevent_clobbering$

BEGIN
  IF NEW._last_state != OLD._last_state THEN
    RAISE EXCEPTION 'Object was changed';
  END IF;
  NEW._last_state = uuid_generate_v4();
  RETURN NEW;
END;

$prevent_clobbering$
LANGUAGE plpgsql STRICT IMMUTABLE;

CREATE TRIGGER prevent_clobbering
BEFORE UPDATE ON person_person
FOR EACH ROW EXECUTE PROCEDURE prevent_clobbering();

You’d also want to have some level of handling in Django to capture the exception, and re-display the form. You can’t use the form/model validation handling for this, as it needs to happen during the save.

To make this work would also require the _last_state column to have a DEFAULT uuid_generate_v4(), so that newly created rows would get a value.


This is only a prototype at this stage, but does work as a mechanism for preventing overwrites. As usual, there’s probably more work in the application server, and indeed in the UI that would need to be required for displaying stale/updated values.

What this does have going for it is that it’s happening at the database level. There is no way that an update could happen (unless the request coming in happened to guess what the new UUID was going to be).

What about drawbacks? Well, there is a bit more storage in the UUID, and we need to regenerate a new one each time we save a row. We could have something that checks the other rows looking for changes.

Perhaps we could even have the hash of the previous row’s value stored in this field - that way it would not matter that there had been N changes, what matters is the value the user saw before they entered their changes.

Another drawback is that it’s hard-coded to a specific column. We could rewrite the function to allow defining the column when we create the trigger:

CREATE TRIGGER prevent_clobbering
BEFORE UPDATE ON person_person
FOR EACH ROW EXECUTE PROCEDURE prevent_clobbering('_last_state_');

But that requires a bit more work in the function itself:

CREATE OR REPLACE FUNCTION prevent_clobbering()
RETURNS TRIGGER AS $prevent_clobbering$

BEGIN
  IF to_jsonb(NEW)->TG_ARGV[0] != to_jsonb(OLD)->TG_ARGV[0] THEN
    RAISE EXCEPTION 'Object was changed';
  END IF;
  NEW._last_state = uuid_generate_v4();
  RETURN NEW;
END;

$prevent_clobbering$
LANGUAGE plpgsql STRICT IMMUTABLE;

Django and Robot Framework

One of my colleagues has spent a bunch of time investigating and then implementing some testing using Robot Framework. Whilst at times the command line feels like it was written by someone who hasn’t used unix much, it’s pretty powerful. There are also some nice tools, like several Google Chrome plugins that will record what you are doing and generate a script based upon that. There are also other tools to help build testing scripts.

There is also an existing DjangoLibrary for integrating with Django.

It’s an interesting approach: you install some extra middleware that allows you to perform requests directly to the server to create instances using Factory Boy, or fetch data from Querysets. However, it requires that the data is serialised before sending to the django server, and the same the other way. This means, for instance, that you cannot follow object references to get a related object without a bunch of legwork: usually you end up doing another Query Set query.

There are some things in it that I do not like:

  • A new instance of the django runserver command is started for each Test Suite. In our case, this takes over 10 seconds to start as all imports are processed.
  • The database is flushed between Test Suites. We have data that is added through migrations that is required for the system to operate correctly, and in some cases for tests to execute. This is the same problem I’ve seen with TransactionTestCase.
  • Migrations are applied before running each Test Suite. This is unnecessary, and just takes more time.
  • Migrations are created automatically before running each Test Suite. This is just the wrong approach: at worst you’d want to warn that migrations are not up to date - otherwise you are testing migrations that may not have been committed: your CI would pass because the migrations were generated, but your system would fail in reality because those migrations do not really exist. Unless you are also making migrations directly on your production server and not committing them at all, in which case you really should stop that.

That’s in addition to having to install extra middleware.

But, back onto the initial issue: interacting with Django models.

What would be much nicer is if you could just call the python code directly. You’d get python objects back, which means you can follow references, and not have to deal with serialisation.

It’s fairly easy to write a Library for Robot Framework, as it already runs under Python. The tricky bit is that to access Django models (or Factory Boy factories), you’ll want to have the Django infrastructure all managed for you.

Let’s look at what the DjangoLibrary might look like if you are able to assume that django is already available and configured:

import importlib

from django.apps import apps
from django.core.urlresolvers import reverse

from robot.libraries.BuiltIn import BuiltIn


class DjangoLibrary:
    """

    Tools for making interaction with Django easier.

    Installation: ensure that in your `resource.robot` or test file, you have the
    following in your "***Settings***" section:

        Library         djangobot.DjangoLibrary     ${HOSTNAME}     ${PORT}

    The following keywords are provided:


    Factory:        execute the named factory with the args and kwargs. You may omit
                    the 'factories' module from the path to reduce the amount of code
                    required.

        ${obj}=     Factory     app_label.FactoryName       arg  kwarg=value
        ${obj}=     Factory     app_label.factories.FactoryName     arg  kwarg=value


    Queryset:       return a queryset of the installed model, using the default manager
                    and filtering according to any keyword arguments.

        ${qs}=      Queryset    auth.User       pk=1


    Method Call:    Execute the callable with tha args/kwargs provided. This differs
                    from the Builtin "Call Method" in that it expects a callable, rather
                    than an instance and a method name.

        ${x}=       Method Call     ${foo.bar}      arg  kwargs=value


    Relative Url:   Resolve the named url and args/kwargs, and return the path. Not
                    quite as useful as the "Url", since it has no hostname, but may be
                    useful when dealing with `?next=/path/` values, for instance.

        ${url}=     Relative Url        foo:bar     baz=qux


    Url:            Resolve the named url with args/kwargs, and return the fully qualified url.

        ${url}=     Url                 foo:bar     baz=qux


    Fetch Url:      Resolve the named url with args/kwargs, and then using SeleniumLibrary,
                    navigate to that URL. This should be used instead of the "Go To" command,
                    as it allows using named urls instead of manually specifying urls.

        Fetch Url   foo:bar     baz=qux


    Url Should Match:   Assert that the current page matches the named url with args/kwargs.

        Url Should Match        foo:bar     baz=qux

    """

    def __init__(self, hostname, port, **kwargs):
        self.hostname = hostname
        self.port = port
        self.protocol = kwargs.pop('protocol', 'http')

    @property
    def selenium(self):
        return BuiltIn().get_library_instance('SeleniumLibrary')

    def factory(self, factory, **kwargs):
        module, name = factory.rsplit('.', 1)
        factory = getattr(importlib.import_module(module), name)
        return factory(**kwargs)

    def queryset(self, dotted_path, **kwargs):
        return apps.get_model(dotted_path.split('.'))._default_manager.filter(**kwargs)

    def method_call(self, method, *args, **kwargs):
        return method(*args, **kwargs)

    def fetch_url(self, name, *args, **kwargs):
        return self.selenium.go_to(self.url(name, *args, **kwargs))

    def relative_url(self, name, *args, **kwargs):
        return reverse(name, args=args, kwargs=kwargs)

    def url(self, name, *args, **kwargs):
        return '{}://{}:{}'.format(
            self.protocol,
            self.hostname,
            self.port,
        ) + reverse(name, args=args, kwargs=kwargs)

    def url_should_match(self, name, *args, **kwargs):
        self.selenium.location_should_be(self.url(name, *args, **kwargs))

You can write a management command: this allows you to hook in to Django’s existing infrastructure. Then, instead of calling robot directly, you use ./manage.py robot

What’s even nicer about using a management command is that you can have that (optionally, because in development you probably will already have a devserver running) start runserver, and kill it when it’s finished. This is the same philosophy as robotframework-DjangoLibrary already does, but we can start it once before running out tests, and kill it at the end.

So, what could our management command look like? Omitting the code for starting runserver, it’s quite neat:

from __future__ import absolute_import

from django.core.management import BaseCommand, CommandError

import robot


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument('tests', nargs='?', action='append')
        parser.add_argument('--variable', action='append')
        parser.add_argument('--include', action='append')

    def handle(self, **options):
        robot_options = {
            'outputdir': 'robot_results',
            'variable': options.get('variable') or []
        }
        if options.get('include'):
            robot_options['include'] = options['include']

        args = [
            'robot_tests/{}_test.robot'.format(arg)
            for arg in options['tests'] or ()
            if arg
        ] or ['robot_tests']

        result = robot.run(*args, **robot_options)

        if result:
            raise CommandError('Robot tests failed: {}'.format(result))

I think I’d like to do a bit more work on finding tests, but this works as a starting point. We can call this like:

./manage.py robot foo --variable BROWSER:firefox --variable PORT:8000

This will find a test called robot_tests/foo_test.robot, and execute that. If you omit the test argument, it will run on all tests in the robot_tests/ directory.

I’ve still got a bit to do on cleaning up the code that starts/stops the server, but I think this is useful even without that.