root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/sql/util.py @ 3

リビジョン 3, 19.2 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1from sqlalchemy import exc, schema, topological, util, sql
2from sqlalchemy.sql import expression, operators, visitors
3from itertools import chain
4
5"""Utility functions that build upon SQL and Schema constructs."""
6
7def sort_tables(tables):
8    """sort a collection of Table objects in order of their foreign-key dependency."""
9   
10    tables = list(tables)
11    tuples = []
12    def visit_foreign_key(fkey):
13        if fkey.use_alter:
14            return
15        parent_table = fkey.column.table
16        if parent_table in tables:
17            child_table = fkey.parent.table
18            tuples.append( ( parent_table, child_table ) )
19
20    for table in tables:
21        visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})   
22    return topological.sort(tuples, tables)
23
24def find_join_source(clauses, join_to):
25    """Given a list of FROM clauses and a selectable,
26    return the first index and element from the list of
27    clauses which can be joined against the selectable.  returns
28    None, None if no match is found.
29   
30    e.g.::
31   
32        clause1 = table1.join(table2)
33        clause2 = table4.join(table5)
34       
35        join_to = table2.join(table3)
36       
37        find_join_source([clause1, clause2], join_to) == clause1
38   
39    """
40   
41    selectables = list(expression._from_objects(join_to))
42    for i, f in enumerate(clauses):
43        for s in selectables:
44            if f.is_derived_from(s):
45                return i, f
46    else:
47        return None, None
48
49   
50def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
51    """locate Table objects within the given expression."""
52   
53    tables = []
54    _visitors = {}
55   
56    if include_selects:
57        _visitors['select'] = _visitors['compound_select'] = tables.append
58   
59    if include_joins:
60        _visitors['join'] = tables.append
61       
62    if include_aliases:
63        _visitors['alias']  = tables.append
64
65    if check_columns:
66        def visit_column(column):
67            tables.append(column.table)
68        _visitors['column'] = visit_column
69
70    _visitors['table'] = tables.append
71
72    visitors.traverse(clause, {'column_collections':False}, _visitors)
73    return tables
74
75def find_columns(clause):
76    """locate Column objects within the given expression."""
77   
78    cols = util.column_set()
79    def visit_column(col):
80        cols.add(col)
81    visitors.traverse(clause, {}, {'column':visit_column})
82    return cols
83
84def join_condition(a, b, ignore_nonexistent_tables=False):
85    """create a join condition between two tables.
86   
87    ignore_nonexistent_tables=True allows a join condition to be
88    determined between two tables which may contain references to
89    other not-yet-defined tables.  In general the NoSuchTableError
90    raised is only required if the user is trying to join selectables
91    across multiple MetaData objects (which is an extremely rare use
92    case).
93   
94    """
95    crit = []
96    constraints = set()
97    for fk in b.foreign_keys:
98        try:
99            col = fk.get_referent(a)
100        except exc.NoReferencedTableError:
101            if ignore_nonexistent_tables:
102                continue
103            else:
104                raise
105               
106        if col:
107            crit.append(col == fk.parent)
108            constraints.add(fk.constraint)
109    if a is not b:
110        for fk in a.foreign_keys:
111            try:
112                col = fk.get_referent(b)
113            except exc.NoReferencedTableError:
114                if ignore_nonexistent_tables:
115                    continue
116                else:
117                    raise
118
119            if col:
120                crit.append(col == fk.parent)
121                constraints.add(fk.constraint)
122
123    if len(crit) == 0:
124        raise exc.ArgumentError(
125            "Can't find any foreign key relationships "
126            "between '%s' and '%s'" % (a.description, b.description))
127    elif len(constraints) > 1:
128        raise exc.ArgumentError(
129            "Can't determine join between '%s' and '%s'; "
130            "tables have more than one foreign key "
131            "constraint relationship between them. "
132            "Please specify the 'onclause' of this "
133            "join explicitly." % (a.description, b.description))
134    elif len(crit) == 1:
135        return (crit[0])
136    else:
137        return sql.and_(*crit)
138
139
140class Annotated(object):
141    """clones a ClauseElement and applies an 'annotations' dictionary.
142   
143    Unlike regular clones, this clone also mimics __hash__() and
144    __cmp__() of the original element so that it takes its place
145    in hashed collections.
146   
147    A reference to the original element is maintained, for the important
148    reason of keeping its hash value current.  When GC'ed, the
149    hash value may be reused, causing conflicts.
150
151    """
152   
153    def __new__(cls, *args):
154        if not args:
155            # clone constructor
156            return object.__new__(cls)
157        else:
158            element, values = args
159            # pull appropriate subclass from registry of annotated
160            # classes
161            try:
162                cls = annotated_classes[element.__class__]
163            except KeyError:
164                cls = annotated_classes[element.__class__] = type.__new__(type,
165                        "Annotated%s" % element.__class__.__name__,
166                        (Annotated, element.__class__), {})
167            return object.__new__(cls)
168
169    def __init__(self, element, values):
170        # force FromClause to generate their internal
171        # collections into __dict__
172        if isinstance(element, expression.FromClause):
173            element.c
174       
175        self.__dict__ = element.__dict__.copy()
176        self.__element = element
177        self._annotations = values
178
179    def _annotate(self, values):
180        _values = self._annotations.copy()
181        _values.update(values)
182        clone = self.__class__.__new__(self.__class__)
183        clone.__dict__ = self.__dict__.copy()
184        clone._annotations = _values
185        return clone
186   
187    def _deannotate(self):
188        return self.__element
189       
190    def _clone(self):
191        clone = self.__element._clone()
192        if clone is self.__element:
193            # detect immutable, don't change anything
194            return self
195        else:
196            # update the clone with any changes that have occured
197            # to this object's __dict__.
198            clone.__dict__.update(self.__dict__)
199            return Annotated(clone, self._annotations)
200   
201    def __hash__(self):
202        return hash(self.__element)
203
204    def __cmp__(self, other):
205        return cmp(hash(self.__element), hash(other))
206
207# hard-generate Annotated subclasses.  this technique
208# is used instead of on-the-fly types (i.e. type.__new__())
209# so that the resulting objects are pickleable.
210annotated_classes = {}
211
212from sqlalchemy.sql import expression
213for cls in expression.__dict__.values() + [schema.Column, schema.Table]:
214    if isinstance(cls, type) and issubclass(cls, expression.ClauseElement):
215        exec "class Annotated%s(Annotated, cls):\n" \
216             "    __visit_name__ = cls.__visit_name__\n"\
217             "    pass" % (cls.__name__, ) in locals()
218        exec "annotated_classes[cls] = Annotated%s" % (cls.__name__)
219
220def _deep_annotate(element, annotations, exclude=None):
221    """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary.
222
223    Elements within the exclude collection will be cloned but not annotated.
224
225    """
226    def clone(elem):
227        # check if element is present in the exclude list.
228        # take into account proxying relationships.
229        if exclude and elem.proxy_set.intersection(exclude):
230            elem = elem._clone()
231        elif annotations != elem._annotations:
232            elem = elem._annotate(annotations.copy())
233        elem._copy_internals(clone=clone)
234        return elem
235
236    if element is not None:
237        element = clone(element)
238    return element
239
240def _deep_deannotate(element):
241    """Deep copy the given element, removing all annotations."""
242
243    def clone(elem):
244        elem = elem._deannotate()
245        elem._copy_internals(clone=clone)
246        return elem
247
248    if element is not None:
249        element = clone(element)
250    return element
251
252
253def splice_joins(left, right, stop_on=None):
254    if left is None:
255        return right
256       
257    stack = [(right, None)]
258
259    adapter = ClauseAdapter(left)
260    ret = None
261    while stack:
262        (right, prevright) = stack.pop()
263        if isinstance(right, expression.Join) and right is not stop_on:
264            right = right._clone()
265            right._reset_exported()
266            right.onclause = adapter.traverse(right.onclause)
267            stack.append((right.left, right))
268        else:
269            right = adapter.traverse(right)
270        if prevright:
271            prevright.left = right
272        if not ret:
273            ret = right
274
275    return ret
276   
277def reduce_columns(columns, *clauses, **kw):
278    """given a list of columns, return a 'reduced' set based on natural equivalents.
279
280    the set is reduced to the smallest list of columns which have no natural
281    equivalent present in the list.  A "natural equivalent" means that two columns
282    will ultimately represent the same value because they are related by a foreign key.
283
284    \*clauses is an optional list of join clauses which will be traversed
285    to further identify columns that are "equivalent".
286
287    \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
288    whose tables are not yet configured.
289   
290    This function is primarily used to determine the most minimal "primary key"
291    from a selectable, by reducing the set of primary key columns present
292    in the the selectable to just those that are not repeated.
293
294    """
295    ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
296   
297    columns = util.ordered_column_set(columns)
298
299    omit = util.column_set()
300    for col in columns:
301        for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
302            for c in columns:
303                if c is col:
304                    continue
305                try:
306                    fk_col = fk.column
307                except exc.NoReferencedTableError:
308                    if ignore_nonexistent_tables:
309                        continue
310                    else:
311                        raise
312                if fk_col.shares_lineage(c):
313                    omit.add(col)
314                    break
315
316    if clauses:
317        def visit_binary(binary):
318            if binary.operator == operators.eq:
319                cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)]))
320                if binary.left in cols and binary.right in cols:
321                    for c in columns:
322                        if c.shares_lineage(binary.right):
323                            omit.add(c)
324                            break
325        for clause in clauses:
326            visitors.traverse(clause, {}, {'binary':visit_binary})
327
328    return expression.ColumnSet(columns.difference(omit))
329
330def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False):
331    """traverse an expression and locate binary criterion pairs."""
332   
333    if consider_as_foreign_keys and consider_as_referenced_keys:
334        raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
335       
336    def visit_binary(binary):
337        if not any_operator and binary.operator is not operators.eq:
338            return
339        if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
340            return
341
342        if consider_as_foreign_keys:
343            if binary.left in consider_as_foreign_keys and (binary.right is binary.left or binary.right not in consider_as_foreign_keys):
344                pairs.append((binary.right, binary.left))
345            elif binary.right in consider_as_foreign_keys and (binary.left is binary.right or binary.left not in consider_as_foreign_keys):
346                pairs.append((binary.left, binary.right))
347        elif consider_as_referenced_keys:
348            if binary.left in consider_as_referenced_keys and (binary.right is binary.left or binary.right not in consider_as_referenced_keys):
349                pairs.append((binary.left, binary.right))
350            elif binary.right in consider_as_referenced_keys and (binary.left is binary.right or binary.left not in consider_as_referenced_keys):
351                pairs.append((binary.right, binary.left))
352        else:
353            if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
354                if binary.left.references(binary.right):
355                    pairs.append((binary.right, binary.left))
356                elif binary.right.references(binary.left):
357                    pairs.append((binary.left, binary.right))
358    pairs = []
359    visitors.traverse(expression, {}, {'binary':visit_binary})
360    return pairs
361
362def folded_equivalents(join, equivs=None):
363    """Return a list of uniquely named columns.
364   
365    The column list of the given Join will be narrowed
366    down to a list of all equivalently-named,
367    equated columns folded into one column, where 'equated' means they are
368    equated to each other in the ON clause of this join.
369
370    This function is used by Join.select(fold_equivalents=True).
371   
372    Deprecated.   This function is used for a certain kind of
373    "polymorphic_union" which is designed to achieve joined
374    table inheritance where the base table has no "discriminator"
375    column; [ticket:1131] will provide a better way to
376    achieve this.
377
378    """
379    if equivs is None:
380        equivs = set()
381    def visit_binary(binary):
382        if binary.operator == operators.eq and binary.left.name == binary.right.name:
383            equivs.add(binary.right)
384            equivs.add(binary.left)
385    visitors.traverse(join.onclause, {}, {'binary':visit_binary})
386    collist = []
387    if isinstance(join.left, expression.Join):
388        left = folded_equivalents(join.left, equivs)
389    else:
390        left = list(join.left.columns)
391    if isinstance(join.right, expression.Join):
392        right = folded_equivalents(join.right, equivs)
393    else:
394        right = list(join.right.columns)
395    used = set()
396    for c in left + right:
397        if c in equivs:
398            if c.name not in used:
399                collist.append(c)
400                used.add(c.name)
401        else:
402            collist.append(c)
403    return collist
404
405class AliasedRow(object):
406    """Wrap a RowProxy with a translation map.
407   
408    This object allows a set of keys to be translated
409    to those present in a RowProxy.
410   
411    """
412    def __init__(self, row, map):
413        # AliasedRow objects don't nest, so un-nest
414        # if another AliasedRow was passed
415        if isinstance(row, AliasedRow):
416            self.row = row.row
417        else:
418            self.row = row
419        self.map = map
420       
421    def __contains__(self, key):
422        return self.map[key] in self.row
423
424    def has_key(self, key):
425        return key in self
426
427    def __getitem__(self, key):
428        return self.row[self.map[key]]
429
430    def keys(self):
431        return self.row.keys()
432
433
434class ClauseAdapter(visitors.ReplacingCloningVisitor):
435    """Clones and modifies clauses based on column correspondence.
436   
437    E.g.::
438
439      table1 = Table('sometable', metadata,
440          Column('col1', Integer),
441          Column('col2', Integer)
442          )
443      table2 = Table('someothertable', metadata,
444          Column('col1', Integer),
445          Column('col2', Integer)
446          )
447
448      condition = table1.c.col1 == table2.c.col1
449
450    make an alias of table1::
451
452      s = table1.alias('foo')
453
454    calling ``ClauseAdapter(s).traverse(condition)`` converts
455    condition to read::
456
457      s.c.col1 == table2.c.col1
458
459    """
460    def __init__(self, selectable, equivalents=None, include=None, exclude=None):
461        self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
462        self.selectable = selectable
463        self.include = include
464        self.exclude = exclude
465        self.equivalents = util.column_dict(equivalents or {})
466       
467    def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
468        newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
469
470        if not newcol and col in self.equivalents and col not in _seen:
471            for equiv in self.equivalents[col]:
472                newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col]))
473                if newcol:
474                    return newcol
475        return newcol
476
477    def replace(self, col):
478        if isinstance(col, expression.FromClause):
479            if self.selectable.is_derived_from(col):
480                return self.selectable
481
482        if not isinstance(col, expression.ColumnElement):
483            return None
484
485        if self.include and col not in self.include:
486            return None
487        elif self.exclude and col in self.exclude:
488            return None
489
490        return self._corresponding_column(col, True)
491
492class ColumnAdapter(ClauseAdapter):
493    """Extends ClauseAdapter with extra utility functions.
494   
495    Provides the ability to "wrap" this ClauseAdapter
496    around another, a columns dictionary which returns
497    cached, adapted elements given an original, and an
498    adapted_row() factory.
499   
500    """
501    def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None, adapt_required=False):
502        ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
503        if chain_to:
504            self.chain(chain_to)
505        self.columns = util.populate_column_dict(self._locate_col)
506        self.adapt_required = adapt_required
507
508    def wrap(self, adapter):
509        ac = self.__class__.__new__(self.__class__)
510        ac.__dict__ = self.__dict__.copy()
511        ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
512        ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
513        ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
514        ac.columns = util.populate_column_dict(ac._locate_col)
515        return ac
516
517    adapt_clause = ClauseAdapter.traverse
518    adapt_list = ClauseAdapter.copy_and_process
519
520    def _wrap(self, local, wrapped):
521        def locate(col):
522            col = local(col)
523            return wrapped(col)
524        return locate
525
526    def _locate_col(self, col):
527        c = self._corresponding_column(col, False)
528        if not c:
529            c = self.adapt_clause(col)
530           
531            # anonymize labels in case they have a hardcoded name
532            if isinstance(c, expression._Label):
533                c = c.label(None)
534               
535        # adapt_required indicates that if we got the same column
536        # back which we put in (i.e. it passed through),
537        # it's not correct.  this is used by eagerloading which
538        # knows that all columns and expressions need to be adapted
539        # to a result row, and a "passthrough" is definitely targeting
540        # the wrong column.
541        if self.adapt_required and c is col:
542            return None
543           
544        return c   
545
546    def adapted_row(self, row):
547        return AliasedRow(row, self.columns)
548   
549    def __getstate__(self):
550        d = self.__dict__.copy()
551        del d['columns']
552        return d
553       
554    def __setstate__(self, state):
555        self.__dict__.update(state)
556        self.columns = util.PopulateDict(self._locate_col)
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。