1 | import optparse, os, sys, re, ConfigParser, StringIO, time, warnings |
---|
2 | logging = None |
---|
3 | |
---|
4 | __all__ = 'parser', 'configure', 'options', |
---|
5 | |
---|
6 | db = None |
---|
7 | db_label, db_url, db_opts = None, None, {} |
---|
8 | |
---|
9 | options = None |
---|
10 | file_config = None |
---|
11 | |
---|
12 | base_config = """ |
---|
13 | [db] |
---|
14 | sqlite=sqlite:///:memory: |
---|
15 | sqlite_file=sqlite:///querytest.db |
---|
16 | postgres=postgres://scott:tiger@127.0.0.1:5432/test |
---|
17 | postgresql=postgres://scott:tiger@127.0.0.1:5432/test |
---|
18 | mysql=mysql://scott:tiger@127.0.0.1:3306/test |
---|
19 | oracle=oracle://scott:tiger@127.0.0.1:1521 |
---|
20 | oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 |
---|
21 | mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test |
---|
22 | firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb |
---|
23 | maxdb=maxdb://MONA:RED@/maxdb1 |
---|
24 | """ |
---|
25 | |
---|
26 | def _log(option, opt_str, value, parser): |
---|
27 | global logging |
---|
28 | if not logging: |
---|
29 | import logging |
---|
30 | logging.basicConfig() |
---|
31 | |
---|
32 | if opt_str.endswith('-info'): |
---|
33 | logging.getLogger(value).setLevel(logging.INFO) |
---|
34 | elif opt_str.endswith('-debug'): |
---|
35 | logging.getLogger(value).setLevel(logging.DEBUG) |
---|
36 | |
---|
37 | |
---|
38 | def _list_dbs(*args): |
---|
39 | print "Available --db options (use --dburi to override)" |
---|
40 | for macro in sorted(file_config.options('db')): |
---|
41 | print "%20s\t%s" % (macro, file_config.get('db', macro)) |
---|
42 | sys.exit(0) |
---|
43 | |
---|
44 | def _server_side_cursors(options, opt_str, value, parser): |
---|
45 | db_opts['server_side_cursors'] = True |
---|
46 | |
---|
47 | def _engine_strategy(options, opt_str, value, parser): |
---|
48 | if value: |
---|
49 | db_opts['strategy'] = value |
---|
50 | |
---|
51 | class _ordered_map(object): |
---|
52 | def __init__(self): |
---|
53 | self._keys = list() |
---|
54 | self._data = dict() |
---|
55 | |
---|
56 | def __setitem__(self, key, value): |
---|
57 | if key not in self._keys: |
---|
58 | self._keys.append(key) |
---|
59 | self._data[key] = value |
---|
60 | |
---|
61 | def __iter__(self): |
---|
62 | for key in self._keys: |
---|
63 | yield self._data[key] |
---|
64 | |
---|
65 | # at one point in refactoring, modules were injecting into the config |
---|
66 | # process. this could probably just become a list now. |
---|
67 | post_configure = _ordered_map() |
---|
68 | |
---|
69 | def _engine_uri(options, file_config): |
---|
70 | global db_label, db_url |
---|
71 | db_label = 'sqlite' |
---|
72 | if options.dburi: |
---|
73 | db_url = options.dburi |
---|
74 | db_label = db_url[:db_url.index(':')] |
---|
75 | elif options.db: |
---|
76 | db_label = options.db |
---|
77 | db_url = None |
---|
78 | |
---|
79 | if db_url is None: |
---|
80 | if db_label not in file_config.options('db'): |
---|
81 | raise RuntimeError( |
---|
82 | "Unknown engine. Specify --dbs for known engines.") |
---|
83 | db_url = file_config.get('db', db_label) |
---|
84 | post_configure['engine_uri'] = _engine_uri |
---|
85 | |
---|
86 | def _require(options, file_config): |
---|
87 | if not(options.require or |
---|
88 | (file_config.has_section('require') and |
---|
89 | file_config.items('require'))): |
---|
90 | return |
---|
91 | |
---|
92 | try: |
---|
93 | import pkg_resources |
---|
94 | except ImportError: |
---|
95 | raise RuntimeError("setuptools is required for version requirements") |
---|
96 | |
---|
97 | cmdline = [] |
---|
98 | for requirement in options.require: |
---|
99 | pkg_resources.require(requirement) |
---|
100 | cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) |
---|
101 | |
---|
102 | if file_config.has_section('require'): |
---|
103 | for label, requirement in file_config.items('require'): |
---|
104 | if not label == db_label or label.startswith('%s.' % db_label): |
---|
105 | continue |
---|
106 | seen = [c for c in cmdline if requirement.startswith(c)] |
---|
107 | if seen: |
---|
108 | continue |
---|
109 | pkg_resources.require(requirement) |
---|
110 | post_configure['require'] = _require |
---|
111 | |
---|
112 | def _engine_pool(options, file_config): |
---|
113 | if options.mockpool: |
---|
114 | from sqlalchemy import pool |
---|
115 | db_opts['poolclass'] = pool.AssertionPool |
---|
116 | post_configure['engine_pool'] = _engine_pool |
---|
117 | |
---|
118 | def _create_testing_engine(options, file_config): |
---|
119 | from sqlalchemy.test import engines, testing |
---|
120 | global db |
---|
121 | db = engines.testing_engine(db_url, db_opts) |
---|
122 | testing.db = db |
---|
123 | post_configure['create_engine'] = _create_testing_engine |
---|
124 | |
---|
125 | def _prep_testing_database(options, file_config): |
---|
126 | from sqlalchemy.test import engines |
---|
127 | from sqlalchemy import schema |
---|
128 | |
---|
129 | try: |
---|
130 | # also create alt schemas etc. here? |
---|
131 | if options.dropfirst: |
---|
132 | e = engines.utf8_engine() |
---|
133 | existing = e.table_names() |
---|
134 | if existing: |
---|
135 | print "Dropping existing tables in database: " + db_url |
---|
136 | try: |
---|
137 | print "Tables: %s" % ', '.join(existing) |
---|
138 | except: |
---|
139 | pass |
---|
140 | print "Abort within 5 seconds..." |
---|
141 | time.sleep(5) |
---|
142 | md = schema.MetaData(e, reflect=True) |
---|
143 | md.drop_all() |
---|
144 | e.dispose() |
---|
145 | except (KeyboardInterrupt, SystemExit): |
---|
146 | raise |
---|
147 | except Exception, e: |
---|
148 | warnings.warn(RuntimeWarning( |
---|
149 | "Error checking for existing tables in testing " |
---|
150 | "database: %s" % e)) |
---|
151 | post_configure['prep_db'] = _prep_testing_database |
---|
152 | |
---|
153 | def _set_table_options(options, file_config): |
---|
154 | from sqlalchemy.test import schema |
---|
155 | |
---|
156 | table_options = schema.table_options |
---|
157 | for spec in options.tableopts: |
---|
158 | key, value = spec.split('=') |
---|
159 | table_options[key] = value |
---|
160 | |
---|
161 | if options.mysql_engine: |
---|
162 | table_options['mysql_engine'] = options.mysql_engine |
---|
163 | post_configure['table_options'] = _set_table_options |
---|
164 | |
---|
165 | def _reverse_topological(options, file_config): |
---|
166 | if options.reversetop: |
---|
167 | from sqlalchemy.orm import unitofwork |
---|
168 | from sqlalchemy import topological |
---|
169 | class RevQueueDepSort(topological.QueueDependencySorter): |
---|
170 | def __init__(self, tuples, allitems): |
---|
171 | self.tuples = list(tuples) |
---|
172 | self.allitems = list(allitems) |
---|
173 | self.tuples.reverse() |
---|
174 | self.allitems.reverse() |
---|
175 | topological.QueueDependencySorter = RevQueueDepSort |
---|
176 | unitofwork.DependencySorter = RevQueueDepSort |
---|
177 | post_configure['topological'] = _reverse_topological |
---|
178 | |
---|