Tuesday, August 07, 2007

Python: Database Migrations

As part of my day job, I've written a Rails-style database migration script. This lets you write migrations from one version of a schema to the next. This allows you to develop schemas iteratively. It also lets you upgrade or downgrade the schema. Best of all, if an attempted upgrade fails, it can back it out even if you're not using transactions. Of course, this is based on writing "up" and "down" routines--it's practical, not magical.

I'm releasing this code in the hope that others will find it useful. It's well-written, solid, and well-tested. This is the type of thing you could probably write in a day. I took four, and polished the heck out of it.

It uses SQLAlchemy to talk to the database. However, that doesn't mean you have to use SQLAlchemy. Personally, I like writing table create statements by hand. You can do either.

My database configuration is stored in a .ini file ala Paste / Pylons. Hence, the script takes a .ini file to retrieve the database configuration. If you don't use Pylons, but you still want to use my script, that's an easy change to make. Migrations are stored in Python modules named like "${yourpackage}.db.migration_${number}.py". Again, I use Pylons to figure out what "${yourpackage}" is, but that's easy enough to change.

The name of my Pylons app is "multicosmic", and the script is installed in my application. You'll need to change the name to match your app.

Start by creating directories and __init__.py files for "multicosmic/db" and "multicosmic/scripts".

First, there's a migrate script in "multicosmic/scripts/migrate.py":
#!/usr/bin/env python

"""This is a script to apply database migrations.

Run this script with the -h flag to get usage information.

Migration Modules
-----------------

Each migration is a module stored in
``${appname}/db/migration_${revision}.py`` where revision starts at 000
(i.e. an empty database). Each such module should have a module-level
global named migration containing a list of pairs of atoms. For
instance::

migration = [
# (up, down)
("CREATE TABLE A(...)", "DROP TABLE IF EXISTS A"),
("CREATE TABLE B(...)", "DROP TABLE IF EXISTS B")
]

The up and down atoms may either be SQL strings, or they may be
functions that accept a SQLAlchemy connection.

Since I'm using SQLAlchemy, you might wonder why I'm writing actual SQL.
I like to use the SQLAlchemy ORM. However, when creating tables in
MySQL, there are so many fancy options that I find it easier to write
the SQL by hand.

Error Handling
--------------

* If something goes wrong when down migrating, just let the exception
propagate.

* If something goes wrong when up migrating, complain, try to back it
out, and then let the exception propagate. If backing it out fails,
just let that exception propagate.

* Use transactions as appropriate. There are a lot of cases in
MySQL where transactions aren't supported. Hence, backing things
out is sometimes necessary. However, it's also possible that a
transaction might rollback, and then the code to back things out
runs anyway. It's best to make your down atoms idempotent. For
instance, use "DROP TABLE IF EXISTS" rather than just "DROP
TABLE".

Avoiding SQLAlchemy, Pylons, Paste, and Python 2.5
--------------------------------------------------

I'm using SQLAlchemy, but that doesn't force you to use SQLAlchemy in
the rest of your app. I'm using Paste's configuration mechanism because
that's how my database configuration information is stored. Passing a
CONFIG.ini to the script meets the needs of Paste and Pylons users
perfectly. If you're not one of those users and you want to use my
script, it's easy to subclass it and do something differently.
Similarly, if you're not using Python 2.5, I'm happy to remove the
Python 2.5-isms. Let's talk!

"""

# Copyright: "Shannon -jj Behrens <jjinux@gmail.com>"
# License: I am contributing this code to the Pylons project under the same license as Pylons.

from __future__ import with_statement

from contextlib import contextmanager, closing
from glob import glob
from optparse import OptionParser
import os
import re
import sys
import traceback

from paste.deploy import loadapp
from pylons import config as conf
from pylons.database import create_engine

__docformat__ = "restructuredtext"


class Migrate:

"""This is the main class that runs the migrations."""

def __init__(self, args=None):
"""Set everything up, but don't run the migrations.

args
This defaults to ``sys.argv[1:]``.

"""
self.setup_option_parser(args)

def setup_option_parser(self, args):
"""Parse command line arguments."""
self.args = args
usage = "usage: %prog [options] CONFIG.ini"
self.parser = OptionParser(usage=usage)
self.parser.add_option('-r', '--revision', type='int',
help='schema revision; defaults to most current')
self.parser.add_option('-p', '--print-revision', action="store_true",
default=False,
help='print current revision and exit')
self.parser.add_option("-v", "--verbose", action="store_true",
default=False)
(self.options, self.args) = self.parser.parse_args(self.args)
if len(self.args) != 1:
self.parser.error("Expected exactly one argument for CONFIG.ini")

def run(self):
"""Run the migrations.

All database activity starts from here.

"""
self.load_configuration()
self.engine = create_engine()
self.engine.echo = bool(self.options.verbose)
with closing(self.engine.connect()) as self.connection:
self.find_migration_modules()
self.find_desired_revision()
self.find_current_revision()
if self.options.print_revision:
print self.current_revision
return
self.find_desired_migrations()
self.print_overview()
for migration in self.desired_migrations:
self.apply_migration(migration)

def load_configuration(self):
"""Load the configuration."""
try:
loadapp('config:%s' % self.args[0], relative_to=os.getcwd())
except OSError, e:
self.parser.error(str(e))
dburi = conf.get('sqlalchemy.dburi')
if not dburi:
self.parser.error("%s: No sqlalchemy.dburi found" % self.args[0])

def find_migration_modules(self):
"""Figure out what migrations exist.

They should start at 000.

"""
package = conf['pylons.package']
module = __import__(package + '.db', fromlist=['db'])
dirname = os.path.dirname(module.__file__)
glob_pattern = os.path.join(dirname, 'migration_*.py')
files = glob(glob_pattern)
files.sort()
basenames = map(os.path.basename, files)
for (i, name) in enumerate(basenames):
expected = 'migration_%03d.py' % i
if name != expected:
raise ValueError("Expected %s, got %s" % (expected, name))
self.migration_modules = []
for name in basenames:
name = name[:-len('.py')]
module = __import__('%s.db.%s' % (package, name),
fromlist=[name])
self.migration_modules.append(module)

def find_desired_revision(self):
"""Find the target revision."""
len_migration_modules = len(self.migration_modules)
if self.options.revision is None:
self.desired_revision = len_migration_modules - 1
else:
self.desired_revision = self.options.revision
if (self.desired_revision < 0 or
self.desired_revision >= len_migration_modules):
self.parser.error(
"Revision argument out of range [0, %s]" %
(len_migration_modules - 1))

def find_current_revision(self):
"""Figure out what revision we're currently at."""
if self.connection.execute(
"SHOW TABLES LIKE 'revision'").rowcount == 0:
self.current_revision = 0
else:
result = self.connection.execute(
"SELECT revision_id FROM revision")
self.current_revision = int(result.fetchone()[0])

def find_desired_migrations(self):
"""Figure out which migrations need to be applied."""
self.find_migration_range()
self.desired_migrations = [
self.migration_modules[i]
for i in self.migration_range
]

def find_migration_range(self):

"""Figure out the range of the migrations that need to be applied."""

if self.current_revision <= self.desired_revision:

# Don't reapply the current revision. Do apply the
# desired revision.

self.step = 1
self.migration_range = range(self.current_revision + self.step,
self.desired_revision + self.step)
else:

# Unapply the current revision. Don't unapply the
# desired revision.

self.step = -1
self.migration_range = range(self.current_revision,
self.desired_revision, self.step)

def print_overview(self):
"""If verbose, tell the user what's going on."""
if self.options.verbose:
print "Current revision:", self.current_revision
print "Desired revision:", self.desired_revision
print "Direction:", ("up" if self.step == 1 else "down")
print "Migrations to be applied:", self.migration_range

def apply_migration(self, migration):
"""Apply the given migration list.

migration
This is a migration module.

"""
name = migration.__name__
revision = self.parse_revision(name)
if self.options.verbose:
print "Applying migration:", name
if self.step == -1:
with self.manage_transaction():
for (up, down) in reversed(migration.migration):
self.apply_atom(down)
self.record_revision(revision - 1)
else:
undo_atoms = []
try:
with self.manage_transaction():
for (up, down) in migration.migration:
self.apply_atom(up)
undo_atoms.append(down)
self.record_revision(revision)
except Exception, e:
print >> sys.stderr, "An exception occurred:"
traceback.print_exc()
print >> sys.stderr, "Trying to back out migration:", name
with self.manage_transaction():
for down in reversed(undo_atoms):
self.apply_atom(down)
print >> sys.stderr, "Backed out migration:", name
print >> sys.stderr, "Re-raising original exception."
raise

def apply_atom(self, atom):
"""Apply the given atom. Let exceptions propagate."""
if isinstance(atom, basestring):
self.connection.execute(atom)
else:
atom(self.connection)

def parse_revision(self, s):
"""Given a string, return the revision number embedded in it.

Raise a ValueError on failure.

"""
match = re.search('(\d+)', s)
if match is None:
raise ValueError("Couldn't find a revision in: %s" % s)
return int(match.group(0))

def record_revision(self, revision):
"""Record the given revision.

The current revision is stored in a table named revision.
There's nothing to do if revision is 0.

"""
if revision != 0:
self.connection.execute("UPDATE revision SET revision_id = %s",
revision)
self.current_revision = revision

@contextmanager
def manage_transaction(self):
"""Manage a database transaction.

Usage::

with self.manage_transaction():
...

"""
transaction = self.connection.begin()
try:
yield
transaction.commit()
except:
transaction.rollback()
raise


if __name__ == '__main__':
Migrate().run()
It comes with two migrations.

multicosmic/db/migration_000.py:
"""This is the first migration.

It doesn't really do anything; it represents an empty database. It
makes sense that a database at revision 0 should be empty.

"""

__docformat__ = "restructuredtext"


migration = []
multicosmic/db/migration_001.py
"""Create the revision table with a revision_id column."""

__docformat__ = "restructuredtext"


# I'm using a creative whitespace style that makes it readable both here
# and when printed.

migration = [
("""\
CREATE TABLE revision (
revision_id INT NOT NULL
) ENGINE = INNODB""",
"""\
DROP TABLE IF EXISTS revision"""),

# Subsequent migrations don't need to manage this value. The
# migrate.py script will take care of it.

("""\
INSERT INTO revision (revision_id) VALUES (1)""",
"""\
DELETE FROM revision""")
]
Last of all, there are test cases in multicosmic/tests/functional/test_migrate_script.py:
"""Test that the migrate script works."""

# Copyright: "Shannon -jj Behrens <jjinux@gmail.com>"
# License: I am contributing this code to the Pylons project under the same license as Pylons.

from cStringIO import StringIO
import sys

from nose.tools import assert_raises
from sqlalchemy.exceptions import SQLError

from multicosmic.scripts.migrate import Migrate
from multicosmic.db.migration_001 import migration as migration_001

__docformat__ = "restructuredtext"

BASE_ARGS = ['-v', 'test.ini']


def setup_module():
_do_migration(0)


def teardown_module():
_do_migration()


def test_setup_option_parser():
migrate = Migrate(['-r1'] + BASE_ARGS)
assert migrate.options.revision == 1
assert migrate.options.verbose


def test_bad_up_migration():
orig_stderr = sys.stderr
fake_stderr = StringIO()
migration_001.append(("INSERT INTO garbage", "DELETE FROM garbage"))
sys.stderr = fake_stderr
try:
migrate = Migrate(['-r1'] + BASE_ARGS)
assert_raises(SQLError, migrate.run)
finally:
sys.stderr = orig_stderr
migration_001.pop()
assert fake_stderr.getvalue()
migrate = Migrate(['-p'] + BASE_ARGS)
migrate.run()
assert migrate.current_revision == 0


def test_bad_down_migration():
_do_migration(1)
migration_001.append(("INSERT INTO garbage", "DELETE FROM garbage"))
try:
migrate = Migrate(['-r0'] + BASE_ARGS)
assert_raises(SQLError, migrate.run)
finally:
migration_001.pop()
migrate = Migrate(['-p'] + BASE_ARGS)
migrate.run()
assert migrate.current_revision == 1


def _do_migration(revision=None):
"""Construct and run the Migrate class. Return it."""
args = BASE_ARGS
if revision is not None:
args = ['-r%s' % revision] + BASE_ARGS
migrate = Migrate(args)
migrate.run()
if revision is None:
assert migrate.current_revision > 0
else:
assert migrate.current_revision == revision
return migrate
I use nose for my tests. You can find out more about using nose with Pylons, including things to watch out for, here.

If this code works out for you, leave me a comment :)

8 comments:

Max Ischenko said...

Cool. I need something like this and was about to write one. I'll see if I can use your solution verbatim (the only twist I need -- my upgrade scripts are written as plain .sql files. could you expand your Migrate() class to support this behavior out of the box?). If you setup svn rep I could probably contribute back.

Shannon -jj Behrens said...

Cool. Well, I hope it works out for you.

I'm using plain SQL, but I'm embedding it in Python files. There's a reason. A single migration might involve multiple steps. For instance, you might create three tables, and each CREATE TABLE is a step. In MySQL, there are a lot of things that ignore transactions. Hence, you might want to create, create, alter, create, but if anything fails, you want it to back out exactly what succeeded. Using Python, I can create a list of up/down pairs like: [(up, down), (up, down), (up, down)].

TG said...

Any plans to mention this on the pylons or SA lists? Or did I miss it?

As you likely know it's an oft-requested feature...

Shannon -jj Behrens said...

I sent email to Ben and Mike Bayer. I asked Ben to mention it on the Pylons mailing list. If you mention it on those mailing lists, I'd be quite grateful. I'm not currently subscribed because I'm in "just-had-a-baby-working-for-a-startup" mode. I'd also be happy to contribute this code to either of those projects.

Noah Gift said...

Nice script. OptionParser is a treat isn't it?

The "one" thing I wish it had was built in support for config file parsing...

Keep up the good work.

Shannon -jj Behrens said...

> Nice script.

Thanks ;)

samokk said...

Hi,

Thanks !

this script seems to be the closest thing I've been looking for.

And frankly, I'm pretty sure TONS of other people are looking for something similar.

What about trying to make this script a completly standalone project, that is unrelated to a specific web framework ?
Indeed, the feeling of having it tied to another project would only be harmful to widen its use.

It could definitely be seen as the tool of choice to support agile database refactoring development. !!!

Thanks again !
Sami Dalouche

Shannon -jj Behrens said...

Thanks for the comments, Sami.

I didn't have time to make this into a full project, so I figured the best I could do was blog about it and release the code. I hope that at least helps you out.

I do think there are a few parts that are going to be specific to your individual situations such as a) how to get the database connection parameters b) whether you want to use SQLAlchemy or something else c) where to get your migrations from, etc.

I think the truly generic parts are in just a few functions, such as knowing how to correctly apply a migration and back it out safely.

Happy Hacking!