| 1 | """ |
|---|
| 2 | Module for visitor class mapping. |
|---|
| 3 | """ |
|---|
| 4 | import sqlalchemy as sa |
|---|
| 5 | from migrate.changeset.databases import sqlite, postgres, mysql, oracle |
|---|
| 6 | from migrate.changeset import ansisql |
|---|
| 7 | |
|---|
| 8 | # Map SA dialects to the corresponding Migrate extensions |
|---|
| 9 | dialects = { |
|---|
| 10 | sa.engine.default.DefaultDialect: ansisql.ANSIDialect, |
|---|
| 11 | sa.databases.sqlite.SQLiteDialect: sqlite.SQLiteDialect, |
|---|
| 12 | sa.databases.postgres.PGDialect: postgres.PGDialect, |
|---|
| 13 | sa.databases.mysql.MySQLDialect: mysql.MySQLDialect, |
|---|
| 14 | sa.databases.oracle.OracleDialect: oracle.OracleDialect, |
|---|
| 15 | } |
|---|
| 16 | |
|---|
| 17 | |
|---|
| 18 | def get_engine_visitor(engine, name): |
|---|
| 19 | """ |
|---|
| 20 | Get the visitor implementation for the given database engine. |
|---|
| 21 | """ |
|---|
| 22 | return get_dialect_visitor(engine.dialect, name) |
|---|
| 23 | |
|---|
| 24 | |
|---|
| 25 | def get_dialect_visitor(sa_dialect, name): |
|---|
| 26 | """ |
|---|
| 27 | Get the visitor implementation for the given dialect. |
|---|
| 28 | |
|---|
| 29 | Finds the visitor implementation based on the dialect class and |
|---|
| 30 | returns and instance initialized with the given name. |
|---|
| 31 | """ |
|---|
| 32 | sa_dialect_cls = sa_dialect.__class__ |
|---|
| 33 | migrate_dialect_cls = dialects[sa_dialect_cls] |
|---|
| 34 | return migrate_dialect_cls.visitor(name) |
|---|