| 1 | # postgres.py |
|---|
| 2 | # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com |
|---|
| 3 | # |
|---|
| 4 | # This module is part of SQLAlchemy and is released under |
|---|
| 5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php |
|---|
| 6 | |
|---|
| 7 | """Support for the PostgreSQL database. |
|---|
| 8 | |
|---|
| 9 | Driver |
|---|
| 10 | ------ |
|---|
| 11 | |
|---|
| 12 | The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . |
|---|
| 13 | The dialect has several behaviors which are specifically tailored towards compatibility |
|---|
| 14 | with this module. |
|---|
| 15 | |
|---|
| 16 | Note that psycopg1 is **not** supported. |
|---|
| 17 | |
|---|
| 18 | Connecting |
|---|
| 19 | ---------- |
|---|
| 20 | |
|---|
| 21 | URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`. |
|---|
| 22 | |
|---|
| 23 | PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: |
|---|
| 24 | |
|---|
| 25 | * *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support |
|---|
| 26 | this feature. What this essentially means from a psycopg2 point of view is that the cursor is |
|---|
| 27 | created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows |
|---|
| 28 | are not immediately pre-fetched and buffered after statement execution, but are instead left |
|---|
| 29 | on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` |
|---|
| 30 | uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows |
|---|
| 31 | at a time are fetched over the wire to reduce conversational overhead. |
|---|
| 32 | |
|---|
| 33 | Sequences/SERIAL |
|---|
| 34 | ---------------- |
|---|
| 35 | |
|---|
| 36 | PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating |
|---|
| 37 | new primary key values for integer-based primary key columns. When creating tables, |
|---|
| 38 | SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, |
|---|
| 39 | which generates a sequence corresponding to the column and associated with it based on |
|---|
| 40 | a naming convention. |
|---|
| 41 | |
|---|
| 42 | To specify a specific named sequence to be used for primary key generation, use the |
|---|
| 43 | :func:`~sqlalchemy.schema.Sequence` construct:: |
|---|
| 44 | |
|---|
| 45 | Table('sometable', metadata, |
|---|
| 46 | Column('id', Integer, Sequence('some_id_seq'), primary_key=True) |
|---|
| 47 | ) |
|---|
| 48 | |
|---|
| 49 | Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of |
|---|
| 50 | having the "last insert identifier" available, the sequence is executed independently |
|---|
| 51 | beforehand and the new value is retrieved, to be used in the subsequent insert. Note |
|---|
| 52 | that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using |
|---|
| 53 | "executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior |
|---|
| 54 | is used. |
|---|
| 55 | |
|---|
| 56 | PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports |
|---|
| 57 | as well. A future release of SQLA will use this feature by default in lieu of |
|---|
| 58 | sequence pre-execution in order to retrieve new primary key values, when available. |
|---|
| 59 | |
|---|
| 60 | INSERT/UPDATE...RETURNING |
|---|
| 61 | ------------------------- |
|---|
| 62 | |
|---|
| 63 | The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, |
|---|
| 64 | but must be explicitly enabled on a per-statement basis:: |
|---|
| 65 | |
|---|
| 66 | # INSERT..RETURNING |
|---|
| 67 | result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\ |
|---|
| 68 | values(name='foo') |
|---|
| 69 | print result.fetchall() |
|---|
| 70 | |
|---|
| 71 | # UPDATE..RETURNING |
|---|
| 72 | result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\ |
|---|
| 73 | where(table.c.name=='foo').values(name='bar') |
|---|
| 74 | print result.fetchall() |
|---|
| 75 | |
|---|
| 76 | Indexes |
|---|
| 77 | ------- |
|---|
| 78 | |
|---|
| 79 | PostgreSQL supports partial indexes. To create them pass a postgres_where |
|---|
| 80 | option to the Index constructor:: |
|---|
| 81 | |
|---|
| 82 | Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) |
|---|
| 83 | |
|---|
| 84 | Transactions |
|---|
| 85 | ------------ |
|---|
| 86 | |
|---|
| 87 | The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations. |
|---|
| 88 | |
|---|
| 89 | |
|---|
| 90 | """ |
|---|
| 91 | |
|---|
| 92 | import decimal, random, re, string |
|---|
| 93 | |
|---|
| 94 | from sqlalchemy import sql, schema, exc, util |
|---|
| 95 | from sqlalchemy.engine import base, default |
|---|
| 96 | from sqlalchemy.sql import compiler, expression |
|---|
| 97 | from sqlalchemy.sql import operators as sql_operators |
|---|
| 98 | from sqlalchemy import types as sqltypes |
|---|
| 99 | |
|---|
| 100 | |
|---|
| 101 | class PGInet(sqltypes.TypeEngine): |
|---|
| 102 | def get_col_spec(self): |
|---|
| 103 | return "INET" |
|---|
| 104 | |
|---|
| 105 | class PGCidr(sqltypes.TypeEngine): |
|---|
| 106 | def get_col_spec(self): |
|---|
| 107 | return "CIDR" |
|---|
| 108 | |
|---|
| 109 | class PGMacAddr(sqltypes.TypeEngine): |
|---|
| 110 | def get_col_spec(self): |
|---|
| 111 | return "MACADDR" |
|---|
| 112 | |
|---|
| 113 | class PGNumeric(sqltypes.Numeric): |
|---|
| 114 | def get_col_spec(self): |
|---|
| 115 | if not self.precision: |
|---|
| 116 | return "NUMERIC" |
|---|
| 117 | else: |
|---|
| 118 | return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} |
|---|
| 119 | |
|---|
| 120 | def bind_processor(self, dialect): |
|---|
| 121 | return None |
|---|
| 122 | |
|---|
| 123 | def result_processor(self, dialect): |
|---|
| 124 | if self.asdecimal: |
|---|
| 125 | return None |
|---|
| 126 | else: |
|---|
| 127 | def process(value): |
|---|
| 128 | if isinstance(value, decimal.Decimal): |
|---|
| 129 | return float(value) |
|---|
| 130 | else: |
|---|
| 131 | return value |
|---|
| 132 | return process |
|---|
| 133 | |
|---|
| 134 | class PGFloat(sqltypes.Float): |
|---|
| 135 | def get_col_spec(self): |
|---|
| 136 | if not self.precision: |
|---|
| 137 | return "FLOAT" |
|---|
| 138 | else: |
|---|
| 139 | return "FLOAT(%(precision)s)" % {'precision': self.precision} |
|---|
| 140 | |
|---|
| 141 | |
|---|
| 142 | class PGInteger(sqltypes.Integer): |
|---|
| 143 | def get_col_spec(self): |
|---|
| 144 | return "INTEGER" |
|---|
| 145 | |
|---|
| 146 | class PGSmallInteger(sqltypes.Smallinteger): |
|---|
| 147 | def get_col_spec(self): |
|---|
| 148 | return "SMALLINT" |
|---|
| 149 | |
|---|
| 150 | class PGBigInteger(PGInteger): |
|---|
| 151 | def get_col_spec(self): |
|---|
| 152 | return "BIGINT" |
|---|
| 153 | |
|---|
| 154 | class PGDateTime(sqltypes.DateTime): |
|---|
| 155 | def get_col_spec(self): |
|---|
| 156 | return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" |
|---|
| 157 | |
|---|
| 158 | class PGDate(sqltypes.Date): |
|---|
| 159 | def get_col_spec(self): |
|---|
| 160 | return "DATE" |
|---|
| 161 | |
|---|
| 162 | class PGTime(sqltypes.Time): |
|---|
| 163 | def get_col_spec(self): |
|---|
| 164 | return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" |
|---|
| 165 | |
|---|
| 166 | class PGInterval(sqltypes.TypeEngine): |
|---|
| 167 | def get_col_spec(self): |
|---|
| 168 | return "INTERVAL" |
|---|
| 169 | |
|---|
| 170 | class PGText(sqltypes.Text): |
|---|
| 171 | def get_col_spec(self): |
|---|
| 172 | return "TEXT" |
|---|
| 173 | |
|---|
| 174 | class PGString(sqltypes.String): |
|---|
| 175 | def get_col_spec(self): |
|---|
| 176 | if self.length: |
|---|
| 177 | return "VARCHAR(%(length)d)" % {'length' : self.length} |
|---|
| 178 | else: |
|---|
| 179 | return "VARCHAR" |
|---|
| 180 | |
|---|
| 181 | class PGChar(sqltypes.CHAR): |
|---|
| 182 | def get_col_spec(self): |
|---|
| 183 | if self.length: |
|---|
| 184 | return "CHAR(%(length)d)" % {'length' : self.length} |
|---|
| 185 | else: |
|---|
| 186 | return "CHAR" |
|---|
| 187 | |
|---|
| 188 | class PGBinary(sqltypes.Binary): |
|---|
| 189 | def get_col_spec(self): |
|---|
| 190 | return "BYTEA" |
|---|
| 191 | |
|---|
| 192 | class PGBoolean(sqltypes.Boolean): |
|---|
| 193 | def get_col_spec(self): |
|---|
| 194 | return "BOOLEAN" |
|---|
| 195 | |
|---|
| 196 | class PGBit(sqltypes.TypeEngine): |
|---|
| 197 | def get_col_spec(self): |
|---|
| 198 | return "BIT" |
|---|
| 199 | |
|---|
| 200 | class PGUuid(sqltypes.TypeEngine): |
|---|
| 201 | def get_col_spec(self): |
|---|
| 202 | return "UUID" |
|---|
| 203 | |
|---|
| 204 | class PGDoublePrecision(sqltypes.Float): |
|---|
| 205 | def get_col_spec(self): |
|---|
| 206 | return "DOUBLE PRECISION" |
|---|
| 207 | |
|---|
| 208 | class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): |
|---|
| 209 | def __init__(self, item_type, mutable=True): |
|---|
| 210 | if isinstance(item_type, type): |
|---|
| 211 | item_type = item_type() |
|---|
| 212 | self.item_type = item_type |
|---|
| 213 | self.mutable = mutable |
|---|
| 214 | |
|---|
| 215 | def copy_value(self, value): |
|---|
| 216 | if value is None: |
|---|
| 217 | return None |
|---|
| 218 | elif self.mutable: |
|---|
| 219 | return list(value) |
|---|
| 220 | else: |
|---|
| 221 | return value |
|---|
| 222 | |
|---|
| 223 | def compare_values(self, x, y): |
|---|
| 224 | return x == y |
|---|
| 225 | |
|---|
| 226 | def is_mutable(self): |
|---|
| 227 | return self.mutable |
|---|
| 228 | |
|---|
| 229 | def dialect_impl(self, dialect, **kwargs): |
|---|
| 230 | impl = self.__class__.__new__(self.__class__) |
|---|
| 231 | impl.__dict__.update(self.__dict__) |
|---|
| 232 | impl.item_type = self.item_type.dialect_impl(dialect) |
|---|
| 233 | return impl |
|---|
| 234 | |
|---|
| 235 | def bind_processor(self, dialect): |
|---|
| 236 | item_proc = self.item_type.bind_processor(dialect) |
|---|
| 237 | def process(value): |
|---|
| 238 | if value is None: |
|---|
| 239 | return value |
|---|
| 240 | def convert_item(item): |
|---|
| 241 | if isinstance(item, (list, tuple)): |
|---|
| 242 | return [convert_item(child) for child in item] |
|---|
| 243 | else: |
|---|
| 244 | if item_proc: |
|---|
| 245 | return item_proc(item) |
|---|
| 246 | else: |
|---|
| 247 | return item |
|---|
| 248 | return [convert_item(item) for item in value] |
|---|
| 249 | return process |
|---|
| 250 | |
|---|
| 251 | def result_processor(self, dialect): |
|---|
| 252 | item_proc = self.item_type.result_processor(dialect) |
|---|
| 253 | def process(value): |
|---|
| 254 | if value is None: |
|---|
| 255 | return value |
|---|
| 256 | def convert_item(item): |
|---|
| 257 | if isinstance(item, list): |
|---|
| 258 | return [convert_item(child) for child in item] |
|---|
| 259 | else: |
|---|
| 260 | if item_proc: |
|---|
| 261 | return item_proc(item) |
|---|
| 262 | else: |
|---|
| 263 | return item |
|---|
| 264 | return [convert_item(item) for item in value] |
|---|
| 265 | return process |
|---|
| 266 | def get_col_spec(self): |
|---|
| 267 | return self.item_type.get_col_spec() + '[]' |
|---|
| 268 | |
|---|
| 269 | colspecs = { |
|---|
| 270 | sqltypes.Integer : PGInteger, |
|---|
| 271 | sqltypes.Smallinteger : PGSmallInteger, |
|---|
| 272 | sqltypes.Numeric : PGNumeric, |
|---|
| 273 | sqltypes.Float : PGFloat, |
|---|
| 274 | PGDoublePrecision : PGDoublePrecision, |
|---|
| 275 | sqltypes.DateTime : PGDateTime, |
|---|
| 276 | sqltypes.Date : PGDate, |
|---|
| 277 | sqltypes.Time : PGTime, |
|---|
| 278 | sqltypes.String : PGString, |
|---|
| 279 | sqltypes.Binary : PGBinary, |
|---|
| 280 | sqltypes.Boolean : PGBoolean, |
|---|
| 281 | sqltypes.Text : PGText, |
|---|
| 282 | sqltypes.CHAR: PGChar, |
|---|
| 283 | } |
|---|
| 284 | |
|---|
| 285 | ischema_names = { |
|---|
| 286 | 'integer' : PGInteger, |
|---|
| 287 | 'bigint' : PGBigInteger, |
|---|
| 288 | 'smallint' : PGSmallInteger, |
|---|
| 289 | 'character varying' : PGString, |
|---|
| 290 | 'character' : PGChar, |
|---|
| 291 | '"char"' : PGChar, |
|---|
| 292 | 'name': PGChar, |
|---|
| 293 | 'text' : PGText, |
|---|
| 294 | 'numeric' : PGNumeric, |
|---|
| 295 | 'float' : PGFloat, |
|---|
| 296 | 'real' : PGFloat, |
|---|
| 297 | 'inet': PGInet, |
|---|
| 298 | 'cidr': PGCidr, |
|---|
| 299 | 'uuid':PGUuid, |
|---|
| 300 | 'bit':PGBit, |
|---|
| 301 | 'macaddr': PGMacAddr, |
|---|
| 302 | 'double precision' : PGDoublePrecision, |
|---|
| 303 | 'timestamp' : PGDateTime, |
|---|
| 304 | 'timestamp with time zone' : PGDateTime, |
|---|
| 305 | 'timestamp without time zone' : PGDateTime, |
|---|
| 306 | 'time with time zone' : PGTime, |
|---|
| 307 | 'time without time zone' : PGTime, |
|---|
| 308 | 'date' : PGDate, |
|---|
| 309 | 'time': PGTime, |
|---|
| 310 | 'bytea' : PGBinary, |
|---|
| 311 | 'boolean' : PGBoolean, |
|---|
| 312 | 'interval':PGInterval, |
|---|
| 313 | 'interval year to month':PGInterval, |
|---|
| 314 | 'interval day to second':PGInterval, |
|---|
| 315 | } |
|---|
| 316 | |
|---|
| 317 | # TODO: filter out 'FOR UPDATE' statements |
|---|
| 318 | SERVER_SIDE_CURSOR_RE = re.compile( |
|---|
| 319 | r'\s*SELECT', |
|---|
| 320 | re.I | re.UNICODE) |
|---|
| 321 | |
|---|
| 322 | class PGExecutionContext(default.DefaultExecutionContext): |
|---|
| 323 | def create_cursor(self): |
|---|
| 324 | # TODO: coverage for server side cursors + select.for_update() |
|---|
| 325 | is_server_side = \ |
|---|
| 326 | self.dialect.server_side_cursors and \ |
|---|
| 327 | ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) |
|---|
| 328 | and not getattr(self.compiled.statement, 'for_update', False)) \ |
|---|
| 329 | or \ |
|---|
| 330 | ( |
|---|
| 331 | (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) |
|---|
| 332 | and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) |
|---|
| 333 | ) |
|---|
| 334 | |
|---|
| 335 | self.__is_server_side = is_server_side |
|---|
| 336 | if is_server_side: |
|---|
| 337 | # use server-side cursors: |
|---|
| 338 | # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html |
|---|
| 339 | ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) |
|---|
| 340 | return self._connection.connection.cursor(ident) |
|---|
| 341 | else: |
|---|
| 342 | return self._connection.connection.cursor() |
|---|
| 343 | |
|---|
| 344 | def get_result_proxy(self): |
|---|
| 345 | if self.__is_server_side: |
|---|
| 346 | return base.BufferedRowResultProxy(self) |
|---|
| 347 | else: |
|---|
| 348 | return base.ResultProxy(self) |
|---|
| 349 | |
|---|
| 350 | class PGDialect(default.DefaultDialect): |
|---|
| 351 | name = 'postgres' |
|---|
| 352 | supports_alter = True |
|---|
| 353 | supports_unicode_statements = False |
|---|
| 354 | max_identifier_length = 63 |
|---|
| 355 | supports_sane_rowcount = True |
|---|
| 356 | supports_sane_multi_rowcount = False |
|---|
| 357 | preexecute_pk_sequences = True |
|---|
| 358 | supports_pk_autoincrement = False |
|---|
| 359 | default_paramstyle = 'pyformat' |
|---|
| 360 | supports_default_values = True |
|---|
| 361 | supports_empty_insert = False |
|---|
| 362 | |
|---|
| 363 | def __init__(self, server_side_cursors=False, **kwargs): |
|---|
| 364 | default.DefaultDialect.__init__(self, **kwargs) |
|---|
| 365 | self.server_side_cursors = server_side_cursors |
|---|
| 366 | |
|---|
| 367 | def dbapi(cls): |
|---|
| 368 | import psycopg2 as psycopg |
|---|
| 369 | return psycopg |
|---|
| 370 | dbapi = classmethod(dbapi) |
|---|
| 371 | |
|---|
| 372 | def create_connect_args(self, url): |
|---|
| 373 | opts = url.translate_connect_args(username='user') |
|---|
| 374 | if 'port' in opts: |
|---|
| 375 | opts['port'] = int(opts['port']) |
|---|
| 376 | opts.update(url.query) |
|---|
| 377 | return ([], opts) |
|---|
| 378 | |
|---|
| 379 | def type_descriptor(self, typeobj): |
|---|
| 380 | return sqltypes.adapt_type(typeobj, colspecs) |
|---|
| 381 | |
|---|
| 382 | def do_begin_twophase(self, connection, xid): |
|---|
| 383 | self.do_begin(connection.connection) |
|---|
| 384 | |
|---|
| 385 | def do_prepare_twophase(self, connection, xid): |
|---|
| 386 | connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)])) |
|---|
| 387 | |
|---|
| 388 | def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): |
|---|
| 389 | if is_prepared: |
|---|
| 390 | if recover: |
|---|
| 391 | #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions |
|---|
| 392 | # Must find out a way how to make the dbapi not open a transaction. |
|---|
| 393 | connection.execute(sql.text("ROLLBACK")) |
|---|
| 394 | connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) |
|---|
| 395 | connection.execute(sql.text("BEGIN")) |
|---|
| 396 | self.do_rollback(connection.connection) |
|---|
| 397 | else: |
|---|
| 398 | self.do_rollback(connection.connection) |
|---|
| 399 | |
|---|
| 400 | def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): |
|---|
| 401 | if is_prepared: |
|---|
| 402 | if recover: |
|---|
| 403 | connection.execute(sql.text("ROLLBACK")) |
|---|
| 404 | connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) |
|---|
| 405 | connection.execute(sql.text("BEGIN")) |
|---|
| 406 | self.do_rollback(connection.connection) |
|---|
| 407 | else: |
|---|
| 408 | self.do_commit(connection.connection) |
|---|
| 409 | |
|---|
| 410 | def do_recover_twophase(self, connection): |
|---|
| 411 | resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) |
|---|
| 412 | return [row[0] for row in resultset] |
|---|
| 413 | |
|---|
| 414 | def get_default_schema_name(self, connection): |
|---|
| 415 | return connection.scalar("select current_schema()", None) |
|---|
| 416 | get_default_schema_name = base.connection_memoize( |
|---|
| 417 | ('dialect', 'default_schema_name'))(get_default_schema_name) |
|---|
| 418 | |
|---|
| 419 | def last_inserted_ids(self): |
|---|
| 420 | if self.context.last_inserted_ids is None: |
|---|
| 421 | raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled") |
|---|
| 422 | else: |
|---|
| 423 | return self.context.last_inserted_ids |
|---|
| 424 | |
|---|
| 425 | def has_table(self, connection, table_name, schema=None): |
|---|
| 426 | # seems like case gets folded in pg_class... |
|---|
| 427 | if schema is None: |
|---|
| 428 | cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)}); |
|---|
| 429 | else: |
|---|
| 430 | cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema}); |
|---|
| 431 | try: |
|---|
| 432 | return bool(cursor.fetchone()) |
|---|
| 433 | finally: |
|---|
| 434 | cursor.close() |
|---|
| 435 | |
|---|
| 436 | def has_sequence(self, connection, sequence_name, schema=None): |
|---|
| 437 | if schema is None: |
|---|
| 438 | cursor = connection.execute( |
|---|
| 439 | sql.text("SELECT relname FROM pg_class c join pg_namespace n on " |
|---|
| 440 | "n.oid=c.relnamespace where relkind='S' and n.nspname=current_schema() and lower(relname)=:name", |
|---|
| 441 | bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode)] |
|---|
| 442 | ) |
|---|
| 443 | ) |
|---|
| 444 | else: |
|---|
| 445 | cursor = connection.execute( |
|---|
| 446 | sql.text("SELECT relname FROM pg_class c join pg_namespace n on " |
|---|
| 447 | "n.oid=c.relnamespace where relkind='S' and n.nspname=:schema and lower(relname)=:name", |
|---|
| 448 | bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode), |
|---|
| 449 | sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] |
|---|
| 450 | ) |
|---|
| 451 | ) |
|---|
| 452 | |
|---|
| 453 | try: |
|---|
| 454 | return bool(cursor.fetchone()) |
|---|
| 455 | finally: |
|---|
| 456 | cursor.close() |
|---|
| 457 | |
|---|
| 458 | def is_disconnect(self, e): |
|---|
| 459 | if isinstance(e, self.dbapi.OperationalError): |
|---|
| 460 | return 'closed the connection' in str(e) or 'connection not open' in str(e) |
|---|
| 461 | elif isinstance(e, self.dbapi.InterfaceError): |
|---|
| 462 | return 'connection already closed' in str(e) or 'cursor already closed' in str(e) |
|---|
| 463 | elif isinstance(e, self.dbapi.ProgrammingError): |
|---|
| 464 | # yes, it really says "losed", not "closed" |
|---|
| 465 | return "losed the connection unexpectedly" in str(e) |
|---|
| 466 | else: |
|---|
| 467 | return False |
|---|
| 468 | |
|---|
| 469 | def table_names(self, connection, schema): |
|---|
| 470 | s = """ |
|---|
| 471 | SELECT relname |
|---|
| 472 | FROM pg_class c |
|---|
| 473 | WHERE relkind = 'r' |
|---|
| 474 | AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) |
|---|
| 475 | """ % locals() |
|---|
| 476 | return [row[0].decode(self.encoding) for row in connection.execute(s)] |
|---|
| 477 | |
|---|
| 478 | def server_version_info(self, connection): |
|---|
| 479 | v = connection.execute("select version()").scalar() |
|---|
| 480 | m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) |
|---|
| 481 | if not m: |
|---|
| 482 | raise AssertionError("Could not determine version from string '%s'" % v) |
|---|
| 483 | return tuple([int(x) for x in m.group(1, 2, 3)]) |
|---|
| 484 | |
|---|
| 485 | def reflecttable(self, connection, table, include_columns): |
|---|
| 486 | preparer = self.identifier_preparer |
|---|
| 487 | if table.schema is not None: |
|---|
| 488 | schema_where_clause = "n.nspname = :schema" |
|---|
| 489 | schemaname = table.schema |
|---|
| 490 | if isinstance(schemaname, str): |
|---|
| 491 | schemaname = schemaname.decode(self.encoding) |
|---|
| 492 | else: |
|---|
| 493 | schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" |
|---|
| 494 | schemaname = None |
|---|
| 495 | |
|---|
| 496 | SQL_COLS = """ |
|---|
| 497 | SELECT a.attname, |
|---|
| 498 | pg_catalog.format_type(a.atttypid, a.atttypmod), |
|---|
| 499 | (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d |
|---|
| 500 | WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) |
|---|
| 501 | AS DEFAULT, |
|---|
| 502 | a.attnotnull, a.attnum, a.attrelid as table_oid |
|---|
| 503 | FROM pg_catalog.pg_attribute a |
|---|
| 504 | WHERE a.attrelid = ( |
|---|
| 505 | SELECT c.oid |
|---|
| 506 | FROM pg_catalog.pg_class c |
|---|
| 507 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace |
|---|
| 508 | WHERE (%s) |
|---|
| 509 | AND c.relname = :table_name AND c.relkind in ('r','v') |
|---|
| 510 | ) AND a.attnum > 0 AND NOT a.attisdropped |
|---|
| 511 | ORDER BY a.attnum |
|---|
| 512 | """ % schema_where_clause |
|---|
| 513 | |
|---|
| 514 | s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) |
|---|
| 515 | tablename = table.name |
|---|
| 516 | if isinstance(tablename, str): |
|---|
| 517 | tablename = tablename.decode(self.encoding) |
|---|
| 518 | c = connection.execute(s, table_name=tablename, schema=schemaname) |
|---|
| 519 | rows = c.fetchall() |
|---|
| 520 | |
|---|
| 521 | if not rows: |
|---|
| 522 | raise exc.NoSuchTableError(table.name) |
|---|
| 523 | |
|---|
| 524 | domains = self._load_domains(connection) |
|---|
| 525 | |
|---|
| 526 | for name, format_type, default, notnull, attnum, table_oid in rows: |
|---|
| 527 | if include_columns and name not in include_columns: |
|---|
| 528 | continue |
|---|
| 529 | |
|---|
| 530 | ## strip (30) from character varying(30) |
|---|
| 531 | attype = re.search('([^\([]+)', format_type).group(1) |
|---|
| 532 | nullable = not notnull |
|---|
| 533 | is_array = format_type.endswith('[]') |
|---|
| 534 | |
|---|
| 535 | try: |
|---|
| 536 | charlen = re.search('\(([\d,]+)\)', format_type).group(1) |
|---|
| 537 | except: |
|---|
| 538 | charlen = False |
|---|
| 539 | |
|---|
| 540 | numericprec = False |
|---|
| 541 | numericscale = False |
|---|
| 542 | if attype == 'numeric': |
|---|
| 543 | if charlen is False: |
|---|
| 544 | numericprec, numericscale = (None, None) |
|---|
| 545 | else: |
|---|
| 546 | numericprec, numericscale = charlen.split(',') |
|---|
| 547 | charlen = False |
|---|
| 548 | elif attype == 'double precision': |
|---|
| 549 | numericprec, numericscale = (True, False) |
|---|
| 550 | charlen = False |
|---|
| 551 | elif attype == 'integer': |
|---|
| 552 | numericprec, numericscale = (32, 0) |
|---|
| 553 | charlen = False |
|---|
| 554 | |
|---|
| 555 | args = [] |
|---|
| 556 | for a in (charlen, numericprec, numericscale): |
|---|
| 557 | if a is None: |
|---|
| 558 | args.append(None) |
|---|
| 559 | elif a is not False: |
|---|
| 560 | args.append(int(a)) |
|---|
| 561 | |
|---|
| 562 | kwargs = {} |
|---|
| 563 | if attype == 'timestamp with time zone': |
|---|
| 564 | kwargs['timezone'] = True |
|---|
| 565 | elif attype == 'timestamp without time zone': |
|---|
| 566 | kwargs['timezone'] = False |
|---|
| 567 | |
|---|
| 568 | coltype = None |
|---|
| 569 | if attype in ischema_names: |
|---|
| 570 | coltype = ischema_names[attype] |
|---|
| 571 | else: |
|---|
| 572 | if attype in domains: |
|---|
| 573 | domain = domains[attype] |
|---|
| 574 | if domain['attype'] in ischema_names: |
|---|
| 575 | # A table can't override whether the domain is nullable. |
|---|
| 576 | nullable = domain['nullable'] |
|---|
| 577 | |
|---|
| 578 | if domain['default'] and not default: |
|---|
| 579 | # It can, however, override the default value, but can't set it to null. |
|---|
| 580 | default = domain['default'] |
|---|
| 581 | coltype = ischema_names[domain['attype']] |
|---|
| 582 | |
|---|
| 583 | if coltype: |
|---|
| 584 | coltype = coltype(*args, **kwargs) |
|---|
| 585 | if is_array: |
|---|
| 586 | coltype = PGArray(coltype) |
|---|
| 587 | else: |
|---|
| 588 | util.warn("Did not recognize type '%s' of column '%s'" % |
|---|
| 589 | (attype, name)) |
|---|
| 590 | coltype = sqltypes.NULLTYPE |
|---|
| 591 | |
|---|
| 592 | colargs = [] |
|---|
| 593 | if default is not None: |
|---|
| 594 | match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) |
|---|
| 595 | if match is not None: |
|---|
| 596 | # the default is related to a Sequence |
|---|
| 597 | sch = table.schema |
|---|
| 598 | if '.' not in match.group(2) and sch is not None: |
|---|
| 599 | # unconditionally quote the schema name. this could |
|---|
| 600 | # later be enhanced to obey quoting rules / "quote schema" |
|---|
| 601 | default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) |
|---|
| 602 | colargs.append(schema.DefaultClause(sql.text(default))) |
|---|
| 603 | table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) |
|---|
| 604 | |
|---|
| 605 | |
|---|
| 606 | # Primary keys |
|---|
| 607 | PK_SQL = """ |
|---|
| 608 | SELECT attname FROM pg_attribute |
|---|
| 609 | WHERE attrelid = ( |
|---|
| 610 | SELECT indexrelid FROM pg_index i |
|---|
| 611 | WHERE i.indrelid = :table |
|---|
| 612 | AND i.indisprimary = 't') |
|---|
| 613 | ORDER BY attnum |
|---|
| 614 | """ |
|---|
| 615 | t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) |
|---|
| 616 | c = connection.execute(t, table=table_oid) |
|---|
| 617 | for row in c.fetchall(): |
|---|
| 618 | pk = row[0] |
|---|
| 619 | if pk in table.c: |
|---|
| 620 | col = table.c[pk] |
|---|
| 621 | table.primary_key.add(col) |
|---|
| 622 | if col.default is None: |
|---|
| 623 | col.autoincrement = False |
|---|
| 624 | |
|---|
| 625 | # Foreign keys |
|---|
| 626 | FK_SQL = """ |
|---|
| 627 | SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef |
|---|
| 628 | FROM pg_catalog.pg_constraint r |
|---|
| 629 | WHERE r.conrelid = :table AND r.contype = 'f' |
|---|
| 630 | ORDER BY 1 |
|---|
| 631 | """ |
|---|
| 632 | |
|---|
| 633 | t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) |
|---|
| 634 | c = connection.execute(t, table=table_oid) |
|---|
| 635 | for conname, condef in c.fetchall(): |
|---|
| 636 | m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() |
|---|
| 637 | (constrained_columns, referred_schema, referred_table, referred_columns) = m |
|---|
| 638 | constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] |
|---|
| 639 | if referred_schema: |
|---|
| 640 | referred_schema = preparer._unquote_identifier(referred_schema) |
|---|
| 641 | elif table.schema is not None and table.schema == self.get_default_schema_name(connection): |
|---|
| 642 | # no schema (i.e. its the default schema), and the table we're |
|---|
| 643 | # reflecting has the default schema explicit, then use that. |
|---|
| 644 | # i.e. try to use the user's conventions |
|---|
| 645 | referred_schema = table.schema |
|---|
| 646 | referred_table = preparer._unquote_identifier(referred_table) |
|---|
| 647 | referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] |
|---|
| 648 | |
|---|
| 649 | refspec = [] |
|---|
| 650 | if referred_schema is not None: |
|---|
| 651 | schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, |
|---|
| 652 | autoload_with=connection) |
|---|
| 653 | for column in referred_columns: |
|---|
| 654 | refspec.append(".".join([referred_schema, referred_table, column])) |
|---|
| 655 | else: |
|---|
| 656 | schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) |
|---|
| 657 | for column in referred_columns: |
|---|
| 658 | refspec.append(".".join([referred_table, column])) |
|---|
| 659 | |
|---|
| 660 | table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) |
|---|
| 661 | |
|---|
| 662 | # Indexes |
|---|
| 663 | IDX_SQL = """ |
|---|
| 664 | SELECT c.relname, i.indisunique, i.indexprs, i.indpred, |
|---|
| 665 | a.attname |
|---|
| 666 | FROM pg_index i, pg_class c, pg_attribute a |
|---|
| 667 | WHERE i.indrelid = :table AND i.indexrelid = c.oid |
|---|
| 668 | AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' |
|---|
| 669 | ORDER BY c.relname, a.attnum |
|---|
| 670 | """ |
|---|
| 671 | t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) |
|---|
| 672 | c = connection.execute(t, table=table_oid) |
|---|
| 673 | indexes = {} |
|---|
| 674 | sv_idx_name = None |
|---|
| 675 | for row in c.fetchall(): |
|---|
| 676 | idx_name, unique, expr, prd, col = row |
|---|
| 677 | |
|---|
| 678 | if expr: |
|---|
| 679 | if not idx_name == sv_idx_name: |
|---|
| 680 | util.warn( |
|---|
| 681 | "Skipped unsupported reflection of expression-based index %s" |
|---|
| 682 | % idx_name) |
|---|
| 683 | sv_idx_name = idx_name |
|---|
| 684 | continue |
|---|
| 685 | if prd and not idx_name == sv_idx_name: |
|---|
| 686 | util.warn( |
|---|
| 687 | "Predicate of partial index %s ignored during reflection" |
|---|
| 688 | % idx_name) |
|---|
| 689 | sv_idx_name = idx_name |
|---|
| 690 | |
|---|
| 691 | if not indexes.has_key(idx_name): |
|---|
| 692 | indexes[idx_name] = [unique, []] |
|---|
| 693 | indexes[idx_name][1].append(col) |
|---|
| 694 | |
|---|
| 695 | for name, (unique, columns) in indexes.items(): |
|---|
| 696 | schema.Index(name, *[table.columns[c] for c in columns], |
|---|
| 697 | **dict(unique=unique)) |
|---|
| 698 | |
|---|
| 699 | |
|---|
| 700 | |
|---|
| 701 | def _load_domains(self, connection): |
|---|
| 702 | ## Load data types for domains: |
|---|
| 703 | SQL_DOMAINS = """ |
|---|
| 704 | SELECT t.typname as "name", |
|---|
| 705 | pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", |
|---|
| 706 | not t.typnotnull as "nullable", |
|---|
| 707 | t.typdefault as "default", |
|---|
| 708 | pg_catalog.pg_type_is_visible(t.oid) as "visible", |
|---|
| 709 | n.nspname as "schema" |
|---|
| 710 | FROM pg_catalog.pg_type t |
|---|
| 711 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace |
|---|
| 712 | LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid |
|---|
| 713 | WHERE t.typtype = 'd' |
|---|
| 714 | """ |
|---|
| 715 | |
|---|
| 716 | s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode}) |
|---|
| 717 | c = connection.execute(s) |
|---|
| 718 | |
|---|
| 719 | domains = {} |
|---|
| 720 | for domain in c.fetchall(): |
|---|
| 721 | ## strip (30) from character varying(30) |
|---|
| 722 | attype = re.search('([^\(]+)', domain['attype']).group(1) |
|---|
| 723 | if domain['visible']: |
|---|
| 724 | # 'visible' just means whether or not the domain is in a |
|---|
| 725 | # schema that's on the search path -- or not overriden by |
|---|
| 726 | # a schema with higher presedence. If it's not visible, |
|---|
| 727 | # it will be prefixed with the schema-name when it's used. |
|---|
| 728 | name = domain['name'] |
|---|
| 729 | else: |
|---|
| 730 | name = "%s.%s" % (domain['schema'], domain['name']) |
|---|
| 731 | |
|---|
| 732 | domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']} |
|---|
| 733 | |
|---|
| 734 | return domains |
|---|
| 735 | |
|---|
| 736 | |
|---|
| 737 | class PGCompiler(compiler.DefaultCompiler): |
|---|
| 738 | operators = compiler.DefaultCompiler.operators.copy() |
|---|
| 739 | operators.update( |
|---|
| 740 | { |
|---|
| 741 | sql_operators.mod : '%%', |
|---|
| 742 | sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
|---|
| 743 | sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), |
|---|
| 744 | sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), |
|---|
| 745 | } |
|---|
| 746 | ) |
|---|
| 747 | |
|---|
| 748 | functions = compiler.DefaultCompiler.functions.copy() |
|---|
| 749 | functions.update ( |
|---|
| 750 | { |
|---|
| 751 | 'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x), |
|---|
| 752 | } |
|---|
| 753 | ) |
|---|
| 754 | |
|---|
| 755 | def visit_sequence(self, seq): |
|---|
| 756 | if seq.optional: |
|---|
| 757 | return None |
|---|
| 758 | else: |
|---|
| 759 | return "nextval('%s')" % self.preparer.format_sequence(seq) |
|---|
| 760 | |
|---|
| 761 | def post_process_text(self, text): |
|---|
| 762 | if '%%' in text: |
|---|
| 763 | util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.") |
|---|
| 764 | return text.replace('%', '%%') |
|---|
| 765 | |
|---|
| 766 | def limit_clause(self, select): |
|---|
| 767 | text = "" |
|---|
| 768 | if select._limit is not None: |
|---|
| 769 | text += " \n LIMIT " + str(select._limit) |
|---|
| 770 | if select._offset is not None: |
|---|
| 771 | if select._limit is None: |
|---|
| 772 | text += " \n LIMIT ALL" |
|---|
| 773 | text += " OFFSET " + str(select._offset) |
|---|
| 774 | return text |
|---|
| 775 | |
|---|
| 776 | def get_select_precolumns(self, select): |
|---|
| 777 | if select._distinct: |
|---|
| 778 | if isinstance(select._distinct, bool): |
|---|
| 779 | return "DISTINCT " |
|---|
| 780 | elif isinstance(select._distinct, (list, tuple)): |
|---|
| 781 | return "DISTINCT ON (" + ', '.join( |
|---|
| 782 | [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] |
|---|
| 783 | )+ ") " |
|---|
| 784 | else: |
|---|
| 785 | return "DISTINCT ON (" + unicode(select._distinct) + ") " |
|---|
| 786 | else: |
|---|
| 787 | return "" |
|---|
| 788 | |
|---|
| 789 | def for_update_clause(self, select): |
|---|
| 790 | if select.for_update == 'nowait': |
|---|
| 791 | return " FOR UPDATE NOWAIT" |
|---|
| 792 | else: |
|---|
| 793 | return super(PGCompiler, self).for_update_clause(select) |
|---|
| 794 | |
|---|
| 795 | def _append_returning(self, text, stmt): |
|---|
| 796 | returning_cols = stmt.kwargs['postgres_returning'] |
|---|
| 797 | def flatten_columnlist(collist): |
|---|
| 798 | for c in collist: |
|---|
| 799 | if isinstance(c, expression.Selectable): |
|---|
| 800 | for co in c.columns: |
|---|
| 801 | yield co |
|---|
| 802 | else: |
|---|
| 803 | yield c |
|---|
| 804 | columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] |
|---|
| 805 | text += ' RETURNING ' + string.join(columns, ', ') |
|---|
| 806 | return text |
|---|
| 807 | |
|---|
| 808 | def visit_update(self, update_stmt): |
|---|
| 809 | text = super(PGCompiler, self).visit_update(update_stmt) |
|---|
| 810 | if 'postgres_returning' in update_stmt.kwargs: |
|---|
| 811 | return self._append_returning(text, update_stmt) |
|---|
| 812 | else: |
|---|
| 813 | return text |
|---|
| 814 | |
|---|
| 815 | def visit_insert(self, insert_stmt): |
|---|
| 816 | text = super(PGCompiler, self).visit_insert(insert_stmt) |
|---|
| 817 | if 'postgres_returning' in insert_stmt.kwargs: |
|---|
| 818 | return self._append_returning(text, insert_stmt) |
|---|
| 819 | else: |
|---|
| 820 | return text |
|---|
| 821 | |
|---|
| 822 | def visit_extract(self, extract, **kwargs): |
|---|
| 823 | field = self.extract_map.get(extract.field, extract.field) |
|---|
| 824 | return "EXTRACT(%s FROM %s::timestamp)" % ( |
|---|
| 825 | field, self.process(extract.expr)) |
|---|
| 826 | |
|---|
| 827 | |
|---|
| 828 | class PGSchemaGenerator(compiler.SchemaGenerator): |
|---|
| 829 | def get_column_specification(self, column, **kwargs): |
|---|
| 830 | colspec = self.preparer.format_column(column) |
|---|
| 831 | if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): |
|---|
| 832 | if isinstance(column.type, PGBigInteger): |
|---|
| 833 | colspec += " BIGSERIAL" |
|---|
| 834 | else: |
|---|
| 835 | colspec += " SERIAL" |
|---|
| 836 | else: |
|---|
| 837 | colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() |
|---|
| 838 | default = self.get_column_default_string(column) |
|---|
| 839 | if default is not None: |
|---|
| 840 | colspec += " DEFAULT " + default |
|---|
| 841 | |
|---|
| 842 | if not column.nullable: |
|---|
| 843 | colspec += " NOT NULL" |
|---|
| 844 | return colspec |
|---|
| 845 | |
|---|
| 846 | def visit_sequence(self, sequence): |
|---|
| 847 | if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema)): |
|---|
| 848 | self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) |
|---|
| 849 | self.execute() |
|---|
| 850 | |
|---|
| 851 | def visit_index(self, index): |
|---|
| 852 | preparer = self.preparer |
|---|
| 853 | self.append("CREATE ") |
|---|
| 854 | if index.unique: |
|---|
| 855 | self.append("UNIQUE ") |
|---|
| 856 | self.append("INDEX %s ON %s (%s)" \ |
|---|
| 857 | % (preparer.quote(self._validate_identifier(index.name, True), index.quote), |
|---|
| 858 | preparer.format_table(index.table), |
|---|
| 859 | string.join([preparer.format_column(c) for c in index.columns], ', '))) |
|---|
| 860 | whereclause = index.kwargs.get('postgres_where', None) |
|---|
| 861 | if whereclause is not None: |
|---|
| 862 | compiler = self._compile(whereclause, None) |
|---|
| 863 | # this might belong to the compiler class |
|---|
| 864 | inlined_clause = str(compiler) % dict( |
|---|
| 865 | [(key,bind.value) for key,bind in compiler.binds.iteritems()]) |
|---|
| 866 | self.append(" WHERE " + inlined_clause) |
|---|
| 867 | self.execute() |
|---|
| 868 | |
|---|
| 869 | class PGSchemaDropper(compiler.SchemaDropper): |
|---|
| 870 | def visit_sequence(self, sequence): |
|---|
| 871 | if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema)): |
|---|
| 872 | self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) |
|---|
| 873 | self.execute() |
|---|
| 874 | |
|---|
| 875 | class PGDefaultRunner(base.DefaultRunner): |
|---|
| 876 | def __init__(self, context): |
|---|
| 877 | base.DefaultRunner.__init__(self, context) |
|---|
| 878 | # craete cursor which won't conflict with a server-side cursor |
|---|
| 879 | self.cursor = context._connection.connection.cursor() |
|---|
| 880 | |
|---|
| 881 | def get_column_default(self, column, isinsert=True): |
|---|
| 882 | if column.primary_key: |
|---|
| 883 | # pre-execute passive defaults on primary keys |
|---|
| 884 | if (isinstance(column.server_default, schema.DefaultClause) and |
|---|
| 885 | column.server_default.arg is not None): |
|---|
| 886 | return self.execute_string("select %s" % column.server_default.arg) |
|---|
| 887 | elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): |
|---|
| 888 | sch = column.table.schema |
|---|
| 889 | # TODO: this has to build into the Sequence object so we can get the quoting |
|---|
| 890 | # logic from it |
|---|
| 891 | if sch is not None: |
|---|
| 892 | exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) |
|---|
| 893 | else: |
|---|
| 894 | exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) |
|---|
| 895 | return self.execute_string(exc.encode(self.dialect.encoding)) |
|---|
| 896 | |
|---|
| 897 | return super(PGDefaultRunner, self).get_column_default(column) |
|---|
| 898 | |
|---|
| 899 | def visit_sequence(self, seq): |
|---|
| 900 | if not seq.optional: |
|---|
| 901 | return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) |
|---|
| 902 | else: |
|---|
| 903 | return None |
|---|
| 904 | |
|---|
| 905 | class PGIdentifierPreparer(compiler.IdentifierPreparer): |
|---|
| 906 | def _unquote_identifier(self, value): |
|---|
| 907 | if value[0] == self.initial_quote: |
|---|
| 908 | value = value[1:-1].replace('""','"') |
|---|
| 909 | return value |
|---|
| 910 | |
|---|
| 911 | dialect = PGDialect |
|---|
| 912 | dialect.statement_compiler = PGCompiler |
|---|
| 913 | dialect.schemagenerator = PGSchemaGenerator |
|---|
| 914 | dialect.schemadropper = PGSchemaDropper |
|---|
| 915 | dialect.preparer = PGIdentifierPreparer |
|---|
| 916 | dialect.defaultrunner = PGDefaultRunner |
|---|
| 917 | dialect.execution_ctx_cls = PGExecutionContext |
|---|