| 1 | """ |
|---|
| 2 | Code to generate a Python model from a database or differences |
|---|
| 3 | between a model and database. |
|---|
| 4 | |
|---|
| 5 | Some of this is borrowed heavily from the AutoCode project at: |
|---|
| 6 | http://code.google.com/p/sqlautocode/ |
|---|
| 7 | """ |
|---|
| 8 | |
|---|
| 9 | import sys |
|---|
| 10 | |
|---|
| 11 | import migrate |
|---|
| 12 | import sqlalchemy |
|---|
| 13 | |
|---|
| 14 | |
|---|
| 15 | HEADER = """ |
|---|
| 16 | ## File autogenerated by genmodel.py |
|---|
| 17 | |
|---|
| 18 | from sqlalchemy import * |
|---|
| 19 | meta = MetaData() |
|---|
| 20 | """ |
|---|
| 21 | |
|---|
| 22 | DECLARATIVE_HEADER = """ |
|---|
| 23 | ## File autogenerated by genmodel.py |
|---|
| 24 | |
|---|
| 25 | from sqlalchemy import * |
|---|
| 26 | from sqlalchemy.ext import declarative |
|---|
| 27 | |
|---|
| 28 | Base = declarative.declarative_base() |
|---|
| 29 | """ |
|---|
| 30 | |
|---|
| 31 | |
|---|
| 32 | class ModelGenerator(object): |
|---|
| 33 | |
|---|
| 34 | def __init__(self, diff, declarative=False): |
|---|
| 35 | self.diff = diff |
|---|
| 36 | self.declarative = declarative |
|---|
| 37 | # is there an easier way to get this? |
|---|
| 38 | dialectModule = sys.modules[self.diff.conn.dialect.__module__] |
|---|
| 39 | self.colTypeMappings = dict((v, k) for k, v in \ |
|---|
| 40 | dialectModule.colspecs.items()) |
|---|
| 41 | |
|---|
| 42 | def column_repr(self, col): |
|---|
| 43 | kwarg = [] |
|---|
| 44 | if col.key != col.name: |
|---|
| 45 | kwarg.append('key') |
|---|
| 46 | if col.primary_key: |
|---|
| 47 | col.primary_key = True # otherwise it dumps it as 1 |
|---|
| 48 | kwarg.append('primary_key') |
|---|
| 49 | if not col.nullable: |
|---|
| 50 | kwarg.append('nullable') |
|---|
| 51 | if col.onupdate: |
|---|
| 52 | kwarg.append('onupdate') |
|---|
| 53 | if col.default: |
|---|
| 54 | if col.primary_key: |
|---|
| 55 | # I found that PostgreSQL automatically creates a |
|---|
| 56 | # default value for the sequence, but let's not show |
|---|
| 57 | # that. |
|---|
| 58 | pass |
|---|
| 59 | else: |
|---|
| 60 | kwarg.append('default') |
|---|
| 61 | ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg) |
|---|
| 62 | |
|---|
| 63 | # crs: not sure if this is good idea, but it gets rid of extra |
|---|
| 64 | # u'' |
|---|
| 65 | name = col.name.encode('utf8') |
|---|
| 66 | type = self.colTypeMappings.get(col.type.__class__, None) |
|---|
| 67 | if type: |
|---|
| 68 | # Make the column type be an instance of this type. |
|---|
| 69 | type = type() |
|---|
| 70 | else: |
|---|
| 71 | # We must already be a model type, no need to map from the |
|---|
| 72 | # database-specific types. |
|---|
| 73 | type = col.type |
|---|
| 74 | |
|---|
| 75 | data = { |
|---|
| 76 | 'name': name, |
|---|
| 77 | 'type': type, |
|---|
| 78 | 'constraints': ', '.join([repr(cn) for cn in col.constraints]), |
|---|
| 79 | 'args': ks and ks or ''} |
|---|
| 80 | |
|---|
| 81 | if data['constraints']: |
|---|
| 82 | if data['args']: |
|---|
| 83 | data['args'] = ',' + data['args'] |
|---|
| 84 | |
|---|
| 85 | if data['constraints'] or data['args']: |
|---|
| 86 | data['maybeComma'] = ',' |
|---|
| 87 | else: |
|---|
| 88 | data['maybeComma'] = '' |
|---|
| 89 | |
|---|
| 90 | commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data |
|---|
| 91 | commonStuff = commonStuff.strip() |
|---|
| 92 | data['commonStuff'] = commonStuff |
|---|
| 93 | if self.declarative: |
|---|
| 94 | return """%(name)s = Column(%(type)r%(commonStuff)s""" % data |
|---|
| 95 | else: |
|---|
| 96 | return """Column(%(name)r, %(type)r%(commonStuff)s""" % data |
|---|
| 97 | |
|---|
| 98 | def getTableDefn(self, table): |
|---|
| 99 | out = [] |
|---|
| 100 | tableName = table.name |
|---|
| 101 | if self.declarative: |
|---|
| 102 | out.append("class %(table)s(Base):" % {'table': tableName}) |
|---|
| 103 | out.append(" __tablename__ = '%(table)s'" % {'table': tableName}) |
|---|
| 104 | for col in table.columns: |
|---|
| 105 | out.append(" %s" % self.column_repr(col)) |
|---|
| 106 | else: |
|---|
| 107 | out.append("%(table)s = Table('%(table)s', meta," % \ |
|---|
| 108 | {'table': tableName}) |
|---|
| 109 | for col in table.columns: |
|---|
| 110 | out.append(" %s," % self.column_repr(col)) |
|---|
| 111 | out.append(")") |
|---|
| 112 | return out |
|---|
| 113 | |
|---|
| 114 | def toPython(self): |
|---|
| 115 | """Assume database is current and model is empty.""" |
|---|
| 116 | out = [] |
|---|
| 117 | if self.declarative: |
|---|
| 118 | out.append(DECLARATIVE_HEADER) |
|---|
| 119 | else: |
|---|
| 120 | out.append(HEADER) |
|---|
| 121 | out.append("") |
|---|
| 122 | for table in self.diff.tablesMissingInModel: |
|---|
| 123 | out.extend(self.getTableDefn(table)) |
|---|
| 124 | out.append("") |
|---|
| 125 | return '\n'.join(out) |
|---|
| 126 | |
|---|
| 127 | def toUpgradeDowngradePython(self, indent=' '): |
|---|
| 128 | ''' Assume model is most current and database is out-of-date. ''' |
|---|
| 129 | |
|---|
| 130 | decls = ['meta = MetaData(migrate_engine)'] |
|---|
| 131 | for table in self.diff.tablesMissingInModel + \ |
|---|
| 132 | self.diff.tablesMissingInDatabase: |
|---|
| 133 | decls.extend(self.getTableDefn(table)) |
|---|
| 134 | |
|---|
| 135 | upgradeCommands, downgradeCommands = [], [] |
|---|
| 136 | for table in self.diff.tablesMissingInModel: |
|---|
| 137 | tableName = table.name |
|---|
| 138 | upgradeCommands.append("%(table)s.drop()" % {'table': tableName}) |
|---|
| 139 | downgradeCommands.append("%(table)s.create()" % \ |
|---|
| 140 | {'table': tableName}) |
|---|
| 141 | for table in self.diff.tablesMissingInDatabase: |
|---|
| 142 | tableName = table.name |
|---|
| 143 | upgradeCommands.append("%(table)s.create()" % {'table': tableName}) |
|---|
| 144 | downgradeCommands.append("%(table)s.drop()" % {'table': tableName}) |
|---|
| 145 | |
|---|
| 146 | return ( |
|---|
| 147 | '\n'.join(decls), |
|---|
| 148 | '\n'.join(['%s%s' % (indent, line) for line in upgradeCommands]), |
|---|
| 149 | '\n'.join(['%s%s' % (indent, line) for line in downgradeCommands])) |
|---|
| 150 | |
|---|
| 151 | def applyModel(self): |
|---|
| 152 | """Apply model to current database.""" |
|---|
| 153 | # Yuck! We have to import from changeset to apply the |
|---|
| 154 | # monkey-patch to allow column adding/dropping. |
|---|
| 155 | from migrate.changeset import schema |
|---|
| 156 | |
|---|
| 157 | def dbCanHandleThisChange(missingInDatabase, missingInModel, diffDecl): |
|---|
| 158 | if missingInDatabase and not missingInModel and not diffDecl: |
|---|
| 159 | # Even sqlite can handle this. |
|---|
| 160 | return True |
|---|
| 161 | else: |
|---|
| 162 | return not self.diff.conn.url.drivername.startswith('sqlite') |
|---|
| 163 | |
|---|
| 164 | meta = sqlalchemy.MetaData(self.diff.conn.engine) |
|---|
| 165 | |
|---|
| 166 | for table in self.diff.tablesMissingInModel: |
|---|
| 167 | table = table.tometadata(meta) |
|---|
| 168 | table.drop() |
|---|
| 169 | for table in self.diff.tablesMissingInDatabase: |
|---|
| 170 | table = table.tometadata(meta) |
|---|
| 171 | table.create() |
|---|
| 172 | for modelTable in self.diff.tablesWithDiff: |
|---|
| 173 | modelTable = modelTable.tometadata(meta) |
|---|
| 174 | dbTable = self.diff.reflected_model.tables[modelTable.name] |
|---|
| 175 | tableName = modelTable.name |
|---|
| 176 | missingInDatabase, missingInModel, diffDecl = \ |
|---|
| 177 | self.diff.colDiffs[tableName] |
|---|
| 178 | if dbCanHandleThisChange(missingInDatabase, missingInModel, |
|---|
| 179 | diffDecl): |
|---|
| 180 | for col in missingInDatabase: |
|---|
| 181 | modelTable.columns[col.name].create() |
|---|
| 182 | for col in missingInModel: |
|---|
| 183 | dbTable.columns[col.name].drop() |
|---|
| 184 | for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl: |
|---|
| 185 | databaseCol.alter(modelCol) |
|---|
| 186 | else: |
|---|
| 187 | # Sqlite doesn't support drop column, so you have to |
|---|
| 188 | # do more: create temp table, copy data to it, drop |
|---|
| 189 | # old table, create new table, copy data back. |
|---|
| 190 | # |
|---|
| 191 | # I wonder if this is guaranteed to be unique? |
|---|
| 192 | tempName = '_temp_%s' % modelTable.name |
|---|
| 193 | |
|---|
| 194 | def getCopyStatement(): |
|---|
| 195 | preparer = self.diff.conn.engine.dialect.preparer |
|---|
| 196 | commonCols = [] |
|---|
| 197 | for modelCol in modelTable.columns: |
|---|
| 198 | if modelCol.name in dbTable.columns: |
|---|
| 199 | commonCols.append(modelCol.name) |
|---|
| 200 | commonColsStr = ', '.join(commonCols) |
|---|
| 201 | return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \ |
|---|
| 202 | (tableName, commonColsStr, commonColsStr, tempName) |
|---|
| 203 | |
|---|
| 204 | # Move the data in one transaction, so that we don't |
|---|
| 205 | # leave the database in a nasty state. |
|---|
| 206 | connection = self.diff.conn.connect() |
|---|
| 207 | trans = connection.begin() |
|---|
| 208 | try: |
|---|
| 209 | connection.execute( |
|---|
| 210 | 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \ |
|---|
| 211 | (tempName, modelTable.name)) |
|---|
| 212 | # make sure the drop takes place inside our |
|---|
| 213 | # transaction with the bind parameter |
|---|
| 214 | modelTable.drop(bind=connection) |
|---|
| 215 | modelTable.create(bind=connection) |
|---|
| 216 | connection.execute(getCopyStatement()) |
|---|
| 217 | connection.execute('DROP TABLE %s' % tempName) |
|---|
| 218 | trans.commit() |
|---|
| 219 | except: |
|---|
| 220 | trans.rollback() |
|---|
| 221 | raise |
|---|