1 | # mapper/util.py |
---|
2 | # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com |
---|
3 | # |
---|
4 | # This module is part of SQLAlchemy and is released under |
---|
5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php |
---|
6 | |
---|
7 | import sqlalchemy.exceptions as sa_exc |
---|
8 | from sqlalchemy import sql, util |
---|
9 | from sqlalchemy.sql import expression, util as sql_util, operators |
---|
10 | from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty, AttributeExtension |
---|
11 | from sqlalchemy.orm import attributes, exc |
---|
12 | |
---|
13 | |
---|
14 | all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", |
---|
15 | "expunge", "save-update", "refresh-expire", |
---|
16 | "none")) |
---|
17 | |
---|
18 | _INSTRUMENTOR = ('mapper', 'instrumentor') |
---|
19 | |
---|
20 | class CascadeOptions(object): |
---|
21 | """Keeps track of the options sent to relation().cascade""" |
---|
22 | |
---|
23 | def __init__(self, arg=""): |
---|
24 | if not arg: |
---|
25 | values = set() |
---|
26 | else: |
---|
27 | values = set(c.strip() for c in arg.split(',')) |
---|
28 | self.delete_orphan = "delete-orphan" in values |
---|
29 | self.delete = "delete" in values or "all" in values |
---|
30 | self.save_update = "save-update" in values or "all" in values |
---|
31 | self.merge = "merge" in values or "all" in values |
---|
32 | self.expunge = "expunge" in values or "all" in values |
---|
33 | self.refresh_expire = "refresh-expire" in values or "all" in values |
---|
34 | |
---|
35 | if self.delete_orphan and not self.delete: |
---|
36 | util.warn("The 'delete-orphan' cascade option requires " |
---|
37 | "'delete'. This will raise an error in 0.6.") |
---|
38 | |
---|
39 | for x in values: |
---|
40 | if x not in all_cascades: |
---|
41 | raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) |
---|
42 | |
---|
43 | def __contains__(self, item): |
---|
44 | return getattr(self, item.replace("-", "_"), False) |
---|
45 | |
---|
46 | def __repr__(self): |
---|
47 | return "CascadeOptions(%s)" % repr(",".join( |
---|
48 | [x for x in ['delete', 'save_update', 'merge', 'expunge', |
---|
49 | 'delete_orphan', 'refresh-expire'] |
---|
50 | if getattr(self, x, False) is True])) |
---|
51 | |
---|
52 | |
---|
53 | class Validator(AttributeExtension): |
---|
54 | """Runs a validation method on an attribute value to be set or appended. |
---|
55 | |
---|
56 | The Validator class is used by the :func:`~sqlalchemy.orm.validates` |
---|
57 | decorator, and direct access is usually not needed. |
---|
58 | |
---|
59 | """ |
---|
60 | |
---|
61 | def __init__(self, key, validator): |
---|
62 | """Construct a new Validator. |
---|
63 | |
---|
64 | key - name of the attribute to be validated; |
---|
65 | will be passed as the second argument to |
---|
66 | the validation method (the first is the object instance itself). |
---|
67 | |
---|
68 | validator - an function or instance method which accepts |
---|
69 | three arguments; an instance (usually just 'self' for a method), |
---|
70 | the key name of the attribute, and the value. The function should |
---|
71 | return the same value given, unless it wishes to modify it. |
---|
72 | |
---|
73 | """ |
---|
74 | self.key = key |
---|
75 | self.validator = validator |
---|
76 | |
---|
77 | def append(self, state, value, initiator): |
---|
78 | return self.validator(state.obj(), self.key, value) |
---|
79 | |
---|
80 | def set(self, state, value, oldvalue, initiator): |
---|
81 | return self.validator(state.obj(), self.key, value) |
---|
82 | |
---|
83 | def polymorphic_union(table_map, typecolname, aliasname='p_union'): |
---|
84 | """Create a ``UNION`` statement used by a polymorphic mapper. |
---|
85 | |
---|
86 | See :ref:`concrete_inheritance` for an example of how |
---|
87 | this is used. |
---|
88 | """ |
---|
89 | |
---|
90 | colnames = set() |
---|
91 | colnamemaps = {} |
---|
92 | types = {} |
---|
93 | for key in table_map.keys(): |
---|
94 | table = table_map[key] |
---|
95 | |
---|
96 | # mysql doesnt like selecting from a select; make it an alias of the select |
---|
97 | if isinstance(table, sql.Select): |
---|
98 | table = table.alias() |
---|
99 | table_map[key] = table |
---|
100 | |
---|
101 | m = {} |
---|
102 | for c in table.c: |
---|
103 | colnames.add(c.key) |
---|
104 | m[c.key] = c |
---|
105 | types[c.key] = c.type |
---|
106 | colnamemaps[table] = m |
---|
107 | |
---|
108 | def col(name, table): |
---|
109 | try: |
---|
110 | return colnamemaps[table][name] |
---|
111 | except KeyError: |
---|
112 | return sql.cast(sql.null(), types[name]).label(name) |
---|
113 | |
---|
114 | result = [] |
---|
115 | for type, table in table_map.iteritems(): |
---|
116 | if typecolname is not None: |
---|
117 | result.append(sql.select([col(name, table) for name in colnames] + |
---|
118 | [sql.literal_column("'%s'" % type).label(typecolname)], |
---|
119 | from_obj=[table])) |
---|
120 | else: |
---|
121 | result.append(sql.select([col(name, table) for name in colnames], |
---|
122 | from_obj=[table])) |
---|
123 | return sql.union_all(*result).alias(aliasname) |
---|
124 | |
---|
125 | def identity_key(*args, **kwargs): |
---|
126 | """Get an identity key. |
---|
127 | |
---|
128 | Valid call signatures: |
---|
129 | |
---|
130 | * ``identity_key(class, ident)`` |
---|
131 | |
---|
132 | class |
---|
133 | mapped class (must be a positional argument) |
---|
134 | |
---|
135 | ident |
---|
136 | primary key, if the key is composite this is a tuple |
---|
137 | |
---|
138 | |
---|
139 | * ``identity_key(instance=instance)`` |
---|
140 | |
---|
141 | instance |
---|
142 | object instance (must be given as a keyword arg) |
---|
143 | |
---|
144 | * ``identity_key(class, row=row)`` |
---|
145 | |
---|
146 | class |
---|
147 | mapped class (must be a positional argument) |
---|
148 | |
---|
149 | row |
---|
150 | result proxy row (must be given as a keyword arg) |
---|
151 | |
---|
152 | """ |
---|
153 | if args: |
---|
154 | if len(args) == 1: |
---|
155 | class_ = args[0] |
---|
156 | try: |
---|
157 | row = kwargs.pop("row") |
---|
158 | except KeyError: |
---|
159 | ident = kwargs.pop("ident") |
---|
160 | elif len(args) == 2: |
---|
161 | class_, ident = args |
---|
162 | elif len(args) == 3: |
---|
163 | class_, ident = args |
---|
164 | else: |
---|
165 | raise sa_exc.ArgumentError("expected up to three " |
---|
166 | "positional arguments, got %s" % len(args)) |
---|
167 | if kwargs: |
---|
168 | raise sa_exc.ArgumentError("unknown keyword arguments: %s" |
---|
169 | % ", ".join(kwargs.keys())) |
---|
170 | mapper = class_mapper(class_) |
---|
171 | if "ident" in locals(): |
---|
172 | return mapper.identity_key_from_primary_key(ident) |
---|
173 | return mapper.identity_key_from_row(row) |
---|
174 | instance = kwargs.pop("instance") |
---|
175 | if kwargs: |
---|
176 | raise sa_exc.ArgumentError("unknown keyword arguments: %s" |
---|
177 | % ", ".join(kwargs.keys())) |
---|
178 | mapper = object_mapper(instance) |
---|
179 | return mapper.identity_key_from_instance(instance) |
---|
180 | |
---|
181 | class ExtensionCarrier(dict): |
---|
182 | """Fronts an ordered collection of MapperExtension objects. |
---|
183 | |
---|
184 | Bundles multiple MapperExtensions into a unified callable unit, |
---|
185 | encapsulating ordering, looping and EXT_CONTINUE logic. The |
---|
186 | ExtensionCarrier implements the MapperExtension interface, e.g.:: |
---|
187 | |
---|
188 | carrier.after_insert(...args...) |
---|
189 | |
---|
190 | The dictionary interface provides containment for implemented |
---|
191 | method names mapped to a callable which executes that method |
---|
192 | for participating extensions. |
---|
193 | |
---|
194 | """ |
---|
195 | |
---|
196 | interface = set(method for method in dir(MapperExtension) |
---|
197 | if not method.startswith('_')) |
---|
198 | |
---|
199 | def __init__(self, extensions=None): |
---|
200 | self._extensions = [] |
---|
201 | for ext in extensions or (): |
---|
202 | self.append(ext) |
---|
203 | |
---|
204 | def copy(self): |
---|
205 | return ExtensionCarrier(self._extensions) |
---|
206 | |
---|
207 | def push(self, extension): |
---|
208 | """Insert a MapperExtension at the beginning of the collection.""" |
---|
209 | self._register(extension) |
---|
210 | self._extensions.insert(0, extension) |
---|
211 | |
---|
212 | def append(self, extension): |
---|
213 | """Append a MapperExtension at the end of the collection.""" |
---|
214 | self._register(extension) |
---|
215 | self._extensions.append(extension) |
---|
216 | |
---|
217 | def __iter__(self): |
---|
218 | """Iterate over MapperExtensions in the collection.""" |
---|
219 | return iter(self._extensions) |
---|
220 | |
---|
221 | def _register(self, extension): |
---|
222 | """Register callable fronts for overridden interface methods.""" |
---|
223 | |
---|
224 | for method in self.interface.difference(self): |
---|
225 | impl = getattr(extension, method, None) |
---|
226 | if impl and impl is not getattr(MapperExtension, method): |
---|
227 | self[method] = self._create_do(method) |
---|
228 | |
---|
229 | def _create_do(self, method): |
---|
230 | """Return a closure that loops over impls of the named method.""" |
---|
231 | |
---|
232 | def _do(*args, **kwargs): |
---|
233 | for ext in self._extensions: |
---|
234 | ret = getattr(ext, method)(*args, **kwargs) |
---|
235 | if ret is not EXT_CONTINUE: |
---|
236 | return ret |
---|
237 | else: |
---|
238 | return EXT_CONTINUE |
---|
239 | _do.__name__ = method |
---|
240 | return _do |
---|
241 | |
---|
242 | @staticmethod |
---|
243 | def _pass(*args, **kwargs): |
---|
244 | return EXT_CONTINUE |
---|
245 | |
---|
246 | def __getattr__(self, key): |
---|
247 | """Delegate MapperExtension methods to bundled fronts.""" |
---|
248 | |
---|
249 | if key not in self.interface: |
---|
250 | raise AttributeError(key) |
---|
251 | return self.get(key, self._pass) |
---|
252 | |
---|
253 | class ORMAdapter(sql_util.ColumnAdapter): |
---|
254 | """Extends ColumnAdapter to accept ORM entities. |
---|
255 | |
---|
256 | The selectable is extracted from the given entity, |
---|
257 | and the AliasedClass if any is referenced. |
---|
258 | |
---|
259 | """ |
---|
260 | def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False): |
---|
261 | self.mapper, selectable, is_aliased_class = _entity_info(entity) |
---|
262 | if is_aliased_class: |
---|
263 | self.aliased_class = entity |
---|
264 | else: |
---|
265 | self.aliased_class = None |
---|
266 | sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required) |
---|
267 | |
---|
268 | def replace(self, elem): |
---|
269 | entity = elem._annotations.get('parentmapper', None) |
---|
270 | if not entity or entity.isa(self.mapper): |
---|
271 | return sql_util.ColumnAdapter.replace(self, elem) |
---|
272 | else: |
---|
273 | return None |
---|
274 | |
---|
275 | class AliasedClass(object): |
---|
276 | """Represents an 'alias'ed form of a mapped class for usage with Query. |
---|
277 | |
---|
278 | The ORM equivalent of a :class:`~sqlalchemy.sql.expression.Alias` |
---|
279 | object, this object mimics the mapped class using a |
---|
280 | __getattr__ scheme and maintains a reference to a |
---|
281 | real Alias object. It indicates to Query that the |
---|
282 | selectable produced for this class should be aliased, |
---|
283 | and also adapts PropComparators produced by the class' |
---|
284 | InstrumentedAttributes so that they adapt the |
---|
285 | "local" side of SQL expressions against the alias. |
---|
286 | |
---|
287 | """ |
---|
288 | def __init__(self, cls, alias=None, name=None): |
---|
289 | self.__mapper = _class_to_mapper(cls) |
---|
290 | self.__target = self.__mapper.class_ |
---|
291 | alias = alias or self.__mapper._with_polymorphic_selectable.alias() |
---|
292 | self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) |
---|
293 | self.__alias = alias |
---|
294 | # used to assign a name to the RowTuple object |
---|
295 | # returned by Query. |
---|
296 | self._sa_label_name = name |
---|
297 | self.__name__ = 'AliasedClass_' + str(self.__target) |
---|
298 | |
---|
299 | def __getstate__(self): |
---|
300 | return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name} |
---|
301 | |
---|
302 | def __setstate__(self, state): |
---|
303 | self.__mapper = state['mapper'] |
---|
304 | self.__target = self.__mapper.class_ |
---|
305 | alias = state['alias'] |
---|
306 | self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) |
---|
307 | self.__alias = alias |
---|
308 | name = state['name'] |
---|
309 | self._sa_label_name = name |
---|
310 | self.__name__ = 'AliasedClass_' + str(self.__target) |
---|
311 | |
---|
312 | def __adapt_element(self, elem): |
---|
313 | return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper}) |
---|
314 | |
---|
315 | def __adapt_prop(self, prop): |
---|
316 | existing = getattr(self.__target, prop.key) |
---|
317 | comparator = existing.comparator.adapted(self.__adapt_element) |
---|
318 | |
---|
319 | queryattr = attributes.QueryableAttribute(prop.key, |
---|
320 | impl=existing.impl, parententity=self, comparator=comparator) |
---|
321 | setattr(self, prop.key, queryattr) |
---|
322 | return queryattr |
---|
323 | |
---|
324 | def __getattr__(self, key): |
---|
325 | prop = self.__mapper._get_property(key, raiseerr=False) |
---|
326 | if prop: |
---|
327 | return self.__adapt_prop(prop) |
---|
328 | |
---|
329 | for base in self.__target.__mro__: |
---|
330 | try: |
---|
331 | attr = object.__getattribute__(base, key) |
---|
332 | except AttributeError: |
---|
333 | continue |
---|
334 | else: |
---|
335 | break |
---|
336 | else: |
---|
337 | raise AttributeError(key) |
---|
338 | |
---|
339 | if hasattr(attr, 'func_code'): |
---|
340 | is_method = getattr(self.__target, key, None) |
---|
341 | if is_method and is_method.im_self is not None: |
---|
342 | return util.types.MethodType(attr.im_func, self, self) |
---|
343 | else: |
---|
344 | return None |
---|
345 | elif hasattr(attr, '__get__'): |
---|
346 | return attr.__get__(None, self) |
---|
347 | else: |
---|
348 | return attr |
---|
349 | |
---|
350 | def __repr__(self): |
---|
351 | return '<AliasedClass at 0x%x; %s>' % ( |
---|
352 | id(self), self.__target.__name__) |
---|
353 | |
---|
354 | def _orm_annotate(element, exclude=None): |
---|
355 | """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. |
---|
356 | |
---|
357 | Elements within the exclude collection will be cloned but not annotated. |
---|
358 | |
---|
359 | """ |
---|
360 | return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) |
---|
361 | |
---|
362 | _orm_deannotate = sql_util._deep_deannotate |
---|
363 | |
---|
364 | class _ORMJoin(expression.Join): |
---|
365 | """Extend Join to support ORM constructs as input.""" |
---|
366 | |
---|
367 | __visit_name__ = expression.Join.__visit_name__ |
---|
368 | |
---|
369 | def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True): |
---|
370 | adapt_from = None |
---|
371 | |
---|
372 | if hasattr(left, '_orm_mappers'): |
---|
373 | left_mapper = left._orm_mappers[1] |
---|
374 | if join_to_left: |
---|
375 | adapt_from = left.right |
---|
376 | else: |
---|
377 | left_mapper, left, left_is_aliased = _entity_info(left) |
---|
378 | if join_to_left and (left_is_aliased or not left_mapper): |
---|
379 | adapt_from = left |
---|
380 | |
---|
381 | right_mapper, right, right_is_aliased = _entity_info(right) |
---|
382 | if right_is_aliased: |
---|
383 | adapt_to = right |
---|
384 | else: |
---|
385 | adapt_to = None |
---|
386 | |
---|
387 | if left_mapper or right_mapper: |
---|
388 | self._orm_mappers = (left_mapper, right_mapper) |
---|
389 | |
---|
390 | if isinstance(onclause, basestring): |
---|
391 | prop = left_mapper.get_property(onclause) |
---|
392 | elif isinstance(onclause, attributes.QueryableAttribute): |
---|
393 | if not adapt_from: |
---|
394 | adapt_from = onclause.__clause_element__() |
---|
395 | prop = onclause.property |
---|
396 | elif isinstance(onclause, MapperProperty): |
---|
397 | prop = onclause |
---|
398 | else: |
---|
399 | prop = None |
---|
400 | |
---|
401 | if prop: |
---|
402 | pj, sj, source, dest, secondary, target_adapter = prop._create_joins( |
---|
403 | source_selectable=adapt_from, |
---|
404 | dest_selectable=adapt_to, |
---|
405 | source_polymorphic=True, |
---|
406 | dest_polymorphic=True, |
---|
407 | of_type=right_mapper) |
---|
408 | |
---|
409 | if sj: |
---|
410 | left = sql.join(left, secondary, pj, isouter) |
---|
411 | onclause = sj |
---|
412 | else: |
---|
413 | onclause = pj |
---|
414 | self._target_adapter = target_adapter |
---|
415 | |
---|
416 | expression.Join.__init__(self, left, right, onclause, isouter) |
---|
417 | |
---|
418 | def join(self, right, onclause=None, isouter=False, join_to_left=True): |
---|
419 | return _ORMJoin(self, right, onclause, isouter, join_to_left) |
---|
420 | |
---|
421 | def outerjoin(self, right, onclause=None, join_to_left=True): |
---|
422 | return _ORMJoin(self, right, onclause, True, join_to_left) |
---|
423 | |
---|
424 | def join(left, right, onclause=None, isouter=False, join_to_left=True): |
---|
425 | """Produce an inner join between left and right clauses. |
---|
426 | |
---|
427 | In addition to the interface provided by |
---|
428 | :func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped |
---|
429 | classes or AliasedClass instances. The onclause may be a |
---|
430 | string name of a relation(), or a class-bound descriptor |
---|
431 | representing a relation. |
---|
432 | |
---|
433 | join_to_left indicates to attempt aliasing the ON clause, |
---|
434 | in whatever form it is passed, to the selectable |
---|
435 | passed as the left side. If False, the onclause |
---|
436 | is used as is. |
---|
437 | |
---|
438 | """ |
---|
439 | return _ORMJoin(left, right, onclause, isouter, join_to_left) |
---|
440 | |
---|
441 | def outerjoin(left, right, onclause=None, join_to_left=True): |
---|
442 | """Produce a left outer join between left and right clauses. |
---|
443 | |
---|
444 | In addition to the interface provided by |
---|
445 | :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped |
---|
446 | classes or AliasedClass instances. The onclause may be a |
---|
447 | string name of a relation(), or a class-bound descriptor |
---|
448 | representing a relation. |
---|
449 | |
---|
450 | """ |
---|
451 | return _ORMJoin(left, right, onclause, True, join_to_left) |
---|
452 | |
---|
453 | def with_parent(instance, prop): |
---|
454 | """Return criterion which selects instances with a given parent. |
---|
455 | |
---|
456 | instance |
---|
457 | a parent instance, which should be persistent or detached. |
---|
458 | |
---|
459 | property |
---|
460 | a class-attached descriptor, MapperProperty or string property name |
---|
461 | attached to the parent instance. |
---|
462 | |
---|
463 | \**kwargs |
---|
464 | all extra keyword arguments are propagated to the constructor of |
---|
465 | Query. |
---|
466 | |
---|
467 | """ |
---|
468 | if isinstance(prop, basestring): |
---|
469 | mapper = object_mapper(instance) |
---|
470 | prop = mapper.get_property(prop, resolve_synonyms=True) |
---|
471 | elif isinstance(prop, attributes.QueryableAttribute): |
---|
472 | prop = prop.property |
---|
473 | |
---|
474 | return prop.compare(operators.eq, instance, value_is_parent=True) |
---|
475 | |
---|
476 | |
---|
477 | def _entity_info(entity, compile=True): |
---|
478 | """Return mapping information given a class, mapper, or AliasedClass. |
---|
479 | |
---|
480 | Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this |
---|
481 | is an aliased() construct. |
---|
482 | |
---|
483 | If the given entity is not a mapper, mapped class, or aliased construct, |
---|
484 | returns None, the entity, False. This is typically used to allow |
---|
485 | unmapped selectables through. |
---|
486 | |
---|
487 | """ |
---|
488 | if isinstance(entity, AliasedClass): |
---|
489 | return entity._AliasedClass__mapper, entity._AliasedClass__alias, True |
---|
490 | elif _is_mapped_class(entity): |
---|
491 | if isinstance(entity, type): |
---|
492 | mapper = class_mapper(entity, compile) |
---|
493 | else: |
---|
494 | if compile: |
---|
495 | mapper = entity.compile() |
---|
496 | else: |
---|
497 | mapper = entity |
---|
498 | return mapper, mapper._with_polymorphic_selectable, False |
---|
499 | else: |
---|
500 | return None, entity, False |
---|
501 | |
---|
502 | def _entity_descriptor(entity, key): |
---|
503 | """Return attribute/property information given an entity and string name. |
---|
504 | |
---|
505 | Returns a 2-tuple representing InstrumentedAttribute/MapperProperty. |
---|
506 | |
---|
507 | """ |
---|
508 | if isinstance(entity, AliasedClass): |
---|
509 | desc = getattr(entity, key) |
---|
510 | return desc, desc.property |
---|
511 | elif isinstance(entity, type): |
---|
512 | desc = attributes.manager_of_class(entity)[key] |
---|
513 | return desc, desc.property |
---|
514 | else: |
---|
515 | desc = entity.class_manager[key] |
---|
516 | return desc, desc.property |
---|
517 | |
---|
518 | def _orm_columns(entity): |
---|
519 | mapper, selectable, is_aliased_class = _entity_info(entity) |
---|
520 | if isinstance(selectable, expression.Selectable): |
---|
521 | return [c for c in selectable.c] |
---|
522 | else: |
---|
523 | return [selectable] |
---|
524 | |
---|
525 | def _orm_selectable(entity): |
---|
526 | mapper, selectable, is_aliased_class = _entity_info(entity) |
---|
527 | return selectable |
---|
528 | |
---|
529 | def _is_aliased_class(entity): |
---|
530 | return isinstance(entity, AliasedClass) |
---|
531 | |
---|
532 | def _state_mapper(state): |
---|
533 | return state.manager.mapper |
---|
534 | |
---|
535 | def object_mapper(instance): |
---|
536 | """Given an object, return the primary Mapper associated with the object instance. |
---|
537 | |
---|
538 | Raises UnmappedInstanceError if no mapping is configured. |
---|
539 | |
---|
540 | """ |
---|
541 | try: |
---|
542 | state = attributes.instance_state(instance) |
---|
543 | if not state.manager.mapper: |
---|
544 | raise exc.UnmappedInstanceError(instance) |
---|
545 | return state.manager.mapper |
---|
546 | except exc.NO_STATE: |
---|
547 | raise exc.UnmappedInstanceError(instance) |
---|
548 | |
---|
549 | def class_mapper(class_, compile=True): |
---|
550 | """Given a class, return the primary Mapper associated with the key. |
---|
551 | |
---|
552 | Raises UnmappedClassError if no mapping is configured. |
---|
553 | |
---|
554 | """ |
---|
555 | try: |
---|
556 | class_manager = attributes.manager_of_class(class_) |
---|
557 | mapper = class_manager.mapper |
---|
558 | |
---|
559 | # HACK until [ticket:1142] is complete |
---|
560 | if mapper is None: |
---|
561 | raise AttributeError |
---|
562 | |
---|
563 | except exc.NO_STATE: |
---|
564 | raise exc.UnmappedClassError(class_) |
---|
565 | |
---|
566 | if compile: |
---|
567 | mapper = mapper.compile() |
---|
568 | return mapper |
---|
569 | |
---|
570 | def _class_to_mapper(class_or_mapper, compile=True): |
---|
571 | if _is_aliased_class(class_or_mapper): |
---|
572 | return class_or_mapper._AliasedClass__mapper |
---|
573 | elif isinstance(class_or_mapper, type): |
---|
574 | return class_mapper(class_or_mapper, compile=compile) |
---|
575 | elif hasattr(class_or_mapper, 'compile'): |
---|
576 | if compile: |
---|
577 | return class_or_mapper.compile() |
---|
578 | else: |
---|
579 | return class_or_mapper |
---|
580 | else: |
---|
581 | raise exc.UnmappedClassError(class_or_mapper) |
---|
582 | |
---|
583 | def has_identity(object): |
---|
584 | state = attributes.instance_state(object) |
---|
585 | return _state_has_identity(state) |
---|
586 | |
---|
587 | def _state_has_identity(state): |
---|
588 | return bool(state.key) |
---|
589 | |
---|
590 | def _is_mapped_class(cls): |
---|
591 | from sqlalchemy.orm import mapperlib as mapper |
---|
592 | if isinstance(cls, (AliasedClass, mapper.Mapper)): |
---|
593 | return True |
---|
594 | if isinstance(cls, expression.ClauseElement): |
---|
595 | return False |
---|
596 | if isinstance(cls, type): |
---|
597 | manager = attributes.manager_of_class(cls) |
---|
598 | return manager and _INSTRUMENTOR in manager.info |
---|
599 | return False |
---|
600 | |
---|
601 | def instance_str(instance): |
---|
602 | """Return a string describing an instance.""" |
---|
603 | |
---|
604 | return state_str(attributes.instance_state(instance)) |
---|
605 | |
---|
606 | def state_str(state): |
---|
607 | """Return a string describing an instance via its InstanceState.""" |
---|
608 | |
---|
609 | if state is None: |
---|
610 | return "None" |
---|
611 | else: |
---|
612 | return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) |
---|
613 | |
---|
614 | def attribute_str(instance, attribute): |
---|
615 | return instance_str(instance) + "." + attribute |
---|
616 | |
---|
617 | def state_attribute_str(state, attribute): |
---|
618 | return state_str(state) + "." + attribute |
---|
619 | |
---|
620 | def identity_equal(a, b): |
---|
621 | if a is b: |
---|
622 | return True |
---|
623 | if a is None or b is None: |
---|
624 | return False |
---|
625 | try: |
---|
626 | state_a = attributes.instance_state(a) |
---|
627 | state_b = attributes.instance_state(b) |
---|
628 | except exc.NO_STATE: |
---|
629 | return False |
---|
630 | if state_a.key is None or state_b.key is None: |
---|
631 | return False |
---|
632 | return state_a.key == state_b.key |
---|
633 | |
---|
634 | |
---|
635 | # TODO: Avoid circular import. |
---|
636 | attributes.identity_equal = identity_equal |
---|
637 | attributes._is_aliased_class = _is_aliased_class |
---|
638 | attributes._entity_info = _entity_info |
---|