1 | """ |
---|
2 | Module for visitor class mapping. |
---|
3 | """ |
---|
4 | import sqlalchemy as sa |
---|
5 | from migrate.changeset.databases import sqlite, postgres, mysql, oracle |
---|
6 | from migrate.changeset import ansisql |
---|
7 | |
---|
8 | # Map SA dialects to the corresponding Migrate extensions |
---|
9 | dialects = { |
---|
10 | sa.engine.default.DefaultDialect: ansisql.ANSIDialect, |
---|
11 | sa.databases.sqlite.SQLiteDialect: sqlite.SQLiteDialect, |
---|
12 | sa.databases.postgres.PGDialect: postgres.PGDialect, |
---|
13 | sa.databases.mysql.MySQLDialect: mysql.MySQLDialect, |
---|
14 | sa.databases.oracle.OracleDialect: oracle.OracleDialect, |
---|
15 | } |
---|
16 | |
---|
17 | |
---|
18 | def get_engine_visitor(engine, name): |
---|
19 | """ |
---|
20 | Get the visitor implementation for the given database engine. |
---|
21 | """ |
---|
22 | return get_dialect_visitor(engine.dialect, name) |
---|
23 | |
---|
24 | |
---|
25 | def get_dialect_visitor(sa_dialect, name): |
---|
26 | """ |
---|
27 | Get the visitor implementation for the given dialect. |
---|
28 | |
---|
29 | Finds the visitor implementation based on the dialect class and |
---|
30 | returns and instance initialized with the given name. |
---|
31 | """ |
---|
32 | sa_dialect_cls = sa_dialect.__class__ |
---|
33 | migrate_dialect_cls = dialects[sa_dialect_cls] |
---|
34 | return migrate_dialect_cls.visitor(name) |
---|