1 | # compiler.py |
---|
2 | # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com |
---|
3 | # |
---|
4 | # This module is part of SQLAlchemy and is released under |
---|
5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php |
---|
6 | |
---|
7 | """Base SQL and DDL compiler implementations. |
---|
8 | |
---|
9 | Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is |
---|
10 | responsible for generating all SQL query strings, as well as |
---|
11 | :class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper` |
---|
12 | which issue CREATE and DROP DDL for tables, sequences, and indexes. |
---|
13 | |
---|
14 | The elements in this module are used by public-facing constructs like |
---|
15 | :class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`. |
---|
16 | While dialect authors will want to be familiar with this module for the purpose of |
---|
17 | creating database-specific compilers and schema generators, the module |
---|
18 | is otherwise internal to SQLAlchemy. |
---|
19 | """ |
---|
20 | |
---|
21 | import string, re |
---|
22 | from sqlalchemy import schema, engine, util, exc |
---|
23 | from sqlalchemy.sql import operators, functions, util as sql_util, visitors |
---|
24 | from sqlalchemy.sql import expression as sql |
---|
25 | |
---|
26 | RESERVED_WORDS = set([ |
---|
27 | 'all', 'analyse', 'analyze', 'and', 'any', 'array', |
---|
28 | 'as', 'asc', 'asymmetric', 'authorization', 'between', |
---|
29 | 'binary', 'both', 'case', 'cast', 'check', 'collate', |
---|
30 | 'column', 'constraint', 'create', 'cross', 'current_date', |
---|
31 | 'current_role', 'current_time', 'current_timestamp', |
---|
32 | 'current_user', 'default', 'deferrable', 'desc', |
---|
33 | 'distinct', 'do', 'else', 'end', 'except', 'false', |
---|
34 | 'for', 'foreign', 'freeze', 'from', 'full', 'grant', |
---|
35 | 'group', 'having', 'ilike', 'in', 'initially', 'inner', |
---|
36 | 'intersect', 'into', 'is', 'isnull', 'join', 'leading', |
---|
37 | 'left', 'like', 'limit', 'localtime', 'localtimestamp', |
---|
38 | 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', |
---|
39 | 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', |
---|
40 | 'placing', 'primary', 'references', 'right', 'select', |
---|
41 | 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', |
---|
42 | 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', |
---|
43 | 'using', 'verbose', 'when', 'where']) |
---|
44 | |
---|
45 | LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) |
---|
46 | ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') |
---|
47 | |
---|
48 | BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) |
---|
49 | BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) |
---|
50 | |
---|
51 | BIND_TEMPLATES = { |
---|
52 | 'pyformat':"%%(%(name)s)s", |
---|
53 | 'qmark':"?", |
---|
54 | 'format':"%%s", |
---|
55 | 'numeric':":%(position)s", |
---|
56 | 'named':":%(name)s" |
---|
57 | } |
---|
58 | |
---|
59 | |
---|
60 | OPERATORS = { |
---|
61 | operators.and_ : 'AND', |
---|
62 | operators.or_ : 'OR', |
---|
63 | operators.inv : 'NOT', |
---|
64 | operators.add : '+', |
---|
65 | operators.mul : '*', |
---|
66 | operators.sub : '-', |
---|
67 | operators.div : '/', |
---|
68 | operators.mod : '%', |
---|
69 | operators.truediv : '/', |
---|
70 | operators.lt : '<', |
---|
71 | operators.le : '<=', |
---|
72 | operators.ne : '!=', |
---|
73 | operators.gt : '>', |
---|
74 | operators.ge : '>=', |
---|
75 | operators.eq : '=', |
---|
76 | operators.distinct_op : 'DISTINCT', |
---|
77 | operators.concat_op : '||', |
---|
78 | operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
79 | operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
80 | operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
81 | operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
82 | operators.between_op : 'BETWEEN', |
---|
83 | operators.match_op : 'MATCH', |
---|
84 | operators.in_op : 'IN', |
---|
85 | operators.notin_op : 'NOT IN', |
---|
86 | operators.comma_op : ', ', |
---|
87 | operators.desc_op : 'DESC', |
---|
88 | operators.asc_op : 'ASC', |
---|
89 | operators.from_ : 'FROM', |
---|
90 | operators.as_ : 'AS', |
---|
91 | operators.exists : 'EXISTS', |
---|
92 | operators.is_ : 'IS', |
---|
93 | operators.isnot : 'IS NOT', |
---|
94 | operators.collate : 'COLLATE', |
---|
95 | } |
---|
96 | |
---|
97 | FUNCTIONS = { |
---|
98 | functions.coalesce : 'coalesce%(expr)s', |
---|
99 | functions.current_date: 'CURRENT_DATE', |
---|
100 | functions.current_time: 'CURRENT_TIME', |
---|
101 | functions.current_timestamp: 'CURRENT_TIMESTAMP', |
---|
102 | functions.current_user: 'CURRENT_USER', |
---|
103 | functions.localtime: 'LOCALTIME', |
---|
104 | functions.localtimestamp: 'LOCALTIMESTAMP', |
---|
105 | functions.random: 'random%(expr)s', |
---|
106 | functions.sysdate: 'sysdate', |
---|
107 | functions.session_user :'SESSION_USER', |
---|
108 | functions.user: 'USER' |
---|
109 | } |
---|
110 | |
---|
111 | EXTRACT_MAP = { |
---|
112 | 'month': 'month', |
---|
113 | 'day': 'day', |
---|
114 | 'year': 'year', |
---|
115 | 'second': 'second', |
---|
116 | 'hour': 'hour', |
---|
117 | 'doy': 'doy', |
---|
118 | 'minute': 'minute', |
---|
119 | 'quarter': 'quarter', |
---|
120 | 'dow': 'dow', |
---|
121 | 'week': 'week', |
---|
122 | 'epoch': 'epoch', |
---|
123 | 'milliseconds': 'milliseconds', |
---|
124 | 'microseconds': 'microseconds', |
---|
125 | 'timezone_hour': 'timezone_hour', |
---|
126 | 'timezone_minute': 'timezone_minute' |
---|
127 | } |
---|
128 | |
---|
129 | class _CompileLabel(visitors.Visitable): |
---|
130 | """lightweight label object which acts as an expression._Label.""" |
---|
131 | |
---|
132 | __visit_name__ = 'label' |
---|
133 | __slots__ = 'element', 'name' |
---|
134 | |
---|
135 | def __init__(self, col, name): |
---|
136 | self.element = col |
---|
137 | self.name = name |
---|
138 | |
---|
139 | @property |
---|
140 | def quote(self): |
---|
141 | return self.element.quote |
---|
142 | |
---|
143 | class DefaultCompiler(engine.Compiled): |
---|
144 | """Default implementation of Compiled. |
---|
145 | |
---|
146 | Compiles ClauseElements into SQL strings. Uses a similar visit |
---|
147 | paradigm as visitors.ClauseVisitor but implements its own traversal. |
---|
148 | |
---|
149 | """ |
---|
150 | |
---|
151 | operators = OPERATORS |
---|
152 | functions = FUNCTIONS |
---|
153 | extract_map = EXTRACT_MAP |
---|
154 | |
---|
155 | # if we are insert/update/delete. |
---|
156 | # set to true when we visit an INSERT, UPDATE or DELETE |
---|
157 | isdelete = isinsert = isupdate = False |
---|
158 | |
---|
159 | def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): |
---|
160 | """Construct a new ``DefaultCompiler`` object. |
---|
161 | |
---|
162 | dialect |
---|
163 | Dialect to be used |
---|
164 | |
---|
165 | statement |
---|
166 | ClauseElement to be compiled |
---|
167 | |
---|
168 | column_keys |
---|
169 | a list of column names to be compiled into an INSERT or UPDATE |
---|
170 | statement. |
---|
171 | |
---|
172 | """ |
---|
173 | engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) |
---|
174 | |
---|
175 | # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) |
---|
176 | self.inline = inline or getattr(statement, 'inline', False) |
---|
177 | |
---|
178 | # a dictionary of bind parameter keys to _BindParamClause instances. |
---|
179 | self.binds = {} |
---|
180 | |
---|
181 | # a dictionary of _BindParamClause instances to "compiled" names that are |
---|
182 | # actually present in the generated SQL |
---|
183 | self.bind_names = util.column_dict() |
---|
184 | |
---|
185 | # stack which keeps track of nested SELECT statements |
---|
186 | self.stack = [] |
---|
187 | |
---|
188 | # relates label names in the final SQL to |
---|
189 | # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. |
---|
190 | # ResultProxy uses this for type processing and column targeting |
---|
191 | self.result_map = {} |
---|
192 | |
---|
193 | # true if the paramstyle is positional |
---|
194 | self.positional = self.dialect.positional |
---|
195 | if self.positional: |
---|
196 | self.positiontup = [] |
---|
197 | |
---|
198 | self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] |
---|
199 | |
---|
200 | # an IdentifierPreparer that formats the quoting of identifiers |
---|
201 | self.preparer = self.dialect.identifier_preparer |
---|
202 | |
---|
203 | self.label_length = self.dialect.label_length or self.dialect.max_identifier_length |
---|
204 | |
---|
205 | # a map which tracks "anonymous" identifiers that are |
---|
206 | # created on the fly here |
---|
207 | self.anon_map = util.PopulateDict(self._process_anon) |
---|
208 | |
---|
209 | # a map which tracks "truncated" names based on dialect.label_length |
---|
210 | # or dialect.max_identifier_length |
---|
211 | self.truncated_names = {} |
---|
212 | |
---|
213 | def compile(self): |
---|
214 | self.string = self.process(self.statement) |
---|
215 | |
---|
216 | def process(self, obj, **kwargs): |
---|
217 | return obj._compiler_dispatch(self, **kwargs) |
---|
218 | |
---|
219 | def is_subquery(self): |
---|
220 | return len(self.stack) > 1 |
---|
221 | |
---|
222 | def construct_params(self, params=None): |
---|
223 | """return a dictionary of bind parameter keys and values""" |
---|
224 | |
---|
225 | if params: |
---|
226 | params = util.column_dict(params) |
---|
227 | pd = {} |
---|
228 | for bindparam, name in self.bind_names.iteritems(): |
---|
229 | for paramname in (bindparam.key, bindparam.shortname, name): |
---|
230 | if paramname in params: |
---|
231 | pd[name] = params[paramname] |
---|
232 | break |
---|
233 | else: |
---|
234 | if util.callable(bindparam.value): |
---|
235 | pd[name] = bindparam.value() |
---|
236 | else: |
---|
237 | pd[name] = bindparam.value |
---|
238 | return pd |
---|
239 | else: |
---|
240 | pd = {} |
---|
241 | for bindparam in self.bind_names: |
---|
242 | if util.callable(bindparam.value): |
---|
243 | pd[self.bind_names[bindparam]] = bindparam.value() |
---|
244 | else: |
---|
245 | pd[self.bind_names[bindparam]] = bindparam.value |
---|
246 | return pd |
---|
247 | |
---|
248 | params = property(construct_params) |
---|
249 | |
---|
250 | def default_from(self): |
---|
251 | """Called when a SELECT statement has no froms, and no FROM clause is to be appended. |
---|
252 | |
---|
253 | Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. |
---|
254 | |
---|
255 | """ |
---|
256 | return "" |
---|
257 | |
---|
258 | def visit_grouping(self, grouping, **kwargs): |
---|
259 | return "(" + self.process(grouping.element) + ")" |
---|
260 | |
---|
261 | def visit_label(self, label, result_map=None, within_columns_clause=False): |
---|
262 | # only render labels within the columns clause |
---|
263 | # or ORDER BY clause of a select. dialect-specific compilers |
---|
264 | # can modify this behavior. |
---|
265 | if within_columns_clause: |
---|
266 | labelname = isinstance(label.name, sql._generated_label) and \ |
---|
267 | self._truncated_identifier("colident", label.name) or label.name |
---|
268 | |
---|
269 | if result_map is not None: |
---|
270 | result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) |
---|
271 | |
---|
272 | return self.process(label.element) + " " + \ |
---|
273 | self.operator_string(operators.as_) + " " + \ |
---|
274 | self.preparer.format_label(label, labelname) |
---|
275 | else: |
---|
276 | return self.process(label.element) |
---|
277 | |
---|
278 | def visit_column(self, column, result_map=None, **kwargs): |
---|
279 | name = column.name |
---|
280 | if not column.is_literal and isinstance(name, sql._generated_label): |
---|
281 | name = self._truncated_identifier("colident", name) |
---|
282 | |
---|
283 | if result_map is not None: |
---|
284 | result_map[name.lower()] = (name, (column, ), column.type) |
---|
285 | |
---|
286 | if column.is_literal: |
---|
287 | name = self.escape_literal_column(name) |
---|
288 | else: |
---|
289 | name = self.preparer.quote(name, column.quote) |
---|
290 | |
---|
291 | if column.table is None or not column.table.named_with_column: |
---|
292 | return name |
---|
293 | else: |
---|
294 | if column.table.schema: |
---|
295 | schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' |
---|
296 | else: |
---|
297 | schema_prefix = '' |
---|
298 | tablename = column.table.name |
---|
299 | tablename = isinstance(tablename, sql._generated_label) and \ |
---|
300 | self._truncated_identifier("alias", tablename) or tablename |
---|
301 | |
---|
302 | return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name |
---|
303 | |
---|
304 | def escape_literal_column(self, text): |
---|
305 | """provide escaping for the literal_column() construct.""" |
---|
306 | |
---|
307 | # TODO: some dialects might need different behavior here |
---|
308 | return text.replace('%', '%%') |
---|
309 | |
---|
310 | def visit_fromclause(self, fromclause, **kwargs): |
---|
311 | return fromclause.name |
---|
312 | |
---|
313 | def visit_index(self, index, **kwargs): |
---|
314 | return index.name |
---|
315 | |
---|
316 | def visit_typeclause(self, typeclause, **kwargs): |
---|
317 | return typeclause.type.dialect_impl(self.dialect).get_col_spec() |
---|
318 | |
---|
319 | def post_process_text(self, text): |
---|
320 | return text |
---|
321 | |
---|
322 | def visit_textclause(self, textclause, **kwargs): |
---|
323 | if textclause.typemap is not None: |
---|
324 | for colname, type_ in textclause.typemap.iteritems(): |
---|
325 | self.result_map[colname.lower()] = (colname, None, type_) |
---|
326 | |
---|
327 | def do_bindparam(m): |
---|
328 | name = m.group(1) |
---|
329 | if name in textclause.bindparams: |
---|
330 | return self.process(textclause.bindparams[name]) |
---|
331 | else: |
---|
332 | return self.bindparam_string(name) |
---|
333 | |
---|
334 | # un-escape any \:params |
---|
335 | return BIND_PARAMS_ESC.sub(lambda m: m.group(1), |
---|
336 | BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)) |
---|
337 | ) |
---|
338 | |
---|
339 | def visit_null(self, null, **kwargs): |
---|
340 | return 'NULL' |
---|
341 | |
---|
342 | def visit_clauselist(self, clauselist, **kwargs): |
---|
343 | sep = clauselist.operator |
---|
344 | if sep is None: |
---|
345 | sep = " " |
---|
346 | elif sep is operators.comma_op: |
---|
347 | sep = ', ' |
---|
348 | else: |
---|
349 | sep = " " + self.operator_string(clauselist.operator) + " " |
---|
350 | return sep.join(s for s in (self.process(c) for c in clauselist.clauses) |
---|
351 | if s is not None) |
---|
352 | |
---|
353 | def visit_case(self, clause, **kwargs): |
---|
354 | x = "CASE " |
---|
355 | if clause.value: |
---|
356 | x += self.process(clause.value) + " " |
---|
357 | for cond, result in clause.whens: |
---|
358 | x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " " |
---|
359 | if clause.else_: |
---|
360 | x += "ELSE " + self.process(clause.else_) + " " |
---|
361 | x += "END" |
---|
362 | return x |
---|
363 | |
---|
364 | def visit_cast(self, cast, **kwargs): |
---|
365 | return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) |
---|
366 | |
---|
367 | def visit_extract(self, extract, **kwargs): |
---|
368 | field = self.extract_map.get(extract.field, extract.field) |
---|
369 | return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr)) |
---|
370 | |
---|
371 | def visit_function(self, func, result_map=None, **kwargs): |
---|
372 | if result_map is not None: |
---|
373 | result_map[func.name.lower()] = (func.name, None, func.type) |
---|
374 | |
---|
375 | name = self.function_string(func) |
---|
376 | |
---|
377 | if util.callable(name): |
---|
378 | return name(*[self.process(x) for x in func.clauses]) |
---|
379 | else: |
---|
380 | return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} |
---|
381 | |
---|
382 | def function_argspec(self, func, **kwargs): |
---|
383 | return self.process(func.clause_expr, **kwargs) |
---|
384 | |
---|
385 | def function_string(self, func): |
---|
386 | return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) |
---|
387 | |
---|
388 | def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): |
---|
389 | entry = self.stack and self.stack[-1] or {} |
---|
390 | self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) |
---|
391 | |
---|
392 | text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) |
---|
393 | for i, c in enumerate(cs.selects)), |
---|
394 | " " + cs.keyword + " ") |
---|
395 | group_by = self.process(cs._group_by_clause, asfrom=asfrom) |
---|
396 | if group_by: |
---|
397 | text += " GROUP BY " + group_by |
---|
398 | |
---|
399 | text += self.order_by_clause(cs) |
---|
400 | text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" |
---|
401 | |
---|
402 | self.stack.pop(-1) |
---|
403 | if asfrom and parens: |
---|
404 | return "(" + text + ")" |
---|
405 | else: |
---|
406 | return text |
---|
407 | |
---|
408 | def visit_unary(self, unary, **kw): |
---|
409 | s = self.process(unary.element, **kw) |
---|
410 | if unary.operator: |
---|
411 | s = self.operator_string(unary.operator) + " " + s |
---|
412 | if unary.modifier: |
---|
413 | s = s + " " + self.operator_string(unary.modifier) |
---|
414 | return s |
---|
415 | |
---|
416 | def visit_binary(self, binary, **kwargs): |
---|
417 | op = self.operator_string(binary.operator) |
---|
418 | if util.callable(op): |
---|
419 | return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) |
---|
420 | else: |
---|
421 | return self.process(binary.left) + " " + op + " " + self.process(binary.right) |
---|
422 | |
---|
423 | def operator_string(self, operator): |
---|
424 | return self.operators.get(operator, str(operator)) |
---|
425 | |
---|
426 | def visit_bindparam(self, bindparam, **kwargs): |
---|
427 | name = self._truncate_bindparam(bindparam) |
---|
428 | if name in self.binds: |
---|
429 | existing = self.binds[name] |
---|
430 | if existing is not bindparam and (existing.unique or bindparam.unique): |
---|
431 | raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) |
---|
432 | self.binds[bindparam.key] = self.binds[name] = bindparam |
---|
433 | return self.bindparam_string(name) |
---|
434 | |
---|
435 | def _truncate_bindparam(self, bindparam): |
---|
436 | if bindparam in self.bind_names: |
---|
437 | return self.bind_names[bindparam] |
---|
438 | |
---|
439 | bind_name = bindparam.key |
---|
440 | bind_name = isinstance(bind_name, sql._generated_label) and \ |
---|
441 | self._truncated_identifier("bindparam", bind_name) or bind_name |
---|
442 | # add to bind_names for translation |
---|
443 | self.bind_names[bindparam] = bind_name |
---|
444 | |
---|
445 | return bind_name |
---|
446 | |
---|
447 | def _truncated_identifier(self, ident_class, name): |
---|
448 | if (ident_class, name) in self.truncated_names: |
---|
449 | return self.truncated_names[(ident_class, name)] |
---|
450 | |
---|
451 | anonname = name % self.anon_map |
---|
452 | |
---|
453 | if len(anonname) > self.label_length: |
---|
454 | counter = self.truncated_names.get(ident_class, 1) |
---|
455 | truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] |
---|
456 | self.truncated_names[ident_class] = counter + 1 |
---|
457 | else: |
---|
458 | truncname = anonname |
---|
459 | self.truncated_names[(ident_class, name)] = truncname |
---|
460 | return truncname |
---|
461 | |
---|
462 | def _anonymize(self, name): |
---|
463 | return name % self.anon_map |
---|
464 | |
---|
465 | def _process_anon(self, key): |
---|
466 | (ident, derived) = key.split(' ', 1) |
---|
467 | anonymous_counter = self.anon_map.get(derived, 1) |
---|
468 | self.anon_map[derived] = anonymous_counter + 1 |
---|
469 | return derived + "_" + str(anonymous_counter) |
---|
470 | |
---|
471 | def bindparam_string(self, name): |
---|
472 | if self.positional: |
---|
473 | self.positiontup.append(name) |
---|
474 | return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} |
---|
475 | else: |
---|
476 | return self.bindtemplate % {'name':name} |
---|
477 | |
---|
478 | def visit_alias(self, alias, asfrom=False, **kwargs): |
---|
479 | if asfrom: |
---|
480 | alias_name = isinstance(alias.name, sql._generated_label) and \ |
---|
481 | self._truncated_identifier("alias", alias.name) or alias.name |
---|
482 | |
---|
483 | return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ |
---|
484 | self.preparer.format_alias(alias, alias_name) |
---|
485 | else: |
---|
486 | return self.process(alias.original, **kwargs) |
---|
487 | |
---|
488 | def label_select_column(self, select, column, asfrom): |
---|
489 | """label columns present in a select().""" |
---|
490 | |
---|
491 | if isinstance(column, sql._Label): |
---|
492 | return column |
---|
493 | |
---|
494 | if select.use_labels and column._label: |
---|
495 | return _CompileLabel(column, column._label) |
---|
496 | |
---|
497 | if \ |
---|
498 | asfrom and \ |
---|
499 | isinstance(column, sql.ColumnClause) and \ |
---|
500 | not column.is_literal and \ |
---|
501 | column.table is not None and \ |
---|
502 | not isinstance(column.table, sql.Select): |
---|
503 | return _CompileLabel(column, sql._generated_label(column.name)) |
---|
504 | elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ |
---|
505 | and (not hasattr(column, 'name') or isinstance(column, sql.Function)): |
---|
506 | return _CompileLabel(column, column.anon_label) |
---|
507 | else: |
---|
508 | return column |
---|
509 | |
---|
510 | def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): |
---|
511 | |
---|
512 | entry = self.stack and self.stack[-1] or {} |
---|
513 | |
---|
514 | existingfroms = entry.get('from', None) |
---|
515 | |
---|
516 | froms = select._get_display_froms(existingfroms) |
---|
517 | |
---|
518 | correlate_froms = set(sql._from_objects(*froms)) |
---|
519 | |
---|
520 | # TODO: might want to propagate existing froms for select(select(select)) |
---|
521 | # where innermost select should correlate to outermost |
---|
522 | # if existingfroms: |
---|
523 | # correlate_froms = correlate_froms.union(existingfroms) |
---|
524 | |
---|
525 | self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) |
---|
526 | |
---|
527 | if compound_index==1 and not entry or entry.get('iswrapper', False): |
---|
528 | column_clause_args = {'result_map':self.result_map} |
---|
529 | else: |
---|
530 | column_clause_args = {} |
---|
531 | |
---|
532 | # the actual list of columns to print in the SELECT column list. |
---|
533 | inner_columns = [ |
---|
534 | c for c in [ |
---|
535 | self.process( |
---|
536 | self.label_select_column(select, co, asfrom=asfrom), |
---|
537 | within_columns_clause=True, |
---|
538 | **column_clause_args) |
---|
539 | for co in util.unique_list(select.inner_columns) |
---|
540 | ] |
---|
541 | if c is not None |
---|
542 | ] |
---|
543 | |
---|
544 | text = "SELECT " # we're off to a good start ! |
---|
545 | if select._prefixes: |
---|
546 | text += " ".join(self.process(x) for x in select._prefixes) + " " |
---|
547 | text += self.get_select_precolumns(select) |
---|
548 | text += ', '.join(inner_columns) |
---|
549 | |
---|
550 | if froms: |
---|
551 | text += " \nFROM " |
---|
552 | text += ', '.join(self.process(f, asfrom=True) for f in froms) |
---|
553 | else: |
---|
554 | text += self.default_from() |
---|
555 | |
---|
556 | if select._whereclause is not None: |
---|
557 | t = self.process(select._whereclause) |
---|
558 | if t: |
---|
559 | text += " \nWHERE " + t |
---|
560 | |
---|
561 | if select._group_by_clause.clauses: |
---|
562 | group_by = self.process(select._group_by_clause) |
---|
563 | if group_by: |
---|
564 | text += " GROUP BY " + group_by |
---|
565 | |
---|
566 | if select._having is not None: |
---|
567 | t = self.process(select._having) |
---|
568 | if t: |
---|
569 | text += " \nHAVING " + t |
---|
570 | |
---|
571 | if select._order_by_clause.clauses: |
---|
572 | text += self.order_by_clause(select) |
---|
573 | if select._limit is not None or select._offset is not None: |
---|
574 | text += self.limit_clause(select) |
---|
575 | if select.for_update: |
---|
576 | text += self.for_update_clause(select) |
---|
577 | |
---|
578 | self.stack.pop(-1) |
---|
579 | |
---|
580 | if asfrom and parens: |
---|
581 | return "(" + text + ")" |
---|
582 | else: |
---|
583 | return text |
---|
584 | |
---|
585 | def get_select_precolumns(self, select): |
---|
586 | """Called when building a ``SELECT`` statement, position is just before column list.""" |
---|
587 | |
---|
588 | return select._distinct and "DISTINCT " or "" |
---|
589 | |
---|
590 | def order_by_clause(self, select): |
---|
591 | order_by = self.process(select._order_by_clause) |
---|
592 | if order_by: |
---|
593 | return " ORDER BY " + order_by |
---|
594 | else: |
---|
595 | return "" |
---|
596 | |
---|
597 | def for_update_clause(self, select): |
---|
598 | if select.for_update: |
---|
599 | return " FOR UPDATE" |
---|
600 | else: |
---|
601 | return "" |
---|
602 | |
---|
603 | def limit_clause(self, select): |
---|
604 | text = "" |
---|
605 | if select._limit is not None: |
---|
606 | text += " \n LIMIT " + str(select._limit) |
---|
607 | if select._offset is not None: |
---|
608 | if select._limit is None: |
---|
609 | text += " \n LIMIT -1" |
---|
610 | text += " OFFSET " + str(select._offset) |
---|
611 | return text |
---|
612 | |
---|
613 | def visit_table(self, table, asfrom=False, **kwargs): |
---|
614 | if asfrom: |
---|
615 | if getattr(table, "schema", None): |
---|
616 | return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) |
---|
617 | else: |
---|
618 | return self.preparer.quote(table.name, table.quote) |
---|
619 | else: |
---|
620 | return "" |
---|
621 | |
---|
622 | def visit_join(self, join, asfrom=False, **kwargs): |
---|
623 | return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ |
---|
624 | self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) |
---|
625 | |
---|
626 | def visit_sequence(self, seq): |
---|
627 | return None |
---|
628 | |
---|
629 | def visit_insert(self, insert_stmt): |
---|
630 | self.isinsert = True |
---|
631 | colparams = self._get_colparams(insert_stmt) |
---|
632 | preparer = self.preparer |
---|
633 | |
---|
634 | insert = ' '.join(["INSERT"] + |
---|
635 | [self.process(x) for x in insert_stmt._prefixes]) |
---|
636 | |
---|
637 | if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: |
---|
638 | raise exc.CompileError( |
---|
639 | "The version of %s you are using does not support empty inserts." % self.dialect.name) |
---|
640 | elif not colparams and self.dialect.supports_default_values: |
---|
641 | return (insert + " INTO %s DEFAULT VALUES" % ( |
---|
642 | (preparer.format_table(insert_stmt.table),))) |
---|
643 | else: |
---|
644 | return (insert + " INTO %s (%s) VALUES (%s)" % |
---|
645 | (preparer.format_table(insert_stmt.table), |
---|
646 | ', '.join([preparer.format_column(c[0]) |
---|
647 | for c in colparams]), |
---|
648 | ', '.join([c[1] for c in colparams]))) |
---|
649 | |
---|
650 | def visit_update(self, update_stmt): |
---|
651 | self.stack.append({'from': set([update_stmt.table])}) |
---|
652 | |
---|
653 | self.isupdate = True |
---|
654 | colparams = self._get_colparams(update_stmt) |
---|
655 | |
---|
656 | text = ' '.join(( |
---|
657 | "UPDATE", |
---|
658 | self.preparer.format_table(update_stmt.table), |
---|
659 | 'SET', |
---|
660 | ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] |
---|
661 | for c in colparams) |
---|
662 | )) |
---|
663 | |
---|
664 | if update_stmt._whereclause: |
---|
665 | text += " WHERE " + self.process(update_stmt._whereclause) |
---|
666 | |
---|
667 | self.stack.pop(-1) |
---|
668 | |
---|
669 | return text |
---|
670 | |
---|
671 | def _get_colparams(self, stmt): |
---|
672 | """create a set of tuples representing column/string pairs for use |
---|
673 | in an INSERT or UPDATE statement. |
---|
674 | |
---|
675 | """ |
---|
676 | |
---|
677 | def create_bind_param(col, value): |
---|
678 | bindparam = sql.bindparam(col.key, value, type_=col.type) |
---|
679 | self.binds[col.key] = bindparam |
---|
680 | return self.bindparam_string(self._truncate_bindparam(bindparam)) |
---|
681 | |
---|
682 | self.postfetch = [] |
---|
683 | self.prefetch = [] |
---|
684 | |
---|
685 | # no parameters in the statement, no parameters in the |
---|
686 | # compiled params - return binds for all columns |
---|
687 | if self.column_keys is None and stmt.parameters is None: |
---|
688 | return [(c, create_bind_param(c, None)) for c in stmt.table.columns] |
---|
689 | |
---|
690 | # if we have statement parameters - set defaults in the |
---|
691 | # compiled params |
---|
692 | if self.column_keys is None: |
---|
693 | parameters = {} |
---|
694 | else: |
---|
695 | parameters = dict((sql._column_as_key(key), None) |
---|
696 | for key in self.column_keys) |
---|
697 | |
---|
698 | if stmt.parameters is not None: |
---|
699 | for k, v in stmt.parameters.iteritems(): |
---|
700 | parameters.setdefault(sql._column_as_key(k), v) |
---|
701 | |
---|
702 | # create a list of column assignment clauses as tuples |
---|
703 | values = [] |
---|
704 | for c in stmt.table.columns: |
---|
705 | if c.key in parameters: |
---|
706 | value = parameters[c.key] |
---|
707 | if sql._is_literal(value): |
---|
708 | value = create_bind_param(c, value) |
---|
709 | else: |
---|
710 | self.postfetch.append(c) |
---|
711 | value = self.process(value.self_group()) |
---|
712 | values.append((c, value)) |
---|
713 | elif isinstance(c, schema.Column): |
---|
714 | if self.isinsert: |
---|
715 | if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): |
---|
716 | if (((isinstance(c.default, schema.Sequence) and |
---|
717 | not c.default.optional) or |
---|
718 | not self.dialect.supports_pk_autoincrement) or |
---|
719 | (c.default is not None and |
---|
720 | not isinstance(c.default, schema.Sequence))): |
---|
721 | values.append((c, create_bind_param(c, None))) |
---|
722 | self.prefetch.append(c) |
---|
723 | elif isinstance(c.default, schema.ColumnDefault): |
---|
724 | if isinstance(c.default.arg, sql.ClauseElement): |
---|
725 | values.append((c, self.process(c.default.arg.self_group()))) |
---|
726 | if not c.primary_key: |
---|
727 | # dont add primary key column to postfetch |
---|
728 | self.postfetch.append(c) |
---|
729 | else: |
---|
730 | values.append((c, create_bind_param(c, None))) |
---|
731 | self.prefetch.append(c) |
---|
732 | elif c.server_default is not None: |
---|
733 | if not c.primary_key: |
---|
734 | self.postfetch.append(c) |
---|
735 | elif isinstance(c.default, schema.Sequence): |
---|
736 | proc = self.process(c.default) |
---|
737 | if proc is not None: |
---|
738 | values.append((c, proc)) |
---|
739 | if not c.primary_key: |
---|
740 | self.postfetch.append(c) |
---|
741 | elif self.isupdate: |
---|
742 | if isinstance(c.onupdate, schema.ColumnDefault): |
---|
743 | if isinstance(c.onupdate.arg, sql.ClauseElement): |
---|
744 | values.append((c, self.process(c.onupdate.arg.self_group()))) |
---|
745 | self.postfetch.append(c) |
---|
746 | else: |
---|
747 | values.append((c, create_bind_param(c, None))) |
---|
748 | self.prefetch.append(c) |
---|
749 | elif c.server_onupdate is not None: |
---|
750 | self.postfetch.append(c) |
---|
751 | # deprecated? or remove? |
---|
752 | elif isinstance(c.onupdate, schema.FetchedValue): |
---|
753 | self.postfetch.append(c) |
---|
754 | return values |
---|
755 | |
---|
756 | def visit_delete(self, delete_stmt): |
---|
757 | self.stack.append({'from': set([delete_stmt.table])}) |
---|
758 | self.isdelete = True |
---|
759 | |
---|
760 | text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) |
---|
761 | |
---|
762 | if delete_stmt._whereclause: |
---|
763 | text += " WHERE " + self.process(delete_stmt._whereclause) |
---|
764 | |
---|
765 | self.stack.pop(-1) |
---|
766 | |
---|
767 | return text |
---|
768 | |
---|
769 | def visit_savepoint(self, savepoint_stmt): |
---|
770 | return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
771 | |
---|
772 | def visit_rollback_to_savepoint(self, savepoint_stmt): |
---|
773 | return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
774 | |
---|
775 | def visit_release_savepoint(self, savepoint_stmt): |
---|
776 | return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
777 | |
---|
778 | def __str__(self): |
---|
779 | return self.string or '' |
---|
780 | |
---|
781 | class DDLBase(engine.SchemaIterator): |
---|
782 | def find_alterables(self, tables): |
---|
783 | alterables = [] |
---|
784 | class FindAlterables(schema.SchemaVisitor): |
---|
785 | def visit_foreign_key_constraint(self, constraint): |
---|
786 | if constraint.use_alter and constraint.table in tables: |
---|
787 | alterables.append(constraint) |
---|
788 | findalterables = FindAlterables() |
---|
789 | for table in tables: |
---|
790 | for c in table.constraints: |
---|
791 | findalterables.traverse(c) |
---|
792 | return alterables |
---|
793 | |
---|
794 | def _validate_identifier(self, ident, truncate): |
---|
795 | if truncate: |
---|
796 | if len(ident) > self.dialect.max_identifier_length: |
---|
797 | counter = getattr(self, 'counter', 0) |
---|
798 | self.counter = counter + 1 |
---|
799 | return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] |
---|
800 | else: |
---|
801 | return ident |
---|
802 | else: |
---|
803 | self.dialect.validate_identifier(ident) |
---|
804 | return ident |
---|
805 | |
---|
806 | |
---|
807 | class SchemaGenerator(DDLBase): |
---|
808 | def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): |
---|
809 | super(SchemaGenerator, self).__init__(connection, **kwargs) |
---|
810 | self.checkfirst = checkfirst |
---|
811 | self.tables = tables and set(tables) or None |
---|
812 | self.preparer = dialect.identifier_preparer |
---|
813 | self.dialect = dialect |
---|
814 | |
---|
815 | def get_column_specification(self, column, first_pk=False): |
---|
816 | raise NotImplementedError() |
---|
817 | |
---|
818 | def _can_create(self, table): |
---|
819 | self.dialect.validate_identifier(table.name) |
---|
820 | if table.schema: |
---|
821 | self.dialect.validate_identifier(table.schema) |
---|
822 | return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) |
---|
823 | |
---|
824 | def visit_metadata(self, metadata): |
---|
825 | if self.tables: |
---|
826 | tables = self.tables |
---|
827 | else: |
---|
828 | tables = metadata.tables.values() |
---|
829 | collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] |
---|
830 | for table in collection: |
---|
831 | self.traverse_single(table) |
---|
832 | if self.dialect.supports_alter: |
---|
833 | for alterable in self.find_alterables(collection): |
---|
834 | self.add_foreignkey(alterable) |
---|
835 | |
---|
836 | def visit_table(self, table): |
---|
837 | for listener in table.ddl_listeners['before-create']: |
---|
838 | listener('before-create', table, self.connection) |
---|
839 | |
---|
840 | for column in table.columns: |
---|
841 | if column.default is not None: |
---|
842 | self.traverse_single(column.default) |
---|
843 | |
---|
844 | self.append("\n" + " ".join(['CREATE'] + |
---|
845 | table._prefixes + |
---|
846 | ['TABLE', |
---|
847 | self.preparer.format_table(table), |
---|
848 | "("])) |
---|
849 | separator = "\n" |
---|
850 | |
---|
851 | # if only one primary key, specify it along with the column |
---|
852 | first_pk = False |
---|
853 | for column in table.columns: |
---|
854 | self.append(separator) |
---|
855 | separator = ", \n" |
---|
856 | self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) |
---|
857 | if column.primary_key: |
---|
858 | first_pk = True |
---|
859 | for constraint in column.constraints: |
---|
860 | self.traverse_single(constraint) |
---|
861 | |
---|
862 | # On some DB order is significant: visit PK first, then the |
---|
863 | # other constraints (engine.ReflectionTest.testbasic failed on FB2) |
---|
864 | if table.primary_key: |
---|
865 | self.traverse_single(table.primary_key) |
---|
866 | for constraint in [c for c in table.constraints if c is not table.primary_key]: |
---|
867 | self.traverse_single(constraint) |
---|
868 | |
---|
869 | self.append("\n)%s\n\n" % self.post_create_table(table)) |
---|
870 | self.execute() |
---|
871 | |
---|
872 | if hasattr(table, 'indexes'): |
---|
873 | for index in table.indexes: |
---|
874 | self.traverse_single(index) |
---|
875 | |
---|
876 | for listener in table.ddl_listeners['after-create']: |
---|
877 | listener('after-create', table, self.connection) |
---|
878 | |
---|
879 | def post_create_table(self, table): |
---|
880 | return '' |
---|
881 | |
---|
882 | def get_column_default_string(self, column): |
---|
883 | if isinstance(column.server_default, schema.DefaultClause): |
---|
884 | if isinstance(column.server_default.arg, basestring): |
---|
885 | return "'%s'" % column.server_default.arg |
---|
886 | else: |
---|
887 | return unicode(self._compile(column.server_default.arg, None)) |
---|
888 | else: |
---|
889 | return None |
---|
890 | |
---|
891 | def _compile(self, tocompile, parameters): |
---|
892 | """compile the given string/parameters using this SchemaGenerator's dialect.""" |
---|
893 | compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) |
---|
894 | compiler.compile() |
---|
895 | return compiler |
---|
896 | |
---|
897 | def visit_check_constraint(self, constraint): |
---|
898 | self.append(", \n\t") |
---|
899 | if constraint.name is not None: |
---|
900 | self.append("CONSTRAINT %s " % |
---|
901 | self.preparer.format_constraint(constraint)) |
---|
902 | self.append(" CHECK (%s)" % constraint.sqltext) |
---|
903 | self.define_constraint_deferrability(constraint) |
---|
904 | |
---|
905 | def visit_column_check_constraint(self, constraint): |
---|
906 | self.append(" CHECK (%s)" % constraint.sqltext) |
---|
907 | self.define_constraint_deferrability(constraint) |
---|
908 | |
---|
909 | def visit_primary_key_constraint(self, constraint): |
---|
910 | if len(constraint) == 0: |
---|
911 | return |
---|
912 | self.append(", \n\t") |
---|
913 | if constraint.name is not None: |
---|
914 | self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) |
---|
915 | self.append("PRIMARY KEY ") |
---|
916 | self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) |
---|
917 | for c in constraint)) |
---|
918 | self.define_constraint_deferrability(constraint) |
---|
919 | |
---|
920 | def visit_foreign_key_constraint(self, constraint): |
---|
921 | if constraint.use_alter and self.dialect.supports_alter: |
---|
922 | return |
---|
923 | self.append(", \n\t ") |
---|
924 | self.define_foreign_key(constraint) |
---|
925 | |
---|
926 | def add_foreignkey(self, constraint): |
---|
927 | self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) |
---|
928 | self.define_foreign_key(constraint) |
---|
929 | self.execute() |
---|
930 | |
---|
931 | def define_foreign_key(self, constraint): |
---|
932 | preparer = self.preparer |
---|
933 | if constraint.name is not None: |
---|
934 | self.append("CONSTRAINT %s " % |
---|
935 | preparer.format_constraint(constraint)) |
---|
936 | table = list(constraint.elements)[0].column.table |
---|
937 | self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( |
---|
938 | ', '.join(preparer.quote(f.parent.name, f.parent.quote) |
---|
939 | for f in constraint.elements), |
---|
940 | preparer.format_table(table), |
---|
941 | ', '.join(preparer.quote(f.column.name, f.column.quote) |
---|
942 | for f in constraint.elements) |
---|
943 | )) |
---|
944 | if constraint.ondelete is not None: |
---|
945 | self.append(" ON DELETE %s" % constraint.ondelete) |
---|
946 | if constraint.onupdate is not None: |
---|
947 | self.append(" ON UPDATE %s" % constraint.onupdate) |
---|
948 | self.define_constraint_deferrability(constraint) |
---|
949 | |
---|
950 | def visit_unique_constraint(self, constraint): |
---|
951 | self.append(", \n\t") |
---|
952 | if constraint.name is not None: |
---|
953 | self.append("CONSTRAINT %s " % |
---|
954 | self.preparer.format_constraint(constraint)) |
---|
955 | self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) |
---|
956 | self.define_constraint_deferrability(constraint) |
---|
957 | |
---|
958 | def define_constraint_deferrability(self, constraint): |
---|
959 | if constraint.deferrable is not None: |
---|
960 | if constraint.deferrable: |
---|
961 | self.append(" DEFERRABLE") |
---|
962 | else: |
---|
963 | self.append(" NOT DEFERRABLE") |
---|
964 | if constraint.initially is not None: |
---|
965 | self.append(" INITIALLY %s" % constraint.initially) |
---|
966 | |
---|
967 | def visit_column(self, column): |
---|
968 | pass |
---|
969 | |
---|
970 | def visit_index(self, index): |
---|
971 | preparer = self.preparer |
---|
972 | self.append("CREATE ") |
---|
973 | if index.unique: |
---|
974 | self.append("UNIQUE ") |
---|
975 | self.append("INDEX %s ON %s (%s)" \ |
---|
976 | % (preparer.quote(self._validate_identifier(index.name, True), index.quote), |
---|
977 | preparer.format_table(index.table), |
---|
978 | ', '.join(preparer.quote(c.name, c.quote) |
---|
979 | for c in index.columns))) |
---|
980 | self.execute() |
---|
981 | |
---|
982 | |
---|
983 | class SchemaDropper(DDLBase): |
---|
984 | def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): |
---|
985 | super(SchemaDropper, self).__init__(connection, **kwargs) |
---|
986 | self.checkfirst = checkfirst |
---|
987 | self.tables = tables |
---|
988 | self.preparer = dialect.identifier_preparer |
---|
989 | self.dialect = dialect |
---|
990 | |
---|
991 | def visit_metadata(self, metadata): |
---|
992 | if self.tables: |
---|
993 | tables = self.tables |
---|
994 | else: |
---|
995 | tables = metadata.tables.values() |
---|
996 | collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] |
---|
997 | if self.dialect.supports_alter: |
---|
998 | for alterable in self.find_alterables(collection): |
---|
999 | self.drop_foreignkey(alterable) |
---|
1000 | for table in collection: |
---|
1001 | self.traverse_single(table) |
---|
1002 | |
---|
1003 | def _can_drop(self, table): |
---|
1004 | self.dialect.validate_identifier(table.name) |
---|
1005 | if table.schema: |
---|
1006 | self.dialect.validate_identifier(table.schema) |
---|
1007 | return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) |
---|
1008 | |
---|
1009 | def visit_index(self, index): |
---|
1010 | self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) |
---|
1011 | self.execute() |
---|
1012 | |
---|
1013 | def drop_foreignkey(self, constraint): |
---|
1014 | self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( |
---|
1015 | self.preparer.format_table(constraint.table), |
---|
1016 | self.preparer.format_constraint(constraint))) |
---|
1017 | self.execute() |
---|
1018 | |
---|
1019 | def visit_table(self, table): |
---|
1020 | for listener in table.ddl_listeners['before-drop']: |
---|
1021 | listener('before-drop', table, self.connection) |
---|
1022 | |
---|
1023 | for column in table.columns: |
---|
1024 | if column.default is not None: |
---|
1025 | self.traverse_single(column.default) |
---|
1026 | |
---|
1027 | self.append("\nDROP TABLE " + self.preparer.format_table(table)) |
---|
1028 | self.execute() |
---|
1029 | |
---|
1030 | for listener in table.ddl_listeners['after-drop']: |
---|
1031 | listener('after-drop', table, self.connection) |
---|
1032 | |
---|
1033 | |
---|
1034 | class IdentifierPreparer(object): |
---|
1035 | """Handle quoting and case-folding of identifiers based on options.""" |
---|
1036 | |
---|
1037 | reserved_words = RESERVED_WORDS |
---|
1038 | |
---|
1039 | legal_characters = LEGAL_CHARACTERS |
---|
1040 | |
---|
1041 | illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS |
---|
1042 | |
---|
1043 | def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): |
---|
1044 | """Construct a new ``IdentifierPreparer`` object. |
---|
1045 | |
---|
1046 | initial_quote |
---|
1047 | Character that begins a delimited identifier. |
---|
1048 | |
---|
1049 | final_quote |
---|
1050 | Character that ends a delimited identifier. Defaults to `initial_quote`. |
---|
1051 | |
---|
1052 | omit_schema |
---|
1053 | Prevent prepending schema name. Useful for databases that do |
---|
1054 | not support schemae. |
---|
1055 | """ |
---|
1056 | |
---|
1057 | self.dialect = dialect |
---|
1058 | self.initial_quote = initial_quote |
---|
1059 | self.final_quote = final_quote or self.initial_quote |
---|
1060 | self.omit_schema = omit_schema |
---|
1061 | self._strings = {} |
---|
1062 | |
---|
1063 | def _escape_identifier(self, value): |
---|
1064 | """Escape an identifier. |
---|
1065 | |
---|
1066 | Subclasses should override this to provide database-dependent |
---|
1067 | escaping behavior. |
---|
1068 | """ |
---|
1069 | |
---|
1070 | return value.replace('"', '""') |
---|
1071 | |
---|
1072 | def _unescape_identifier(self, value): |
---|
1073 | """Canonicalize an escaped identifier. |
---|
1074 | |
---|
1075 | Subclasses should override this to provide database-dependent |
---|
1076 | unescaping behavior that reverses _escape_identifier. |
---|
1077 | """ |
---|
1078 | |
---|
1079 | return value.replace('""', '"') |
---|
1080 | |
---|
1081 | def quote_identifier(self, value): |
---|
1082 | """Quote an identifier. |
---|
1083 | |
---|
1084 | Subclasses should override this to provide database-dependent |
---|
1085 | quoting behavior. |
---|
1086 | """ |
---|
1087 | |
---|
1088 | return self.initial_quote + self._escape_identifier(value) + self.final_quote |
---|
1089 | |
---|
1090 | def _requires_quotes(self, value): |
---|
1091 | """Return True if the given identifier requires quoting.""" |
---|
1092 | lc_value = value.lower() |
---|
1093 | return (lc_value in self.reserved_words |
---|
1094 | or self.illegal_initial_characters.match(value[0]) |
---|
1095 | or not self.legal_characters.match(unicode(value)) |
---|
1096 | or (lc_value != value)) |
---|
1097 | |
---|
1098 | def quote_schema(self, schema, force): |
---|
1099 | """Quote a schema. |
---|
1100 | |
---|
1101 | Subclasses should override this to provide database-dependent |
---|
1102 | quoting behavior. |
---|
1103 | """ |
---|
1104 | return self.quote(schema, force) |
---|
1105 | |
---|
1106 | def quote(self, ident, force): |
---|
1107 | if force is None: |
---|
1108 | if ident in self._strings: |
---|
1109 | return self._strings[ident] |
---|
1110 | else: |
---|
1111 | if self._requires_quotes(ident): |
---|
1112 | self._strings[ident] = self.quote_identifier(ident) |
---|
1113 | else: |
---|
1114 | self._strings[ident] = ident |
---|
1115 | return self._strings[ident] |
---|
1116 | elif force: |
---|
1117 | return self.quote_identifier(ident) |
---|
1118 | else: |
---|
1119 | return ident |
---|
1120 | |
---|
1121 | def format_sequence(self, sequence, use_schema=True): |
---|
1122 | name = self.quote(sequence.name, sequence.quote) |
---|
1123 | if not self.omit_schema and use_schema and sequence.schema is not None: |
---|
1124 | name = self.quote_schema(sequence.schema, sequence.quote) + "." + name |
---|
1125 | return name |
---|
1126 | |
---|
1127 | def format_label(self, label, name=None): |
---|
1128 | return self.quote(name or label.name, label.quote) |
---|
1129 | |
---|
1130 | def format_alias(self, alias, name=None): |
---|
1131 | return self.quote(name or alias.name, alias.quote) |
---|
1132 | |
---|
1133 | def format_savepoint(self, savepoint, name=None): |
---|
1134 | return self.quote(name or savepoint.ident, savepoint.quote) |
---|
1135 | |
---|
1136 | def format_constraint(self, constraint): |
---|
1137 | return self.quote(constraint.name, constraint.quote) |
---|
1138 | |
---|
1139 | def format_table(self, table, use_schema=True, name=None): |
---|
1140 | """Prepare a quoted table and schema name.""" |
---|
1141 | |
---|
1142 | if name is None: |
---|
1143 | name = table.name |
---|
1144 | result = self.quote(name, table.quote) |
---|
1145 | if not self.omit_schema and use_schema and getattr(table, "schema", None): |
---|
1146 | result = self.quote_schema(table.schema, table.quote_schema) + "." + result |
---|
1147 | return result |
---|
1148 | |
---|
1149 | def format_column(self, column, use_table=False, name=None, table_name=None): |
---|
1150 | """Prepare a quoted column name.""" |
---|
1151 | |
---|
1152 | if name is None: |
---|
1153 | name = column.name |
---|
1154 | if not getattr(column, 'is_literal', False): |
---|
1155 | if use_table: |
---|
1156 | return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) |
---|
1157 | else: |
---|
1158 | return self.quote(name, column.quote) |
---|
1159 | else: |
---|
1160 | # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted |
---|
1161 | if use_table: |
---|
1162 | return self.format_table(column.table, use_schema=False, name=table_name) + "." + name |
---|
1163 | else: |
---|
1164 | return name |
---|
1165 | |
---|
1166 | def format_table_seq(self, table, use_schema=True): |
---|
1167 | """Format table name and schema as a tuple.""" |
---|
1168 | |
---|
1169 | # Dialects with more levels in their fully qualified references |
---|
1170 | # ('database', 'owner', etc.) could override this and return |
---|
1171 | # a longer sequence. |
---|
1172 | |
---|
1173 | if not self.omit_schema and use_schema and getattr(table, 'schema', None): |
---|
1174 | return (self.quote_schema(table.schema, table.quote_schema), |
---|
1175 | self.format_table(table, use_schema=False)) |
---|
1176 | else: |
---|
1177 | return (self.format_table(table, use_schema=False), ) |
---|
1178 | |
---|
1179 | def unformat_identifiers(self, identifiers): |
---|
1180 | """Unpack 'schema.table.column'-like strings into components.""" |
---|
1181 | |
---|
1182 | try: |
---|
1183 | r = self._r_identifiers |
---|
1184 | except AttributeError: |
---|
1185 | initial, final, escaped_final = \ |
---|
1186 | [re.escape(s) for s in |
---|
1187 | (self.initial_quote, self.final_quote, |
---|
1188 | self._escape_identifier(self.final_quote))] |
---|
1189 | r = re.compile( |
---|
1190 | r'(?:' |
---|
1191 | r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' |
---|
1192 | r'|([^\.]+))(?=\.|$))+' % |
---|
1193 | { 'initial': initial, |
---|
1194 | 'final': final, |
---|
1195 | 'escaped': escaped_final }) |
---|
1196 | self._r_identifiers = r |
---|
1197 | |
---|
1198 | return [self._unescape_identifier(i) |
---|
1199 | for i in [a or b for a, b in r.findall(identifiers)]] |
---|