root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/sql/compiler.py

リビジョン 3, 45.6 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1# compiler.py
2# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
3#
4# This module is part of SQLAlchemy and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""Base SQL and DDL compiler implementations.
8
9Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is
10responsible for generating all SQL query strings, as well as
11:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper`
12which issue CREATE and DROP DDL for tables, sequences, and indexes.
13
14The elements in this module are used by public-facing constructs like
15:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`.
16While dialect authors will want to be familiar with this module for the purpose of
17creating database-specific compilers and schema generators, the module
18is otherwise internal to SQLAlchemy.
19"""
20
21import string, re
22from sqlalchemy import schema, engine, util, exc
23from sqlalchemy.sql import operators, functions, util as sql_util, visitors
24from sqlalchemy.sql import expression as sql
25
26RESERVED_WORDS = set([
27    'all', 'analyse', 'analyze', 'and', 'any', 'array',
28    'as', 'asc', 'asymmetric', 'authorization', 'between',
29    'binary', 'both', 'case', 'cast', 'check', 'collate',
30    'column', 'constraint', 'create', 'cross', 'current_date',
31    'current_role', 'current_time', 'current_timestamp',
32    'current_user', 'default', 'deferrable', 'desc',
33    'distinct', 'do', 'else', 'end', 'except', 'false',
34    'for', 'foreign', 'freeze', 'from', 'full', 'grant',
35    'group', 'having', 'ilike', 'in', 'initially', 'inner',
36    'intersect', 'into', 'is', 'isnull', 'join', 'leading',
37    'left', 'like', 'limit', 'localtime', 'localtimestamp',
38    'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
39    'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
40    'placing', 'primary', 'references', 'right', 'select',
41    'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
42    'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
43    'using', 'verbose', 'when', 'where'])
44
45LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
46ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]')
47
48BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
49BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
50
51BIND_TEMPLATES = {
52    'pyformat':"%%(%(name)s)s",
53    'qmark':"?",
54    'format':"%%s",
55    'numeric':":%(position)s",
56    'named':":%(name)s"
57}
58
59
60OPERATORS =  {
61    operators.and_ : 'AND',
62    operators.or_ : 'OR',
63    operators.inv : 'NOT',
64    operators.add : '+',
65    operators.mul : '*',
66    operators.sub : '-',
67    operators.div : '/',
68    operators.mod : '%',
69    operators.truediv : '/',
70    operators.lt : '<',
71    operators.le : '<=',
72    operators.ne : '!=',
73    operators.gt : '>',
74    operators.ge : '>=',
75    operators.eq : '=',
76    operators.distinct_op : 'DISTINCT',
77    operators.concat_op : '||',
78    operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
79    operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
80    operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
81    operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
82    operators.between_op : 'BETWEEN',
83    operators.match_op : 'MATCH',
84    operators.in_op : 'IN',
85    operators.notin_op : 'NOT IN',
86    operators.comma_op : ', ',
87    operators.desc_op : 'DESC',
88    operators.asc_op : 'ASC',
89    operators.from_ : 'FROM',
90    operators.as_ : 'AS',
91    operators.exists : 'EXISTS',
92    operators.is_ : 'IS',
93    operators.isnot : 'IS NOT',
94    operators.collate : 'COLLATE',
95}
96
97FUNCTIONS = {
98    functions.coalesce : 'coalesce%(expr)s',
99    functions.current_date: 'CURRENT_DATE',
100    functions.current_time: 'CURRENT_TIME',
101    functions.current_timestamp: 'CURRENT_TIMESTAMP',
102    functions.current_user: 'CURRENT_USER',
103    functions.localtime: 'LOCALTIME',
104    functions.localtimestamp: 'LOCALTIMESTAMP',
105    functions.random: 'random%(expr)s',
106    functions.sysdate: 'sysdate',
107    functions.session_user :'SESSION_USER',
108    functions.user: 'USER'
109}
110
111EXTRACT_MAP = {
112    'month': 'month',
113    'day': 'day',
114    'year': 'year',
115    'second': 'second',
116    'hour': 'hour',
117    'doy': 'doy',
118    'minute': 'minute',
119    'quarter': 'quarter',
120    'dow': 'dow',
121    'week': 'week',
122    'epoch': 'epoch',
123    'milliseconds': 'milliseconds',
124    'microseconds': 'microseconds',
125    'timezone_hour': 'timezone_hour',
126    'timezone_minute': 'timezone_minute'
127}
128
129class _CompileLabel(visitors.Visitable):
130    """lightweight label object which acts as an expression._Label."""
131
132    __visit_name__ = 'label'
133    __slots__ = 'element', 'name'
134   
135    def __init__(self, col, name):
136        self.element = col
137        self.name = name
138       
139    @property
140    def quote(self):
141        return self.element.quote
142
143class DefaultCompiler(engine.Compiled):
144    """Default implementation of Compiled.
145
146    Compiles ClauseElements into SQL strings.   Uses a similar visit
147    paradigm as visitors.ClauseVisitor but implements its own traversal.
148
149    """
150
151    operators = OPERATORS
152    functions = FUNCTIONS
153    extract_map = EXTRACT_MAP
154
155    # if we are insert/update/delete.
156    # set to true when we visit an INSERT, UPDATE or DELETE
157    isdelete = isinsert = isupdate = False
158
159    def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
160        """Construct a new ``DefaultCompiler`` object.
161
162        dialect
163          Dialect to be used
164
165        statement
166          ClauseElement to be compiled
167
168        column_keys
169          a list of column names to be compiled into an INSERT or UPDATE
170          statement.
171
172        """
173        engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
174
175        # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
176        self.inline = inline or getattr(statement, 'inline', False)
177
178        # a dictionary of bind parameter keys to _BindParamClause instances.
179        self.binds = {}
180
181        # a dictionary of _BindParamClause instances to "compiled" names that are
182        # actually present in the generated SQL
183        self.bind_names = util.column_dict()
184
185        # stack which keeps track of nested SELECT statements
186        self.stack = []
187
188        # relates label names in the final SQL to
189        # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
190        # ResultProxy uses this for type processing and column targeting
191        self.result_map = {}
192
193        # true if the paramstyle is positional
194        self.positional = self.dialect.positional
195        if self.positional:
196            self.positiontup = []
197
198        self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
199
200        # an IdentifierPreparer that formats the quoting of identifiers
201        self.preparer = self.dialect.identifier_preparer
202
203        self.label_length = self.dialect.label_length or self.dialect.max_identifier_length
204       
205        # a map which tracks "anonymous" identifiers that are
206        # created on the fly here
207        self.anon_map = util.PopulateDict(self._process_anon)
208
209        # a map which tracks "truncated" names based on dialect.label_length
210        # or dialect.max_identifier_length
211        self.truncated_names = {}
212
213    def compile(self):
214        self.string = self.process(self.statement)
215
216    def process(self, obj, **kwargs):
217        return obj._compiler_dispatch(self, **kwargs)
218
219    def is_subquery(self):
220        return len(self.stack) > 1
221
222    def construct_params(self, params=None):
223        """return a dictionary of bind parameter keys and values"""
224
225        if params:
226            params = util.column_dict(params)
227            pd = {}
228            for bindparam, name in self.bind_names.iteritems():
229                for paramname in (bindparam.key, bindparam.shortname, name):
230                    if paramname in params:
231                        pd[name] = params[paramname]
232                        break
233                else:
234                    if util.callable(bindparam.value):
235                        pd[name] = bindparam.value()
236                    else:
237                        pd[name] = bindparam.value
238            return pd
239        else:
240            pd = {}
241            for bindparam in self.bind_names:
242                if util.callable(bindparam.value):
243                    pd[self.bind_names[bindparam]] = bindparam.value()
244                else:
245                    pd[self.bind_names[bindparam]] = bindparam.value
246            return pd
247
248    params = property(construct_params)
249
250    def default_from(self):
251        """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
252
253        Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
254
255        """
256        return ""
257
258    def visit_grouping(self, grouping, **kwargs):
259        return "(" + self.process(grouping.element) + ")"
260
261    def visit_label(self, label, result_map=None, within_columns_clause=False):
262        # only render labels within the columns clause
263        # or ORDER BY clause of a select.  dialect-specific compilers
264        # can modify this behavior.
265        if within_columns_clause:
266            labelname = isinstance(label.name, sql._generated_label) and \
267                    self._truncated_identifier("colident", label.name) or label.name
268
269            if result_map is not None:
270                result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
271
272            return self.process(label.element) + " " + \
273                        self.operator_string(operators.as_) + " " + \
274                        self.preparer.format_label(label, labelname)
275        else:
276            return self.process(label.element)
277           
278    def visit_column(self, column, result_map=None, **kwargs):
279        name = column.name
280        if not column.is_literal and isinstance(name, sql._generated_label):
281            name = self._truncated_identifier("colident", name)
282
283        if result_map is not None:
284            result_map[name.lower()] = (name, (column, ), column.type)
285       
286        if column.is_literal:
287            name = self.escape_literal_column(name)
288        else:
289            name = self.preparer.quote(name, column.quote)
290
291        if column.table is None or not column.table.named_with_column:
292            return name
293        else:
294            if column.table.schema:
295                schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.'
296            else:
297                schema_prefix = ''
298            tablename = column.table.name
299            tablename = isinstance(tablename, sql._generated_label) and \
300                            self._truncated_identifier("alias", tablename) or tablename
301           
302            return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
303
304    def escape_literal_column(self, text):
305        """provide escaping for the literal_column() construct."""
306
307        # TODO: some dialects might need different behavior here
308        return text.replace('%', '%%')
309
310    def visit_fromclause(self, fromclause, **kwargs):
311        return fromclause.name
312
313    def visit_index(self, index, **kwargs):
314        return index.name
315
316    def visit_typeclause(self, typeclause, **kwargs):
317        return typeclause.type.dialect_impl(self.dialect).get_col_spec()
318
319    def post_process_text(self, text):
320        return text
321       
322    def visit_textclause(self, textclause, **kwargs):
323        if textclause.typemap is not None:
324            for colname, type_ in textclause.typemap.iteritems():
325                self.result_map[colname.lower()] = (colname, None, type_)
326
327        def do_bindparam(m):
328            name = m.group(1)
329            if name in textclause.bindparams:
330                return self.process(textclause.bindparams[name])
331            else:
332                return self.bindparam_string(name)
333
334        # un-escape any \:params
335        return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
336            BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text))
337        )
338
339    def visit_null(self, null, **kwargs):
340        return 'NULL'
341
342    def visit_clauselist(self, clauselist, **kwargs):
343        sep = clauselist.operator
344        if sep is None:
345            sep = " "
346        elif sep is operators.comma_op:
347            sep = ', '
348        else:
349            sep = " " + self.operator_string(clauselist.operator) + " "
350        return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
351                        if s is not None)
352
353    def visit_case(self, clause, **kwargs):
354        x = "CASE "
355        if clause.value:
356            x += self.process(clause.value) + " "
357        for cond, result in clause.whens:
358            x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " "
359        if clause.else_:
360            x += "ELSE " + self.process(clause.else_) + " "
361        x += "END"
362        return x
363
364    def visit_cast(self, cast, **kwargs):
365        return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
366
367    def visit_extract(self, extract, **kwargs):
368        field = self.extract_map.get(extract.field, extract.field)
369        return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr))
370
371    def visit_function(self, func, result_map=None, **kwargs):
372        if result_map is not None:
373            result_map[func.name.lower()] = (func.name, None, func.type)
374
375        name = self.function_string(func)
376
377        if util.callable(name):
378            return name(*[self.process(x) for x in func.clauses])
379        else:
380            return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
381
382    def function_argspec(self, func, **kwargs):
383        return self.process(func.clause_expr, **kwargs)
384
385    def function_string(self, func):
386        return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
387
388    def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
389        entry = self.stack and self.stack[-1] or {}
390        self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
391
392        text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i)
393                            for i, c in enumerate(cs.selects)),
394                           " " + cs.keyword + " ")
395        group_by = self.process(cs._group_by_clause, asfrom=asfrom)
396        if group_by:
397            text += " GROUP BY " + group_by
398
399        text += self.order_by_clause(cs)
400        text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
401
402        self.stack.pop(-1)
403        if asfrom and parens:
404            return "(" + text + ")"
405        else:
406            return text
407
408    def visit_unary(self, unary, **kw):
409        s = self.process(unary.element, **kw)
410        if unary.operator:
411            s = self.operator_string(unary.operator) + " " + s
412        if unary.modifier:
413            s = s + " " + self.operator_string(unary.modifier)
414        return s
415
416    def visit_binary(self, binary, **kwargs):
417        op = self.operator_string(binary.operator)
418        if util.callable(op):
419            return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
420        else:
421            return self.process(binary.left) + " " + op + " " + self.process(binary.right)
422
423    def operator_string(self, operator):
424        return self.operators.get(operator, str(operator))
425
426    def visit_bindparam(self, bindparam, **kwargs):
427        name = self._truncate_bindparam(bindparam)
428        if name in self.binds:
429            existing = self.binds[name]
430            if existing is not bindparam and (existing.unique or bindparam.unique):
431                raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
432        self.binds[bindparam.key] = self.binds[name] = bindparam
433        return self.bindparam_string(name)
434
435    def _truncate_bindparam(self, bindparam):
436        if bindparam in self.bind_names:
437            return self.bind_names[bindparam]
438
439        bind_name = bindparam.key
440        bind_name = isinstance(bind_name, sql._generated_label) and \
441                        self._truncated_identifier("bindparam", bind_name) or bind_name
442        # add to bind_names for translation
443        self.bind_names[bindparam] = bind_name
444
445        return bind_name
446
447    def _truncated_identifier(self, ident_class, name):
448        if (ident_class, name) in self.truncated_names:
449            return self.truncated_names[(ident_class, name)]
450
451        anonname = name % self.anon_map
452
453        if len(anonname) > self.label_length:
454            counter = self.truncated_names.get(ident_class, 1)
455            truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:]
456            self.truncated_names[ident_class] = counter + 1
457        else:
458            truncname = anonname
459        self.truncated_names[(ident_class, name)] = truncname
460        return truncname
461   
462    def _anonymize(self, name):
463        return name % self.anon_map
464       
465    def _process_anon(self, key):
466        (ident, derived) = key.split(' ', 1)
467        anonymous_counter = self.anon_map.get(derived, 1)
468        self.anon_map[derived] = anonymous_counter + 1
469        return derived + "_" + str(anonymous_counter)
470
471    def bindparam_string(self, name):
472        if self.positional:
473            self.positiontup.append(name)
474            return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
475        else:
476            return self.bindtemplate % {'name':name}
477
478    def visit_alias(self, alias, asfrom=False, **kwargs):
479        if asfrom:
480            alias_name = isinstance(alias.name, sql._generated_label) and \
481                            self._truncated_identifier("alias", alias.name) or alias.name
482           
483            return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
484                    self.preparer.format_alias(alias, alias_name)
485        else:
486            return self.process(alias.original, **kwargs)
487
488    def label_select_column(self, select, column, asfrom):
489        """label columns present in a select()."""
490
491        if isinstance(column, sql._Label):
492            return column
493
494        if select.use_labels and column._label:
495            return _CompileLabel(column, column._label)
496
497        if \
498            asfrom and \
499            isinstance(column, sql.ColumnClause) and \
500            not column.is_literal and \
501            column.table is not None and \
502            not isinstance(column.table, sql.Select):
503            return _CompileLabel(column, sql._generated_label(column.name))
504        elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
505                and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
506            return _CompileLabel(column, column.anon_label)
507        else:
508            return column
509
510    def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs):
511
512        entry = self.stack and self.stack[-1] or {}
513       
514        existingfroms = entry.get('from', None)
515
516        froms = select._get_display_froms(existingfroms)
517
518        correlate_froms = set(sql._from_objects(*froms))
519
520        # TODO: might want to propagate existing froms for select(select(select))
521        # where innermost select should correlate to outermost
522        # if existingfroms:
523        #     correlate_froms = correlate_froms.union(existingfroms)
524
525        self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper})
526
527        if compound_index==1 and not entry or entry.get('iswrapper', False):
528            column_clause_args = {'result_map':self.result_map}
529        else:
530            column_clause_args = {}
531
532        # the actual list of columns to print in the SELECT column list.
533        inner_columns = [
534            c for c in [
535                self.process(
536                    self.label_select_column(select, co, asfrom=asfrom),
537                    within_columns_clause=True,
538                    **column_clause_args)
539                for co in util.unique_list(select.inner_columns)
540            ]
541            if c is not None
542        ]
543       
544        text = "SELECT "  # we're off to a good start !
545        if select._prefixes:
546            text += " ".join(self.process(x) for x in select._prefixes) + " "
547        text += self.get_select_precolumns(select)
548        text += ', '.join(inner_columns)
549
550        if froms:
551            text += " \nFROM "
552            text += ', '.join(self.process(f, asfrom=True) for f in froms)
553        else:
554            text += self.default_from()
555
556        if select._whereclause is not None:
557            t = self.process(select._whereclause)
558            if t:
559                text += " \nWHERE " + t
560
561        if select._group_by_clause.clauses:
562            group_by = self.process(select._group_by_clause)
563            if group_by:
564                text += " GROUP BY " + group_by
565
566        if select._having is not None:
567            t = self.process(select._having)
568            if t:
569                text += " \nHAVING " + t
570
571        if select._order_by_clause.clauses:
572            text += self.order_by_clause(select)
573        if select._limit is not None or select._offset is not None:
574            text += self.limit_clause(select)
575        if select.for_update:
576            text += self.for_update_clause(select)
577
578        self.stack.pop(-1)
579
580        if asfrom and parens:
581            return "(" + text + ")"
582        else:
583            return text
584
585    def get_select_precolumns(self, select):
586        """Called when building a ``SELECT`` statement, position is just before column list."""
587
588        return select._distinct and "DISTINCT " or ""
589
590    def order_by_clause(self, select):
591        order_by = self.process(select._order_by_clause)
592        if order_by:
593            return " ORDER BY " + order_by
594        else:
595            return ""
596
597    def for_update_clause(self, select):
598        if select.for_update:
599            return " FOR UPDATE"
600        else:
601            return ""
602
603    def limit_clause(self, select):
604        text = ""
605        if select._limit is not None:
606            text +=  " \n LIMIT " + str(select._limit)
607        if select._offset is not None:
608            if select._limit is None:
609                text += " \n LIMIT -1"
610            text += " OFFSET " + str(select._offset)
611        return text
612
613    def visit_table(self, table, asfrom=False, **kwargs):
614        if asfrom:
615            if getattr(table, "schema", None):
616                return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
617            else:
618                return self.preparer.quote(table.name, table.quote)
619        else:
620            return ""
621
622    def visit_join(self, join, asfrom=False, **kwargs):
623        return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
624            self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
625
626    def visit_sequence(self, seq):
627        return None
628
629    def visit_insert(self, insert_stmt):
630        self.isinsert = True
631        colparams = self._get_colparams(insert_stmt)
632        preparer = self.preparer
633
634        insert = ' '.join(["INSERT"] +
635                          [self.process(x) for x in insert_stmt._prefixes])
636
637        if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
638            raise exc.CompileError(
639                "The version of %s you are using does not support empty inserts." % self.dialect.name)
640        elif not colparams and self.dialect.supports_default_values:
641            return (insert + " INTO %s DEFAULT VALUES" % (
642                (preparer.format_table(insert_stmt.table),)))
643        else:
644            return (insert + " INTO %s (%s) VALUES (%s)" %
645                (preparer.format_table(insert_stmt.table),
646                 ', '.join([preparer.format_column(c[0])
647                           for c in colparams]),
648                 ', '.join([c[1] for c in colparams])))
649
650    def visit_update(self, update_stmt):
651        self.stack.append({'from': set([update_stmt.table])})
652
653        self.isupdate = True
654        colparams = self._get_colparams(update_stmt)
655
656        text = ' '.join((
657            "UPDATE",
658            self.preparer.format_table(update_stmt.table),
659            'SET',
660            ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
661                      for c in colparams)
662            ))
663
664        if update_stmt._whereclause:
665            text += " WHERE " + self.process(update_stmt._whereclause)
666
667        self.stack.pop(-1)
668
669        return text
670
671    def _get_colparams(self, stmt):
672        """create a set of tuples representing column/string pairs for use
673        in an INSERT or UPDATE statement.
674
675        """
676
677        def create_bind_param(col, value):
678            bindparam = sql.bindparam(col.key, value, type_=col.type)
679            self.binds[col.key] = bindparam
680            return self.bindparam_string(self._truncate_bindparam(bindparam))
681
682        self.postfetch = []
683        self.prefetch = []
684
685        # no parameters in the statement, no parameters in the
686        # compiled params - return binds for all columns
687        if self.column_keys is None and stmt.parameters is None:
688            return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
689
690        # if we have statement parameters - set defaults in the
691        # compiled params
692        if self.column_keys is None:
693            parameters = {}
694        else:
695            parameters = dict((sql._column_as_key(key), None)
696                              for key in self.column_keys)
697
698        if stmt.parameters is not None:
699            for k, v in stmt.parameters.iteritems():
700                parameters.setdefault(sql._column_as_key(k), v)
701
702        # create a list of column assignment clauses as tuples
703        values = []
704        for c in stmt.table.columns:
705            if c.key in parameters:
706                value = parameters[c.key]
707                if sql._is_literal(value):
708                    value = create_bind_param(c, value)
709                else:
710                    self.postfetch.append(c)
711                    value = self.process(value.self_group())
712                values.append((c, value))
713            elif isinstance(c, schema.Column):
714                if self.isinsert:
715                    if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
716                        if (((isinstance(c.default, schema.Sequence) and
717                              not c.default.optional) or
718                             not self.dialect.supports_pk_autoincrement) or
719                            (c.default is not None and
720                             not isinstance(c.default, schema.Sequence))):
721                            values.append((c, create_bind_param(c, None)))
722                            self.prefetch.append(c)
723                    elif isinstance(c.default, schema.ColumnDefault):
724                        if isinstance(c.default.arg, sql.ClauseElement):
725                            values.append((c, self.process(c.default.arg.self_group())))
726                            if not c.primary_key:
727                                # dont add primary key column to postfetch
728                                self.postfetch.append(c)
729                        else:
730                            values.append((c, create_bind_param(c, None)))
731                            self.prefetch.append(c)
732                    elif c.server_default is not None:
733                        if not c.primary_key:
734                            self.postfetch.append(c)
735                    elif isinstance(c.default, schema.Sequence):
736                        proc = self.process(c.default)
737                        if proc is not None:
738                            values.append((c, proc))
739                            if not c.primary_key:
740                                self.postfetch.append(c)
741                elif self.isupdate:
742                    if isinstance(c.onupdate, schema.ColumnDefault):
743                        if isinstance(c.onupdate.arg, sql.ClauseElement):
744                            values.append((c, self.process(c.onupdate.arg.self_group())))
745                            self.postfetch.append(c)
746                        else:
747                            values.append((c, create_bind_param(c, None)))
748                            self.prefetch.append(c)
749                    elif c.server_onupdate is not None:
750                        self.postfetch.append(c)
751                    # deprecated? or remove?
752                    elif isinstance(c.onupdate, schema.FetchedValue):
753                        self.postfetch.append(c)
754        return values
755
756    def visit_delete(self, delete_stmt):
757        self.stack.append({'from': set([delete_stmt.table])})
758        self.isdelete = True
759
760        text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
761
762        if delete_stmt._whereclause:
763            text += " WHERE " + self.process(delete_stmt._whereclause)
764
765        self.stack.pop(-1)
766
767        return text
768
769    def visit_savepoint(self, savepoint_stmt):
770        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
771
772    def visit_rollback_to_savepoint(self, savepoint_stmt):
773        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
774
775    def visit_release_savepoint(self, savepoint_stmt):
776        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
777
778    def __str__(self):
779        return self.string or ''
780
781class DDLBase(engine.SchemaIterator):
782    def find_alterables(self, tables):
783        alterables = []
784        class FindAlterables(schema.SchemaVisitor):
785            def visit_foreign_key_constraint(self, constraint):
786                if constraint.use_alter and constraint.table in tables:
787                    alterables.append(constraint)
788        findalterables = FindAlterables()
789        for table in tables:
790            for c in table.constraints:
791                findalterables.traverse(c)
792        return alterables
793
794    def _validate_identifier(self, ident, truncate):
795        if truncate:
796            if len(ident) > self.dialect.max_identifier_length:
797                counter = getattr(self, 'counter', 0)
798                self.counter = counter + 1
799                return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
800            else:
801                return ident
802        else:
803            self.dialect.validate_identifier(ident)
804            return ident
805
806
807class SchemaGenerator(DDLBase):
808    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
809        super(SchemaGenerator, self).__init__(connection, **kwargs)
810        self.checkfirst = checkfirst
811        self.tables = tables and set(tables) or None
812        self.preparer = dialect.identifier_preparer
813        self.dialect = dialect
814
815    def get_column_specification(self, column, first_pk=False):
816        raise NotImplementedError()
817
818    def _can_create(self, table):
819        self.dialect.validate_identifier(table.name)
820        if table.schema:
821            self.dialect.validate_identifier(table.schema)
822        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
823
824    def visit_metadata(self, metadata):
825        if self.tables:
826            tables = self.tables
827        else:
828            tables = metadata.tables.values()
829        collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
830        for table in collection:
831            self.traverse_single(table)
832        if self.dialect.supports_alter:
833            for alterable in self.find_alterables(collection):
834                self.add_foreignkey(alterable)
835
836    def visit_table(self, table):
837        for listener in table.ddl_listeners['before-create']:
838            listener('before-create', table, self.connection)
839
840        for column in table.columns:
841            if column.default is not None:
842                self.traverse_single(column.default)
843
844        self.append("\n" + " ".join(['CREATE'] +
845                                    table._prefixes +
846                                    ['TABLE',
847                                     self.preparer.format_table(table),
848                                     "("]))
849        separator = "\n"
850
851        # if only one primary key, specify it along with the column
852        first_pk = False
853        for column in table.columns:
854            self.append(separator)
855            separator = ", \n"
856            self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
857            if column.primary_key:
858                first_pk = True
859            for constraint in column.constraints:
860                self.traverse_single(constraint)
861
862        # On some DB order is significant: visit PK first, then the
863        # other constraints (engine.ReflectionTest.testbasic failed on FB2)
864        if table.primary_key:
865            self.traverse_single(table.primary_key)
866        for constraint in [c for c in table.constraints if c is not table.primary_key]:
867            self.traverse_single(constraint)
868
869        self.append("\n)%s\n\n" % self.post_create_table(table))
870        self.execute()
871
872        if hasattr(table, 'indexes'):
873            for index in table.indexes:
874                self.traverse_single(index)
875
876        for listener in table.ddl_listeners['after-create']:
877            listener('after-create', table, self.connection)
878
879    def post_create_table(self, table):
880        return ''
881
882    def get_column_default_string(self, column):
883        if isinstance(column.server_default, schema.DefaultClause):
884            if isinstance(column.server_default.arg, basestring):
885                return "'%s'" % column.server_default.arg
886            else:
887                return unicode(self._compile(column.server_default.arg, None))
888        else:
889            return None
890
891    def _compile(self, tocompile, parameters):
892        """compile the given string/parameters using this SchemaGenerator's dialect."""
893        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
894        compiler.compile()
895        return compiler
896
897    def visit_check_constraint(self, constraint):
898        self.append(", \n\t")
899        if constraint.name is not None:
900            self.append("CONSTRAINT %s " %
901                        self.preparer.format_constraint(constraint))
902        self.append(" CHECK (%s)" % constraint.sqltext)
903        self.define_constraint_deferrability(constraint)
904
905    def visit_column_check_constraint(self, constraint):
906        self.append(" CHECK (%s)" % constraint.sqltext)
907        self.define_constraint_deferrability(constraint)
908
909    def visit_primary_key_constraint(self, constraint):
910        if len(constraint) == 0:
911            return
912        self.append(", \n\t")
913        if constraint.name is not None:
914            self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
915        self.append("PRIMARY KEY ")
916        self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
917                                       for c in constraint))
918        self.define_constraint_deferrability(constraint)
919
920    def visit_foreign_key_constraint(self, constraint):
921        if constraint.use_alter and self.dialect.supports_alter:
922            return
923        self.append(", \n\t ")
924        self.define_foreign_key(constraint)
925
926    def add_foreignkey(self, constraint):
927        self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
928        self.define_foreign_key(constraint)
929        self.execute()
930
931    def define_foreign_key(self, constraint):
932        preparer = self.preparer
933        if constraint.name is not None:
934            self.append("CONSTRAINT %s " %
935                        preparer.format_constraint(constraint))
936        table = list(constraint.elements)[0].column.table
937        self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
938            ', '.join(preparer.quote(f.parent.name, f.parent.quote)
939                      for f in constraint.elements),
940            preparer.format_table(table),
941            ', '.join(preparer.quote(f.column.name, f.column.quote)
942                      for f in constraint.elements)
943        ))
944        if constraint.ondelete is not None:
945            self.append(" ON DELETE %s" % constraint.ondelete)
946        if constraint.onupdate is not None:
947            self.append(" ON UPDATE %s" % constraint.onupdate)
948        self.define_constraint_deferrability(constraint)
949
950    def visit_unique_constraint(self, constraint):
951        self.append(", \n\t")
952        if constraint.name is not None:
953            self.append("CONSTRAINT %s " %
954                        self.preparer.format_constraint(constraint))
955        self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
956        self.define_constraint_deferrability(constraint)
957
958    def define_constraint_deferrability(self, constraint):
959        if constraint.deferrable is not None:
960            if constraint.deferrable:
961                self.append(" DEFERRABLE")
962            else:
963                self.append(" NOT DEFERRABLE")
964        if constraint.initially is not None:
965            self.append(" INITIALLY %s" % constraint.initially)
966
967    def visit_column(self, column):
968        pass
969
970    def visit_index(self, index):
971        preparer = self.preparer
972        self.append("CREATE ")
973        if index.unique:
974            self.append("UNIQUE ")
975        self.append("INDEX %s ON %s (%s)" \
976                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
977                       preparer.format_table(index.table),
978                       ', '.join(preparer.quote(c.name, c.quote)
979                                 for c in index.columns)))
980        self.execute()
981
982
983class SchemaDropper(DDLBase):
984    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
985        super(SchemaDropper, self).__init__(connection, **kwargs)
986        self.checkfirst = checkfirst
987        self.tables = tables
988        self.preparer = dialect.identifier_preparer
989        self.dialect = dialect
990
991    def visit_metadata(self, metadata):
992        if self.tables:
993            tables = self.tables
994        else:
995            tables = metadata.tables.values()
996        collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
997        if self.dialect.supports_alter:
998            for alterable in self.find_alterables(collection):
999                self.drop_foreignkey(alterable)
1000        for table in collection:
1001            self.traverse_single(table)
1002
1003    def _can_drop(self, table):
1004        self.dialect.validate_identifier(table.name)
1005        if table.schema:
1006            self.dialect.validate_identifier(table.schema)
1007        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
1008
1009    def visit_index(self, index):
1010        self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
1011        self.execute()
1012
1013    def drop_foreignkey(self, constraint):
1014        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
1015            self.preparer.format_table(constraint.table),
1016            self.preparer.format_constraint(constraint)))
1017        self.execute()
1018
1019    def visit_table(self, table):
1020        for listener in table.ddl_listeners['before-drop']:
1021            listener('before-drop', table, self.connection)
1022
1023        for column in table.columns:
1024            if column.default is not None:
1025                self.traverse_single(column.default)
1026
1027        self.append("\nDROP TABLE " + self.preparer.format_table(table))
1028        self.execute()
1029
1030        for listener in table.ddl_listeners['after-drop']:
1031            listener('after-drop', table, self.connection)
1032
1033
1034class IdentifierPreparer(object):
1035    """Handle quoting and case-folding of identifiers based on options."""
1036
1037    reserved_words = RESERVED_WORDS
1038
1039    legal_characters = LEGAL_CHARACTERS
1040
1041    illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
1042
1043    def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
1044        """Construct a new ``IdentifierPreparer`` object.
1045
1046        initial_quote
1047          Character that begins a delimited identifier.
1048
1049        final_quote
1050          Character that ends a delimited identifier. Defaults to `initial_quote`.
1051
1052        omit_schema
1053          Prevent prepending schema name. Useful for databases that do
1054          not support schemae.
1055        """
1056
1057        self.dialect = dialect
1058        self.initial_quote = initial_quote
1059        self.final_quote = final_quote or self.initial_quote
1060        self.omit_schema = omit_schema
1061        self._strings = {}
1062       
1063    def _escape_identifier(self, value):
1064        """Escape an identifier.
1065
1066        Subclasses should override this to provide database-dependent
1067        escaping behavior.
1068        """
1069
1070        return value.replace('"', '""')
1071
1072    def _unescape_identifier(self, value):
1073        """Canonicalize an escaped identifier.
1074
1075        Subclasses should override this to provide database-dependent
1076        unescaping behavior that reverses _escape_identifier.
1077        """
1078
1079        return value.replace('""', '"')
1080
1081    def quote_identifier(self, value):
1082        """Quote an identifier.
1083
1084        Subclasses should override this to provide database-dependent
1085        quoting behavior.
1086        """
1087
1088        return self.initial_quote + self._escape_identifier(value) + self.final_quote
1089
1090    def _requires_quotes(self, value):
1091        """Return True if the given identifier requires quoting."""
1092        lc_value = value.lower()
1093        return (lc_value in self.reserved_words
1094                or self.illegal_initial_characters.match(value[0])
1095                or not self.legal_characters.match(unicode(value))
1096                or (lc_value != value))
1097
1098    def quote_schema(self, schema, force):
1099        """Quote a schema.
1100
1101        Subclasses should override this to provide database-dependent
1102        quoting behavior.
1103        """
1104        return self.quote(schema, force)
1105
1106    def quote(self, ident, force):
1107        if force is None:
1108            if ident in self._strings:
1109                return self._strings[ident]
1110            else:
1111                if self._requires_quotes(ident):
1112                    self._strings[ident] = self.quote_identifier(ident)
1113                else:
1114                    self._strings[ident] = ident
1115                return self._strings[ident]
1116        elif force:
1117            return self.quote_identifier(ident)
1118        else:
1119            return ident
1120
1121    def format_sequence(self, sequence, use_schema=True):
1122        name = self.quote(sequence.name, sequence.quote)
1123        if not self.omit_schema and use_schema and sequence.schema is not None:
1124            name = self.quote_schema(sequence.schema, sequence.quote) + "." + name
1125        return name
1126
1127    def format_label(self, label, name=None):
1128        return self.quote(name or label.name, label.quote)
1129
1130    def format_alias(self, alias, name=None):
1131        return self.quote(name or alias.name, alias.quote)
1132
1133    def format_savepoint(self, savepoint, name=None):
1134        return self.quote(name or savepoint.ident, savepoint.quote)
1135
1136    def format_constraint(self, constraint):
1137        return self.quote(constraint.name, constraint.quote)
1138   
1139    def format_table(self, table, use_schema=True, name=None):
1140        """Prepare a quoted table and schema name."""
1141
1142        if name is None:
1143            name = table.name
1144        result = self.quote(name, table.quote)
1145        if not self.omit_schema and use_schema and getattr(table, "schema", None):
1146            result = self.quote_schema(table.schema, table.quote_schema) + "." + result
1147        return result
1148
1149    def format_column(self, column, use_table=False, name=None, table_name=None):
1150        """Prepare a quoted column name."""
1151
1152        if name is None:
1153            name = column.name
1154        if not getattr(column, 'is_literal', False):
1155            if use_table:
1156                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
1157            else:
1158                return self.quote(name, column.quote)
1159        else:
1160            # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
1161            if use_table:
1162                return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
1163            else:
1164                return name
1165
1166    def format_table_seq(self, table, use_schema=True):
1167        """Format table name and schema as a tuple."""
1168
1169        # Dialects with more levels in their fully qualified references
1170        # ('database', 'owner', etc.) could override this and return
1171        # a longer sequence.
1172
1173        if not self.omit_schema and use_schema and getattr(table, 'schema', None):
1174            return (self.quote_schema(table.schema, table.quote_schema),
1175                    self.format_table(table, use_schema=False))
1176        else:
1177            return (self.format_table(table, use_schema=False), )
1178
1179    def unformat_identifiers(self, identifiers):
1180        """Unpack 'schema.table.column'-like strings into components."""
1181
1182        try:
1183            r = self._r_identifiers
1184        except AttributeError:
1185            initial, final, escaped_final = \
1186                     [re.escape(s) for s in
1187                      (self.initial_quote, self.final_quote,
1188                       self._escape_identifier(self.final_quote))]
1189            r = re.compile(
1190                r'(?:'
1191                r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
1192                r'|([^\.]+))(?=\.|$))+' %
1193                { 'initial': initial,
1194                  'final': final,
1195                  'escaped': escaped_final })
1196            self._r_identifiers = r
1197
1198        return [self._unescape_identifier(i)
1199                for i in [a or b for a, b in r.findall(identifiers)]]
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。