""" This module defines standalone schema constraint classes. """ import sqlalchemy from sqlalchemy import schema class ConstraintChangeset(object): """Base class for Constraint classes. """ def _normalize_columns(self, cols, fullname=False): """Given: column objects or names; return col names and (maybe) a table""" colnames = [] table = None for col in cols: if isinstance(col, schema.Column): if col.table is not None and table is None: table = col.table if fullname: col = '.'.join((col.table.name, col.name)) else: col = col.name colnames.append(col) return colnames, table def create(self, engine=None): """Create the constraint in the database. :param engine: the database engine to use. If this is :keyword:`None` the instance's engine will be used :type engine: :class:`sqlalchemy.engine.base.Engine` """ if engine is None: engine = self.engine engine.create(self) def drop(self, engine=None): """Drop the constraint from the database. :param engine: the database engine to use. If this is :keyword:`None` the instance's engine will be used :type engine: :class:`sqlalchemy.engine.base.Engine` """ if engine is None: engine = self.engine engine.drop(self) def _derived_metadata(self): return self.table._derived_metadata() def accept_schema_visitor(self, visitor, *p, **k): """ :raises: :exc:`NotImplementedError` if this method is not \ overridden by a subclass """ raise NotImplementedError() def _accept_schema_visitor(self, visitor, func, *p, **k): """Call the visitor only if it defines the given function""" try: func = getattr(visitor, func) except AttributeError: return return func(self) def autoname(self): """Automatically generate a name for the constraint instance. Subclasses must implement this method. :raises: :exc:`NotImplementedError` if this method is not \ overridden by a subclass """ raise NotImplementedError() def _engine_run_visitor(engine, visitorcallable, element, **kwargs): conn = engine.connect() try: element.accept_schema_visitor(visitorcallable(conn)) finally: conn.close() class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint): """Primary key constraint class.""" def __init__(self, *cols, **kwargs): colnames, table = self._normalize_columns(cols) table = kwargs.pop('table', table) super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs) if table is not None: self._set_parent(table) def _set_parent(self, table): self.table = table return super(ConstraintChangeset, self)._set_parent(table) def create(self, *args, **kwargs): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintgenerator') _engine_run_visitor(self.table.bind, visitorcallable, self) def autoname(self): """Mimic the database's automatic constraint names""" ret = "%(table)s_pkey"%dict( table=self.table.name, ) return ret def drop(self, *args, **kwargs): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintdropper') _engine_run_visitor(self.table.bind, visitorcallable, self) self.columns.clear() return self def accept_schema_visitor(self, visitor, *p, **k): func = 'visit_migrate_primary_key_constraint' return self._accept_schema_visitor(visitor, func, *p, **k) class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint): """Foreign key constraint class.""" def __init__(self, columns, refcolumns, *p, **k): colnames, table = self._normalize_columns(columns) table = k.pop('table', table) refcolnames, reftable = self._normalize_columns(refcolumns, fullname=True) super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *p, **k) if table is not None: self._set_parent(table) def _get_referenced(self): return [e.column for e in self.elements] referenced = property(_get_referenced) def _get_reftable(self): return self.referenced[0].table reftable = property(_get_reftable) def autoname(self): """Mimic the database's automatic constraint names""" ret = "%(table)s_%(reftable)s_fkey"%dict( table=self.table.name, reftable=self.reftable.name, ) return ret def create(self, *args, **kwargs): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintgenerator') _engine_run_visitor(self.table.bind, visitorcallable, self) return self def drop(self, *args, **kwargs): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintdropper') _engine_run_visitor(self.table.bind, visitorcallable, self) self.columns.clear() return self def accept_schema_visitor(self, visitor, *p, **k): func = 'visit_migrate_foreign_key_constraint' return self._accept_schema_visitor(visitor, func, *p, **k) class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): """Check constraint class.""" def __init__(self, sqltext, *args, **kwargs): cols = kwargs.pop('columns') colnames, table = self._normalize_columns(cols) table = kwargs.pop('table', table) ConstraintChangeset.__init__(self, *args, **kwargs) schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs) if table is not None: self._set_parent(table) self.colnames = colnames def _set_parent(self, table): self.table = table return super(ConstraintChangeset, self)._set_parent(table) def create(self): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintgenerator') _engine_run_visitor(self.table.bind, visitorcallable, self) def drop(self): from migrate.changeset.databases.visitor import get_engine_visitor visitorcallable = get_engine_visitor(self.table.bind, 'constraintdropper') _engine_run_visitor(self.table.bind, visitorcallable, self) self.columns.clear() return self def autoname(self): return "%(table)s_%(cols)s_check" % \ {"table": self.table.name, "cols": "_".join(self.colnames)} def accept_schema_visitor(self, visitor, *args, **kwargs): func = 'visit_migrate_check_constraint' return self._accept_schema_visitor(visitor, func, *args, **kwargs)