root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/orm/util.py @ 3

リビジョン 3, 21.5 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1# mapper/util.py
2# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
3#
4# This module is part of SQLAlchemy and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7import sqlalchemy.exceptions as sa_exc
8from sqlalchemy import sql, util
9from sqlalchemy.sql import expression, util as sql_util, operators
10from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty, AttributeExtension
11from sqlalchemy.orm import attributes, exc
12
13
14all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
15                          "expunge", "save-update", "refresh-expire",
16                          "none"))
17
18_INSTRUMENTOR = ('mapper', 'instrumentor')
19
20class CascadeOptions(object):
21    """Keeps track of the options sent to relation().cascade"""
22
23    def __init__(self, arg=""):
24        if not arg:
25            values = set()
26        else:
27            values = set(c.strip() for c in arg.split(','))
28        self.delete_orphan = "delete-orphan" in values
29        self.delete = "delete" in values or "all" in values
30        self.save_update = "save-update" in values or "all" in values
31        self.merge = "merge" in values or "all" in values
32        self.expunge = "expunge" in values or "all" in values
33        self.refresh_expire = "refresh-expire" in values or "all" in values
34
35        if self.delete_orphan and not self.delete:
36            util.warn("The 'delete-orphan' cascade option requires "
37                        "'delete'.  This will raise an error in 0.6.")
38
39        for x in values:
40            if x not in all_cascades:
41                raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
42
43    def __contains__(self, item):
44        return getattr(self, item.replace("-", "_"), False)
45
46    def __repr__(self):
47        return "CascadeOptions(%s)" % repr(",".join(
48            [x for x in ['delete', 'save_update', 'merge', 'expunge',
49                         'delete_orphan', 'refresh-expire']
50             if getattr(self, x, False) is True]))
51
52
53class Validator(AttributeExtension):
54    """Runs a validation method on an attribute value to be set or appended.
55
56    The Validator class is used by the :func:`~sqlalchemy.orm.validates`
57    decorator, and direct access is usually not needed.
58
59    """
60
61    def __init__(self, key, validator):
62        """Construct a new Validator.
63
64            key - name of the attribute to be validated;
65            will be passed as the second argument to
66            the validation method (the first is the object instance itself).
67
68            validator - an function or instance method which accepts
69            three arguments; an instance (usually just 'self' for a method),
70            the key name of the attribute, and the value.  The function should
71            return the same value given, unless it wishes to modify it.
72
73        """
74        self.key = key
75        self.validator = validator
76
77    def append(self, state, value, initiator):
78        return self.validator(state.obj(), self.key, value)
79
80    def set(self, state, value, oldvalue, initiator):
81        return self.validator(state.obj(), self.key, value)
82
83def polymorphic_union(table_map, typecolname, aliasname='p_union'):
84    """Create a ``UNION`` statement used by a polymorphic mapper.
85
86    See  :ref:`concrete_inheritance` for an example of how
87    this is used.
88    """
89
90    colnames = set()
91    colnamemaps = {}
92    types = {}
93    for key in table_map.keys():
94        table = table_map[key]
95
96        # mysql doesnt like selecting from a select; make it an alias of the select
97        if isinstance(table, sql.Select):
98            table = table.alias()
99            table_map[key] = table
100
101        m = {}
102        for c in table.c:
103            colnames.add(c.key)
104            m[c.key] = c
105            types[c.key] = c.type
106        colnamemaps[table] = m
107
108    def col(name, table):
109        try:
110            return colnamemaps[table][name]
111        except KeyError:
112            return sql.cast(sql.null(), types[name]).label(name)
113
114    result = []
115    for type, table in table_map.iteritems():
116        if typecolname is not None:
117            result.append(sql.select([col(name, table) for name in colnames] +
118                                     [sql.literal_column("'%s'" % type).label(typecolname)],
119                                     from_obj=[table]))
120        else:
121            result.append(sql.select([col(name, table) for name in colnames],
122                                     from_obj=[table]))
123    return sql.union_all(*result).alias(aliasname)
124
125def identity_key(*args, **kwargs):
126    """Get an identity key.
127
128    Valid call signatures:
129
130    * ``identity_key(class, ident)``
131
132      class
133          mapped class (must be a positional argument)
134
135      ident
136          primary key, if the key is composite this is a tuple
137
138
139    * ``identity_key(instance=instance)``
140
141      instance
142          object instance (must be given as a keyword arg)
143
144    * ``identity_key(class, row=row)``
145
146      class
147          mapped class (must be a positional argument)
148
149      row
150          result proxy row (must be given as a keyword arg)
151
152    """
153    if args:
154        if len(args) == 1:
155            class_ = args[0]
156            try:
157                row = kwargs.pop("row")
158            except KeyError:
159                ident = kwargs.pop("ident")
160        elif len(args) == 2:
161            class_, ident = args
162        elif len(args) == 3:
163            class_, ident = args
164        else:
165            raise sa_exc.ArgumentError("expected up to three "
166                "positional arguments, got %s" % len(args))
167        if kwargs:
168            raise sa_exc.ArgumentError("unknown keyword arguments: %s"
169                % ", ".join(kwargs.keys()))
170        mapper = class_mapper(class_)
171        if "ident" in locals():
172            return mapper.identity_key_from_primary_key(ident)
173        return mapper.identity_key_from_row(row)
174    instance = kwargs.pop("instance")
175    if kwargs:
176        raise sa_exc.ArgumentError("unknown keyword arguments: %s"
177            % ", ".join(kwargs.keys()))
178    mapper = object_mapper(instance)
179    return mapper.identity_key_from_instance(instance)
180
181class ExtensionCarrier(dict):
182    """Fronts an ordered collection of MapperExtension objects.
183
184    Bundles multiple MapperExtensions into a unified callable unit,
185    encapsulating ordering, looping and EXT_CONTINUE logic.  The
186    ExtensionCarrier implements the MapperExtension interface, e.g.::
187
188      carrier.after_insert(...args...)
189
190    The dictionary interface provides containment for implemented
191    method names mapped to a callable which executes that method
192    for participating extensions.
193
194    """
195
196    interface = set(method for method in dir(MapperExtension)
197                    if not method.startswith('_'))
198
199    def __init__(self, extensions=None):
200        self._extensions = []
201        for ext in extensions or ():
202            self.append(ext)
203
204    def copy(self):
205        return ExtensionCarrier(self._extensions)
206
207    def push(self, extension):
208        """Insert a MapperExtension at the beginning of the collection."""
209        self._register(extension)
210        self._extensions.insert(0, extension)
211
212    def append(self, extension):
213        """Append a MapperExtension at the end of the collection."""
214        self._register(extension)
215        self._extensions.append(extension)
216
217    def __iter__(self):
218        """Iterate over MapperExtensions in the collection."""
219        return iter(self._extensions)
220
221    def _register(self, extension):
222        """Register callable fronts for overridden interface methods."""
223
224        for method in self.interface.difference(self):
225            impl = getattr(extension, method, None)
226            if impl and impl is not getattr(MapperExtension, method):
227                self[method] = self._create_do(method)
228
229    def _create_do(self, method):
230        """Return a closure that loops over impls of the named method."""
231
232        def _do(*args, **kwargs):
233            for ext in self._extensions:
234                ret = getattr(ext, method)(*args, **kwargs)
235                if ret is not EXT_CONTINUE:
236                    return ret
237            else:
238                return EXT_CONTINUE
239        _do.__name__ = method
240        return _do
241
242    @staticmethod
243    def _pass(*args, **kwargs):
244        return EXT_CONTINUE
245
246    def __getattr__(self, key):
247        """Delegate MapperExtension methods to bundled fronts."""
248
249        if key not in self.interface:
250            raise AttributeError(key)
251        return self.get(key, self._pass)
252
253class ORMAdapter(sql_util.ColumnAdapter):
254    """Extends ColumnAdapter to accept ORM entities.
255
256    The selectable is extracted from the given entity,
257    and the AliasedClass if any is referenced.
258
259    """
260    def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
261        self.mapper, selectable, is_aliased_class = _entity_info(entity)
262        if is_aliased_class:
263            self.aliased_class = entity
264        else:
265            self.aliased_class = None
266        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
267
268    def replace(self, elem):
269        entity = elem._annotations.get('parentmapper', None)
270        if not entity or entity.isa(self.mapper):
271            return sql_util.ColumnAdapter.replace(self, elem)
272        else:
273            return None
274
275class AliasedClass(object):
276    """Represents an 'alias'ed form of a mapped class for usage with Query.
277
278    The ORM equivalent of a :class:`~sqlalchemy.sql.expression.Alias`
279    object, this object mimics the mapped class using a
280    __getattr__ scheme and maintains a reference to a
281    real Alias object.   It indicates to Query that the
282    selectable produced for this class should be aliased,
283    and also adapts PropComparators produced by the class'
284    InstrumentedAttributes so that they adapt the
285    "local" side of SQL expressions against the alias.
286
287    """
288    def __init__(self, cls, alias=None, name=None):
289        self.__mapper = _class_to_mapper(cls)
290        self.__target = self.__mapper.class_
291        alias = alias or self.__mapper._with_polymorphic_selectable.alias()
292        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
293        self.__alias = alias
294        # used to assign a name to the RowTuple object
295        # returned by Query.
296        self._sa_label_name = name
297        self.__name__ = 'AliasedClass_' + str(self.__target)
298
299    def __getstate__(self):
300        return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
301
302    def __setstate__(self, state):
303        self.__mapper = state['mapper']
304        self.__target = self.__mapper.class_
305        alias = state['alias']
306        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
307        self.__alias = alias
308        name = state['name']
309        self._sa_label_name = name
310        self.__name__ = 'AliasedClass_' + str(self.__target)
311
312    def __adapt_element(self, elem):
313        return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
314
315    def __adapt_prop(self, prop):
316        existing = getattr(self.__target, prop.key)
317        comparator = existing.comparator.adapted(self.__adapt_element)
318
319        queryattr = attributes.QueryableAttribute(prop.key,
320            impl=existing.impl, parententity=self, comparator=comparator)
321        setattr(self, prop.key, queryattr)
322        return queryattr
323
324    def __getattr__(self, key):
325        prop = self.__mapper._get_property(key, raiseerr=False)
326        if prop:
327            return self.__adapt_prop(prop)
328
329        for base in self.__target.__mro__:
330            try:
331                attr = object.__getattribute__(base, key)
332            except AttributeError:
333                continue
334            else:
335                break
336        else:
337            raise AttributeError(key)
338
339        if hasattr(attr, 'func_code'):
340            is_method = getattr(self.__target, key, None)
341            if is_method and is_method.im_self is not None:
342                return util.types.MethodType(attr.im_func, self, self)
343            else:
344                return None
345        elif hasattr(attr, '__get__'):
346            return attr.__get__(None, self)
347        else:
348            return attr
349
350    def __repr__(self):
351        return '<AliasedClass at 0x%x; %s>' % (
352            id(self), self.__target.__name__)
353
354def _orm_annotate(element, exclude=None):
355    """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
356
357    Elements within the exclude collection will be cloned but not annotated.
358
359    """
360    return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
361
362_orm_deannotate = sql_util._deep_deannotate
363
364class _ORMJoin(expression.Join):
365    """Extend Join to support ORM constructs as input."""
366
367    __visit_name__ = expression.Join.__visit_name__
368
369    def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
370        adapt_from = None
371
372        if hasattr(left, '_orm_mappers'):
373            left_mapper = left._orm_mappers[1]
374            if join_to_left:
375                adapt_from = left.right
376        else:
377            left_mapper, left, left_is_aliased = _entity_info(left)
378            if join_to_left and (left_is_aliased or not left_mapper):
379                adapt_from = left
380
381        right_mapper, right, right_is_aliased = _entity_info(right)
382        if right_is_aliased:
383            adapt_to = right
384        else:
385            adapt_to = None
386
387        if left_mapper or right_mapper:
388            self._orm_mappers = (left_mapper, right_mapper)
389
390            if isinstance(onclause, basestring):
391                prop = left_mapper.get_property(onclause)
392            elif isinstance(onclause, attributes.QueryableAttribute):
393                if not adapt_from:
394                    adapt_from = onclause.__clause_element__()
395                prop = onclause.property
396            elif isinstance(onclause, MapperProperty):
397                prop = onclause
398            else:
399                prop = None
400
401            if prop:
402                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
403                                source_selectable=adapt_from,
404                                dest_selectable=adapt_to,
405                                source_polymorphic=True,
406                                dest_polymorphic=True,
407                                of_type=right_mapper)
408
409                if sj:
410                    left = sql.join(left, secondary, pj, isouter)
411                    onclause = sj
412                else:
413                    onclause = pj
414                self._target_adapter = target_adapter
415
416        expression.Join.__init__(self, left, right, onclause, isouter)
417
418    def join(self, right, onclause=None, isouter=False, join_to_left=True):
419        return _ORMJoin(self, right, onclause, isouter, join_to_left)
420
421    def outerjoin(self, right, onclause=None, join_to_left=True):
422        return _ORMJoin(self, right, onclause, True, join_to_left)
423
424def join(left, right, onclause=None, isouter=False, join_to_left=True):
425    """Produce an inner join between left and right clauses.
426
427    In addition to the interface provided by
428    :func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped
429    classes or AliasedClass instances. The onclause may be a
430    string name of a relation(), or a class-bound descriptor
431    representing a relation.
432
433    join_to_left indicates to attempt aliasing the ON clause,
434    in whatever form it is passed, to the selectable
435    passed as the left side.  If False, the onclause
436    is used as is.
437
438    """
439    return _ORMJoin(left, right, onclause, isouter, join_to_left)
440
441def outerjoin(left, right, onclause=None, join_to_left=True):
442    """Produce a left outer join between left and right clauses.
443
444    In addition to the interface provided by
445    :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
446    classes or AliasedClass instances. The onclause may be a
447    string name of a relation(), or a class-bound descriptor
448    representing a relation.
449
450    """
451    return _ORMJoin(left, right, onclause, True, join_to_left)
452
453def with_parent(instance, prop):
454    """Return criterion which selects instances with a given parent.
455
456    instance
457      a parent instance, which should be persistent or detached.
458
459    property
460      a class-attached descriptor, MapperProperty or string property name
461      attached to the parent instance.
462
463    \**kwargs
464      all extra keyword arguments are propagated to the constructor of
465      Query.
466
467    """
468    if isinstance(prop, basestring):
469        mapper = object_mapper(instance)
470        prop = mapper.get_property(prop, resolve_synonyms=True)
471    elif isinstance(prop, attributes.QueryableAttribute):
472        prop = prop.property
473
474    return prop.compare(operators.eq, instance, value_is_parent=True)
475
476
477def _entity_info(entity, compile=True):
478    """Return mapping information given a class, mapper, or AliasedClass.
479
480    Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
481    is an aliased() construct.
482
483    If the given entity is not a mapper, mapped class, or aliased construct,
484    returns None, the entity, False.  This is typically used to allow
485    unmapped selectables through.
486
487    """
488    if isinstance(entity, AliasedClass):
489        return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
490    elif _is_mapped_class(entity):
491        if isinstance(entity, type):
492            mapper = class_mapper(entity, compile)
493        else:
494            if compile:
495                mapper = entity.compile()
496            else:
497                mapper = entity
498        return mapper, mapper._with_polymorphic_selectable, False
499    else:
500        return None, entity, False
501
502def _entity_descriptor(entity, key):
503    """Return attribute/property information given an entity and string name.
504
505    Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
506
507    """
508    if isinstance(entity, AliasedClass):
509        desc = getattr(entity, key)
510        return desc, desc.property
511    elif isinstance(entity, type):
512        desc = attributes.manager_of_class(entity)[key]
513        return desc, desc.property
514    else:
515        desc = entity.class_manager[key]
516        return desc, desc.property
517
518def _orm_columns(entity):
519    mapper, selectable, is_aliased_class = _entity_info(entity)
520    if isinstance(selectable, expression.Selectable):
521        return [c for c in selectable.c]
522    else:
523        return [selectable]
524
525def _orm_selectable(entity):
526    mapper, selectable, is_aliased_class = _entity_info(entity)
527    return selectable
528
529def _is_aliased_class(entity):
530    return isinstance(entity, AliasedClass)
531
532def _state_mapper(state):
533    return state.manager.mapper
534
535def object_mapper(instance):
536    """Given an object, return the primary Mapper associated with the object instance.
537
538    Raises UnmappedInstanceError if no mapping is configured.
539
540    """
541    try:
542        state = attributes.instance_state(instance)
543        if not state.manager.mapper:
544            raise exc.UnmappedInstanceError(instance)
545        return state.manager.mapper
546    except exc.NO_STATE:
547        raise exc.UnmappedInstanceError(instance)
548
549def class_mapper(class_, compile=True):
550    """Given a class, return the primary Mapper associated with the key.
551
552    Raises UnmappedClassError if no mapping is configured.
553
554    """
555    try:
556        class_manager = attributes.manager_of_class(class_)
557        mapper = class_manager.mapper
558
559        # HACK until [ticket:1142] is complete
560        if mapper is None:
561            raise AttributeError
562
563    except exc.NO_STATE:
564        raise exc.UnmappedClassError(class_)
565
566    if compile:
567        mapper = mapper.compile()
568    return mapper
569
570def _class_to_mapper(class_or_mapper, compile=True):
571    if _is_aliased_class(class_or_mapper):
572        return class_or_mapper._AliasedClass__mapper
573    elif isinstance(class_or_mapper, type):
574        return class_mapper(class_or_mapper, compile=compile)
575    elif hasattr(class_or_mapper, 'compile'):
576        if compile:
577            return class_or_mapper.compile()
578        else:
579            return class_or_mapper
580    else:
581        raise exc.UnmappedClassError(class_or_mapper)
582
583def has_identity(object):
584    state = attributes.instance_state(object)
585    return _state_has_identity(state)
586
587def _state_has_identity(state):
588    return bool(state.key)
589
590def _is_mapped_class(cls):
591    from sqlalchemy.orm import mapperlib as mapper
592    if isinstance(cls, (AliasedClass, mapper.Mapper)):
593        return True
594    if isinstance(cls, expression.ClauseElement):
595        return False
596    if isinstance(cls, type):
597        manager = attributes.manager_of_class(cls)
598        return manager and _INSTRUMENTOR in manager.info
599    return False
600
601def instance_str(instance):
602    """Return a string describing an instance."""
603
604    return state_str(attributes.instance_state(instance))
605
606def state_str(state):
607    """Return a string describing an instance via its InstanceState."""
608
609    if state is None:
610        return "None"
611    else:
612        return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
613
614def attribute_str(instance, attribute):
615    return instance_str(instance) + "." + attribute
616
617def state_attribute_str(state, attribute):
618    return state_str(state) + "." + attribute
619
620def identity_equal(a, b):
621    if a is b:
622        return True
623    if a is None or b is None:
624        return False
625    try:
626        state_a = attributes.instance_state(a)
627        state_b = attributes.instance_state(b)
628    except exc.NO_STATE:
629        return False
630    if state_a.key is None or state_b.key is None:
631        return False
632    return state_a.key == state_b.key
633
634
635# TODO: Avoid circular import.
636attributes.identity_equal = identity_equal
637attributes._is_aliased_class = _is_aliased_class
638attributes._entity_info = _entity_info
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。