| 1 | """ |
|---|
| 2 | Extensions to SQLAlchemy for altering existing tables. |
|---|
| 3 | |
|---|
| 4 | At the moment, this isn't so much based off of ANSI as much as |
|---|
| 5 | things that just happen to work with multiple databases. |
|---|
| 6 | """ |
|---|
| 7 | import sqlalchemy as sa |
|---|
| 8 | from sqlalchemy.engine.base import Connection, Dialect |
|---|
| 9 | from sqlalchemy.sql.compiler import SchemaGenerator |
|---|
| 10 | from sqlalchemy.schema import ForeignKeyConstraint |
|---|
| 11 | from migrate.changeset import constraint, exceptions |
|---|
| 12 | |
|---|
| 13 | SchemaIterator = sa.engine.SchemaIterator |
|---|
| 14 | |
|---|
| 15 | |
|---|
| 16 | class RawAlterTableVisitor(object): |
|---|
| 17 | """Common operations for ``ALTER TABLE`` statements.""" |
|---|
| 18 | |
|---|
| 19 | def _to_table(self, param): |
|---|
| 20 | """Returns the table object for the given param object.""" |
|---|
| 21 | if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)): |
|---|
| 22 | ret = param.table |
|---|
| 23 | else: |
|---|
| 24 | ret = param |
|---|
| 25 | return ret |
|---|
| 26 | |
|---|
| 27 | def _to_table_name(self, param): |
|---|
| 28 | """Returns the table name for the given param object.""" |
|---|
| 29 | ret = self._to_table(param) |
|---|
| 30 | if isinstance(ret, sa.Table): |
|---|
| 31 | ret = ret.fullname |
|---|
| 32 | return ret |
|---|
| 33 | |
|---|
| 34 | def _do_quote_table_identifier(self, identifier): |
|---|
| 35 | """Returns a quoted version of the given table identifier.""" |
|---|
| 36 | return '"%s"'%identifier |
|---|
| 37 | |
|---|
| 38 | def start_alter_table(self, param): |
|---|
| 39 | """Returns the start of an ``ALTER TABLE`` SQL-Statement. |
|---|
| 40 | |
|---|
| 41 | Use the param object to determine the table name and use it |
|---|
| 42 | for building the SQL statement. |
|---|
| 43 | |
|---|
| 44 | :param param: object to determine the table from |
|---|
| 45 | :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`, |
|---|
| 46 | :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`, |
|---|
| 47 | or string (table name) |
|---|
| 48 | """ |
|---|
| 49 | table = self._to_table(param) |
|---|
| 50 | table_name = self._to_table_name(table) |
|---|
| 51 | self.append('\nALTER TABLE %s ' % \ |
|---|
| 52 | self._do_quote_table_identifier(table_name)) |
|---|
| 53 | return table |
|---|
| 54 | |
|---|
| 55 | def _pk_constraint(self, table, column, status): |
|---|
| 56 | """Create a primary key constraint from a table, column. |
|---|
| 57 | |
|---|
| 58 | Status: true if the constraint is being added; false if being dropped |
|---|
| 59 | """ |
|---|
| 60 | if isinstance(column, basestring): |
|---|
| 61 | column = getattr(table.c, name) |
|---|
| 62 | |
|---|
| 63 | ret = constraint.PrimaryKeyConstraint(*table.primary_key) |
|---|
| 64 | if status: |
|---|
| 65 | # Created PK |
|---|
| 66 | ret.c.append(column) |
|---|
| 67 | else: |
|---|
| 68 | # Dropped PK |
|---|
| 69 | names = [c.name for c in cons.c] |
|---|
| 70 | index = names.index(col.name) |
|---|
| 71 | del ret.c[index] |
|---|
| 72 | |
|---|
| 73 | # Allow explicit PK name assignment |
|---|
| 74 | if isinstance(pk, basestring): |
|---|
| 75 | ret.name = pk |
|---|
| 76 | return ret |
|---|
| 77 | |
|---|
| 78 | |
|---|
| 79 | class AlterTableVisitor(SchemaIterator, RawAlterTableVisitor): |
|---|
| 80 | """Common operations for ``ALTER TABLE`` statements""" |
|---|
| 81 | pass |
|---|
| 82 | |
|---|
| 83 | |
|---|
| 84 | class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator): |
|---|
| 85 | """Extends ansisql generator for column creation (alter table add col)""" |
|---|
| 86 | |
|---|
| 87 | def visit_column(self, column): |
|---|
| 88 | """Create a column (table already exists). |
|---|
| 89 | |
|---|
| 90 | :param column: column object |
|---|
| 91 | :type column: :class:`sqlalchemy.Column` |
|---|
| 92 | """ |
|---|
| 93 | table = self.start_alter_table(column) |
|---|
| 94 | self.append(" ADD ") |
|---|
| 95 | colspec = self.get_column_specification(column) |
|---|
| 96 | self.append(colspec) |
|---|
| 97 | self.execute() |
|---|
| 98 | |
|---|
| 99 | def visit_table(self, table): |
|---|
| 100 | """Default table visitor, does nothing. |
|---|
| 101 | |
|---|
| 102 | :param table: table object |
|---|
| 103 | :type table: :class:`sqlalchemy.Table` |
|---|
| 104 | """ |
|---|
| 105 | pass |
|---|
| 106 | |
|---|
| 107 | |
|---|
| 108 | class ANSIColumnDropper(AlterTableVisitor): |
|---|
| 109 | """Extends ANSI SQL dropper for column dropping (``ALTER TABLE |
|---|
| 110 | DROP COLUMN``).""" |
|---|
| 111 | |
|---|
| 112 | def visit_column(self, column): |
|---|
| 113 | """Drop a column from its table. |
|---|
| 114 | |
|---|
| 115 | :param column: the column object |
|---|
| 116 | :type column: :class:`sqlalchemy.Column` |
|---|
| 117 | """ |
|---|
| 118 | table = self.start_alter_table(column) |
|---|
| 119 | self.append(' DROP COLUMN %s' % \ |
|---|
| 120 | self._do_quote_column_identifier(column.name)) |
|---|
| 121 | self.execute() |
|---|
| 122 | |
|---|
| 123 | |
|---|
| 124 | class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator): |
|---|
| 125 | """Manages changes to existing schema elements. |
|---|
| 126 | |
|---|
| 127 | Note that columns are schema elements; ``ALTER TABLE ADD COLUMN`` |
|---|
| 128 | is in SchemaGenerator. |
|---|
| 129 | |
|---|
| 130 | All items may be renamed. Columns can also have many of their properties - |
|---|
| 131 | type, for example - changed. |
|---|
| 132 | |
|---|
| 133 | Each function is passed a tuple, containing (object,name); where |
|---|
| 134 | object is a type of object you'd expect for that function |
|---|
| 135 | (ie. table for visit_table) and name is the object's new |
|---|
| 136 | name. NONE means the name is unchanged. |
|---|
| 137 | """ |
|---|
| 138 | |
|---|
| 139 | def _do_quote_column_identifier(self, identifier): |
|---|
| 140 | """override this function to define how identifiers (table and |
|---|
| 141 | column names) should be written in the SQL. For instance, in |
|---|
| 142 | PostgreSQL, double quotes should surround the identifier |
|---|
| 143 | """ |
|---|
| 144 | return identifier |
|---|
| 145 | |
|---|
| 146 | def visit_table(self, param): |
|---|
| 147 | """Rename a table. Other ops aren't supported.""" |
|---|
| 148 | table, newname = param |
|---|
| 149 | self.start_alter_table(table) |
|---|
| 150 | self.append("RENAME TO %s"%newname) |
|---|
| 151 | self.execute() |
|---|
| 152 | |
|---|
| 153 | def visit_column(self, delta): |
|---|
| 154 | """Rename/change a column.""" |
|---|
| 155 | # ALTER COLUMN is implemented as several ALTER statements |
|---|
| 156 | keys = delta.keys() |
|---|
| 157 | if 'type' in keys: |
|---|
| 158 | self._run_subvisit(delta, self._visit_column_type) |
|---|
| 159 | if 'nullable' in keys: |
|---|
| 160 | self._run_subvisit(delta, self._visit_column_nullable) |
|---|
| 161 | if 'server_default' in keys: |
|---|
| 162 | # Skip 'default': only handle server-side defaults, others |
|---|
| 163 | # are managed by the app, not the db. |
|---|
| 164 | self._run_subvisit(delta, self._visit_column_default) |
|---|
| 165 | if 'name' in keys: |
|---|
| 166 | self._run_subvisit(delta, self._visit_column_name) |
|---|
| 167 | |
|---|
| 168 | def _run_subvisit(self, delta, func, col_name=None, table_name=None): |
|---|
| 169 | if table_name is None: |
|---|
| 170 | table_name = self._to_table(delta.table) |
|---|
| 171 | if col_name is None: |
|---|
| 172 | col_name = delta.current_name |
|---|
| 173 | ret = func(table_name, col_name, delta) |
|---|
| 174 | self.execute() |
|---|
| 175 | return ret |
|---|
| 176 | |
|---|
| 177 | def _visit_column_foreign_key(self, delta): |
|---|
| 178 | table = delta.table |
|---|
| 179 | column = getattr(table.c, delta.current_name) |
|---|
| 180 | cons = constraint.ForeignKeyConstraint(column, autoload=True) |
|---|
| 181 | fk = delta['foreign_key'] |
|---|
| 182 | if fk: |
|---|
| 183 | # For now, cons.columns is limited to one column: |
|---|
| 184 | # no multicolumn FKs |
|---|
| 185 | column.foreign_key = ForeignKey(*cons.columns) |
|---|
| 186 | else: |
|---|
| 187 | column_foreign_key = None |
|---|
| 188 | cons.drop() |
|---|
| 189 | cons.create() |
|---|
| 190 | |
|---|
| 191 | def _visit_column_primary_key(self, delta): |
|---|
| 192 | table = delta.table |
|---|
| 193 | col = getattr(table.c, delta.current_name) |
|---|
| 194 | pk = delta['primary_key'] |
|---|
| 195 | cons = self._pk_constraint(table, col, pk) |
|---|
| 196 | cons.drop() |
|---|
| 197 | cons.create() |
|---|
| 198 | |
|---|
| 199 | def _visit_column_nullable(self, table_name, col_name, delta): |
|---|
| 200 | nullable = delta['nullable'] |
|---|
| 201 | table = self._to_table(delta) |
|---|
| 202 | self.start_alter_table(table_name) |
|---|
| 203 | self.append("ALTER COLUMN %s " % \ |
|---|
| 204 | self._do_quote_column_identifier(col_name)) |
|---|
| 205 | if nullable: |
|---|
| 206 | self.append("DROP NOT NULL") |
|---|
| 207 | else: |
|---|
| 208 | self.append("SET NOT NULL") |
|---|
| 209 | |
|---|
| 210 | def _visit_column_default(self, table_name, col_name, delta): |
|---|
| 211 | server_default = delta['server_default'] |
|---|
| 212 | # Dummy column: get_col_default_string needs a column for some |
|---|
| 213 | # reason |
|---|
| 214 | dummy = sa.Column(None, None, server_default=server_default) |
|---|
| 215 | default_text = self.get_column_default_string(dummy) |
|---|
| 216 | self.start_alter_table(table_name) |
|---|
| 217 | self.append("ALTER COLUMN %s " % \ |
|---|
| 218 | self._do_quote_column_identifier(col_name)) |
|---|
| 219 | if default_text is not None: |
|---|
| 220 | self.append("SET DEFAULT %s"%default_text) |
|---|
| 221 | else: |
|---|
| 222 | self.append("DROP DEFAULT") |
|---|
| 223 | |
|---|
| 224 | def _visit_column_type(self, table_name, col_name, delta): |
|---|
| 225 | type = delta['type'] |
|---|
| 226 | if not isinstance(type, sa.types.AbstractType): |
|---|
| 227 | # It's the class itself, not an instance... make an |
|---|
| 228 | # instance |
|---|
| 229 | type = type() |
|---|
| 230 | type_text = type.dialect_impl(self.dialect).get_col_spec() |
|---|
| 231 | self.start_alter_table(table_name) |
|---|
| 232 | self.append("ALTER COLUMN %s TYPE %s" % \ |
|---|
| 233 | (self._do_quote_column_identifier(col_name), |
|---|
| 234 | type_text)) |
|---|
| 235 | |
|---|
| 236 | def _visit_column_name(self, table_name, col_name, delta): |
|---|
| 237 | new_name = delta['name'] |
|---|
| 238 | self.start_alter_table(table_name) |
|---|
| 239 | self.append('RENAME COLUMN %s TO %s' % \ |
|---|
| 240 | (self._do_quote_column_identifier(col_name), |
|---|
| 241 | self._do_quote_column_identifier(new_name))) |
|---|
| 242 | |
|---|
| 243 | def visit_index(self, param): |
|---|
| 244 | """Rename an index; #36""" |
|---|
| 245 | index, newname = param |
|---|
| 246 | self.append("ALTER INDEX %s RENAME TO %s" % (index.name, newname)) |
|---|
| 247 | self.execute() |
|---|
| 248 | |
|---|
| 249 | |
|---|
| 250 | class ANSIConstraintCommon(AlterTableVisitor): |
|---|
| 251 | """ |
|---|
| 252 | Migrate's constraints require a separate creation function from |
|---|
| 253 | SA's: Migrate's constraints are created independently of a table; |
|---|
| 254 | SA's are created at the same time as the table. |
|---|
| 255 | """ |
|---|
| 256 | |
|---|
| 257 | def get_constraint_name(self, cons): |
|---|
| 258 | """Gets a name for the given constraint. |
|---|
| 259 | |
|---|
| 260 | If the name is already set it will be used otherwise the |
|---|
| 261 | constraint's :meth:`autoname |
|---|
| 262 | <migrate.changeset.constraint.ConstraintChangeset.autoname>` |
|---|
| 263 | method is used. |
|---|
| 264 | |
|---|
| 265 | :param cons: constraint object |
|---|
| 266 | :type cons: :class:`migrate.changeset.constraint.ConstraintChangeset` |
|---|
| 267 | """ |
|---|
| 268 | if cons.name is not None: |
|---|
| 269 | ret = cons.name |
|---|
| 270 | else: |
|---|
| 271 | ret = cons.name = cons.autoname() |
|---|
| 272 | return ret |
|---|
| 273 | |
|---|
| 274 | |
|---|
| 275 | class ANSIConstraintGenerator(ANSIConstraintCommon): |
|---|
| 276 | |
|---|
| 277 | def get_constraint_specification(self, cons, **kwargs): |
|---|
| 278 | if isinstance(cons, constraint.PrimaryKeyConstraint): |
|---|
| 279 | col_names = ','.join([i.name for i in cons.columns]) |
|---|
| 280 | ret = "PRIMARY KEY (%s)" % col_names |
|---|
| 281 | if cons.name: |
|---|
| 282 | # Named constraint |
|---|
| 283 | ret = ("CONSTRAINT %s " % cons.name)+ret |
|---|
| 284 | elif isinstance(cons, constraint.ForeignKeyConstraint): |
|---|
| 285 | params = dict( |
|---|
| 286 | columns=','.join([c.name for c in cons.columns]), |
|---|
| 287 | reftable=cons.reftable, |
|---|
| 288 | referenced=','.join([c.name for c in cons.referenced]), |
|---|
| 289 | name=self.get_constraint_name(cons), |
|---|
| 290 | ) |
|---|
| 291 | ret = "CONSTRAINT %(name)s FOREIGN KEY (%(columns)s) "\ |
|---|
| 292 | "REFERENCES %(reftable)s (%(referenced)s)" % params |
|---|
| 293 | if cons.onupdate: |
|---|
| 294 | ret = ret + " ON UPDATE %s" % cons.onupdate |
|---|
| 295 | if cons.ondelete: |
|---|
| 296 | ret = ret + " ON DELETE %s" % cons.ondelete |
|---|
| 297 | elif isinstance(cons, constraint.CheckConstraint): |
|---|
| 298 | ret = "CHECK (%s)" % cons.sqltext |
|---|
| 299 | else: |
|---|
| 300 | raise exceptions.InvalidConstraintError(cons) |
|---|
| 301 | return ret |
|---|
| 302 | |
|---|
| 303 | def _visit_constraint(self, constraint): |
|---|
| 304 | table = self.start_alter_table(constraint) |
|---|
| 305 | self.append("ADD ") |
|---|
| 306 | spec = self.get_constraint_specification(constraint) |
|---|
| 307 | self.append(spec) |
|---|
| 308 | self.execute() |
|---|
| 309 | |
|---|
| 310 | def visit_migrate_primary_key_constraint(self, *p, **k): |
|---|
| 311 | return self._visit_constraint(*p, **k) |
|---|
| 312 | |
|---|
| 313 | def visit_migrate_foreign_key_constraint(self, *p, **k): |
|---|
| 314 | return self._visit_constraint(*p, **k) |
|---|
| 315 | |
|---|
| 316 | def visit_migrate_check_constraint(self, *p, **k): |
|---|
| 317 | return self._visit_constraint(*p, **k) |
|---|
| 318 | |
|---|
| 319 | |
|---|
| 320 | class ANSIConstraintDropper(ANSIConstraintCommon): |
|---|
| 321 | |
|---|
| 322 | def _visit_constraint(self, constraint): |
|---|
| 323 | self.start_alter_table(constraint) |
|---|
| 324 | self.append("DROP CONSTRAINT ") |
|---|
| 325 | self.append(self.get_constraint_name(constraint)) |
|---|
| 326 | self.execute() |
|---|
| 327 | |
|---|
| 328 | def visit_migrate_primary_key_constraint(self, *p, **k): |
|---|
| 329 | return self._visit_constraint(*p, **k) |
|---|
| 330 | |
|---|
| 331 | def visit_migrate_foreign_key_constraint(self, *p, **k): |
|---|
| 332 | return self._visit_constraint(*p, **k) |
|---|
| 333 | |
|---|
| 334 | def visit_migrate_check_constraint(self, *p, **k): |
|---|
| 335 | return self._visit_constraint(*p, **k) |
|---|
| 336 | |
|---|
| 337 | |
|---|
| 338 | class ANSIFKGenerator(AlterTableVisitor, SchemaGenerator): |
|---|
| 339 | """Extends ansisql generator for column creation (alter table add col)""" |
|---|
| 340 | |
|---|
| 341 | def __init__(self, *args, **kwargs): |
|---|
| 342 | self.fk = kwargs.get('fk', None) |
|---|
| 343 | if self.fk: |
|---|
| 344 | del kwargs['fk'] |
|---|
| 345 | super(ANSIFKGenerator, self).__init__(*args, **kwargs) |
|---|
| 346 | |
|---|
| 347 | def visit_column(self, column): |
|---|
| 348 | """Create foreign keys for a column (table already exists); #32""" |
|---|
| 349 | |
|---|
| 350 | if self.fk: |
|---|
| 351 | self.add_foreignkey(self.fk.constraint) |
|---|
| 352 | |
|---|
| 353 | if self.buffer.getvalue() !='': |
|---|
| 354 | self.execute() |
|---|
| 355 | |
|---|
| 356 | def visit_table(self, table): |
|---|
| 357 | pass |
|---|
| 358 | |
|---|
| 359 | |
|---|
| 360 | class ANSIDialect(object): |
|---|
| 361 | columngenerator = ANSIColumnGenerator |
|---|
| 362 | columndropper = ANSIColumnDropper |
|---|
| 363 | schemachanger = ANSISchemaChanger |
|---|
| 364 | columnfkgenerator = ANSIFKGenerator |
|---|
| 365 | |
|---|
| 366 | @classmethod |
|---|
| 367 | def visitor(self, name): |
|---|
| 368 | return getattr(self, name) |
|---|
| 369 | |
|---|
| 370 | def reflectconstraints(self, connection, table_name): |
|---|
| 371 | raise NotImplementedError() |
|---|