root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/databases/access.py

リビジョン 3, 15.6 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1# access.py
2# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk
3# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8from sqlalchemy import sql, schema, types, exc, pool
9from sqlalchemy.sql import compiler, expression
10from sqlalchemy.engine import default, base
11
12
13class AcNumeric(types.Numeric):
14    def result_processor(self, dialect):
15        return None
16
17    def bind_processor(self, dialect):
18        def process(value):
19            if value is None:
20                # Not sure that this exception is needed
21                return value
22            else:
23                return str(value)
24        return process
25
26    def get_col_spec(self):
27        return "NUMERIC"
28
29class AcFloat(types.Float):
30    def get_col_spec(self):
31        return "FLOAT"
32
33    def bind_processor(self, dialect):
34        """By converting to string, we can use Decimal types round-trip."""
35        def process(value):
36            if not value is None:
37                return str(value)
38            return None
39        return process
40
41class AcInteger(types.Integer):
42    def get_col_spec(self):
43        return "INTEGER"
44
45class AcTinyInteger(types.Integer):
46    def get_col_spec(self):
47        return "TINYINT"
48
49class AcSmallInteger(types.Smallinteger):
50    def get_col_spec(self):
51        return "SMALLINT"
52
53class AcDateTime(types.DateTime):
54    def __init__(self, *a, **kw):
55        super(AcDateTime, self).__init__(False)
56
57    def get_col_spec(self):
58        return "DATETIME"
59
60class AcDate(types.Date):
61    def __init__(self, *a, **kw):
62        super(AcDate, self).__init__(False)
63
64    def get_col_spec(self):
65        return "DATETIME"
66
67class AcText(types.Text):
68    def get_col_spec(self):
69        return "MEMO"
70
71class AcString(types.String):
72    def get_col_spec(self):
73        return "TEXT" + (self.length and ("(%d)" % self.length) or "")
74
75class AcUnicode(types.Unicode):
76    def get_col_spec(self):
77        return "TEXT" + (self.length and ("(%d)" % self.length) or "")
78
79    def bind_processor(self, dialect):
80        return None
81
82    def result_processor(self, dialect):
83        return None
84
85class AcChar(types.CHAR):
86    def get_col_spec(self):
87        return "TEXT" + (self.length and ("(%d)" % self.length) or "")
88
89class AcBinary(types.Binary):
90    def get_col_spec(self):
91        return "BINARY"
92
93class AcBoolean(types.Boolean):
94    def get_col_spec(self):
95        return "YESNO"
96
97    def result_processor(self, dialect):
98        def process(value):
99            if value is None:
100                return None
101            return value and True or False
102        return process
103
104    def bind_processor(self, dialect):
105        def process(value):
106            if value is True:
107                return 1
108            elif value is False:
109                return 0
110            elif value is None:
111                return None
112            else:
113                return value and True or False
114        return process
115
116class AcTimeStamp(types.TIMESTAMP):
117    def get_col_spec(self):
118        return "TIMESTAMP"
119
120class AccessExecutionContext(default.DefaultExecutionContext):
121    def _has_implicit_sequence(self, column):
122        if column.primary_key and column.autoincrement:
123            if isinstance(column.type, types.Integer) and not column.foreign_keys:
124                if column.default is None or (isinstance(column.default, schema.Sequence) and \
125                                              column.default.optional):
126                    return True
127        return False
128
129    def post_exec(self):
130        """If we inserted into a row with a COUNTER column, fetch the ID"""
131
132        if self.compiled.isinsert:
133            tbl = self.compiled.statement.table
134            if not hasattr(tbl, 'has_sequence'):
135                tbl.has_sequence = None
136                for column in tbl.c:
137                    if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
138                        tbl.has_sequence = column
139                        break
140
141            if bool(tbl.has_sequence):
142                # TBD: for some reason _last_inserted_ids doesn't exist here
143                # (but it does at corresponding point in mssql???)
144                #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
145                self.cursor.execute("SELECT @@identity AS lastrowid")
146                row = self.cursor.fetchone()
147                self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
148                # print "LAST ROW ID", self._last_inserted_ids
149
150        super(AccessExecutionContext, self).post_exec()
151
152
153const, daoEngine = None, None
154class AccessDialect(default.DefaultDialect):
155    colspecs = {
156        types.Unicode : AcUnicode,
157        types.Integer : AcInteger,
158        types.Smallinteger: AcSmallInteger,
159        types.Numeric : AcNumeric,
160        types.Float : AcFloat,
161        types.DateTime : AcDateTime,
162        types.Date : AcDate,
163        types.String : AcString,
164        types.Binary : AcBinary,
165        types.Boolean : AcBoolean,
166        types.Text : AcText,
167        types.CHAR: AcChar,
168        types.TIMESTAMP: AcTimeStamp,
169    }
170    name = 'access'
171    supports_sane_rowcount = False
172    supports_sane_multi_rowcount = False
173
174    def type_descriptor(self, typeobj):
175        newobj = types.adapt_type(typeobj, self.colspecs)
176        return newobj
177
178    def __init__(self, **params):
179        super(AccessDialect, self).__init__(**params)
180        self.text_as_varchar = False
181        self._dtbs = None
182
183    def dbapi(cls):
184        import win32com.client, pythoncom
185
186        global const, daoEngine
187        if const is None:
188            const = win32com.client.constants
189            for suffix in (".36", ".35", ".30"):
190                try:
191                    daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix)
192                    break
193                except pythoncom.com_error:
194                    pass
195            else:
196                raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
197
198        import pyodbc as module
199        return module
200    dbapi = classmethod(dbapi)
201
202    def create_connect_args(self, url):
203        opts = url.translate_connect_args()
204        connectors = ["Driver={Microsoft Access Driver (*.mdb)}"]
205        connectors.append("Dbq=%s" % opts["database"])
206        user = opts.get("username", None)
207        if user:
208            connectors.append("UID=%s" % user)
209            connectors.append("PWD=%s" % opts.get("password", ""))
210        return [[";".join(connectors)], {}]
211
212    def last_inserted_ids(self):
213        return self.context.last_inserted_ids
214
215    def do_execute(self, cursor, statement, params, **kwargs):
216        if params == {}:
217            params = ()
218        super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs)
219
220    def _execute(self, c, statement, parameters):
221        try:
222            if parameters == {}:
223                parameters = ()
224            c.execute(statement, parameters)
225            self.context.rowcount = c.rowcount
226        except Exception, e:
227            raise exc.DBAPIError.instance(statement, parameters, e)
228
229    def has_table(self, connection, tablename, schema=None):
230        # This approach seems to be more reliable that using DAO
231        try:
232            connection.execute('select top 1 * from [%s]' % tablename)
233            return True
234        except Exception, e:
235            return False
236
237    def reflecttable(self, connection, table, include_columns):
238        # This is defined in the function, as it relies on win32com constants,
239        # that aren't imported until dbapi method is called
240        if not hasattr(self, 'ischema_names'):
241            self.ischema_names = {
242                const.dbByte:       AcBinary,
243                const.dbInteger:    AcInteger,
244                const.dbLong:       AcInteger,
245                const.dbSingle:     AcFloat,
246                const.dbDouble:     AcFloat,
247                const.dbDate:       AcDateTime,
248                const.dbLongBinary: AcBinary,
249                const.dbMemo:       AcText,
250                const.dbBoolean:    AcBoolean,
251                const.dbText:       AcUnicode, # All Access strings are unicode
252                const.dbCurrency:   AcNumeric,
253            }
254
255        # A fresh DAO connection is opened for each reflection
256        # This is necessary, so we get the latest updates
257        dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
258
259        try:
260            for tbl in dtbs.TableDefs:
261                if tbl.Name.lower() == table.name.lower():
262                    break
263            else:
264                raise exc.NoSuchTableError(table.name)
265
266            for col in tbl.Fields:
267                coltype = self.ischema_names[col.Type]
268                if col.Type == const.dbText:
269                    coltype = coltype(col.Size)
270
271                colargs = \
272                {
273                    'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField),
274                }
275                default = col.DefaultValue
276
277                if col.Attributes & const.dbAutoIncrField:
278                    colargs['default'] = schema.Sequence(col.Name + '_seq')
279                elif default:
280                    if col.Type == const.dbBoolean:
281                        default = default == 'Yes' and '1' or '0'
282                    colargs['server_default'] = schema.DefaultClause(sql.text(default))
283
284                table.append_column(schema.Column(col.Name, coltype, **colargs))
285
286                # TBD: check constraints
287
288            # Find primary key columns first
289            for idx in tbl.Indexes:
290                if idx.Primary:
291                    for col in idx.Fields:
292                        thecol = table.c[col.Name]
293                        table.primary_key.add(thecol)
294                        if isinstance(thecol.type, AcInteger) and \
295                                not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)):
296                            thecol.autoincrement = False
297
298            # Then add other indexes
299            for idx in tbl.Indexes:
300                if not idx.Primary:
301                    if len(idx.Fields) == 1:
302                        col = table.c[idx.Fields[0].Name]
303                        if not col.primary_key:
304                            col.index = True
305                            col.unique = idx.Unique
306                    else:
307                        pass # TBD: multi-column indexes
308
309
310            for fk in dtbs.Relations:
311                if fk.ForeignTable != table.name:
312                    continue
313                scols = [c.ForeignName for c in fk.Fields]
314                rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields]
315                table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True))
316
317        finally:
318            dtbs.Close()
319
320    def table_names(self, connection, schema):
321        # A fresh DAO connection is opened for each reflection
322        # This is necessary, so we get the latest updates
323        dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
324
325        names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
326        dtbs.Close()
327        return names
328
329
330class AccessCompiler(compiler.DefaultCompiler):
331    extract_map = compiler.DefaultCompiler.extract_map.copy()
332    extract_map.update ({
333            'month': 'm',
334            'day': 'd',
335            'year': 'yyyy',
336            'second': 's',
337            'hour': 'h',
338            'doy': 'y',
339            'minute': 'n',
340            'quarter': 'q',
341            'dow': 'w',
342            'week': 'ww'
343    })
344
345    def visit_select_precolumns(self, select):
346        """Access puts TOP, it's version of LIMIT here """
347        s = select.distinct and "DISTINCT " or ""
348        if select.limit:
349            s += "TOP %s " % (select.limit)
350        if select.offset:
351            raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
352        return s
353
354    def limit_clause(self, select):
355        """Limit in access is after the select keyword"""
356        return ""
357
358    def binary_operator_string(self, binary):
359        """Access uses "mod" instead of "%" """
360        return binary.operator == '%' and 'mod' or binary.operator
361
362    def label_select_column(self, select, column, asfrom):
363        if isinstance(column, expression.Function):
364            return column.label()
365        else:
366            return super(AccessCompiler, self).label_select_column(select, column, asfrom)
367
368    function_rewrites =  {'current_date':       'now',
369                          'current_timestamp':  'now',
370                          'length':             'len',
371                          }
372    def visit_function(self, func):
373        """Access function names differ from the ANSI SQL names; rewrite common ones"""
374        func.name = self.function_rewrites.get(func.name, func.name)
375        return super(AccessCompiler, self).visit_function(func)
376
377    def for_update_clause(self, select):
378        """FOR UPDATE is not supported by Access; silently ignore"""
379        return ''
380
381    # Strip schema
382    def visit_table(self, table, asfrom=False, **kwargs):
383        if asfrom:
384            return self.preparer.quote(table.name, table.quote)
385        else:
386            return ""
387
388    def visit_join(self, join, asfrom=False, **kwargs):
389        return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
390            self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
391
392    def visit_extract(self, extract):
393        field = self.extract_map.get(extract.field, extract.field)
394        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
395
396
397class AccessSchemaGenerator(compiler.SchemaGenerator):
398    def get_column_specification(self, column, **kwargs):
399        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
400
401        # install a sequence if we have an implicit IDENTITY column
402        if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
403                column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys:
404            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
405                column.sequence = schema.Sequence(column.name + '_seq')
406
407        if not column.nullable:
408            colspec += " NOT NULL"
409
410        if hasattr(column, 'sequence'):
411            column.table.has_sequence = column
412            colspec = self.preparer.format_column(column) + " counter"
413        else:
414            default = self.get_column_default_string(column)
415            if default is not None:
416                colspec += " DEFAULT " + default
417
418        return colspec
419
420class AccessSchemaDropper(compiler.SchemaDropper):
421    def visit_index(self, index):
422       
423        self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
424        self.execute()
425
426class AccessDefaultRunner(base.DefaultRunner):
427    pass
428
429class AccessIdentifierPreparer(compiler.IdentifierPreparer):
430    reserved_words = compiler.RESERVED_WORDS.copy()
431    reserved_words.update(['value', 'text'])
432    def __init__(self, dialect):
433        super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
434
435
436dialect = AccessDialect
437dialect.poolclass = pool.SingletonThreadPool
438dialect.statement_compiler = AccessCompiler
439dialect.schemagenerator = AccessSchemaGenerator
440dialect.schemadropper = AccessSchemaDropper
441dialect.preparer = AccessIdentifierPreparer
442dialect.defaultrunner = AccessDefaultRunner
443dialect.execution_ctx_cls = AccessExecutionContext
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。