root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/databases/postgres.py @ 3

リビジョン 3, 35.8 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1# postgres.py
2# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
3#
4# This module is part of SQLAlchemy and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""Support for the PostgreSQL database.
8
9Driver
10------
11
12The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
13The dialect has several behaviors  which are specifically tailored towards compatibility
14with this module.
15
16Note that psycopg1 is **not** supported.
17
18Connecting
19----------
20
21URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
22
23PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
24
25* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
26  this feature.  What this essentially means from a psycopg2 point of view is that the cursor is
27  created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
28  are not immediately pre-fetched and buffered after statement execution, but are instead left
29  on the server and only retrieved as needed.    SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
30  uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
31  at a time are fetched over the wire to reduce conversational overhead.
32
33Sequences/SERIAL
34----------------
35
36PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating
37new primary key values for integer-based primary key columns.   When creating tables,
38SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns,
39which generates a sequence corresponding to the column and associated with it based on
40a naming convention.
41
42To specify a specific named sequence to be used for primary key generation, use the
43:func:`~sqlalchemy.schema.Sequence` construct::
44
45    Table('sometable', metadata,
46            Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
47        )
48
49Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
50having the "last insert identifier" available, the sequence is executed independently
51beforehand and the new value is retrieved, to be used in the subsequent insert.  Note
52that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using
53"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
54is used.
55
56PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports
57as well.  A future release of SQLA will use this feature by default in lieu of
58sequence pre-execution in order to retrieve new primary key values, when available.
59
60INSERT/UPDATE...RETURNING
61-------------------------
62
63The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes,
64but must be explicitly enabled on a per-statement basis::
65
66    # INSERT..RETURNING
67    result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
68        values(name='foo')
69    print result.fetchall()
70   
71    # UPDATE..RETURNING
72    result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
73        where(table.c.name=='foo').values(name='bar')
74    print result.fetchall()
75
76Indexes
77-------
78
79PostgreSQL supports partial indexes. To create them pass a postgres_where
80option to the Index constructor::
81
82  Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
83
84Transactions
85------------
86
87The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations.
88
89
90"""
91
92import decimal, random, re, string
93
94from sqlalchemy import sql, schema, exc, util
95from sqlalchemy.engine import base, default
96from sqlalchemy.sql import compiler, expression
97from sqlalchemy.sql import operators as sql_operators
98from sqlalchemy import types as sqltypes
99
100
101class PGInet(sqltypes.TypeEngine):
102    def get_col_spec(self):
103        return "INET"
104
105class PGCidr(sqltypes.TypeEngine):
106    def get_col_spec(self):
107        return "CIDR"
108
109class PGMacAddr(sqltypes.TypeEngine):
110    def get_col_spec(self):
111        return "MACADDR"
112
113class PGNumeric(sqltypes.Numeric):
114    def get_col_spec(self):
115        if not self.precision:
116            return "NUMERIC"
117        else:
118            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
119
120    def bind_processor(self, dialect):
121        return None
122
123    def result_processor(self, dialect):
124        if self.asdecimal:
125            return None
126        else:
127            def process(value):
128                if isinstance(value, decimal.Decimal):
129                    return float(value)
130                else:
131                    return value
132            return process
133
134class PGFloat(sqltypes.Float):
135    def get_col_spec(self):
136        if not self.precision:
137            return "FLOAT"
138        else:
139            return "FLOAT(%(precision)s)" % {'precision': self.precision}
140
141
142class PGInteger(sqltypes.Integer):
143    def get_col_spec(self):
144        return "INTEGER"
145
146class PGSmallInteger(sqltypes.Smallinteger):
147    def get_col_spec(self):
148        return "SMALLINT"
149
150class PGBigInteger(PGInteger):
151    def get_col_spec(self):
152        return "BIGINT"
153
154class PGDateTime(sqltypes.DateTime):
155    def get_col_spec(self):
156        return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
157
158class PGDate(sqltypes.Date):
159    def get_col_spec(self):
160        return "DATE"
161
162class PGTime(sqltypes.Time):
163    def get_col_spec(self):
164        return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
165
166class PGInterval(sqltypes.TypeEngine):
167    def get_col_spec(self):
168        return "INTERVAL"
169
170class PGText(sqltypes.Text):
171    def get_col_spec(self):
172        return "TEXT"
173
174class PGString(sqltypes.String):
175    def get_col_spec(self):
176        if self.length:
177            return "VARCHAR(%(length)d)" % {'length' : self.length}
178        else:
179            return "VARCHAR"
180
181class PGChar(sqltypes.CHAR):
182    def get_col_spec(self):
183        if self.length:
184            return "CHAR(%(length)d)" % {'length' : self.length}
185        else:
186            return "CHAR"
187
188class PGBinary(sqltypes.Binary):
189    def get_col_spec(self):
190        return "BYTEA"
191
192class PGBoolean(sqltypes.Boolean):
193    def get_col_spec(self):
194        return "BOOLEAN"
195
196class PGBit(sqltypes.TypeEngine):
197    def get_col_spec(self):
198        return "BIT"
199       
200class PGUuid(sqltypes.TypeEngine):
201    def get_col_spec(self):
202        return "UUID"
203
204class PGDoublePrecision(sqltypes.Float):
205    def get_col_spec(self):
206        return "DOUBLE PRECISION"
207   
208class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
209    def __init__(self, item_type, mutable=True):
210        if isinstance(item_type, type):
211            item_type = item_type()
212        self.item_type = item_type
213        self.mutable = mutable
214
215    def copy_value(self, value):
216        if value is None:
217            return None
218        elif self.mutable:
219            return list(value)
220        else:
221            return value
222
223    def compare_values(self, x, y):
224        return x == y
225
226    def is_mutable(self):
227        return self.mutable
228
229    def dialect_impl(self, dialect, **kwargs):
230        impl = self.__class__.__new__(self.__class__)
231        impl.__dict__.update(self.__dict__)
232        impl.item_type = self.item_type.dialect_impl(dialect)
233        return impl
234
235    def bind_processor(self, dialect):
236        item_proc = self.item_type.bind_processor(dialect)
237        def process(value):
238            if value is None:
239                return value
240            def convert_item(item):
241                if isinstance(item, (list, tuple)):
242                    return [convert_item(child) for child in item]
243                else:
244                    if item_proc:
245                        return item_proc(item)
246                    else:
247                        return item
248            return [convert_item(item) for item in value]
249        return process
250
251    def result_processor(self, dialect):
252        item_proc = self.item_type.result_processor(dialect)
253        def process(value):
254            if value is None:
255                return value
256            def convert_item(item):
257                if isinstance(item, list):
258                    return [convert_item(child) for child in item]
259                else:
260                    if item_proc:
261                        return item_proc(item)
262                    else:
263                        return item
264            return [convert_item(item) for item in value]
265        return process
266    def get_col_spec(self):
267        return self.item_type.get_col_spec() + '[]'
268
269colspecs = {
270    sqltypes.Integer : PGInteger,
271    sqltypes.Smallinteger : PGSmallInteger,
272    sqltypes.Numeric : PGNumeric,
273    sqltypes.Float : PGFloat,
274    PGDoublePrecision : PGDoublePrecision,
275    sqltypes.DateTime : PGDateTime,
276    sqltypes.Date : PGDate,
277    sqltypes.Time : PGTime,
278    sqltypes.String : PGString,
279    sqltypes.Binary : PGBinary,
280    sqltypes.Boolean : PGBoolean,
281    sqltypes.Text : PGText,
282    sqltypes.CHAR: PGChar,
283}
284
285ischema_names = {
286    'integer' : PGInteger,
287    'bigint' : PGBigInteger,
288    'smallint' : PGSmallInteger,
289    'character varying' : PGString,
290    'character' : PGChar,
291    '"char"' : PGChar,
292    'name': PGChar,
293    'text' : PGText,
294    'numeric' : PGNumeric,
295    'float' : PGFloat,
296    'real' : PGFloat,
297    'inet': PGInet,
298    'cidr': PGCidr,
299    'uuid':PGUuid,
300    'bit':PGBit,
301    'macaddr': PGMacAddr,
302    'double precision' : PGDoublePrecision,
303    'timestamp' : PGDateTime,
304    'timestamp with time zone' : PGDateTime,
305    'timestamp without time zone' : PGDateTime,
306    'time with time zone' : PGTime,
307    'time without time zone' : PGTime,
308    'date' : PGDate,
309    'time': PGTime,
310    'bytea' : PGBinary,
311    'boolean' : PGBoolean,
312    'interval':PGInterval,
313    'interval year to month':PGInterval,
314    'interval day to second':PGInterval,
315}
316
317# TODO: filter out 'FOR UPDATE' statements
318SERVER_SIDE_CURSOR_RE = re.compile(
319    r'\s*SELECT',
320    re.I | re.UNICODE)
321
322class PGExecutionContext(default.DefaultExecutionContext):
323    def create_cursor(self):
324        # TODO: coverage for server side cursors + select.for_update()
325        is_server_side = \
326            self.dialect.server_side_cursors and \
327            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
328                and not getattr(self.compiled.statement, 'for_update', False)) \
329            or \
330            (
331                (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
332                and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
333            )
334
335        self.__is_server_side = is_server_side
336        if is_server_side:
337            # use server-side cursors:
338            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
339            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
340            return self._connection.connection.cursor(ident)
341        else:
342            return self._connection.connection.cursor()
343   
344    def get_result_proxy(self):
345        if self.__is_server_side:
346            return base.BufferedRowResultProxy(self)
347        else:
348            return base.ResultProxy(self)
349
350class PGDialect(default.DefaultDialect):
351    name = 'postgres'
352    supports_alter = True
353    supports_unicode_statements = False
354    max_identifier_length = 63
355    supports_sane_rowcount = True
356    supports_sane_multi_rowcount = False
357    preexecute_pk_sequences = True
358    supports_pk_autoincrement = False
359    default_paramstyle = 'pyformat'
360    supports_default_values = True
361    supports_empty_insert = False
362   
363    def __init__(self, server_side_cursors=False, **kwargs):
364        default.DefaultDialect.__init__(self, **kwargs)
365        self.server_side_cursors = server_side_cursors
366
367    def dbapi(cls):
368        import psycopg2 as psycopg
369        return psycopg
370    dbapi = classmethod(dbapi)
371
372    def create_connect_args(self, url):
373        opts = url.translate_connect_args(username='user')
374        if 'port' in opts:
375            opts['port'] = int(opts['port'])
376        opts.update(url.query)
377        return ([], opts)
378
379    def type_descriptor(self, typeobj):
380        return sqltypes.adapt_type(typeobj, colspecs)
381
382    def do_begin_twophase(self, connection, xid):
383        self.do_begin(connection.connection)
384
385    def do_prepare_twophase(self, connection, xid):
386        connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)]))
387
388    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
389        if is_prepared:
390            if recover:
391                #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
392                # Must find out a way how to make the dbapi not open a transaction.
393                connection.execute(sql.text("ROLLBACK"))
394            connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
395            connection.execute(sql.text("BEGIN"))
396            self.do_rollback(connection.connection)
397        else:
398            self.do_rollback(connection.connection)
399
400    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
401        if is_prepared:
402            if recover:
403                connection.execute(sql.text("ROLLBACK"))
404            connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
405            connection.execute(sql.text("BEGIN"))
406            self.do_rollback(connection.connection)
407        else:
408            self.do_commit(connection.connection)
409
410    def do_recover_twophase(self, connection):
411        resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
412        return [row[0] for row in resultset]
413
414    def get_default_schema_name(self, connection):
415        return connection.scalar("select current_schema()", None)
416    get_default_schema_name = base.connection_memoize(
417        ('dialect', 'default_schema_name'))(get_default_schema_name)
418
419    def last_inserted_ids(self):
420        if self.context.last_inserted_ids is None:
421            raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled")
422        else:
423            return self.context.last_inserted_ids
424
425    def has_table(self, connection, table_name, schema=None):
426        # seems like case gets folded in pg_class...
427        if schema is None:
428            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
429        else:
430            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
431        try:
432            return bool(cursor.fetchone())
433        finally:
434            cursor.close()
435
436    def has_sequence(self, connection, sequence_name, schema=None):
437        if schema is None:
438            cursor = connection.execute(
439                        sql.text("SELECT relname FROM pg_class c join pg_namespace n on "
440                            "n.oid=c.relnamespace where relkind='S' and n.nspname=current_schema() and lower(relname)=:name",
441                            bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode)]
442                        )
443                    )
444        else:
445            cursor = connection.execute(
446                        sql.text("SELECT relname FROM pg_class c join pg_namespace n on "
447                            "n.oid=c.relnamespace where relkind='S' and n.nspname=:schema and lower(relname)=:name",
448                            bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode),
449                                sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)]
450                        )
451                    )
452       
453        try:
454            return bool(cursor.fetchone())
455        finally:
456            cursor.close()
457
458    def is_disconnect(self, e):
459        if isinstance(e, self.dbapi.OperationalError):
460            return 'closed the connection' in str(e) or 'connection not open' in str(e)
461        elif isinstance(e, self.dbapi.InterfaceError):
462            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
463        elif isinstance(e, self.dbapi.ProgrammingError):
464            # yes, it really says "losed", not "closed"
465            return "losed the connection unexpectedly" in str(e)
466        else:
467            return False
468
469    def table_names(self, connection, schema):
470        s = """
471        SELECT relname
472        FROM pg_class c
473        WHERE relkind = 'r'
474          AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
475        """ % locals()
476        return [row[0].decode(self.encoding) for row in connection.execute(s)]
477
478    def server_version_info(self, connection):
479        v = connection.execute("select version()").scalar()
480        m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
481        if not m:
482            raise AssertionError("Could not determine version from string '%s'" % v)
483        return tuple([int(x) for x in m.group(1, 2, 3)])
484
485    def reflecttable(self, connection, table, include_columns):
486        preparer = self.identifier_preparer
487        if table.schema is not None:
488            schema_where_clause = "n.nspname = :schema"
489            schemaname = table.schema
490            if isinstance(schemaname, str):
491                schemaname = schemaname.decode(self.encoding)
492        else:
493            schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
494            schemaname = None
495
496        SQL_COLS = """
497            SELECT a.attname,
498              pg_catalog.format_type(a.atttypid, a.atttypmod),
499              (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
500               WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
501              AS DEFAULT,
502              a.attnotnull, a.attnum, a.attrelid as table_oid
503            FROM pg_catalog.pg_attribute a
504            WHERE a.attrelid = (
505                SELECT c.oid
506                FROM pg_catalog.pg_class c
507                     LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
508                     WHERE (%s)
509                     AND c.relname = :table_name AND c.relkind in ('r','v')
510            ) AND a.attnum > 0 AND NOT a.attisdropped
511            ORDER BY a.attnum
512        """ % schema_where_clause
513
514        s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
515        tablename = table.name
516        if isinstance(tablename, str):
517            tablename = tablename.decode(self.encoding)
518        c = connection.execute(s, table_name=tablename, schema=schemaname)
519        rows = c.fetchall()
520
521        if not rows:
522            raise exc.NoSuchTableError(table.name)
523
524        domains = self._load_domains(connection)
525
526        for name, format_type, default, notnull, attnum, table_oid in rows:
527            if include_columns and name not in include_columns:
528                continue
529
530            ## strip (30) from character varying(30)
531            attype = re.search('([^\([]+)', format_type).group(1)
532            nullable = not notnull
533            is_array = format_type.endswith('[]')
534
535            try:
536                charlen = re.search('\(([\d,]+)\)', format_type).group(1)
537            except:
538                charlen = False
539
540            numericprec = False
541            numericscale = False
542            if attype == 'numeric':
543                if charlen is False:
544                    numericprec, numericscale = (None, None)
545                else:
546                    numericprec, numericscale = charlen.split(',')
547                charlen = False
548            elif attype == 'double precision':
549                numericprec, numericscale = (True, False)
550                charlen = False
551            elif attype == 'integer':
552                numericprec, numericscale = (32, 0)
553                charlen = False
554
555            args = []
556            for a in (charlen, numericprec, numericscale):
557                if a is None:
558                    args.append(None)
559                elif a is not False:
560                    args.append(int(a))
561
562            kwargs = {}
563            if attype == 'timestamp with time zone':
564                kwargs['timezone'] = True
565            elif attype == 'timestamp without time zone':
566                kwargs['timezone'] = False
567
568            coltype = None
569            if attype in ischema_names:
570                coltype = ischema_names[attype]
571            else:
572                if attype in domains:
573                    domain = domains[attype]
574                    if domain['attype'] in ischema_names:
575                        # A table can't override whether the domain is nullable.
576                        nullable = domain['nullable']
577
578                        if domain['default'] and not default:
579                            # It can, however, override the default value, but can't set it to null.
580                            default = domain['default']
581                        coltype = ischema_names[domain['attype']]
582
583            if coltype:
584                coltype = coltype(*args, **kwargs)
585                if is_array:
586                    coltype = PGArray(coltype)
587            else:
588                util.warn("Did not recognize type '%s' of column '%s'" %
589                          (attype, name))
590                coltype = sqltypes.NULLTYPE
591
592            colargs = []
593            if default is not None:
594                match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
595                if match is not None:
596                    # the default is related to a Sequence
597                    sch = table.schema
598                    if '.' not in match.group(2) and sch is not None:
599                        # unconditionally quote the schema name.  this could
600                        # later be enhanced to obey quoting rules / "quote schema"
601                        default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
602                colargs.append(schema.DefaultClause(sql.text(default)))
603            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
604
605
606        # Primary keys
607        PK_SQL = """
608          SELECT attname FROM pg_attribute
609          WHERE attrelid = (
610             SELECT indexrelid FROM pg_index i
611             WHERE i.indrelid = :table
612             AND i.indisprimary = 't')
613          ORDER BY attnum
614        """
615        t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
616        c = connection.execute(t, table=table_oid)
617        for row in c.fetchall():
618            pk = row[0]
619            if pk in table.c:
620                col = table.c[pk]
621                table.primary_key.add(col)
622                if col.default is None:
623                    col.autoincrement = False
624
625        # Foreign keys
626        FK_SQL = """
627          SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
628          FROM  pg_catalog.pg_constraint r
629          WHERE r.conrelid = :table AND r.contype = 'f'
630          ORDER BY 1
631        """
632
633        t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
634        c = connection.execute(t, table=table_oid)
635        for conname, condef in c.fetchall():
636            m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
637            (constrained_columns, referred_schema, referred_table, referred_columns) = m
638            constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
639            if referred_schema:
640                referred_schema = preparer._unquote_identifier(referred_schema)
641            elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
642                # no schema (i.e. its the default schema), and the table we're
643                # reflecting has the default schema explicit, then use that.
644                # i.e. try to use the user's conventions
645                referred_schema = table.schema
646            referred_table = preparer._unquote_identifier(referred_table)
647            referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
648
649            refspec = []
650            if referred_schema is not None:
651                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
652                            autoload_with=connection)
653                for column in referred_columns:
654                    refspec.append(".".join([referred_schema, referred_table, column]))
655            else:
656                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
657                for column in referred_columns:
658                    refspec.append(".".join([referred_table, column]))
659
660            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
661
662        # Indexes
663        IDX_SQL = """
664          SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
665            a.attname
666          FROM pg_index i, pg_class c, pg_attribute a
667          WHERE i.indrelid = :table AND i.indexrelid = c.oid
668            AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
669          ORDER BY c.relname, a.attnum
670        """
671        t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
672        c = connection.execute(t, table=table_oid)
673        indexes = {}
674        sv_idx_name = None
675        for row in c.fetchall():
676            idx_name, unique, expr, prd, col = row
677
678            if expr:
679                if not idx_name == sv_idx_name:
680                    util.warn(
681                      "Skipped unsupported reflection of expression-based index %s"
682                      % idx_name)
683                sv_idx_name = idx_name
684                continue
685            if prd and not idx_name == sv_idx_name:
686                util.warn(
687                   "Predicate of partial index %s ignored during reflection"
688                   % idx_name)
689                sv_idx_name = idx_name
690
691            if not indexes.has_key(idx_name):
692                indexes[idx_name] = [unique, []]
693            indexes[idx_name][1].append(col)
694
695        for name, (unique, columns) in indexes.items():
696            schema.Index(name, *[table.columns[c] for c in columns],
697                         **dict(unique=unique))
698 
699
700
701    def _load_domains(self, connection):
702        ## Load data types for domains:
703        SQL_DOMAINS = """
704            SELECT t.typname as "name",
705                   pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
706                   not t.typnotnull as "nullable",
707                   t.typdefault as "default",
708                   pg_catalog.pg_type_is_visible(t.oid) as "visible",
709                   n.nspname as "schema"
710            FROM pg_catalog.pg_type t
711                 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
712                 LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid
713            WHERE t.typtype = 'd'
714        """
715
716        s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode})
717        c = connection.execute(s)
718
719        domains = {}
720        for domain in c.fetchall():
721            ## strip (30) from character varying(30)
722            attype = re.search('([^\(]+)', domain['attype']).group(1)
723            if domain['visible']:
724                # 'visible' just means whether or not the domain is in a
725                # schema that's on the search path -- or not overriden by
726                # a schema with higher presedence. If it's not visible,
727                # it will be prefixed with the schema-name when it's used.
728                name = domain['name']
729            else:
730                name = "%s.%s" % (domain['schema'], domain['name'])
731
732            domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']}
733
734        return domains
735
736
737class PGCompiler(compiler.DefaultCompiler):
738    operators = compiler.DefaultCompiler.operators.copy()
739    operators.update(
740        {
741            sql_operators.mod : '%%',
742            sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
743            sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
744            sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
745        }
746    )
747
748    functions = compiler.DefaultCompiler.functions.copy()
749    functions.update (
750        {
751            'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x),
752        }
753    )
754
755    def visit_sequence(self, seq):
756        if seq.optional:
757            return None
758        else:
759            return "nextval('%s')" % self.preparer.format_sequence(seq)
760
761    def post_process_text(self, text):
762        if '%%' in text:
763            util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
764        return text.replace('%', '%%')
765
766    def limit_clause(self, select):
767        text = ""
768        if select._limit is not None:
769            text +=  " \n LIMIT " + str(select._limit)
770        if select._offset is not None:
771            if select._limit is None:
772                text += " \n LIMIT ALL"
773            text += " OFFSET " + str(select._offset)
774        return text
775
776    def get_select_precolumns(self, select):
777        if select._distinct:
778            if isinstance(select._distinct, bool):
779                return "DISTINCT "
780            elif isinstance(select._distinct, (list, tuple)):
781                return "DISTINCT ON (" + ', '.join(
782                    [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
783                )+ ") "
784            else:
785                return "DISTINCT ON (" + unicode(select._distinct) + ") "
786        else:
787            return ""
788
789    def for_update_clause(self, select):
790        if select.for_update == 'nowait':
791            return " FOR UPDATE NOWAIT"
792        else:
793            return super(PGCompiler, self).for_update_clause(select)
794
795    def _append_returning(self, text, stmt):
796        returning_cols = stmt.kwargs['postgres_returning']
797        def flatten_columnlist(collist):
798            for c in collist:
799                if isinstance(c, expression.Selectable):
800                    for co in c.columns:
801                        yield co
802                else:
803                    yield c
804        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
805        text += ' RETURNING ' + string.join(columns, ', ')
806        return text
807
808    def visit_update(self, update_stmt):
809        text = super(PGCompiler, self).visit_update(update_stmt)
810        if 'postgres_returning' in update_stmt.kwargs:
811            return self._append_returning(text, update_stmt)
812        else:
813            return text
814
815    def visit_insert(self, insert_stmt):
816        text = super(PGCompiler, self).visit_insert(insert_stmt)
817        if 'postgres_returning' in insert_stmt.kwargs:
818            return self._append_returning(text, insert_stmt)
819        else:
820            return text
821
822    def visit_extract(self, extract, **kwargs):
823        field = self.extract_map.get(extract.field, extract.field)
824        return "EXTRACT(%s FROM %s::timestamp)" % (
825            field, self.process(extract.expr))
826
827
828class PGSchemaGenerator(compiler.SchemaGenerator):
829    def get_column_specification(self, column, **kwargs):
830        colspec = self.preparer.format_column(column)
831        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
832            if isinstance(column.type, PGBigInteger):
833                colspec += " BIGSERIAL"
834            else:
835                colspec += " SERIAL"
836        else:
837            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
838            default = self.get_column_default_string(column)
839            if default is not None:
840                colspec += " DEFAULT " + default
841
842        if not column.nullable:
843            colspec += " NOT NULL"
844        return colspec
845
846    def visit_sequence(self, sequence):
847        if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema)):
848            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
849            self.execute()
850
851    def visit_index(self, index):
852        preparer = self.preparer
853        self.append("CREATE ")
854        if index.unique:
855            self.append("UNIQUE ")
856        self.append("INDEX %s ON %s (%s)" \
857                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
858                       preparer.format_table(index.table),
859                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
860        whereclause = index.kwargs.get('postgres_where', None)
861        if whereclause is not None:
862            compiler = self._compile(whereclause, None)
863            # this might belong to the compiler class
864            inlined_clause = str(compiler) % dict(
865                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
866            self.append(" WHERE " + inlined_clause)
867        self.execute()
868
869class PGSchemaDropper(compiler.SchemaDropper):
870    def visit_sequence(self, sequence):
871        if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema)):
872            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
873            self.execute()
874
875class PGDefaultRunner(base.DefaultRunner):
876    def __init__(self, context):
877        base.DefaultRunner.__init__(self, context)
878        # craete cursor which won't conflict with a server-side cursor
879        self.cursor = context._connection.connection.cursor()
880   
881    def get_column_default(self, column, isinsert=True):
882        if column.primary_key:
883            # pre-execute passive defaults on primary keys
884            if (isinstance(column.server_default, schema.DefaultClause) and
885                column.server_default.arg is not None):
886                return self.execute_string("select %s" % column.server_default.arg)
887            elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
888                sch = column.table.schema
889                # TODO: this has to build into the Sequence object so we can get the quoting
890                # logic from it
891                if sch is not None:
892                    exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
893                else:
894                    exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
895                return self.execute_string(exc.encode(self.dialect.encoding))
896
897        return super(PGDefaultRunner, self).get_column_default(column)
898
899    def visit_sequence(self, seq):
900        if not seq.optional:
901            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
902        else:
903            return None
904
905class PGIdentifierPreparer(compiler.IdentifierPreparer):
906    def _unquote_identifier(self, value):
907        if value[0] == self.initial_quote:
908            value = value[1:-1].replace('""','"')
909        return value
910
911dialect = PGDialect
912dialect.statement_compiler = PGCompiler
913dialect.schemagenerator = PGSchemaGenerator
914dialect.schemadropper = PGSchemaDropper
915dialect.preparer = PGIdentifierPreparer
916dialect.defaultrunner = PGDefaultRunner
917dialect.execution_ctx_cls = PGExecutionContext
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。