1 | #!/usr/bin/env python |
---|
2 | # -*- coding: utf-8 -*- |
---|
3 | |
---|
4 | import shutil |
---|
5 | from StringIO import StringIO |
---|
6 | |
---|
7 | import migrate |
---|
8 | from migrate.versioning import exceptions, genmodel, schemadiff |
---|
9 | from migrate.versioning.base import operations |
---|
10 | from migrate.versioning.template import template |
---|
11 | from migrate.versioning.script import base |
---|
12 | from migrate.versioning.util import import_path, load_model, construct_engine |
---|
13 | |
---|
14 | class PythonScript(base.BaseScript): |
---|
15 | |
---|
16 | @classmethod |
---|
17 | def create(cls, path, **opts): |
---|
18 | """Create an empty migration script""" |
---|
19 | cls.require_notfound(path) |
---|
20 | |
---|
21 | # TODO: Use the default script template (defined in the template |
---|
22 | # module) for now, but we might want to allow people to specify a |
---|
23 | # different one later. |
---|
24 | template_file = None |
---|
25 | src = template.get_script(template_file) |
---|
26 | shutil.copy(src, path) |
---|
27 | |
---|
28 | @classmethod |
---|
29 | def make_update_script_for_model(cls, engine, oldmodel, |
---|
30 | model, repository, **opts): |
---|
31 | """Create a migration script""" |
---|
32 | |
---|
33 | # Compute differences. |
---|
34 | if isinstance(repository, basestring): |
---|
35 | # oh dear, an import cycle! |
---|
36 | from migrate.versioning.repository import Repository |
---|
37 | repository = Repository(repository) |
---|
38 | oldmodel = load_model(oldmodel) |
---|
39 | model = load_model(model) |
---|
40 | diff = schemadiff.getDiffOfModelAgainstModel( |
---|
41 | oldmodel, |
---|
42 | model, |
---|
43 | engine, |
---|
44 | excludeTables=[repository.version_table]) |
---|
45 | decls, upgradeCommands, downgradeCommands = \ |
---|
46 | genmodel.ModelGenerator(diff).toUpgradeDowngradePython() |
---|
47 | |
---|
48 | # Store differences into file. |
---|
49 | template_file = None |
---|
50 | src = template.get_script(template_file) |
---|
51 | contents = open(src).read() |
---|
52 | search = 'def upgrade():' |
---|
53 | contents = contents.replace(search, '\n\n'.join((decls, search)), 1) |
---|
54 | if upgradeCommands: |
---|
55 | contents = contents.replace(' pass', upgradeCommands, 1) |
---|
56 | if downgradeCommands: |
---|
57 | contents = contents.replace(' pass', downgradeCommands, 1) |
---|
58 | return contents |
---|
59 | |
---|
60 | @classmethod |
---|
61 | def verify_module(cls,path): |
---|
62 | """Ensure this is a valid script, or raise InvalidScriptError""" |
---|
63 | # Try to import and get the upgrade() func |
---|
64 | try: |
---|
65 | module=import_path(path) |
---|
66 | except: |
---|
67 | # If the script itself has errors, that's not our problem |
---|
68 | raise |
---|
69 | try: |
---|
70 | assert callable(module.upgrade) |
---|
71 | except Exception, e: |
---|
72 | raise exceptions.InvalidScriptError(path + ': %s' % str(e)) |
---|
73 | return module |
---|
74 | |
---|
75 | def preview_sql(self, url, step, **args): |
---|
76 | """Mock engine to store all executable calls in a string \ |
---|
77 | and execute the step""" |
---|
78 | buf = StringIO() |
---|
79 | args['engine_arg_strategy'] = 'mock' |
---|
80 | args['engine_arg_executor'] = lambda s, p='': buf.write(s + p) |
---|
81 | engine = construct_engine(url, **args) |
---|
82 | |
---|
83 | self.run(engine, step) |
---|
84 | |
---|
85 | return buf.getvalue() |
---|
86 | |
---|
87 | def run(self, engine, step): |
---|
88 | """Core method of Script file. \ |
---|
89 | Exectues update() or downgrade() function""" |
---|
90 | if step > 0: |
---|
91 | op = 'upgrade' |
---|
92 | elif step < 0: |
---|
93 | op = 'downgrade' |
---|
94 | else: |
---|
95 | raise exceptions.ScriptError("%d is not a valid step" % step) |
---|
96 | funcname = base.operations[op] |
---|
97 | |
---|
98 | migrate.migrate_engine = engine |
---|
99 | #migrate.run.migrate_engine = migrate.migrate_engine = engine |
---|
100 | func = self._func(funcname) |
---|
101 | func() |
---|
102 | migrate.migrate_engine = None |
---|
103 | #migrate.run.migrate_engine = migrate.migrate_engine = None |
---|
104 | |
---|
105 | @property |
---|
106 | def module(self): |
---|
107 | if not hasattr(self,'_module'): |
---|
108 | self._module = self.verify_module(self.path) |
---|
109 | return self._module |
---|
110 | |
---|
111 | def _func(self, funcname): |
---|
112 | fn = getattr(self.module, funcname, None) |
---|
113 | if not fn: |
---|
114 | msg = "The function %s is not defined in this script" |
---|
115 | raise exceptions.ScriptError(msg%funcname) |
---|
116 | return fn |
---|