1 | """ |
---|
2 | Database schema version management. |
---|
3 | """ |
---|
4 | from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, |
---|
5 | create_engine) |
---|
6 | from sqlalchemy.sql import and_ |
---|
7 | from sqlalchemy import exceptions as sa_exceptions |
---|
8 | |
---|
9 | from migrate.versioning import exceptions, genmodel, schemadiff |
---|
10 | from migrate.versioning.repository import Repository |
---|
11 | from migrate.versioning.util import load_model |
---|
12 | from migrate.versioning.version import VerNum |
---|
13 | |
---|
14 | |
---|
15 | class ControlledSchema(object): |
---|
16 | """A database under version control""" |
---|
17 | |
---|
18 | def __init__(self, engine, repository): |
---|
19 | if isinstance(repository, str): |
---|
20 | repository=Repository(repository) |
---|
21 | self.engine = engine |
---|
22 | self.repository = repository |
---|
23 | self.meta=MetaData(engine) |
---|
24 | self._load() |
---|
25 | |
---|
26 | def __eq__(self, other): |
---|
27 | return (self.repository is other.repository \ |
---|
28 | and self.version == other.version) |
---|
29 | |
---|
30 | def _load(self): |
---|
31 | """Load controlled schema version info from DB""" |
---|
32 | tname = self.repository.version_table |
---|
33 | self.meta=MetaData(self.engine) |
---|
34 | if not hasattr(self, 'table') or self.table is None: |
---|
35 | try: |
---|
36 | self.table = Table(tname, self.meta, autoload=True) |
---|
37 | except (exceptions.NoSuchTableError): |
---|
38 | raise exceptions.DatabaseNotControlledError(tname) |
---|
39 | # TODO?: verify that the table is correct (# cols, etc.) |
---|
40 | result = self.engine.execute(self.table.select( |
---|
41 | self.table.c.repository_id == str(self.repository.id))) |
---|
42 | data = list(result)[0] |
---|
43 | # TODO?: exception if row count is bad |
---|
44 | # TODO: check repository id, exception if incorrect |
---|
45 | self.version = data['version'] |
---|
46 | |
---|
47 | def _get_repository(self): |
---|
48 | """ |
---|
49 | Given a database engine, try to guess the repository. |
---|
50 | |
---|
51 | :raise: :exc:`NotImplementedError` |
---|
52 | """ |
---|
53 | # TODO: no guessing yet; for now, a repository must be supplied |
---|
54 | raise NotImplementedError() |
---|
55 | |
---|
56 | @classmethod |
---|
57 | def create(cls, engine, repository, version=None): |
---|
58 | """ |
---|
59 | Declare a database to be under a repository's version control. |
---|
60 | """ |
---|
61 | # Confirm that the version # is valid: positive, integer, |
---|
62 | # exists in repos |
---|
63 | if type(repository) is str: |
---|
64 | repository=Repository(repository) |
---|
65 | version = cls._validate_version(repository, version) |
---|
66 | table=cls._create_table_version(engine, repository, version) |
---|
67 | # TODO: history table |
---|
68 | # Load repository information and return |
---|
69 | return cls(engine, repository) |
---|
70 | |
---|
71 | @classmethod |
---|
72 | def _validate_version(cls, repository, version): |
---|
73 | """ |
---|
74 | Ensures this is a valid version number for this repository. |
---|
75 | |
---|
76 | :raises: :exc:`cls.InvalidVersionError` if invalid |
---|
77 | :return: valid version number |
---|
78 | """ |
---|
79 | if version is None: |
---|
80 | version = 0 |
---|
81 | try: |
---|
82 | version = VerNum(version) # raises valueerror |
---|
83 | if version < 0 or version > repository.latest: |
---|
84 | raise ValueError() |
---|
85 | except ValueError: |
---|
86 | raise exceptions.InvalidVersionError(version) |
---|
87 | return version |
---|
88 | |
---|
89 | @classmethod |
---|
90 | def _create_table_version(cls, engine, repository, version): |
---|
91 | """ |
---|
92 | Creates the versioning table in a database. |
---|
93 | """ |
---|
94 | # Create tables |
---|
95 | tname = repository.version_table |
---|
96 | meta = MetaData(engine) |
---|
97 | |
---|
98 | table = Table( |
---|
99 | tname, meta, |
---|
100 | Column('repository_id', String(255), primary_key=True), |
---|
101 | Column('repository_path', Text), |
---|
102 | Column('version', Integer), ) |
---|
103 | |
---|
104 | if not table.exists(): |
---|
105 | table.create() |
---|
106 | |
---|
107 | # Insert data |
---|
108 | try: |
---|
109 | engine.execute(table.insert(), repository_id=repository.id, |
---|
110 | repository_path=repository.path, |
---|
111 | version=int(version)) |
---|
112 | except sa_exceptions.IntegrityError: |
---|
113 | # An Entry for this repo already exists. |
---|
114 | raise exceptions.DatabaseAlreadyControlledError() |
---|
115 | return table |
---|
116 | |
---|
117 | @classmethod |
---|
118 | def compare_model_to_db(cls, engine, model, repository): |
---|
119 | """ |
---|
120 | Compare the current model against the current database. |
---|
121 | """ |
---|
122 | if isinstance(repository, basestring): |
---|
123 | repository=Repository(repository) |
---|
124 | model = load_model(model) |
---|
125 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
---|
126 | model, engine, excludeTables=[repository.version_table]) |
---|
127 | return diff |
---|
128 | |
---|
129 | @classmethod |
---|
130 | def create_model(cls, engine, repository, declarative=False): |
---|
131 | """ |
---|
132 | Dump the current database as a Python model. |
---|
133 | """ |
---|
134 | if isinstance(repository, basestring): |
---|
135 | repository=Repository(repository) |
---|
136 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
---|
137 | MetaData(), engine, excludeTables=[repository.version_table]) |
---|
138 | return genmodel.ModelGenerator(diff, declarative).toPython() |
---|
139 | |
---|
140 | def update_db_from_model(self, model): |
---|
141 | """ |
---|
142 | Modify the database to match the structure of the current Python model. |
---|
143 | """ |
---|
144 | if isinstance(self.repository, basestring): |
---|
145 | self.repository=Repository(self.repository) |
---|
146 | model = load_model(model) |
---|
147 | diff = schemadiff.getDiffOfModelAgainstDatabase( |
---|
148 | model, self.engine, excludeTables=[self.repository.version_table]) |
---|
149 | genmodel.ModelGenerator(diff).applyModel() |
---|
150 | update = self.table.update( |
---|
151 | self.table.c.repository_id == str(self.repository.id)) |
---|
152 | self.engine.execute(update, version=int(self.repository.latest)) |
---|
153 | |
---|
154 | def drop(self): |
---|
155 | """ |
---|
156 | Remove version control from a database. |
---|
157 | """ |
---|
158 | try: |
---|
159 | self.table.drop() |
---|
160 | except (sa_exceptions.SQLError): |
---|
161 | raise exceptions.DatabaseNotControlledError(str(self.table)) |
---|
162 | |
---|
163 | def _engine_db(self, engine): |
---|
164 | """ |
---|
165 | Returns the database name of an engine - ``postgres``, ``sqlite`` ... |
---|
166 | """ |
---|
167 | # TODO: This is a bit of a hack... |
---|
168 | return str(engine.dialect.__module__).split('.')[-1] |
---|
169 | |
---|
170 | def changeset(self, version=None): |
---|
171 | database = self._engine_db(self.engine) |
---|
172 | start_ver = self.version |
---|
173 | changeset = self.repository.changeset(database, start_ver, version) |
---|
174 | return changeset |
---|
175 | |
---|
176 | def runchange(self, ver, change, step): |
---|
177 | startver = ver |
---|
178 | endver = ver + step |
---|
179 | # Current database version must be correct! Don't run if corrupt! |
---|
180 | if self.version != startver: |
---|
181 | raise exceptions.InvalidVersionError("%s is not %s" % \ |
---|
182 | (self.version, startver)) |
---|
183 | # Run the change |
---|
184 | change.run(self.engine, step) |
---|
185 | # Update/refresh database version |
---|
186 | update = self.table.update( |
---|
187 | and_(self.table.c.version == int(startver), |
---|
188 | self.table.c.repository_id == str(self.repository.id))) |
---|
189 | self.engine.execute(update, version=int(endver)) |
---|
190 | self._load() |
---|
191 | |
---|
192 | def upgrade(self, version=None): |
---|
193 | """ |
---|
194 | Upgrade (or downgrade) to a specified version, or latest version. |
---|
195 | """ |
---|
196 | changeset = self.changeset(version) |
---|
197 | for ver, change in changeset: |
---|
198 | self.runchange(ver, change, changeset.step) |
---|