[3] | 1 | # compiler.py |
---|
| 2 | # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com |
---|
| 3 | # |
---|
| 4 | # This module is part of SQLAlchemy and is released under |
---|
| 5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php |
---|
| 6 | |
---|
| 7 | """Base SQL and DDL compiler implementations. |
---|
| 8 | |
---|
| 9 | Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is |
---|
| 10 | responsible for generating all SQL query strings, as well as |
---|
| 11 | :class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper` |
---|
| 12 | which issue CREATE and DROP DDL for tables, sequences, and indexes. |
---|
| 13 | |
---|
| 14 | The elements in this module are used by public-facing constructs like |
---|
| 15 | :class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`. |
---|
| 16 | While dialect authors will want to be familiar with this module for the purpose of |
---|
| 17 | creating database-specific compilers and schema generators, the module |
---|
| 18 | is otherwise internal to SQLAlchemy. |
---|
| 19 | """ |
---|
| 20 | |
---|
| 21 | import string, re |
---|
| 22 | from sqlalchemy import schema, engine, util, exc |
---|
| 23 | from sqlalchemy.sql import operators, functions, util as sql_util, visitors |
---|
| 24 | from sqlalchemy.sql import expression as sql |
---|
| 25 | |
---|
| 26 | RESERVED_WORDS = set([ |
---|
| 27 | 'all', 'analyse', 'analyze', 'and', 'any', 'array', |
---|
| 28 | 'as', 'asc', 'asymmetric', 'authorization', 'between', |
---|
| 29 | 'binary', 'both', 'case', 'cast', 'check', 'collate', |
---|
| 30 | 'column', 'constraint', 'create', 'cross', 'current_date', |
---|
| 31 | 'current_role', 'current_time', 'current_timestamp', |
---|
| 32 | 'current_user', 'default', 'deferrable', 'desc', |
---|
| 33 | 'distinct', 'do', 'else', 'end', 'except', 'false', |
---|
| 34 | 'for', 'foreign', 'freeze', 'from', 'full', 'grant', |
---|
| 35 | 'group', 'having', 'ilike', 'in', 'initially', 'inner', |
---|
| 36 | 'intersect', 'into', 'is', 'isnull', 'join', 'leading', |
---|
| 37 | 'left', 'like', 'limit', 'localtime', 'localtimestamp', |
---|
| 38 | 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', |
---|
| 39 | 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', |
---|
| 40 | 'placing', 'primary', 'references', 'right', 'select', |
---|
| 41 | 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', |
---|
| 42 | 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', |
---|
| 43 | 'using', 'verbose', 'when', 'where']) |
---|
| 44 | |
---|
| 45 | LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) |
---|
| 46 | ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') |
---|
| 47 | |
---|
| 48 | BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) |
---|
| 49 | BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) |
---|
| 50 | |
---|
| 51 | BIND_TEMPLATES = { |
---|
| 52 | 'pyformat':"%%(%(name)s)s", |
---|
| 53 | 'qmark':"?", |
---|
| 54 | 'format':"%%s", |
---|
| 55 | 'numeric':":%(position)s", |
---|
| 56 | 'named':":%(name)s" |
---|
| 57 | } |
---|
| 58 | |
---|
| 59 | |
---|
| 60 | OPERATORS = { |
---|
| 61 | operators.and_ : 'AND', |
---|
| 62 | operators.or_ : 'OR', |
---|
| 63 | operators.inv : 'NOT', |
---|
| 64 | operators.add : '+', |
---|
| 65 | operators.mul : '*', |
---|
| 66 | operators.sub : '-', |
---|
| 67 | operators.div : '/', |
---|
| 68 | operators.mod : '%', |
---|
| 69 | operators.truediv : '/', |
---|
| 70 | operators.lt : '<', |
---|
| 71 | operators.le : '<=', |
---|
| 72 | operators.ne : '!=', |
---|
| 73 | operators.gt : '>', |
---|
| 74 | operators.ge : '>=', |
---|
| 75 | operators.eq : '=', |
---|
| 76 | operators.distinct_op : 'DISTINCT', |
---|
| 77 | operators.concat_op : '||', |
---|
| 78 | operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
| 79 | operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
| 80 | operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
| 81 | operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
---|
| 82 | operators.between_op : 'BETWEEN', |
---|
| 83 | operators.match_op : 'MATCH', |
---|
| 84 | operators.in_op : 'IN', |
---|
| 85 | operators.notin_op : 'NOT IN', |
---|
| 86 | operators.comma_op : ', ', |
---|
| 87 | operators.desc_op : 'DESC', |
---|
| 88 | operators.asc_op : 'ASC', |
---|
| 89 | operators.from_ : 'FROM', |
---|
| 90 | operators.as_ : 'AS', |
---|
| 91 | operators.exists : 'EXISTS', |
---|
| 92 | operators.is_ : 'IS', |
---|
| 93 | operators.isnot : 'IS NOT', |
---|
| 94 | operators.collate : 'COLLATE', |
---|
| 95 | } |
---|
| 96 | |
---|
| 97 | FUNCTIONS = { |
---|
| 98 | functions.coalesce : 'coalesce%(expr)s', |
---|
| 99 | functions.current_date: 'CURRENT_DATE', |
---|
| 100 | functions.current_time: 'CURRENT_TIME', |
---|
| 101 | functions.current_timestamp: 'CURRENT_TIMESTAMP', |
---|
| 102 | functions.current_user: 'CURRENT_USER', |
---|
| 103 | functions.localtime: 'LOCALTIME', |
---|
| 104 | functions.localtimestamp: 'LOCALTIMESTAMP', |
---|
| 105 | functions.random: 'random%(expr)s', |
---|
| 106 | functions.sysdate: 'sysdate', |
---|
| 107 | functions.session_user :'SESSION_USER', |
---|
| 108 | functions.user: 'USER' |
---|
| 109 | } |
---|
| 110 | |
---|
| 111 | EXTRACT_MAP = { |
---|
| 112 | 'month': 'month', |
---|
| 113 | 'day': 'day', |
---|
| 114 | 'year': 'year', |
---|
| 115 | 'second': 'second', |
---|
| 116 | 'hour': 'hour', |
---|
| 117 | 'doy': 'doy', |
---|
| 118 | 'minute': 'minute', |
---|
| 119 | 'quarter': 'quarter', |
---|
| 120 | 'dow': 'dow', |
---|
| 121 | 'week': 'week', |
---|
| 122 | 'epoch': 'epoch', |
---|
| 123 | 'milliseconds': 'milliseconds', |
---|
| 124 | 'microseconds': 'microseconds', |
---|
| 125 | 'timezone_hour': 'timezone_hour', |
---|
| 126 | 'timezone_minute': 'timezone_minute' |
---|
| 127 | } |
---|
| 128 | |
---|
| 129 | class _CompileLabel(visitors.Visitable): |
---|
| 130 | """lightweight label object which acts as an expression._Label.""" |
---|
| 131 | |
---|
| 132 | __visit_name__ = 'label' |
---|
| 133 | __slots__ = 'element', 'name' |
---|
| 134 | |
---|
| 135 | def __init__(self, col, name): |
---|
| 136 | self.element = col |
---|
| 137 | self.name = name |
---|
| 138 | |
---|
| 139 | @property |
---|
| 140 | def quote(self): |
---|
| 141 | return self.element.quote |
---|
| 142 | |
---|
| 143 | class DefaultCompiler(engine.Compiled): |
---|
| 144 | """Default implementation of Compiled. |
---|
| 145 | |
---|
| 146 | Compiles ClauseElements into SQL strings. Uses a similar visit |
---|
| 147 | paradigm as visitors.ClauseVisitor but implements its own traversal. |
---|
| 148 | |
---|
| 149 | """ |
---|
| 150 | |
---|
| 151 | operators = OPERATORS |
---|
| 152 | functions = FUNCTIONS |
---|
| 153 | extract_map = EXTRACT_MAP |
---|
| 154 | |
---|
| 155 | # if we are insert/update/delete. |
---|
| 156 | # set to true when we visit an INSERT, UPDATE or DELETE |
---|
| 157 | isdelete = isinsert = isupdate = False |
---|
| 158 | |
---|
| 159 | def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): |
---|
| 160 | """Construct a new ``DefaultCompiler`` object. |
---|
| 161 | |
---|
| 162 | dialect |
---|
| 163 | Dialect to be used |
---|
| 164 | |
---|
| 165 | statement |
---|
| 166 | ClauseElement to be compiled |
---|
| 167 | |
---|
| 168 | column_keys |
---|
| 169 | a list of column names to be compiled into an INSERT or UPDATE |
---|
| 170 | statement. |
---|
| 171 | |
---|
| 172 | """ |
---|
| 173 | engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) |
---|
| 174 | |
---|
| 175 | # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) |
---|
| 176 | self.inline = inline or getattr(statement, 'inline', False) |
---|
| 177 | |
---|
| 178 | # a dictionary of bind parameter keys to _BindParamClause instances. |
---|
| 179 | self.binds = {} |
---|
| 180 | |
---|
| 181 | # a dictionary of _BindParamClause instances to "compiled" names that are |
---|
| 182 | # actually present in the generated SQL |
---|
| 183 | self.bind_names = util.column_dict() |
---|
| 184 | |
---|
| 185 | # stack which keeps track of nested SELECT statements |
---|
| 186 | self.stack = [] |
---|
| 187 | |
---|
| 188 | # relates label names in the final SQL to |
---|
| 189 | # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. |
---|
| 190 | # ResultProxy uses this for type processing and column targeting |
---|
| 191 | self.result_map = {} |
---|
| 192 | |
---|
| 193 | # true if the paramstyle is positional |
---|
| 194 | self.positional = self.dialect.positional |
---|
| 195 | if self.positional: |
---|
| 196 | self.positiontup = [] |
---|
| 197 | |
---|
| 198 | self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] |
---|
| 199 | |
---|
| 200 | # an IdentifierPreparer that formats the quoting of identifiers |
---|
| 201 | self.preparer = self.dialect.identifier_preparer |
---|
| 202 | |
---|
| 203 | self.label_length = self.dialect.label_length or self.dialect.max_identifier_length |
---|
| 204 | |
---|
| 205 | # a map which tracks "anonymous" identifiers that are |
---|
| 206 | # created on the fly here |
---|
| 207 | self.anon_map = util.PopulateDict(self._process_anon) |
---|
| 208 | |
---|
| 209 | # a map which tracks "truncated" names based on dialect.label_length |
---|
| 210 | # or dialect.max_identifier_length |
---|
| 211 | self.truncated_names = {} |
---|
| 212 | |
---|
| 213 | def compile(self): |
---|
| 214 | self.string = self.process(self.statement) |
---|
| 215 | |
---|
| 216 | def process(self, obj, **kwargs): |
---|
| 217 | return obj._compiler_dispatch(self, **kwargs) |
---|
| 218 | |
---|
| 219 | def is_subquery(self): |
---|
| 220 | return len(self.stack) > 1 |
---|
| 221 | |
---|
| 222 | def construct_params(self, params=None): |
---|
| 223 | """return a dictionary of bind parameter keys and values""" |
---|
| 224 | |
---|
| 225 | if params: |
---|
| 226 | params = util.column_dict(params) |
---|
| 227 | pd = {} |
---|
| 228 | for bindparam, name in self.bind_names.iteritems(): |
---|
| 229 | for paramname in (bindparam.key, bindparam.shortname, name): |
---|
| 230 | if paramname in params: |
---|
| 231 | pd[name] = params[paramname] |
---|
| 232 | break |
---|
| 233 | else: |
---|
| 234 | if util.callable(bindparam.value): |
---|
| 235 | pd[name] = bindparam.value() |
---|
| 236 | else: |
---|
| 237 | pd[name] = bindparam.value |
---|
| 238 | return pd |
---|
| 239 | else: |
---|
| 240 | pd = {} |
---|
| 241 | for bindparam in self.bind_names: |
---|
| 242 | if util.callable(bindparam.value): |
---|
| 243 | pd[self.bind_names[bindparam]] = bindparam.value() |
---|
| 244 | else: |
---|
| 245 | pd[self.bind_names[bindparam]] = bindparam.value |
---|
| 246 | return pd |
---|
| 247 | |
---|
| 248 | params = property(construct_params) |
---|
| 249 | |
---|
| 250 | def default_from(self): |
---|
| 251 | """Called when a SELECT statement has no froms, and no FROM clause is to be appended. |
---|
| 252 | |
---|
| 253 | Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. |
---|
| 254 | |
---|
| 255 | """ |
---|
| 256 | return "" |
---|
| 257 | |
---|
| 258 | def visit_grouping(self, grouping, **kwargs): |
---|
| 259 | return "(" + self.process(grouping.element) + ")" |
---|
| 260 | |
---|
| 261 | def visit_label(self, label, result_map=None, within_columns_clause=False): |
---|
| 262 | # only render labels within the columns clause |
---|
| 263 | # or ORDER BY clause of a select. dialect-specific compilers |
---|
| 264 | # can modify this behavior. |
---|
| 265 | if within_columns_clause: |
---|
| 266 | labelname = isinstance(label.name, sql._generated_label) and \ |
---|
| 267 | self._truncated_identifier("colident", label.name) or label.name |
---|
| 268 | |
---|
| 269 | if result_map is not None: |
---|
| 270 | result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) |
---|
| 271 | |
---|
| 272 | return self.process(label.element) + " " + \ |
---|
| 273 | self.operator_string(operators.as_) + " " + \ |
---|
| 274 | self.preparer.format_label(label, labelname) |
---|
| 275 | else: |
---|
| 276 | return self.process(label.element) |
---|
| 277 | |
---|
| 278 | def visit_column(self, column, result_map=None, **kwargs): |
---|
| 279 | name = column.name |
---|
| 280 | if not column.is_literal and isinstance(name, sql._generated_label): |
---|
| 281 | name = self._truncated_identifier("colident", name) |
---|
| 282 | |
---|
| 283 | if result_map is not None: |
---|
| 284 | result_map[name.lower()] = (name, (column, ), column.type) |
---|
| 285 | |
---|
| 286 | if column.is_literal: |
---|
| 287 | name = self.escape_literal_column(name) |
---|
| 288 | else: |
---|
| 289 | name = self.preparer.quote(name, column.quote) |
---|
| 290 | |
---|
| 291 | if column.table is None or not column.table.named_with_column: |
---|
| 292 | return name |
---|
| 293 | else: |
---|
| 294 | if column.table.schema: |
---|
| 295 | schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' |
---|
| 296 | else: |
---|
| 297 | schema_prefix = '' |
---|
| 298 | tablename = column.table.name |
---|
| 299 | tablename = isinstance(tablename, sql._generated_label) and \ |
---|
| 300 | self._truncated_identifier("alias", tablename) or tablename |
---|
| 301 | |
---|
| 302 | return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name |
---|
| 303 | |
---|
| 304 | def escape_literal_column(self, text): |
---|
| 305 | """provide escaping for the literal_column() construct.""" |
---|
| 306 | |
---|
| 307 | # TODO: some dialects might need different behavior here |
---|
| 308 | return text.replace('%', '%%') |
---|
| 309 | |
---|
| 310 | def visit_fromclause(self, fromclause, **kwargs): |
---|
| 311 | return fromclause.name |
---|
| 312 | |
---|
| 313 | def visit_index(self, index, **kwargs): |
---|
| 314 | return index.name |
---|
| 315 | |
---|
| 316 | def visit_typeclause(self, typeclause, **kwargs): |
---|
| 317 | return typeclause.type.dialect_impl(self.dialect).get_col_spec() |
---|
| 318 | |
---|
| 319 | def post_process_text(self, text): |
---|
| 320 | return text |
---|
| 321 | |
---|
| 322 | def visit_textclause(self, textclause, **kwargs): |
---|
| 323 | if textclause.typemap is not None: |
---|
| 324 | for colname, type_ in textclause.typemap.iteritems(): |
---|
| 325 | self.result_map[colname.lower()] = (colname, None, type_) |
---|
| 326 | |
---|
| 327 | def do_bindparam(m): |
---|
| 328 | name = m.group(1) |
---|
| 329 | if name in textclause.bindparams: |
---|
| 330 | return self.process(textclause.bindparams[name]) |
---|
| 331 | else: |
---|
| 332 | return self.bindparam_string(name) |
---|
| 333 | |
---|
| 334 | # un-escape any \:params |
---|
| 335 | return BIND_PARAMS_ESC.sub(lambda m: m.group(1), |
---|
| 336 | BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)) |
---|
| 337 | ) |
---|
| 338 | |
---|
| 339 | def visit_null(self, null, **kwargs): |
---|
| 340 | return 'NULL' |
---|
| 341 | |
---|
| 342 | def visit_clauselist(self, clauselist, **kwargs): |
---|
| 343 | sep = clauselist.operator |
---|
| 344 | if sep is None: |
---|
| 345 | sep = " " |
---|
| 346 | elif sep is operators.comma_op: |
---|
| 347 | sep = ', ' |
---|
| 348 | else: |
---|
| 349 | sep = " " + self.operator_string(clauselist.operator) + " " |
---|
| 350 | return sep.join(s for s in (self.process(c) for c in clauselist.clauses) |
---|
| 351 | if s is not None) |
---|
| 352 | |
---|
| 353 | def visit_case(self, clause, **kwargs): |
---|
| 354 | x = "CASE " |
---|
| 355 | if clause.value: |
---|
| 356 | x += self.process(clause.value) + " " |
---|
| 357 | for cond, result in clause.whens: |
---|
| 358 | x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " " |
---|
| 359 | if clause.else_: |
---|
| 360 | x += "ELSE " + self.process(clause.else_) + " " |
---|
| 361 | x += "END" |
---|
| 362 | return x |
---|
| 363 | |
---|
| 364 | def visit_cast(self, cast, **kwargs): |
---|
| 365 | return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) |
---|
| 366 | |
---|
| 367 | def visit_extract(self, extract, **kwargs): |
---|
| 368 | field = self.extract_map.get(extract.field, extract.field) |
---|
| 369 | return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr)) |
---|
| 370 | |
---|
| 371 | def visit_function(self, func, result_map=None, **kwargs): |
---|
| 372 | if result_map is not None: |
---|
| 373 | result_map[func.name.lower()] = (func.name, None, func.type) |
---|
| 374 | |
---|
| 375 | name = self.function_string(func) |
---|
| 376 | |
---|
| 377 | if util.callable(name): |
---|
| 378 | return name(*[self.process(x) for x in func.clauses]) |
---|
| 379 | else: |
---|
| 380 | return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} |
---|
| 381 | |
---|
| 382 | def function_argspec(self, func, **kwargs): |
---|
| 383 | return self.process(func.clause_expr, **kwargs) |
---|
| 384 | |
---|
| 385 | def function_string(self, func): |
---|
| 386 | return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) |
---|
| 387 | |
---|
| 388 | def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): |
---|
| 389 | entry = self.stack and self.stack[-1] or {} |
---|
| 390 | self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) |
---|
| 391 | |
---|
| 392 | text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) |
---|
| 393 | for i, c in enumerate(cs.selects)), |
---|
| 394 | " " + cs.keyword + " ") |
---|
| 395 | group_by = self.process(cs._group_by_clause, asfrom=asfrom) |
---|
| 396 | if group_by: |
---|
| 397 | text += " GROUP BY " + group_by |
---|
| 398 | |
---|
| 399 | text += self.order_by_clause(cs) |
---|
| 400 | text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" |
---|
| 401 | |
---|
| 402 | self.stack.pop(-1) |
---|
| 403 | if asfrom and parens: |
---|
| 404 | return "(" + text + ")" |
---|
| 405 | else: |
---|
| 406 | return text |
---|
| 407 | |
---|
| 408 | def visit_unary(self, unary, **kw): |
---|
| 409 | s = self.process(unary.element, **kw) |
---|
| 410 | if unary.operator: |
---|
| 411 | s = self.operator_string(unary.operator) + " " + s |
---|
| 412 | if unary.modifier: |
---|
| 413 | s = s + " " + self.operator_string(unary.modifier) |
---|
| 414 | return s |
---|
| 415 | |
---|
| 416 | def visit_binary(self, binary, **kwargs): |
---|
| 417 | op = self.operator_string(binary.operator) |
---|
| 418 | if util.callable(op): |
---|
| 419 | return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) |
---|
| 420 | else: |
---|
| 421 | return self.process(binary.left) + " " + op + " " + self.process(binary.right) |
---|
| 422 | |
---|
| 423 | def operator_string(self, operator): |
---|
| 424 | return self.operators.get(operator, str(operator)) |
---|
| 425 | |
---|
| 426 | def visit_bindparam(self, bindparam, **kwargs): |
---|
| 427 | name = self._truncate_bindparam(bindparam) |
---|
| 428 | if name in self.binds: |
---|
| 429 | existing = self.binds[name] |
---|
| 430 | if existing is not bindparam and (existing.unique or bindparam.unique): |
---|
| 431 | raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) |
---|
| 432 | self.binds[bindparam.key] = self.binds[name] = bindparam |
---|
| 433 | return self.bindparam_string(name) |
---|
| 434 | |
---|
| 435 | def _truncate_bindparam(self, bindparam): |
---|
| 436 | if bindparam in self.bind_names: |
---|
| 437 | return self.bind_names[bindparam] |
---|
| 438 | |
---|
| 439 | bind_name = bindparam.key |
---|
| 440 | bind_name = isinstance(bind_name, sql._generated_label) and \ |
---|
| 441 | self._truncated_identifier("bindparam", bind_name) or bind_name |
---|
| 442 | # add to bind_names for translation |
---|
| 443 | self.bind_names[bindparam] = bind_name |
---|
| 444 | |
---|
| 445 | return bind_name |
---|
| 446 | |
---|
| 447 | def _truncated_identifier(self, ident_class, name): |
---|
| 448 | if (ident_class, name) in self.truncated_names: |
---|
| 449 | return self.truncated_names[(ident_class, name)] |
---|
| 450 | |
---|
| 451 | anonname = name % self.anon_map |
---|
| 452 | |
---|
| 453 | if len(anonname) > self.label_length: |
---|
| 454 | counter = self.truncated_names.get(ident_class, 1) |
---|
| 455 | truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] |
---|
| 456 | self.truncated_names[ident_class] = counter + 1 |
---|
| 457 | else: |
---|
| 458 | truncname = anonname |
---|
| 459 | self.truncated_names[(ident_class, name)] = truncname |
---|
| 460 | return truncname |
---|
| 461 | |
---|
| 462 | def _anonymize(self, name): |
---|
| 463 | return name % self.anon_map |
---|
| 464 | |
---|
| 465 | def _process_anon(self, key): |
---|
| 466 | (ident, derived) = key.split(' ', 1) |
---|
| 467 | anonymous_counter = self.anon_map.get(derived, 1) |
---|
| 468 | self.anon_map[derived] = anonymous_counter + 1 |
---|
| 469 | return derived + "_" + str(anonymous_counter) |
---|
| 470 | |
---|
| 471 | def bindparam_string(self, name): |
---|
| 472 | if self.positional: |
---|
| 473 | self.positiontup.append(name) |
---|
| 474 | return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} |
---|
| 475 | else: |
---|
| 476 | return self.bindtemplate % {'name':name} |
---|
| 477 | |
---|
| 478 | def visit_alias(self, alias, asfrom=False, **kwargs): |
---|
| 479 | if asfrom: |
---|
| 480 | alias_name = isinstance(alias.name, sql._generated_label) and \ |
---|
| 481 | self._truncated_identifier("alias", alias.name) or alias.name |
---|
| 482 | |
---|
| 483 | return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ |
---|
| 484 | self.preparer.format_alias(alias, alias_name) |
---|
| 485 | else: |
---|
| 486 | return self.process(alias.original, **kwargs) |
---|
| 487 | |
---|
| 488 | def label_select_column(self, select, column, asfrom): |
---|
| 489 | """label columns present in a select().""" |
---|
| 490 | |
---|
| 491 | if isinstance(column, sql._Label): |
---|
| 492 | return column |
---|
| 493 | |
---|
| 494 | if select.use_labels and column._label: |
---|
| 495 | return _CompileLabel(column, column._label) |
---|
| 496 | |
---|
| 497 | if \ |
---|
| 498 | asfrom and \ |
---|
| 499 | isinstance(column, sql.ColumnClause) and \ |
---|
| 500 | not column.is_literal and \ |
---|
| 501 | column.table is not None and \ |
---|
| 502 | not isinstance(column.table, sql.Select): |
---|
| 503 | return _CompileLabel(column, sql._generated_label(column.name)) |
---|
| 504 | elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ |
---|
| 505 | and (not hasattr(column, 'name') or isinstance(column, sql.Function)): |
---|
| 506 | return _CompileLabel(column, column.anon_label) |
---|
| 507 | else: |
---|
| 508 | return column |
---|
| 509 | |
---|
| 510 | def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): |
---|
| 511 | |
---|
| 512 | entry = self.stack and self.stack[-1] or {} |
---|
| 513 | |
---|
| 514 | existingfroms = entry.get('from', None) |
---|
| 515 | |
---|
| 516 | froms = select._get_display_froms(existingfroms) |
---|
| 517 | |
---|
| 518 | correlate_froms = set(sql._from_objects(*froms)) |
---|
| 519 | |
---|
| 520 | # TODO: might want to propagate existing froms for select(select(select)) |
---|
| 521 | # where innermost select should correlate to outermost |
---|
| 522 | # if existingfroms: |
---|
| 523 | # correlate_froms = correlate_froms.union(existingfroms) |
---|
| 524 | |
---|
| 525 | self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) |
---|
| 526 | |
---|
| 527 | if compound_index==1 and not entry or entry.get('iswrapper', False): |
---|
| 528 | column_clause_args = {'result_map':self.result_map} |
---|
| 529 | else: |
---|
| 530 | column_clause_args = {} |
---|
| 531 | |
---|
| 532 | # the actual list of columns to print in the SELECT column list. |
---|
| 533 | inner_columns = [ |
---|
| 534 | c for c in [ |
---|
| 535 | self.process( |
---|
| 536 | self.label_select_column(select, co, asfrom=asfrom), |
---|
| 537 | within_columns_clause=True, |
---|
| 538 | **column_clause_args) |
---|
| 539 | for co in util.unique_list(select.inner_columns) |
---|
| 540 | ] |
---|
| 541 | if c is not None |
---|
| 542 | ] |
---|
| 543 | |
---|
| 544 | text = "SELECT " # we're off to a good start ! |
---|
| 545 | if select._prefixes: |
---|
| 546 | text += " ".join(self.process(x) for x in select._prefixes) + " " |
---|
| 547 | text += self.get_select_precolumns(select) |
---|
| 548 | text += ', '.join(inner_columns) |
---|
| 549 | |
---|
| 550 | if froms: |
---|
| 551 | text += " \nFROM " |
---|
| 552 | text += ', '.join(self.process(f, asfrom=True) for f in froms) |
---|
| 553 | else: |
---|
| 554 | text += self.default_from() |
---|
| 555 | |
---|
| 556 | if select._whereclause is not None: |
---|
| 557 | t = self.process(select._whereclause) |
---|
| 558 | if t: |
---|
| 559 | text += " \nWHERE " + t |
---|
| 560 | |
---|
| 561 | if select._group_by_clause.clauses: |
---|
| 562 | group_by = self.process(select._group_by_clause) |
---|
| 563 | if group_by: |
---|
| 564 | text += " GROUP BY " + group_by |
---|
| 565 | |
---|
| 566 | if select._having is not None: |
---|
| 567 | t = self.process(select._having) |
---|
| 568 | if t: |
---|
| 569 | text += " \nHAVING " + t |
---|
| 570 | |
---|
| 571 | if select._order_by_clause.clauses: |
---|
| 572 | text += self.order_by_clause(select) |
---|
| 573 | if select._limit is not None or select._offset is not None: |
---|
| 574 | text += self.limit_clause(select) |
---|
| 575 | if select.for_update: |
---|
| 576 | text += self.for_update_clause(select) |
---|
| 577 | |
---|
| 578 | self.stack.pop(-1) |
---|
| 579 | |
---|
| 580 | if asfrom and parens: |
---|
| 581 | return "(" + text + ")" |
---|
| 582 | else: |
---|
| 583 | return text |
---|
| 584 | |
---|
| 585 | def get_select_precolumns(self, select): |
---|
| 586 | """Called when building a ``SELECT`` statement, position is just before column list.""" |
---|
| 587 | |
---|
| 588 | return select._distinct and "DISTINCT " or "" |
---|
| 589 | |
---|
| 590 | def order_by_clause(self, select): |
---|
| 591 | order_by = self.process(select._order_by_clause) |
---|
| 592 | if order_by: |
---|
| 593 | return " ORDER BY " + order_by |
---|
| 594 | else: |
---|
| 595 | return "" |
---|
| 596 | |
---|
| 597 | def for_update_clause(self, select): |
---|
| 598 | if select.for_update: |
---|
| 599 | return " FOR UPDATE" |
---|
| 600 | else: |
---|
| 601 | return "" |
---|
| 602 | |
---|
| 603 | def limit_clause(self, select): |
---|
| 604 | text = "" |
---|
| 605 | if select._limit is not None: |
---|
| 606 | text += " \n LIMIT " + str(select._limit) |
---|
| 607 | if select._offset is not None: |
---|
| 608 | if select._limit is None: |
---|
| 609 | text += " \n LIMIT -1" |
---|
| 610 | text += " OFFSET " + str(select._offset) |
---|
| 611 | return text |
---|
| 612 | |
---|
| 613 | def visit_table(self, table, asfrom=False, **kwargs): |
---|
| 614 | if asfrom: |
---|
| 615 | if getattr(table, "schema", None): |
---|
| 616 | return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) |
---|
| 617 | else: |
---|
| 618 | return self.preparer.quote(table.name, table.quote) |
---|
| 619 | else: |
---|
| 620 | return "" |
---|
| 621 | |
---|
| 622 | def visit_join(self, join, asfrom=False, **kwargs): |
---|
| 623 | return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ |
---|
| 624 | self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) |
---|
| 625 | |
---|
| 626 | def visit_sequence(self, seq): |
---|
| 627 | return None |
---|
| 628 | |
---|
| 629 | def visit_insert(self, insert_stmt): |
---|
| 630 | self.isinsert = True |
---|
| 631 | colparams = self._get_colparams(insert_stmt) |
---|
| 632 | preparer = self.preparer |
---|
| 633 | |
---|
| 634 | insert = ' '.join(["INSERT"] + |
---|
| 635 | [self.process(x) for x in insert_stmt._prefixes]) |
---|
| 636 | |
---|
| 637 | if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: |
---|
| 638 | raise exc.CompileError( |
---|
| 639 | "The version of %s you are using does not support empty inserts." % self.dialect.name) |
---|
| 640 | elif not colparams and self.dialect.supports_default_values: |
---|
| 641 | return (insert + " INTO %s DEFAULT VALUES" % ( |
---|
| 642 | (preparer.format_table(insert_stmt.table),))) |
---|
| 643 | else: |
---|
| 644 | return (insert + " INTO %s (%s) VALUES (%s)" % |
---|
| 645 | (preparer.format_table(insert_stmt.table), |
---|
| 646 | ', '.join([preparer.format_column(c[0]) |
---|
| 647 | for c in colparams]), |
---|
| 648 | ', '.join([c[1] for c in colparams]))) |
---|
| 649 | |
---|
| 650 | def visit_update(self, update_stmt): |
---|
| 651 | self.stack.append({'from': set([update_stmt.table])}) |
---|
| 652 | |
---|
| 653 | self.isupdate = True |
---|
| 654 | colparams = self._get_colparams(update_stmt) |
---|
| 655 | |
---|
| 656 | text = ' '.join(( |
---|
| 657 | "UPDATE", |
---|
| 658 | self.preparer.format_table(update_stmt.table), |
---|
| 659 | 'SET', |
---|
| 660 | ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] |
---|
| 661 | for c in colparams) |
---|
| 662 | )) |
---|
| 663 | |
---|
| 664 | if update_stmt._whereclause: |
---|
| 665 | text += " WHERE " + self.process(update_stmt._whereclause) |
---|
| 666 | |
---|
| 667 | self.stack.pop(-1) |
---|
| 668 | |
---|
| 669 | return text |
---|
| 670 | |
---|
| 671 | def _get_colparams(self, stmt): |
---|
| 672 | """create a set of tuples representing column/string pairs for use |
---|
| 673 | in an INSERT or UPDATE statement. |
---|
| 674 | |
---|
| 675 | """ |
---|
| 676 | |
---|
| 677 | def create_bind_param(col, value): |
---|
| 678 | bindparam = sql.bindparam(col.key, value, type_=col.type) |
---|
| 679 | self.binds[col.key] = bindparam |
---|
| 680 | return self.bindparam_string(self._truncate_bindparam(bindparam)) |
---|
| 681 | |
---|
| 682 | self.postfetch = [] |
---|
| 683 | self.prefetch = [] |
---|
| 684 | |
---|
| 685 | # no parameters in the statement, no parameters in the |
---|
| 686 | # compiled params - return binds for all columns |
---|
| 687 | if self.column_keys is None and stmt.parameters is None: |
---|
| 688 | return [(c, create_bind_param(c, None)) for c in stmt.table.columns] |
---|
| 689 | |
---|
| 690 | # if we have statement parameters - set defaults in the |
---|
| 691 | # compiled params |
---|
| 692 | if self.column_keys is None: |
---|
| 693 | parameters = {} |
---|
| 694 | else: |
---|
| 695 | parameters = dict((sql._column_as_key(key), None) |
---|
| 696 | for key in self.column_keys) |
---|
| 697 | |
---|
| 698 | if stmt.parameters is not None: |
---|
| 699 | for k, v in stmt.parameters.iteritems(): |
---|
| 700 | parameters.setdefault(sql._column_as_key(k), v) |
---|
| 701 | |
---|
| 702 | # create a list of column assignment clauses as tuples |
---|
| 703 | values = [] |
---|
| 704 | for c in stmt.table.columns: |
---|
| 705 | if c.key in parameters: |
---|
| 706 | value = parameters[c.key] |
---|
| 707 | if sql._is_literal(value): |
---|
| 708 | value = create_bind_param(c, value) |
---|
| 709 | else: |
---|
| 710 | self.postfetch.append(c) |
---|
| 711 | value = self.process(value.self_group()) |
---|
| 712 | values.append((c, value)) |
---|
| 713 | elif isinstance(c, schema.Column): |
---|
| 714 | if self.isinsert: |
---|
| 715 | if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): |
---|
| 716 | if (((isinstance(c.default, schema.Sequence) and |
---|
| 717 | not c.default.optional) or |
---|
| 718 | not self.dialect.supports_pk_autoincrement) or |
---|
| 719 | (c.default is not None and |
---|
| 720 | not isinstance(c.default, schema.Sequence))): |
---|
| 721 | values.append((c, create_bind_param(c, None))) |
---|
| 722 | self.prefetch.append(c) |
---|
| 723 | elif isinstance(c.default, schema.ColumnDefault): |
---|
| 724 | if isinstance(c.default.arg, sql.ClauseElement): |
---|
| 725 | values.append((c, self.process(c.default.arg.self_group()))) |
---|
| 726 | if not c.primary_key: |
---|
| 727 | # dont add primary key column to postfetch |
---|
| 728 | self.postfetch.append(c) |
---|
| 729 | else: |
---|
| 730 | values.append((c, create_bind_param(c, None))) |
---|
| 731 | self.prefetch.append(c) |
---|
| 732 | elif c.server_default is not None: |
---|
| 733 | if not c.primary_key: |
---|
| 734 | self.postfetch.append(c) |
---|
| 735 | elif isinstance(c.default, schema.Sequence): |
---|
| 736 | proc = self.process(c.default) |
---|
| 737 | if proc is not None: |
---|
| 738 | values.append((c, proc)) |
---|
| 739 | if not c.primary_key: |
---|
| 740 | self.postfetch.append(c) |
---|
| 741 | elif self.isupdate: |
---|
| 742 | if isinstance(c.onupdate, schema.ColumnDefault): |
---|
| 743 | if isinstance(c.onupdate.arg, sql.ClauseElement): |
---|
| 744 | values.append((c, self.process(c.onupdate.arg.self_group()))) |
---|
| 745 | self.postfetch.append(c) |
---|
| 746 | else: |
---|
| 747 | values.append((c, create_bind_param(c, None))) |
---|
| 748 | self.prefetch.append(c) |
---|
| 749 | elif c.server_onupdate is not None: |
---|
| 750 | self.postfetch.append(c) |
---|
| 751 | # deprecated? or remove? |
---|
| 752 | elif isinstance(c.onupdate, schema.FetchedValue): |
---|
| 753 | self.postfetch.append(c) |
---|
| 754 | return values |
---|
| 755 | |
---|
| 756 | def visit_delete(self, delete_stmt): |
---|
| 757 | self.stack.append({'from': set([delete_stmt.table])}) |
---|
| 758 | self.isdelete = True |
---|
| 759 | |
---|
| 760 | text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) |
---|
| 761 | |
---|
| 762 | if delete_stmt._whereclause: |
---|
| 763 | text += " WHERE " + self.process(delete_stmt._whereclause) |
---|
| 764 | |
---|
| 765 | self.stack.pop(-1) |
---|
| 766 | |
---|
| 767 | return text |
---|
| 768 | |
---|
| 769 | def visit_savepoint(self, savepoint_stmt): |
---|
| 770 | return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
| 771 | |
---|
| 772 | def visit_rollback_to_savepoint(self, savepoint_stmt): |
---|
| 773 | return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
| 774 | |
---|
| 775 | def visit_release_savepoint(self, savepoint_stmt): |
---|
| 776 | return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |
---|
| 777 | |
---|
| 778 | def __str__(self): |
---|
| 779 | return self.string or '' |
---|
| 780 | |
---|
| 781 | class DDLBase(engine.SchemaIterator): |
---|
| 782 | def find_alterables(self, tables): |
---|
| 783 | alterables = [] |
---|
| 784 | class FindAlterables(schema.SchemaVisitor): |
---|
| 785 | def visit_foreign_key_constraint(self, constraint): |
---|
| 786 | if constraint.use_alter and constraint.table in tables: |
---|
| 787 | alterables.append(constraint) |
---|
| 788 | findalterables = FindAlterables() |
---|
| 789 | for table in tables: |
---|
| 790 | for c in table.constraints: |
---|
| 791 | findalterables.traverse(c) |
---|
| 792 | return alterables |
---|
| 793 | |
---|
| 794 | def _validate_identifier(self, ident, truncate): |
---|
| 795 | if truncate: |
---|
| 796 | if len(ident) > self.dialect.max_identifier_length: |
---|
| 797 | counter = getattr(self, 'counter', 0) |
---|
| 798 | self.counter = counter + 1 |
---|
| 799 | return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] |
---|
| 800 | else: |
---|
| 801 | return ident |
---|
| 802 | else: |
---|
| 803 | self.dialect.validate_identifier(ident) |
---|
| 804 | return ident |
---|
| 805 | |
---|
| 806 | |
---|
| 807 | class SchemaGenerator(DDLBase): |
---|
| 808 | def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): |
---|
| 809 | super(SchemaGenerator, self).__init__(connection, **kwargs) |
---|
| 810 | self.checkfirst = checkfirst |
---|
| 811 | self.tables = tables and set(tables) or None |
---|
| 812 | self.preparer = dialect.identifier_preparer |
---|
| 813 | self.dialect = dialect |
---|
| 814 | |
---|
| 815 | def get_column_specification(self, column, first_pk=False): |
---|
| 816 | raise NotImplementedError() |
---|
| 817 | |
---|
| 818 | def _can_create(self, table): |
---|
| 819 | self.dialect.validate_identifier(table.name) |
---|
| 820 | if table.schema: |
---|
| 821 | self.dialect.validate_identifier(table.schema) |
---|
| 822 | return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) |
---|
| 823 | |
---|
| 824 | def visit_metadata(self, metadata): |
---|
| 825 | if self.tables: |
---|
| 826 | tables = self.tables |
---|
| 827 | else: |
---|
| 828 | tables = metadata.tables.values() |
---|
| 829 | collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] |
---|
| 830 | for table in collection: |
---|
| 831 | self.traverse_single(table) |
---|
| 832 | if self.dialect.supports_alter: |
---|
| 833 | for alterable in self.find_alterables(collection): |
---|
| 834 | self.add_foreignkey(alterable) |
---|
| 835 | |
---|
| 836 | def visit_table(self, table): |
---|
| 837 | for listener in table.ddl_listeners['before-create']: |
---|
| 838 | listener('before-create', table, self.connection) |
---|
| 839 | |
---|
| 840 | for column in table.columns: |
---|
| 841 | if column.default is not None: |
---|
| 842 | self.traverse_single(column.default) |
---|
| 843 | |
---|
| 844 | self.append("\n" + " ".join(['CREATE'] + |
---|
| 845 | table._prefixes + |
---|
| 846 | ['TABLE', |
---|
| 847 | self.preparer.format_table(table), |
---|
| 848 | "("])) |
---|
| 849 | separator = "\n" |
---|
| 850 | |
---|
| 851 | # if only one primary key, specify it along with the column |
---|
| 852 | first_pk = False |
---|
| 853 | for column in table.columns: |
---|
| 854 | self.append(separator) |
---|
| 855 | separator = ", \n" |
---|
| 856 | self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) |
---|
| 857 | if column.primary_key: |
---|
| 858 | first_pk = True |
---|
| 859 | for constraint in column.constraints: |
---|
| 860 | self.traverse_single(constraint) |
---|
| 861 | |
---|
| 862 | # On some DB order is significant: visit PK first, then the |
---|
| 863 | # other constraints (engine.ReflectionTest.testbasic failed on FB2) |
---|
| 864 | if table.primary_key: |
---|
| 865 | self.traverse_single(table.primary_key) |
---|
| 866 | for constraint in [c for c in table.constraints if c is not table.primary_key]: |
---|
| 867 | self.traverse_single(constraint) |
---|
| 868 | |
---|
| 869 | self.append("\n)%s\n\n" % self.post_create_table(table)) |
---|
| 870 | self.execute() |
---|
| 871 | |
---|
| 872 | if hasattr(table, 'indexes'): |
---|
| 873 | for index in table.indexes: |
---|
| 874 | self.traverse_single(index) |
---|
| 875 | |
---|
| 876 | for listener in table.ddl_listeners['after-create']: |
---|
| 877 | listener('after-create', table, self.connection) |
---|
| 878 | |
---|
| 879 | def post_create_table(self, table): |
---|
| 880 | return '' |
---|
| 881 | |
---|
| 882 | def get_column_default_string(self, column): |
---|
| 883 | if isinstance(column.server_default, schema.DefaultClause): |
---|
| 884 | if isinstance(column.server_default.arg, basestring): |
---|
| 885 | return "'%s'" % column.server_default.arg |
---|
| 886 | else: |
---|
| 887 | return unicode(self._compile(column.server_default.arg, None)) |
---|
| 888 | else: |
---|
| 889 | return None |
---|
| 890 | |
---|
| 891 | def _compile(self, tocompile, parameters): |
---|
| 892 | """compile the given string/parameters using this SchemaGenerator's dialect.""" |
---|
| 893 | compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) |
---|
| 894 | compiler.compile() |
---|
| 895 | return compiler |
---|
| 896 | |
---|
| 897 | def visit_check_constraint(self, constraint): |
---|
| 898 | self.append(", \n\t") |
---|
| 899 | if constraint.name is not None: |
---|
| 900 | self.append("CONSTRAINT %s " % |
---|
| 901 | self.preparer.format_constraint(constraint)) |
---|
| 902 | self.append(" CHECK (%s)" % constraint.sqltext) |
---|
| 903 | self.define_constraint_deferrability(constraint) |
---|
| 904 | |
---|
| 905 | def visit_column_check_constraint(self, constraint): |
---|
| 906 | self.append(" CHECK (%s)" % constraint.sqltext) |
---|
| 907 | self.define_constraint_deferrability(constraint) |
---|
| 908 | |
---|
| 909 | def visit_primary_key_constraint(self, constraint): |
---|
| 910 | if len(constraint) == 0: |
---|
| 911 | return |
---|
| 912 | self.append(", \n\t") |
---|
| 913 | if constraint.name is not None: |
---|
| 914 | self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) |
---|
| 915 | self.append("PRIMARY KEY ") |
---|
| 916 | self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) |
---|
| 917 | for c in constraint)) |
---|
| 918 | self.define_constraint_deferrability(constraint) |
---|
| 919 | |
---|
| 920 | def visit_foreign_key_constraint(self, constraint): |
---|
| 921 | if constraint.use_alter and self.dialect.supports_alter: |
---|
| 922 | return |
---|
| 923 | self.append(", \n\t ") |
---|
| 924 | self.define_foreign_key(constraint) |
---|
| 925 | |
---|
| 926 | def add_foreignkey(self, constraint): |
---|
| 927 | self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) |
---|
| 928 | self.define_foreign_key(constraint) |
---|
| 929 | self.execute() |
---|
| 930 | |
---|
| 931 | def define_foreign_key(self, constraint): |
---|
| 932 | preparer = self.preparer |
---|
| 933 | if constraint.name is not None: |
---|
| 934 | self.append("CONSTRAINT %s " % |
---|
| 935 | preparer.format_constraint(constraint)) |
---|
| 936 | table = list(constraint.elements)[0].column.table |
---|
| 937 | self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( |
---|
| 938 | ', '.join(preparer.quote(f.parent.name, f.parent.quote) |
---|
| 939 | for f in constraint.elements), |
---|
| 940 | preparer.format_table(table), |
---|
| 941 | ', '.join(preparer.quote(f.column.name, f.column.quote) |
---|
| 942 | for f in constraint.elements) |
---|
| 943 | )) |
---|
| 944 | if constraint.ondelete is not None: |
---|
| 945 | self.append(" ON DELETE %s" % constraint.ondelete) |
---|
| 946 | if constraint.onupdate is not None: |
---|
| 947 | self.append(" ON UPDATE %s" % constraint.onupdate) |
---|
| 948 | self.define_constraint_deferrability(constraint) |
---|
| 949 | |
---|
| 950 | def visit_unique_constraint(self, constraint): |
---|
| 951 | self.append(", \n\t") |
---|
| 952 | if constraint.name is not None: |
---|
| 953 | self.append("CONSTRAINT %s " % |
---|
| 954 | self.preparer.format_constraint(constraint)) |
---|
| 955 | self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) |
---|
| 956 | self.define_constraint_deferrability(constraint) |
---|
| 957 | |
---|
| 958 | def define_constraint_deferrability(self, constraint): |
---|
| 959 | if constraint.deferrable is not None: |
---|
| 960 | if constraint.deferrable: |
---|
| 961 | self.append(" DEFERRABLE") |
---|
| 962 | else: |
---|
| 963 | self.append(" NOT DEFERRABLE") |
---|
| 964 | if constraint.initially is not None: |
---|
| 965 | self.append(" INITIALLY %s" % constraint.initially) |
---|
| 966 | |
---|
| 967 | def visit_column(self, column): |
---|
| 968 | pass |
---|
| 969 | |
---|
| 970 | def visit_index(self, index): |
---|
| 971 | preparer = self.preparer |
---|
| 972 | self.append("CREATE ") |
---|
| 973 | if index.unique: |
---|
| 974 | self.append("UNIQUE ") |
---|
| 975 | self.append("INDEX %s ON %s (%s)" \ |
---|
| 976 | % (preparer.quote(self._validate_identifier(index.name, True), index.quote), |
---|
| 977 | preparer.format_table(index.table), |
---|
| 978 | ', '.join(preparer.quote(c.name, c.quote) |
---|
| 979 | for c in index.columns))) |
---|
| 980 | self.execute() |
---|
| 981 | |
---|
| 982 | |
---|
| 983 | class SchemaDropper(DDLBase): |
---|
| 984 | def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): |
---|
| 985 | super(SchemaDropper, self).__init__(connection, **kwargs) |
---|
| 986 | self.checkfirst = checkfirst |
---|
| 987 | self.tables = tables |
---|
| 988 | self.preparer = dialect.identifier_preparer |
---|
| 989 | self.dialect = dialect |
---|
| 990 | |
---|
| 991 | def visit_metadata(self, metadata): |
---|
| 992 | if self.tables: |
---|
| 993 | tables = self.tables |
---|
| 994 | else: |
---|
| 995 | tables = metadata.tables.values() |
---|
| 996 | collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] |
---|
| 997 | if self.dialect.supports_alter: |
---|
| 998 | for alterable in self.find_alterables(collection): |
---|
| 999 | self.drop_foreignkey(alterable) |
---|
| 1000 | for table in collection: |
---|
| 1001 | self.traverse_single(table) |
---|
| 1002 | |
---|
| 1003 | def _can_drop(self, table): |
---|
| 1004 | self.dialect.validate_identifier(table.name) |
---|
| 1005 | if table.schema: |
---|
| 1006 | self.dialect.validate_identifier(table.schema) |
---|
| 1007 | return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) |
---|
| 1008 | |
---|
| 1009 | def visit_index(self, index): |
---|
| 1010 | self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) |
---|
| 1011 | self.execute() |
---|
| 1012 | |
---|
| 1013 | def drop_foreignkey(self, constraint): |
---|
| 1014 | self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( |
---|
| 1015 | self.preparer.format_table(constraint.table), |
---|
| 1016 | self.preparer.format_constraint(constraint))) |
---|
| 1017 | self.execute() |
---|
| 1018 | |
---|
| 1019 | def visit_table(self, table): |
---|
| 1020 | for listener in table.ddl_listeners['before-drop']: |
---|
| 1021 | listener('before-drop', table, self.connection) |
---|
| 1022 | |
---|
| 1023 | for column in table.columns: |
---|
| 1024 | if column.default is not None: |
---|
| 1025 | self.traverse_single(column.default) |
---|
| 1026 | |
---|
| 1027 | self.append("\nDROP TABLE " + self.preparer.format_table(table)) |
---|
| 1028 | self.execute() |
---|
| 1029 | |
---|
| 1030 | for listener in table.ddl_listeners['after-drop']: |
---|
| 1031 | listener('after-drop', table, self.connection) |
---|
| 1032 | |
---|
| 1033 | |
---|
| 1034 | class IdentifierPreparer(object): |
---|
| 1035 | """Handle quoting and case-folding of identifiers based on options.""" |
---|
| 1036 | |
---|
| 1037 | reserved_words = RESERVED_WORDS |
---|
| 1038 | |
---|
| 1039 | legal_characters = LEGAL_CHARACTERS |
---|
| 1040 | |
---|
| 1041 | illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS |
---|
| 1042 | |
---|
| 1043 | def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): |
---|
| 1044 | """Construct a new ``IdentifierPreparer`` object. |
---|
| 1045 | |
---|
| 1046 | initial_quote |
---|
| 1047 | Character that begins a delimited identifier. |
---|
| 1048 | |
---|
| 1049 | final_quote |
---|
| 1050 | Character that ends a delimited identifier. Defaults to `initial_quote`. |
---|
| 1051 | |
---|
| 1052 | omit_schema |
---|
| 1053 | Prevent prepending schema name. Useful for databases that do |
---|
| 1054 | not support schemae. |
---|
| 1055 | """ |
---|
| 1056 | |
---|
| 1057 | self.dialect = dialect |
---|
| 1058 | self.initial_quote = initial_quote |
---|
| 1059 | self.final_quote = final_quote or self.initial_quote |
---|
| 1060 | self.omit_schema = omit_schema |
---|
| 1061 | self._strings = {} |
---|
| 1062 | |
---|
| 1063 | def _escape_identifier(self, value): |
---|
| 1064 | """Escape an identifier. |
---|
| 1065 | |
---|
| 1066 | Subclasses should override this to provide database-dependent |
---|
| 1067 | escaping behavior. |
---|
| 1068 | """ |
---|
| 1069 | |
---|
| 1070 | return value.replace('"', '""') |
---|
| 1071 | |
---|
| 1072 | def _unescape_identifier(self, value): |
---|
| 1073 | """Canonicalize an escaped identifier. |
---|
| 1074 | |
---|
| 1075 | Subclasses should override this to provide database-dependent |
---|
| 1076 | unescaping behavior that reverses _escape_identifier. |
---|
| 1077 | """ |
---|
| 1078 | |
---|
| 1079 | return value.replace('""', '"') |
---|
| 1080 | |
---|
| 1081 | def quote_identifier(self, value): |
---|
| 1082 | """Quote an identifier. |
---|
| 1083 | |
---|
| 1084 | Subclasses should override this to provide database-dependent |
---|
| 1085 | quoting behavior. |
---|
| 1086 | """ |
---|
| 1087 | |
---|
| 1088 | return self.initial_quote + self._escape_identifier(value) + self.final_quote |
---|
| 1089 | |
---|
| 1090 | def _requires_quotes(self, value): |
---|
| 1091 | """Return True if the given identifier requires quoting.""" |
---|
| 1092 | lc_value = value.lower() |
---|
| 1093 | return (lc_value in self.reserved_words |
---|
| 1094 | or self.illegal_initial_characters.match(value[0]) |
---|
| 1095 | or not self.legal_characters.match(unicode(value)) |
---|
| 1096 | or (lc_value != value)) |
---|
| 1097 | |
---|
| 1098 | def quote_schema(self, schema, force): |
---|
| 1099 | """Quote a schema. |
---|
| 1100 | |
---|
| 1101 | Subclasses should override this to provide database-dependent |
---|
| 1102 | quoting behavior. |
---|
| 1103 | """ |
---|
| 1104 | return self.quote(schema, force) |
---|
| 1105 | |
---|
| 1106 | def quote(self, ident, force): |
---|
| 1107 | if force is None: |
---|
| 1108 | if ident in self._strings: |
---|
| 1109 | return self._strings[ident] |
---|
| 1110 | else: |
---|
| 1111 | if self._requires_quotes(ident): |
---|
| 1112 | self._strings[ident] = self.quote_identifier(ident) |
---|
| 1113 | else: |
---|
| 1114 | self._strings[ident] = ident |
---|
| 1115 | return self._strings[ident] |
---|
| 1116 | elif force: |
---|
| 1117 | return self.quote_identifier(ident) |
---|
| 1118 | else: |
---|
| 1119 | return ident |
---|
| 1120 | |
---|
| 1121 | def format_sequence(self, sequence, use_schema=True): |
---|
| 1122 | name = self.quote(sequence.name, sequence.quote) |
---|
| 1123 | if not self.omit_schema and use_schema and sequence.schema is not None: |
---|
| 1124 | name = self.quote_schema(sequence.schema, sequence.quote) + "." + name |
---|
| 1125 | return name |
---|
| 1126 | |
---|
| 1127 | def format_label(self, label, name=None): |
---|
| 1128 | return self.quote(name or label.name, label.quote) |
---|
| 1129 | |
---|
| 1130 | def format_alias(self, alias, name=None): |
---|
| 1131 | return self.quote(name or alias.name, alias.quote) |
---|
| 1132 | |
---|
| 1133 | def format_savepoint(self, savepoint, name=None): |
---|
| 1134 | return self.quote(name or savepoint.ident, savepoint.quote) |
---|
| 1135 | |
---|
| 1136 | def format_constraint(self, constraint): |
---|
| 1137 | return self.quote(constraint.name, constraint.quote) |
---|
| 1138 | |
---|
| 1139 | def format_table(self, table, use_schema=True, name=None): |
---|
| 1140 | """Prepare a quoted table and schema name.""" |
---|
| 1141 | |
---|
| 1142 | if name is None: |
---|
| 1143 | name = table.name |
---|
| 1144 | result = self.quote(name, table.quote) |
---|
| 1145 | if not self.omit_schema and use_schema and getattr(table, "schema", None): |
---|
| 1146 | result = self.quote_schema(table.schema, table.quote_schema) + "." + result |
---|
| 1147 | return result |
---|
| 1148 | |
---|
| 1149 | def format_column(self, column, use_table=False, name=None, table_name=None): |
---|
| 1150 | """Prepare a quoted column name.""" |
---|
| 1151 | |
---|
| 1152 | if name is None: |
---|
| 1153 | name = column.name |
---|
| 1154 | if not getattr(column, 'is_literal', False): |
---|
| 1155 | if use_table: |
---|
| 1156 | return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) |
---|
| 1157 | else: |
---|
| 1158 | return self.quote(name, column.quote) |
---|
| 1159 | else: |
---|
| 1160 | # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted |
---|
| 1161 | if use_table: |
---|
| 1162 | return self.format_table(column.table, use_schema=False, name=table_name) + "." + name |
---|
| 1163 | else: |
---|
| 1164 | return name |
---|
| 1165 | |
---|
| 1166 | def format_table_seq(self, table, use_schema=True): |
---|
| 1167 | """Format table name and schema as a tuple.""" |
---|
| 1168 | |
---|
| 1169 | # Dialects with more levels in their fully qualified references |
---|
| 1170 | # ('database', 'owner', etc.) could override this and return |
---|
| 1171 | # a longer sequence. |
---|
| 1172 | |
---|
| 1173 | if not self.omit_schema and use_schema and getattr(table, 'schema', None): |
---|
| 1174 | return (self.quote_schema(table.schema, table.quote_schema), |
---|
| 1175 | self.format_table(table, use_schema=False)) |
---|
| 1176 | else: |
---|
| 1177 | return (self.format_table(table, use_schema=False), ) |
---|
| 1178 | |
---|
| 1179 | def unformat_identifiers(self, identifiers): |
---|
| 1180 | """Unpack 'schema.table.column'-like strings into components.""" |
---|
| 1181 | |
---|
| 1182 | try: |
---|
| 1183 | r = self._r_identifiers |
---|
| 1184 | except AttributeError: |
---|
| 1185 | initial, final, escaped_final = \ |
---|
| 1186 | [re.escape(s) for s in |
---|
| 1187 | (self.initial_quote, self.final_quote, |
---|
| 1188 | self._escape_identifier(self.final_quote))] |
---|
| 1189 | r = re.compile( |
---|
| 1190 | r'(?:' |
---|
| 1191 | r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' |
---|
| 1192 | r'|([^\.]+))(?=\.|$))+' % |
---|
| 1193 | { 'initial': initial, |
---|
| 1194 | 'final': final, |
---|
| 1195 | 'escaped': escaped_final }) |
---|
| 1196 | self._r_identifiers = r |
---|
| 1197 | |
---|
| 1198 | return [self._unescape_identifier(i) |
---|
| 1199 | for i in [a or b for a, b in r.findall(identifiers)]] |
---|