| 1 | # access.py |
|---|
| 2 | # Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk |
|---|
| 3 | # Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com |
|---|
| 4 | # |
|---|
| 5 | # This module is part of SQLAlchemy and is released under |
|---|
| 6 | # the MIT License: http://www.opensource.org/licenses/mit-license.php |
|---|
| 7 | |
|---|
| 8 | from sqlalchemy import sql, schema, types, exc, pool |
|---|
| 9 | from sqlalchemy.sql import compiler, expression |
|---|
| 10 | from sqlalchemy.engine import default, base |
|---|
| 11 | |
|---|
| 12 | |
|---|
| 13 | class AcNumeric(types.Numeric): |
|---|
| 14 | def result_processor(self, dialect): |
|---|
| 15 | return None |
|---|
| 16 | |
|---|
| 17 | def bind_processor(self, dialect): |
|---|
| 18 | def process(value): |
|---|
| 19 | if value is None: |
|---|
| 20 | # Not sure that this exception is needed |
|---|
| 21 | return value |
|---|
| 22 | else: |
|---|
| 23 | return str(value) |
|---|
| 24 | return process |
|---|
| 25 | |
|---|
| 26 | def get_col_spec(self): |
|---|
| 27 | return "NUMERIC" |
|---|
| 28 | |
|---|
| 29 | class AcFloat(types.Float): |
|---|
| 30 | def get_col_spec(self): |
|---|
| 31 | return "FLOAT" |
|---|
| 32 | |
|---|
| 33 | def bind_processor(self, dialect): |
|---|
| 34 | """By converting to string, we can use Decimal types round-trip.""" |
|---|
| 35 | def process(value): |
|---|
| 36 | if not value is None: |
|---|
| 37 | return str(value) |
|---|
| 38 | return None |
|---|
| 39 | return process |
|---|
| 40 | |
|---|
| 41 | class AcInteger(types.Integer): |
|---|
| 42 | def get_col_spec(self): |
|---|
| 43 | return "INTEGER" |
|---|
| 44 | |
|---|
| 45 | class AcTinyInteger(types.Integer): |
|---|
| 46 | def get_col_spec(self): |
|---|
| 47 | return "TINYINT" |
|---|
| 48 | |
|---|
| 49 | class AcSmallInteger(types.Smallinteger): |
|---|
| 50 | def get_col_spec(self): |
|---|
| 51 | return "SMALLINT" |
|---|
| 52 | |
|---|
| 53 | class AcDateTime(types.DateTime): |
|---|
| 54 | def __init__(self, *a, **kw): |
|---|
| 55 | super(AcDateTime, self).__init__(False) |
|---|
| 56 | |
|---|
| 57 | def get_col_spec(self): |
|---|
| 58 | return "DATETIME" |
|---|
| 59 | |
|---|
| 60 | class AcDate(types.Date): |
|---|
| 61 | def __init__(self, *a, **kw): |
|---|
| 62 | super(AcDate, self).__init__(False) |
|---|
| 63 | |
|---|
| 64 | def get_col_spec(self): |
|---|
| 65 | return "DATETIME" |
|---|
| 66 | |
|---|
| 67 | class AcText(types.Text): |
|---|
| 68 | def get_col_spec(self): |
|---|
| 69 | return "MEMO" |
|---|
| 70 | |
|---|
| 71 | class AcString(types.String): |
|---|
| 72 | def get_col_spec(self): |
|---|
| 73 | return "TEXT" + (self.length and ("(%d)" % self.length) or "") |
|---|
| 74 | |
|---|
| 75 | class AcUnicode(types.Unicode): |
|---|
| 76 | def get_col_spec(self): |
|---|
| 77 | return "TEXT" + (self.length and ("(%d)" % self.length) or "") |
|---|
| 78 | |
|---|
| 79 | def bind_processor(self, dialect): |
|---|
| 80 | return None |
|---|
| 81 | |
|---|
| 82 | def result_processor(self, dialect): |
|---|
| 83 | return None |
|---|
| 84 | |
|---|
| 85 | class AcChar(types.CHAR): |
|---|
| 86 | def get_col_spec(self): |
|---|
| 87 | return "TEXT" + (self.length and ("(%d)" % self.length) or "") |
|---|
| 88 | |
|---|
| 89 | class AcBinary(types.Binary): |
|---|
| 90 | def get_col_spec(self): |
|---|
| 91 | return "BINARY" |
|---|
| 92 | |
|---|
| 93 | class AcBoolean(types.Boolean): |
|---|
| 94 | def get_col_spec(self): |
|---|
| 95 | return "YESNO" |
|---|
| 96 | |
|---|
| 97 | def result_processor(self, dialect): |
|---|
| 98 | def process(value): |
|---|
| 99 | if value is None: |
|---|
| 100 | return None |
|---|
| 101 | return value and True or False |
|---|
| 102 | return process |
|---|
| 103 | |
|---|
| 104 | def bind_processor(self, dialect): |
|---|
| 105 | def process(value): |
|---|
| 106 | if value is True: |
|---|
| 107 | return 1 |
|---|
| 108 | elif value is False: |
|---|
| 109 | return 0 |
|---|
| 110 | elif value is None: |
|---|
| 111 | return None |
|---|
| 112 | else: |
|---|
| 113 | return value and True or False |
|---|
| 114 | return process |
|---|
| 115 | |
|---|
| 116 | class AcTimeStamp(types.TIMESTAMP): |
|---|
| 117 | def get_col_spec(self): |
|---|
| 118 | return "TIMESTAMP" |
|---|
| 119 | |
|---|
| 120 | class AccessExecutionContext(default.DefaultExecutionContext): |
|---|
| 121 | def _has_implicit_sequence(self, column): |
|---|
| 122 | if column.primary_key and column.autoincrement: |
|---|
| 123 | if isinstance(column.type, types.Integer) and not column.foreign_keys: |
|---|
| 124 | if column.default is None or (isinstance(column.default, schema.Sequence) and \ |
|---|
| 125 | column.default.optional): |
|---|
| 126 | return True |
|---|
| 127 | return False |
|---|
| 128 | |
|---|
| 129 | def post_exec(self): |
|---|
| 130 | """If we inserted into a row with a COUNTER column, fetch the ID""" |
|---|
| 131 | |
|---|
| 132 | if self.compiled.isinsert: |
|---|
| 133 | tbl = self.compiled.statement.table |
|---|
| 134 | if not hasattr(tbl, 'has_sequence'): |
|---|
| 135 | tbl.has_sequence = None |
|---|
| 136 | for column in tbl.c: |
|---|
| 137 | if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): |
|---|
| 138 | tbl.has_sequence = column |
|---|
| 139 | break |
|---|
| 140 | |
|---|
| 141 | if bool(tbl.has_sequence): |
|---|
| 142 | # TBD: for some reason _last_inserted_ids doesn't exist here |
|---|
| 143 | # (but it does at corresponding point in mssql???) |
|---|
| 144 | #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: |
|---|
| 145 | self.cursor.execute("SELECT @@identity AS lastrowid") |
|---|
| 146 | row = self.cursor.fetchone() |
|---|
| 147 | self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:] |
|---|
| 148 | # print "LAST ROW ID", self._last_inserted_ids |
|---|
| 149 | |
|---|
| 150 | super(AccessExecutionContext, self).post_exec() |
|---|
| 151 | |
|---|
| 152 | |
|---|
| 153 | const, daoEngine = None, None |
|---|
| 154 | class AccessDialect(default.DefaultDialect): |
|---|
| 155 | colspecs = { |
|---|
| 156 | types.Unicode : AcUnicode, |
|---|
| 157 | types.Integer : AcInteger, |
|---|
| 158 | types.Smallinteger: AcSmallInteger, |
|---|
| 159 | types.Numeric : AcNumeric, |
|---|
| 160 | types.Float : AcFloat, |
|---|
| 161 | types.DateTime : AcDateTime, |
|---|
| 162 | types.Date : AcDate, |
|---|
| 163 | types.String : AcString, |
|---|
| 164 | types.Binary : AcBinary, |
|---|
| 165 | types.Boolean : AcBoolean, |
|---|
| 166 | types.Text : AcText, |
|---|
| 167 | types.CHAR: AcChar, |
|---|
| 168 | types.TIMESTAMP: AcTimeStamp, |
|---|
| 169 | } |
|---|
| 170 | name = 'access' |
|---|
| 171 | supports_sane_rowcount = False |
|---|
| 172 | supports_sane_multi_rowcount = False |
|---|
| 173 | |
|---|
| 174 | def type_descriptor(self, typeobj): |
|---|
| 175 | newobj = types.adapt_type(typeobj, self.colspecs) |
|---|
| 176 | return newobj |
|---|
| 177 | |
|---|
| 178 | def __init__(self, **params): |
|---|
| 179 | super(AccessDialect, self).__init__(**params) |
|---|
| 180 | self.text_as_varchar = False |
|---|
| 181 | self._dtbs = None |
|---|
| 182 | |
|---|
| 183 | def dbapi(cls): |
|---|
| 184 | import win32com.client, pythoncom |
|---|
| 185 | |
|---|
| 186 | global const, daoEngine |
|---|
| 187 | if const is None: |
|---|
| 188 | const = win32com.client.constants |
|---|
| 189 | for suffix in (".36", ".35", ".30"): |
|---|
| 190 | try: |
|---|
| 191 | daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix) |
|---|
| 192 | break |
|---|
| 193 | except pythoncom.com_error: |
|---|
| 194 | pass |
|---|
| 195 | else: |
|---|
| 196 | raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") |
|---|
| 197 | |
|---|
| 198 | import pyodbc as module |
|---|
| 199 | return module |
|---|
| 200 | dbapi = classmethod(dbapi) |
|---|
| 201 | |
|---|
| 202 | def create_connect_args(self, url): |
|---|
| 203 | opts = url.translate_connect_args() |
|---|
| 204 | connectors = ["Driver={Microsoft Access Driver (*.mdb)}"] |
|---|
| 205 | connectors.append("Dbq=%s" % opts["database"]) |
|---|
| 206 | user = opts.get("username", None) |
|---|
| 207 | if user: |
|---|
| 208 | connectors.append("UID=%s" % user) |
|---|
| 209 | connectors.append("PWD=%s" % opts.get("password", "")) |
|---|
| 210 | return [[";".join(connectors)], {}] |
|---|
| 211 | |
|---|
| 212 | def last_inserted_ids(self): |
|---|
| 213 | return self.context.last_inserted_ids |
|---|
| 214 | |
|---|
| 215 | def do_execute(self, cursor, statement, params, **kwargs): |
|---|
| 216 | if params == {}: |
|---|
| 217 | params = () |
|---|
| 218 | super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs) |
|---|
| 219 | |
|---|
| 220 | def _execute(self, c, statement, parameters): |
|---|
| 221 | try: |
|---|
| 222 | if parameters == {}: |
|---|
| 223 | parameters = () |
|---|
| 224 | c.execute(statement, parameters) |
|---|
| 225 | self.context.rowcount = c.rowcount |
|---|
| 226 | except Exception, e: |
|---|
| 227 | raise exc.DBAPIError.instance(statement, parameters, e) |
|---|
| 228 | |
|---|
| 229 | def has_table(self, connection, tablename, schema=None): |
|---|
| 230 | # This approach seems to be more reliable that using DAO |
|---|
| 231 | try: |
|---|
| 232 | connection.execute('select top 1 * from [%s]' % tablename) |
|---|
| 233 | return True |
|---|
| 234 | except Exception, e: |
|---|
| 235 | return False |
|---|
| 236 | |
|---|
| 237 | def reflecttable(self, connection, table, include_columns): |
|---|
| 238 | # This is defined in the function, as it relies on win32com constants, |
|---|
| 239 | # that aren't imported until dbapi method is called |
|---|
| 240 | if not hasattr(self, 'ischema_names'): |
|---|
| 241 | self.ischema_names = { |
|---|
| 242 | const.dbByte: AcBinary, |
|---|
| 243 | const.dbInteger: AcInteger, |
|---|
| 244 | const.dbLong: AcInteger, |
|---|
| 245 | const.dbSingle: AcFloat, |
|---|
| 246 | const.dbDouble: AcFloat, |
|---|
| 247 | const.dbDate: AcDateTime, |
|---|
| 248 | const.dbLongBinary: AcBinary, |
|---|
| 249 | const.dbMemo: AcText, |
|---|
| 250 | const.dbBoolean: AcBoolean, |
|---|
| 251 | const.dbText: AcUnicode, # All Access strings are unicode |
|---|
| 252 | const.dbCurrency: AcNumeric, |
|---|
| 253 | } |
|---|
| 254 | |
|---|
| 255 | # A fresh DAO connection is opened for each reflection |
|---|
| 256 | # This is necessary, so we get the latest updates |
|---|
| 257 | dtbs = daoEngine.OpenDatabase(connection.engine.url.database) |
|---|
| 258 | |
|---|
| 259 | try: |
|---|
| 260 | for tbl in dtbs.TableDefs: |
|---|
| 261 | if tbl.Name.lower() == table.name.lower(): |
|---|
| 262 | break |
|---|
| 263 | else: |
|---|
| 264 | raise exc.NoSuchTableError(table.name) |
|---|
| 265 | |
|---|
| 266 | for col in tbl.Fields: |
|---|
| 267 | coltype = self.ischema_names[col.Type] |
|---|
| 268 | if col.Type == const.dbText: |
|---|
| 269 | coltype = coltype(col.Size) |
|---|
| 270 | |
|---|
| 271 | colargs = \ |
|---|
| 272 | { |
|---|
| 273 | 'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField), |
|---|
| 274 | } |
|---|
| 275 | default = col.DefaultValue |
|---|
| 276 | |
|---|
| 277 | if col.Attributes & const.dbAutoIncrField: |
|---|
| 278 | colargs['default'] = schema.Sequence(col.Name + '_seq') |
|---|
| 279 | elif default: |
|---|
| 280 | if col.Type == const.dbBoolean: |
|---|
| 281 | default = default == 'Yes' and '1' or '0' |
|---|
| 282 | colargs['server_default'] = schema.DefaultClause(sql.text(default)) |
|---|
| 283 | |
|---|
| 284 | table.append_column(schema.Column(col.Name, coltype, **colargs)) |
|---|
| 285 | |
|---|
| 286 | # TBD: check constraints |
|---|
| 287 | |
|---|
| 288 | # Find primary key columns first |
|---|
| 289 | for idx in tbl.Indexes: |
|---|
| 290 | if idx.Primary: |
|---|
| 291 | for col in idx.Fields: |
|---|
| 292 | thecol = table.c[col.Name] |
|---|
| 293 | table.primary_key.add(thecol) |
|---|
| 294 | if isinstance(thecol.type, AcInteger) and \ |
|---|
| 295 | not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)): |
|---|
| 296 | thecol.autoincrement = False |
|---|
| 297 | |
|---|
| 298 | # Then add other indexes |
|---|
| 299 | for idx in tbl.Indexes: |
|---|
| 300 | if not idx.Primary: |
|---|
| 301 | if len(idx.Fields) == 1: |
|---|
| 302 | col = table.c[idx.Fields[0].Name] |
|---|
| 303 | if not col.primary_key: |
|---|
| 304 | col.index = True |
|---|
| 305 | col.unique = idx.Unique |
|---|
| 306 | else: |
|---|
| 307 | pass # TBD: multi-column indexes |
|---|
| 308 | |
|---|
| 309 | |
|---|
| 310 | for fk in dtbs.Relations: |
|---|
| 311 | if fk.ForeignTable != table.name: |
|---|
| 312 | continue |
|---|
| 313 | scols = [c.ForeignName for c in fk.Fields] |
|---|
| 314 | rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields] |
|---|
| 315 | table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True)) |
|---|
| 316 | |
|---|
| 317 | finally: |
|---|
| 318 | dtbs.Close() |
|---|
| 319 | |
|---|
| 320 | def table_names(self, connection, schema): |
|---|
| 321 | # A fresh DAO connection is opened for each reflection |
|---|
| 322 | # This is necessary, so we get the latest updates |
|---|
| 323 | dtbs = daoEngine.OpenDatabase(connection.engine.url.database) |
|---|
| 324 | |
|---|
| 325 | names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] |
|---|
| 326 | dtbs.Close() |
|---|
| 327 | return names |
|---|
| 328 | |
|---|
| 329 | |
|---|
| 330 | class AccessCompiler(compiler.DefaultCompiler): |
|---|
| 331 | extract_map = compiler.DefaultCompiler.extract_map.copy() |
|---|
| 332 | extract_map.update ({ |
|---|
| 333 | 'month': 'm', |
|---|
| 334 | 'day': 'd', |
|---|
| 335 | 'year': 'yyyy', |
|---|
| 336 | 'second': 's', |
|---|
| 337 | 'hour': 'h', |
|---|
| 338 | 'doy': 'y', |
|---|
| 339 | 'minute': 'n', |
|---|
| 340 | 'quarter': 'q', |
|---|
| 341 | 'dow': 'w', |
|---|
| 342 | 'week': 'ww' |
|---|
| 343 | }) |
|---|
| 344 | |
|---|
| 345 | def visit_select_precolumns(self, select): |
|---|
| 346 | """Access puts TOP, it's version of LIMIT here """ |
|---|
| 347 | s = select.distinct and "DISTINCT " or "" |
|---|
| 348 | if select.limit: |
|---|
| 349 | s += "TOP %s " % (select.limit) |
|---|
| 350 | if select.offset: |
|---|
| 351 | raise exc.InvalidRequestError('Access does not support LIMIT with an offset') |
|---|
| 352 | return s |
|---|
| 353 | |
|---|
| 354 | def limit_clause(self, select): |
|---|
| 355 | """Limit in access is after the select keyword""" |
|---|
| 356 | return "" |
|---|
| 357 | |
|---|
| 358 | def binary_operator_string(self, binary): |
|---|
| 359 | """Access uses "mod" instead of "%" """ |
|---|
| 360 | return binary.operator == '%' and 'mod' or binary.operator |
|---|
| 361 | |
|---|
| 362 | def label_select_column(self, select, column, asfrom): |
|---|
| 363 | if isinstance(column, expression.Function): |
|---|
| 364 | return column.label() |
|---|
| 365 | else: |
|---|
| 366 | return super(AccessCompiler, self).label_select_column(select, column, asfrom) |
|---|
| 367 | |
|---|
| 368 | function_rewrites = {'current_date': 'now', |
|---|
| 369 | 'current_timestamp': 'now', |
|---|
| 370 | 'length': 'len', |
|---|
| 371 | } |
|---|
| 372 | def visit_function(self, func): |
|---|
| 373 | """Access function names differ from the ANSI SQL names; rewrite common ones""" |
|---|
| 374 | func.name = self.function_rewrites.get(func.name, func.name) |
|---|
| 375 | return super(AccessCompiler, self).visit_function(func) |
|---|
| 376 | |
|---|
| 377 | def for_update_clause(self, select): |
|---|
| 378 | """FOR UPDATE is not supported by Access; silently ignore""" |
|---|
| 379 | return '' |
|---|
| 380 | |
|---|
| 381 | # Strip schema |
|---|
| 382 | def visit_table(self, table, asfrom=False, **kwargs): |
|---|
| 383 | if asfrom: |
|---|
| 384 | return self.preparer.quote(table.name, table.quote) |
|---|
| 385 | else: |
|---|
| 386 | return "" |
|---|
| 387 | |
|---|
| 388 | def visit_join(self, join, asfrom=False, **kwargs): |
|---|
| 389 | return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ |
|---|
| 390 | self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) |
|---|
| 391 | |
|---|
| 392 | def visit_extract(self, extract): |
|---|
| 393 | field = self.extract_map.get(extract.field, extract.field) |
|---|
| 394 | return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) |
|---|
| 395 | |
|---|
| 396 | |
|---|
| 397 | class AccessSchemaGenerator(compiler.SchemaGenerator): |
|---|
| 398 | def get_column_specification(self, column, **kwargs): |
|---|
| 399 | colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() |
|---|
| 400 | |
|---|
| 401 | # install a sequence if we have an implicit IDENTITY column |
|---|
| 402 | if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ |
|---|
| 403 | column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys: |
|---|
| 404 | if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): |
|---|
| 405 | column.sequence = schema.Sequence(column.name + '_seq') |
|---|
| 406 | |
|---|
| 407 | if not column.nullable: |
|---|
| 408 | colspec += " NOT NULL" |
|---|
| 409 | |
|---|
| 410 | if hasattr(column, 'sequence'): |
|---|
| 411 | column.table.has_sequence = column |
|---|
| 412 | colspec = self.preparer.format_column(column) + " counter" |
|---|
| 413 | else: |
|---|
| 414 | default = self.get_column_default_string(column) |
|---|
| 415 | if default is not None: |
|---|
| 416 | colspec += " DEFAULT " + default |
|---|
| 417 | |
|---|
| 418 | return colspec |
|---|
| 419 | |
|---|
| 420 | class AccessSchemaDropper(compiler.SchemaDropper): |
|---|
| 421 | def visit_index(self, index): |
|---|
| 422 | |
|---|
| 423 | self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) |
|---|
| 424 | self.execute() |
|---|
| 425 | |
|---|
| 426 | class AccessDefaultRunner(base.DefaultRunner): |
|---|
| 427 | pass |
|---|
| 428 | |
|---|
| 429 | class AccessIdentifierPreparer(compiler.IdentifierPreparer): |
|---|
| 430 | reserved_words = compiler.RESERVED_WORDS.copy() |
|---|
| 431 | reserved_words.update(['value', 'text']) |
|---|
| 432 | def __init__(self, dialect): |
|---|
| 433 | super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') |
|---|
| 434 | |
|---|
| 435 | |
|---|
| 436 | dialect = AccessDialect |
|---|
| 437 | dialect.poolclass = pool.SingletonThreadPool |
|---|
| 438 | dialect.statement_compiler = AccessCompiler |
|---|
| 439 | dialect.schemagenerator = AccessSchemaGenerator |
|---|
| 440 | dialect.schemadropper = AccessSchemaDropper |
|---|
| 441 | dialect.preparer = AccessIdentifierPreparer |
|---|
| 442 | dialect.defaultrunner = AccessDefaultRunner |
|---|
| 443 | dialect.execution_ctx_cls = AccessExecutionContext |
|---|