1 | import sys, types, weakref |
---|
2 | from collections import deque |
---|
3 | import config |
---|
4 | from sqlalchemy.util import function_named, callable |
---|
5 | |
---|
6 | class ConnectionKiller(object): |
---|
7 | def __init__(self): |
---|
8 | self.proxy_refs = weakref.WeakKeyDictionary() |
---|
9 | |
---|
10 | def checkout(self, dbapi_con, con_record, con_proxy): |
---|
11 | self.proxy_refs[con_proxy] = True |
---|
12 | |
---|
13 | def _apply_all(self, methods): |
---|
14 | for rec in self.proxy_refs: |
---|
15 | if rec is not None and rec.is_valid: |
---|
16 | try: |
---|
17 | for name in methods: |
---|
18 | if callable(name): |
---|
19 | name(rec) |
---|
20 | else: |
---|
21 | getattr(rec, name)() |
---|
22 | except (SystemExit, KeyboardInterrupt): |
---|
23 | raise |
---|
24 | except Exception, e: |
---|
25 | # fixme |
---|
26 | sys.stderr.write("\n" + str(e) + "\n") |
---|
27 | |
---|
28 | def rollback_all(self): |
---|
29 | self._apply_all(('rollback',)) |
---|
30 | |
---|
31 | def close_all(self): |
---|
32 | self._apply_all(('rollback', 'close')) |
---|
33 | |
---|
34 | def assert_all_closed(self): |
---|
35 | for rec in self.proxy_refs: |
---|
36 | if rec.is_valid: |
---|
37 | assert False |
---|
38 | |
---|
39 | testing_reaper = ConnectionKiller() |
---|
40 | |
---|
41 | def assert_conns_closed(fn): |
---|
42 | def decorated(*args, **kw): |
---|
43 | try: |
---|
44 | fn(*args, **kw) |
---|
45 | finally: |
---|
46 | testing_reaper.assert_all_closed() |
---|
47 | return function_named(decorated, fn.__name__) |
---|
48 | |
---|
49 | def rollback_open_connections(fn): |
---|
50 | """Decorator that rolls back all open connections after fn execution.""" |
---|
51 | |
---|
52 | def decorated(*args, **kw): |
---|
53 | try: |
---|
54 | fn(*args, **kw) |
---|
55 | finally: |
---|
56 | testing_reaper.rollback_all() |
---|
57 | return function_named(decorated, fn.__name__) |
---|
58 | |
---|
59 | def close_open_connections(fn): |
---|
60 | """Decorator that closes all connections after fn execution.""" |
---|
61 | |
---|
62 | def decorated(*args, **kw): |
---|
63 | try: |
---|
64 | fn(*args, **kw) |
---|
65 | finally: |
---|
66 | testing_reaper.close_all() |
---|
67 | return function_named(decorated, fn.__name__) |
---|
68 | |
---|
69 | def all_dialects(): |
---|
70 | import sqlalchemy.databases as d |
---|
71 | for name in d.__all__: |
---|
72 | mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) |
---|
73 | yield mod.dialect() |
---|
74 | |
---|
75 | class ReconnectFixture(object): |
---|
76 | def __init__(self, dbapi): |
---|
77 | self.dbapi = dbapi |
---|
78 | self.connections = [] |
---|
79 | |
---|
80 | def __getattr__(self, key): |
---|
81 | return getattr(self.dbapi, key) |
---|
82 | |
---|
83 | def connect(self, *args, **kwargs): |
---|
84 | conn = self.dbapi.connect(*args, **kwargs) |
---|
85 | self.connections.append(conn) |
---|
86 | return conn |
---|
87 | |
---|
88 | def shutdown(self): |
---|
89 | for c in list(self.connections): |
---|
90 | c.close() |
---|
91 | self.connections = [] |
---|
92 | |
---|
93 | def reconnecting_engine(url=None, options=None): |
---|
94 | url = url or config.db_url |
---|
95 | dbapi = config.db.dialect.dbapi |
---|
96 | if not options: |
---|
97 | options = {} |
---|
98 | options['module'] = ReconnectFixture(dbapi) |
---|
99 | engine = testing_engine(url, options) |
---|
100 | engine.test_shutdown = engine.dialect.dbapi.shutdown |
---|
101 | return engine |
---|
102 | |
---|
103 | def testing_engine(url=None, options=None): |
---|
104 | """Produce an engine configured by --options with optional overrides.""" |
---|
105 | |
---|
106 | from sqlalchemy import create_engine |
---|
107 | from sqlalchemy.test.assertsql import asserter |
---|
108 | |
---|
109 | url = url or config.db_url |
---|
110 | options = options or config.db_opts |
---|
111 | |
---|
112 | options.setdefault('proxy', asserter) |
---|
113 | |
---|
114 | listeners = options.setdefault('listeners', []) |
---|
115 | listeners.append(testing_reaper) |
---|
116 | |
---|
117 | engine = create_engine(url, **options) |
---|
118 | |
---|
119 | return engine |
---|
120 | |
---|
121 | def utf8_engine(url=None, options=None): |
---|
122 | """Hook for dialects or drivers that don't handle utf8 by default.""" |
---|
123 | |
---|
124 | from sqlalchemy.engine import url as engine_url |
---|
125 | |
---|
126 | if config.db.name == 'mysql': |
---|
127 | dbapi_ver = config.db.dialect.dbapi.version_info |
---|
128 | if (dbapi_ver < (1, 2, 1) or |
---|
129 | dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), |
---|
130 | (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): |
---|
131 | raise RuntimeError('Character set support unavailable with this ' |
---|
132 | 'driver version: %s' % repr(dbapi_ver)) |
---|
133 | else: |
---|
134 | url = url or config.db_url |
---|
135 | url = engine_url.make_url(url) |
---|
136 | url.query['charset'] = 'utf8' |
---|
137 | url.query['use_unicode'] = '0' |
---|
138 | url = str(url) |
---|
139 | |
---|
140 | return testing_engine(url, options) |
---|
141 | |
---|
142 | def mock_engine(db=None): |
---|
143 | """Provides a mocking engine based on the current testing.db.""" |
---|
144 | |
---|
145 | from sqlalchemy import create_engine |
---|
146 | |
---|
147 | dbi = db or config.db |
---|
148 | buffer = [] |
---|
149 | def executor(sql, *a, **kw): |
---|
150 | buffer.append(sql) |
---|
151 | engine = create_engine(dbi.name + '://', |
---|
152 | strategy='mock', executor=executor) |
---|
153 | assert not hasattr(engine, 'mock') |
---|
154 | engine.mock = buffer |
---|
155 | return engine |
---|
156 | |
---|
157 | class ReplayableSession(object): |
---|
158 | """A simple record/playback tool. |
---|
159 | |
---|
160 | This is *not* a mock testing class. It only records a session for later |
---|
161 | playback and makes no assertions on call consistency whatsoever. It's |
---|
162 | unlikely to be suitable for anything other than DB-API recording. |
---|
163 | |
---|
164 | """ |
---|
165 | |
---|
166 | Callable = object() |
---|
167 | NoAttribute = object() |
---|
168 | Natives = set([getattr(types, t) |
---|
169 | for t in dir(types) if not t.startswith('_')]). \ |
---|
170 | difference([getattr(types, t) |
---|
171 | for t in ('FunctionType', 'BuiltinFunctionType', |
---|
172 | 'MethodType', 'BuiltinMethodType', |
---|
173 | 'LambdaType', 'UnboundMethodType',)]) |
---|
174 | def __init__(self): |
---|
175 | self.buffer = deque() |
---|
176 | |
---|
177 | def recorder(self, base): |
---|
178 | return self.Recorder(self.buffer, base) |
---|
179 | |
---|
180 | def player(self): |
---|
181 | return self.Player(self.buffer) |
---|
182 | |
---|
183 | class Recorder(object): |
---|
184 | def __init__(self, buffer, subject): |
---|
185 | self._buffer = buffer |
---|
186 | self._subject = subject |
---|
187 | |
---|
188 | def __call__(self, *args, **kw): |
---|
189 | subject, buffer = [object.__getattribute__(self, x) |
---|
190 | for x in ('_subject', '_buffer')] |
---|
191 | |
---|
192 | result = subject(*args, **kw) |
---|
193 | if type(result) not in ReplayableSession.Natives: |
---|
194 | buffer.append(ReplayableSession.Callable) |
---|
195 | return type(self)(buffer, result) |
---|
196 | else: |
---|
197 | buffer.append(result) |
---|
198 | return result |
---|
199 | |
---|
200 | def __getattribute__(self, key): |
---|
201 | try: |
---|
202 | return object.__getattribute__(self, key) |
---|
203 | except AttributeError: |
---|
204 | pass |
---|
205 | |
---|
206 | subject, buffer = [object.__getattribute__(self, x) |
---|
207 | for x in ('_subject', '_buffer')] |
---|
208 | try: |
---|
209 | result = type(subject).__getattribute__(subject, key) |
---|
210 | except AttributeError: |
---|
211 | buffer.append(ReplayableSession.NoAttribute) |
---|
212 | raise |
---|
213 | else: |
---|
214 | if type(result) not in ReplayableSession.Natives: |
---|
215 | buffer.append(ReplayableSession.Callable) |
---|
216 | return type(self)(buffer, result) |
---|
217 | else: |
---|
218 | buffer.append(result) |
---|
219 | return result |
---|
220 | |
---|
221 | class Player(object): |
---|
222 | def __init__(self, buffer): |
---|
223 | self._buffer = buffer |
---|
224 | |
---|
225 | def __call__(self, *args, **kw): |
---|
226 | buffer = object.__getattribute__(self, '_buffer') |
---|
227 | result = buffer.popleft() |
---|
228 | if result is ReplayableSession.Callable: |
---|
229 | return self |
---|
230 | else: |
---|
231 | return result |
---|
232 | |
---|
233 | def __getattribute__(self, key): |
---|
234 | try: |
---|
235 | return object.__getattribute__(self, key) |
---|
236 | except AttributeError: |
---|
237 | pass |
---|
238 | buffer = object.__getattribute__(self, '_buffer') |
---|
239 | result = buffer.popleft() |
---|
240 | if result is ReplayableSession.Callable: |
---|
241 | return self |
---|
242 | elif result is ReplayableSession.NoAttribute: |
---|
243 | raise AttributeError(key) |
---|
244 | else: |
---|
245 | return result |
---|