1 | |
---|
2 | from sqlalchemy.interfaces import ConnectionProxy |
---|
3 | from sqlalchemy.engine.default import DefaultDialect |
---|
4 | from sqlalchemy.engine.base import Connection |
---|
5 | from sqlalchemy import util |
---|
6 | import testing |
---|
7 | import re |
---|
8 | |
---|
9 | class AssertRule(object): |
---|
10 | def process_execute(self, clauseelement, *multiparams, **params): |
---|
11 | pass |
---|
12 | |
---|
13 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
14 | pass |
---|
15 | |
---|
16 | def is_consumed(self): |
---|
17 | """Return True if this rule has been consumed, False if not. |
---|
18 | |
---|
19 | Should raise an AssertionError if this rule's condition has definitely failed. |
---|
20 | |
---|
21 | """ |
---|
22 | raise NotImplementedError() |
---|
23 | |
---|
24 | def rule_passed(self): |
---|
25 | """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" |
---|
26 | |
---|
27 | raise NotImplementedError() |
---|
28 | |
---|
29 | def consume_final(self): |
---|
30 | """Return True if this rule has been consumed. |
---|
31 | |
---|
32 | Should raise an AssertionError if this rule's condition has not been consumed or has failed. |
---|
33 | |
---|
34 | """ |
---|
35 | |
---|
36 | if self._result is None: |
---|
37 | assert False, "Rule has not been consumed" |
---|
38 | |
---|
39 | return self.is_consumed() |
---|
40 | |
---|
41 | class SQLMatchRule(AssertRule): |
---|
42 | def __init__(self): |
---|
43 | self._result = None |
---|
44 | self._errmsg = "" |
---|
45 | |
---|
46 | def rule_passed(self): |
---|
47 | return self._result |
---|
48 | |
---|
49 | def is_consumed(self): |
---|
50 | if self._result is None: |
---|
51 | return False |
---|
52 | |
---|
53 | assert self._result, self._errmsg |
---|
54 | |
---|
55 | return True |
---|
56 | |
---|
57 | class ExactSQL(SQLMatchRule): |
---|
58 | def __init__(self, sql, params=None): |
---|
59 | SQLMatchRule.__init__(self) |
---|
60 | self.sql = sql |
---|
61 | self.params = params |
---|
62 | |
---|
63 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
64 | if not context: |
---|
65 | return |
---|
66 | |
---|
67 | _received_statement = _process_engine_statement(statement, context) |
---|
68 | _received_parameters = context.compiled_parameters |
---|
69 | |
---|
70 | # TODO: remove this step once all unit tests |
---|
71 | # are migrated, as ExactSQL should really be *exact* SQL |
---|
72 | sql = _process_assertion_statement(self.sql, context) |
---|
73 | |
---|
74 | equivalent = _received_statement == sql |
---|
75 | if self.params: |
---|
76 | if util.callable(self.params): |
---|
77 | params = self.params(context) |
---|
78 | else: |
---|
79 | params = self.params |
---|
80 | |
---|
81 | if not isinstance(params, list): |
---|
82 | params = [params] |
---|
83 | equivalent = equivalent and params == context.compiled_parameters |
---|
84 | else: |
---|
85 | params = {} |
---|
86 | |
---|
87 | |
---|
88 | self._result = equivalent |
---|
89 | if not self._result: |
---|
90 | self._errmsg = "Testing for exact statement %r exact params %r, " \ |
---|
91 | "received %r with params %r" % (sql, params, _received_statement, _received_parameters) |
---|
92 | |
---|
93 | |
---|
94 | class RegexSQL(SQLMatchRule): |
---|
95 | def __init__(self, regex, params=None): |
---|
96 | SQLMatchRule.__init__(self) |
---|
97 | self.regex = re.compile(regex) |
---|
98 | self.orig_regex = regex |
---|
99 | self.params = params |
---|
100 | |
---|
101 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
102 | if not context: |
---|
103 | return |
---|
104 | |
---|
105 | _received_statement = _process_engine_statement(statement, context) |
---|
106 | _received_parameters = context.compiled_parameters |
---|
107 | |
---|
108 | equivalent = bool(self.regex.match(_received_statement)) |
---|
109 | if self.params: |
---|
110 | if util.callable(self.params): |
---|
111 | params = self.params(context) |
---|
112 | else: |
---|
113 | params = self.params |
---|
114 | |
---|
115 | if not isinstance(params, list): |
---|
116 | params = [params] |
---|
117 | |
---|
118 | # do a positive compare only |
---|
119 | for param, received in zip(params, _received_parameters): |
---|
120 | for k, v in param.iteritems(): |
---|
121 | if k not in received or received[k] != v: |
---|
122 | equivalent = False |
---|
123 | break |
---|
124 | else: |
---|
125 | params = {} |
---|
126 | |
---|
127 | self._result = equivalent |
---|
128 | if not self._result: |
---|
129 | self._errmsg = "Testing for regex %r partial params %r, "\ |
---|
130 | "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) |
---|
131 | |
---|
132 | class CompiledSQL(SQLMatchRule): |
---|
133 | def __init__(self, statement, params): |
---|
134 | SQLMatchRule.__init__(self) |
---|
135 | self.statement = statement |
---|
136 | self.params = params |
---|
137 | |
---|
138 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
139 | if not context: |
---|
140 | return |
---|
141 | |
---|
142 | _received_parameters = context.compiled_parameters |
---|
143 | |
---|
144 | # recompile from the context, using the default dialect |
---|
145 | compiled = context.compiled.statement.\ |
---|
146 | compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) |
---|
147 | |
---|
148 | _received_statement = re.sub(r'\n', '', str(compiled)) |
---|
149 | |
---|
150 | equivalent = self.statement == _received_statement |
---|
151 | if self.params: |
---|
152 | if util.callable(self.params): |
---|
153 | params = self.params(context) |
---|
154 | else: |
---|
155 | params = self.params |
---|
156 | |
---|
157 | if not isinstance(params, list): |
---|
158 | params = [params] |
---|
159 | |
---|
160 | # do a positive compare only |
---|
161 | for param, received in zip(params, _received_parameters): |
---|
162 | for k, v in param.iteritems(): |
---|
163 | if k not in received or received[k] != v: |
---|
164 | equivalent = False |
---|
165 | break |
---|
166 | else: |
---|
167 | params = {} |
---|
168 | |
---|
169 | self._result = equivalent |
---|
170 | if not self._result: |
---|
171 | self._errmsg = "Testing for compiled statement %r partial params %r, " \ |
---|
172 | "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) |
---|
173 | |
---|
174 | |
---|
175 | class CountStatements(AssertRule): |
---|
176 | def __init__(self, count): |
---|
177 | self.count = count |
---|
178 | self._statement_count = 0 |
---|
179 | |
---|
180 | def process_execute(self, clauseelement, *multiparams, **params): |
---|
181 | self._statement_count += 1 |
---|
182 | |
---|
183 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
184 | pass |
---|
185 | |
---|
186 | def is_consumed(self): |
---|
187 | return False |
---|
188 | |
---|
189 | def consume_final(self): |
---|
190 | assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) |
---|
191 | return True |
---|
192 | |
---|
193 | class AllOf(AssertRule): |
---|
194 | def __init__(self, *rules): |
---|
195 | self.rules = set(rules) |
---|
196 | |
---|
197 | def process_execute(self, clauseelement, *multiparams, **params): |
---|
198 | for rule in self.rules: |
---|
199 | rule.process_execute(clauseelement, *multiparams, **params) |
---|
200 | |
---|
201 | def process_cursor_execute(self, statement, parameters, context, executemany): |
---|
202 | for rule in self.rules: |
---|
203 | rule.process_cursor_execute(statement, parameters, context, executemany) |
---|
204 | |
---|
205 | def is_consumed(self): |
---|
206 | if not self.rules: |
---|
207 | return True |
---|
208 | |
---|
209 | for rule in list(self.rules): |
---|
210 | if rule.rule_passed(): # a rule passed, move on |
---|
211 | self.rules.remove(rule) |
---|
212 | return len(self.rules) == 0 |
---|
213 | |
---|
214 | assert False, "No assertion rules were satisfied for statement" |
---|
215 | |
---|
216 | def consume_final(self): |
---|
217 | return len(self.rules) == 0 |
---|
218 | |
---|
219 | def _process_engine_statement(query, context): |
---|
220 | if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): |
---|
221 | query = query[:-25] |
---|
222 | |
---|
223 | query = re.sub(r'\n', '', query) |
---|
224 | |
---|
225 | return query |
---|
226 | |
---|
227 | def _process_assertion_statement(query, context): |
---|
228 | paramstyle = context.dialect.paramstyle |
---|
229 | if paramstyle == 'named': |
---|
230 | pass |
---|
231 | elif paramstyle =='pyformat': |
---|
232 | query = re.sub(r':([\w_]+)', r"%(\1)s", query) |
---|
233 | else: |
---|
234 | # positional params |
---|
235 | repl = None |
---|
236 | if paramstyle=='qmark': |
---|
237 | repl = "?" |
---|
238 | elif paramstyle=='format': |
---|
239 | repl = r"%s" |
---|
240 | elif paramstyle=='numeric': |
---|
241 | repl = None |
---|
242 | query = re.sub(r':([\w_]+)', repl, query) |
---|
243 | |
---|
244 | return query |
---|
245 | |
---|
246 | class SQLAssert(ConnectionProxy): |
---|
247 | rules = None |
---|
248 | |
---|
249 | def add_rules(self, rules): |
---|
250 | self.rules = list(rules) |
---|
251 | |
---|
252 | def statement_complete(self): |
---|
253 | for rule in self.rules: |
---|
254 | if not rule.consume_final(): |
---|
255 | assert False, "All statements are complete, but pending assertion rules remain" |
---|
256 | |
---|
257 | def clear_rules(self): |
---|
258 | del self.rules |
---|
259 | |
---|
260 | def execute(self, conn, execute, clauseelement, *multiparams, **params): |
---|
261 | result = execute(clauseelement, *multiparams, **params) |
---|
262 | |
---|
263 | if self.rules is not None: |
---|
264 | if not self.rules: |
---|
265 | assert False, "All rules have been exhausted, but further statements remain" |
---|
266 | rule = self.rules[0] |
---|
267 | rule.process_execute(clauseelement, *multiparams, **params) |
---|
268 | if rule.is_consumed(): |
---|
269 | self.rules.pop(0) |
---|
270 | |
---|
271 | return result |
---|
272 | |
---|
273 | def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): |
---|
274 | result = execute(cursor, statement, parameters, context) |
---|
275 | |
---|
276 | if self.rules: |
---|
277 | rule = self.rules[0] |
---|
278 | rule.process_cursor_execute(statement, parameters, context, executemany) |
---|
279 | |
---|
280 | return result |
---|
281 | |
---|
282 | asserter = SQLAssert() |
---|
283 | |
---|