[3] | 1 | import optparse, os, sys, re, ConfigParser, StringIO, time, warnings |
---|
| 2 | logging = None |
---|
| 3 | |
---|
| 4 | __all__ = 'parser', 'configure', 'options', |
---|
| 5 | |
---|
| 6 | db = None |
---|
| 7 | db_label, db_url, db_opts = None, None, {} |
---|
| 8 | |
---|
| 9 | options = None |
---|
| 10 | file_config = None |
---|
| 11 | |
---|
| 12 | base_config = """ |
---|
| 13 | [db] |
---|
| 14 | sqlite=sqlite:///:memory: |
---|
| 15 | sqlite_file=sqlite:///querytest.db |
---|
| 16 | postgres=postgres://scott:tiger@127.0.0.1:5432/test |
---|
| 17 | postgresql=postgres://scott:tiger@127.0.0.1:5432/test |
---|
| 18 | mysql=mysql://scott:tiger@127.0.0.1:3306/test |
---|
| 19 | oracle=oracle://scott:tiger@127.0.0.1:1521 |
---|
| 20 | oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 |
---|
| 21 | mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test |
---|
| 22 | firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb |
---|
| 23 | maxdb=maxdb://MONA:RED@/maxdb1 |
---|
| 24 | """ |
---|
| 25 | |
---|
| 26 | def _log(option, opt_str, value, parser): |
---|
| 27 | global logging |
---|
| 28 | if not logging: |
---|
| 29 | import logging |
---|
| 30 | logging.basicConfig() |
---|
| 31 | |
---|
| 32 | if opt_str.endswith('-info'): |
---|
| 33 | logging.getLogger(value).setLevel(logging.INFO) |
---|
| 34 | elif opt_str.endswith('-debug'): |
---|
| 35 | logging.getLogger(value).setLevel(logging.DEBUG) |
---|
| 36 | |
---|
| 37 | |
---|
| 38 | def _list_dbs(*args): |
---|
| 39 | print "Available --db options (use --dburi to override)" |
---|
| 40 | for macro in sorted(file_config.options('db')): |
---|
| 41 | print "%20s\t%s" % (macro, file_config.get('db', macro)) |
---|
| 42 | sys.exit(0) |
---|
| 43 | |
---|
| 44 | def _server_side_cursors(options, opt_str, value, parser): |
---|
| 45 | db_opts['server_side_cursors'] = True |
---|
| 46 | |
---|
| 47 | def _engine_strategy(options, opt_str, value, parser): |
---|
| 48 | if value: |
---|
| 49 | db_opts['strategy'] = value |
---|
| 50 | |
---|
| 51 | class _ordered_map(object): |
---|
| 52 | def __init__(self): |
---|
| 53 | self._keys = list() |
---|
| 54 | self._data = dict() |
---|
| 55 | |
---|
| 56 | def __setitem__(self, key, value): |
---|
| 57 | if key not in self._keys: |
---|
| 58 | self._keys.append(key) |
---|
| 59 | self._data[key] = value |
---|
| 60 | |
---|
| 61 | def __iter__(self): |
---|
| 62 | for key in self._keys: |
---|
| 63 | yield self._data[key] |
---|
| 64 | |
---|
| 65 | # at one point in refactoring, modules were injecting into the config |
---|
| 66 | # process. this could probably just become a list now. |
---|
| 67 | post_configure = _ordered_map() |
---|
| 68 | |
---|
| 69 | def _engine_uri(options, file_config): |
---|
| 70 | global db_label, db_url |
---|
| 71 | db_label = 'sqlite' |
---|
| 72 | if options.dburi: |
---|
| 73 | db_url = options.dburi |
---|
| 74 | db_label = db_url[:db_url.index(':')] |
---|
| 75 | elif options.db: |
---|
| 76 | db_label = options.db |
---|
| 77 | db_url = None |
---|
| 78 | |
---|
| 79 | if db_url is None: |
---|
| 80 | if db_label not in file_config.options('db'): |
---|
| 81 | raise RuntimeError( |
---|
| 82 | "Unknown engine. Specify --dbs for known engines.") |
---|
| 83 | db_url = file_config.get('db', db_label) |
---|
| 84 | post_configure['engine_uri'] = _engine_uri |
---|
| 85 | |
---|
| 86 | def _require(options, file_config): |
---|
| 87 | if not(options.require or |
---|
| 88 | (file_config.has_section('require') and |
---|
| 89 | file_config.items('require'))): |
---|
| 90 | return |
---|
| 91 | |
---|
| 92 | try: |
---|
| 93 | import pkg_resources |
---|
| 94 | except ImportError: |
---|
| 95 | raise RuntimeError("setuptools is required for version requirements") |
---|
| 96 | |
---|
| 97 | cmdline = [] |
---|
| 98 | for requirement in options.require: |
---|
| 99 | pkg_resources.require(requirement) |
---|
| 100 | cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) |
---|
| 101 | |
---|
| 102 | if file_config.has_section('require'): |
---|
| 103 | for label, requirement in file_config.items('require'): |
---|
| 104 | if not label == db_label or label.startswith('%s.' % db_label): |
---|
| 105 | continue |
---|
| 106 | seen = [c for c in cmdline if requirement.startswith(c)] |
---|
| 107 | if seen: |
---|
| 108 | continue |
---|
| 109 | pkg_resources.require(requirement) |
---|
| 110 | post_configure['require'] = _require |
---|
| 111 | |
---|
| 112 | def _engine_pool(options, file_config): |
---|
| 113 | if options.mockpool: |
---|
| 114 | from sqlalchemy import pool |
---|
| 115 | db_opts['poolclass'] = pool.AssertionPool |
---|
| 116 | post_configure['engine_pool'] = _engine_pool |
---|
| 117 | |
---|
| 118 | def _create_testing_engine(options, file_config): |
---|
| 119 | from sqlalchemy.test import engines, testing |
---|
| 120 | global db |
---|
| 121 | db = engines.testing_engine(db_url, db_opts) |
---|
| 122 | testing.db = db |
---|
| 123 | post_configure['create_engine'] = _create_testing_engine |
---|
| 124 | |
---|
| 125 | def _prep_testing_database(options, file_config): |
---|
| 126 | from sqlalchemy.test import engines |
---|
| 127 | from sqlalchemy import schema |
---|
| 128 | |
---|
| 129 | try: |
---|
| 130 | # also create alt schemas etc. here? |
---|
| 131 | if options.dropfirst: |
---|
| 132 | e = engines.utf8_engine() |
---|
| 133 | existing = e.table_names() |
---|
| 134 | if existing: |
---|
| 135 | print "Dropping existing tables in database: " + db_url |
---|
| 136 | try: |
---|
| 137 | print "Tables: %s" % ', '.join(existing) |
---|
| 138 | except: |
---|
| 139 | pass |
---|
| 140 | print "Abort within 5 seconds..." |
---|
| 141 | time.sleep(5) |
---|
| 142 | md = schema.MetaData(e, reflect=True) |
---|
| 143 | md.drop_all() |
---|
| 144 | e.dispose() |
---|
| 145 | except (KeyboardInterrupt, SystemExit): |
---|
| 146 | raise |
---|
| 147 | except Exception, e: |
---|
| 148 | warnings.warn(RuntimeWarning( |
---|
| 149 | "Error checking for existing tables in testing " |
---|
| 150 | "database: %s" % e)) |
---|
| 151 | post_configure['prep_db'] = _prep_testing_database |
---|
| 152 | |
---|
| 153 | def _set_table_options(options, file_config): |
---|
| 154 | from sqlalchemy.test import schema |
---|
| 155 | |
---|
| 156 | table_options = schema.table_options |
---|
| 157 | for spec in options.tableopts: |
---|
| 158 | key, value = spec.split('=') |
---|
| 159 | table_options[key] = value |
---|
| 160 | |
---|
| 161 | if options.mysql_engine: |
---|
| 162 | table_options['mysql_engine'] = options.mysql_engine |
---|
| 163 | post_configure['table_options'] = _set_table_options |
---|
| 164 | |
---|
| 165 | def _reverse_topological(options, file_config): |
---|
| 166 | if options.reversetop: |
---|
| 167 | from sqlalchemy.orm import unitofwork |
---|
| 168 | from sqlalchemy import topological |
---|
| 169 | class RevQueueDepSort(topological.QueueDependencySorter): |
---|
| 170 | def __init__(self, tuples, allitems): |
---|
| 171 | self.tuples = list(tuples) |
---|
| 172 | self.allitems = list(allitems) |
---|
| 173 | self.tuples.reverse() |
---|
| 174 | self.allitems.reverse() |
---|
| 175 | topological.QueueDependencySorter = RevQueueDepSort |
---|
| 176 | unitofwork.DependencySorter = RevQueueDepSort |
---|
| 177 | post_configure['topological'] = _reverse_topological |
---|
| 178 | |
---|