| 1 | |
|---|
| 2 | from sqlalchemy.interfaces import ConnectionProxy |
|---|
| 3 | from sqlalchemy.engine.default import DefaultDialect |
|---|
| 4 | from sqlalchemy.engine.base import Connection |
|---|
| 5 | from sqlalchemy import util |
|---|
| 6 | import testing |
|---|
| 7 | import re |
|---|
| 8 | |
|---|
| 9 | class AssertRule(object): |
|---|
| 10 | def process_execute(self, clauseelement, *multiparams, **params): |
|---|
| 11 | pass |
|---|
| 12 | |
|---|
| 13 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 14 | pass |
|---|
| 15 | |
|---|
| 16 | def is_consumed(self): |
|---|
| 17 | """Return True if this rule has been consumed, False if not. |
|---|
| 18 | |
|---|
| 19 | Should raise an AssertionError if this rule's condition has definitely failed. |
|---|
| 20 | |
|---|
| 21 | """ |
|---|
| 22 | raise NotImplementedError() |
|---|
| 23 | |
|---|
| 24 | def rule_passed(self): |
|---|
| 25 | """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" |
|---|
| 26 | |
|---|
| 27 | raise NotImplementedError() |
|---|
| 28 | |
|---|
| 29 | def consume_final(self): |
|---|
| 30 | """Return True if this rule has been consumed. |
|---|
| 31 | |
|---|
| 32 | Should raise an AssertionError if this rule's condition has not been consumed or has failed. |
|---|
| 33 | |
|---|
| 34 | """ |
|---|
| 35 | |
|---|
| 36 | if self._result is None: |
|---|
| 37 | assert False, "Rule has not been consumed" |
|---|
| 38 | |
|---|
| 39 | return self.is_consumed() |
|---|
| 40 | |
|---|
| 41 | class SQLMatchRule(AssertRule): |
|---|
| 42 | def __init__(self): |
|---|
| 43 | self._result = None |
|---|
| 44 | self._errmsg = "" |
|---|
| 45 | |
|---|
| 46 | def rule_passed(self): |
|---|
| 47 | return self._result |
|---|
| 48 | |
|---|
| 49 | def is_consumed(self): |
|---|
| 50 | if self._result is None: |
|---|
| 51 | return False |
|---|
| 52 | |
|---|
| 53 | assert self._result, self._errmsg |
|---|
| 54 | |
|---|
| 55 | return True |
|---|
| 56 | |
|---|
| 57 | class ExactSQL(SQLMatchRule): |
|---|
| 58 | def __init__(self, sql, params=None): |
|---|
| 59 | SQLMatchRule.__init__(self) |
|---|
| 60 | self.sql = sql |
|---|
| 61 | self.params = params |
|---|
| 62 | |
|---|
| 63 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 64 | if not context: |
|---|
| 65 | return |
|---|
| 66 | |
|---|
| 67 | _received_statement = _process_engine_statement(statement, context) |
|---|
| 68 | _received_parameters = context.compiled_parameters |
|---|
| 69 | |
|---|
| 70 | # TODO: remove this step once all unit tests |
|---|
| 71 | # are migrated, as ExactSQL should really be *exact* SQL |
|---|
| 72 | sql = _process_assertion_statement(self.sql, context) |
|---|
| 73 | |
|---|
| 74 | equivalent = _received_statement == sql |
|---|
| 75 | if self.params: |
|---|
| 76 | if util.callable(self.params): |
|---|
| 77 | params = self.params(context) |
|---|
| 78 | else: |
|---|
| 79 | params = self.params |
|---|
| 80 | |
|---|
| 81 | if not isinstance(params, list): |
|---|
| 82 | params = [params] |
|---|
| 83 | equivalent = equivalent and params == context.compiled_parameters |
|---|
| 84 | else: |
|---|
| 85 | params = {} |
|---|
| 86 | |
|---|
| 87 | |
|---|
| 88 | self._result = equivalent |
|---|
| 89 | if not self._result: |
|---|
| 90 | self._errmsg = "Testing for exact statement %r exact params %r, " \ |
|---|
| 91 | "received %r with params %r" % (sql, params, _received_statement, _received_parameters) |
|---|
| 92 | |
|---|
| 93 | |
|---|
| 94 | class RegexSQL(SQLMatchRule): |
|---|
| 95 | def __init__(self, regex, params=None): |
|---|
| 96 | SQLMatchRule.__init__(self) |
|---|
| 97 | self.regex = re.compile(regex) |
|---|
| 98 | self.orig_regex = regex |
|---|
| 99 | self.params = params |
|---|
| 100 | |
|---|
| 101 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 102 | if not context: |
|---|
| 103 | return |
|---|
| 104 | |
|---|
| 105 | _received_statement = _process_engine_statement(statement, context) |
|---|
| 106 | _received_parameters = context.compiled_parameters |
|---|
| 107 | |
|---|
| 108 | equivalent = bool(self.regex.match(_received_statement)) |
|---|
| 109 | if self.params: |
|---|
| 110 | if util.callable(self.params): |
|---|
| 111 | params = self.params(context) |
|---|
| 112 | else: |
|---|
| 113 | params = self.params |
|---|
| 114 | |
|---|
| 115 | if not isinstance(params, list): |
|---|
| 116 | params = [params] |
|---|
| 117 | |
|---|
| 118 | # do a positive compare only |
|---|
| 119 | for param, received in zip(params, _received_parameters): |
|---|
| 120 | for k, v in param.iteritems(): |
|---|
| 121 | if k not in received or received[k] != v: |
|---|
| 122 | equivalent = False |
|---|
| 123 | break |
|---|
| 124 | else: |
|---|
| 125 | params = {} |
|---|
| 126 | |
|---|
| 127 | self._result = equivalent |
|---|
| 128 | if not self._result: |
|---|
| 129 | self._errmsg = "Testing for regex %r partial params %r, "\ |
|---|
| 130 | "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) |
|---|
| 131 | |
|---|
| 132 | class CompiledSQL(SQLMatchRule): |
|---|
| 133 | def __init__(self, statement, params): |
|---|
| 134 | SQLMatchRule.__init__(self) |
|---|
| 135 | self.statement = statement |
|---|
| 136 | self.params = params |
|---|
| 137 | |
|---|
| 138 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 139 | if not context: |
|---|
| 140 | return |
|---|
| 141 | |
|---|
| 142 | _received_parameters = context.compiled_parameters |
|---|
| 143 | |
|---|
| 144 | # recompile from the context, using the default dialect |
|---|
| 145 | compiled = context.compiled.statement.\ |
|---|
| 146 | compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) |
|---|
| 147 | |
|---|
| 148 | _received_statement = re.sub(r'\n', '', str(compiled)) |
|---|
| 149 | |
|---|
| 150 | equivalent = self.statement == _received_statement |
|---|
| 151 | if self.params: |
|---|
| 152 | if util.callable(self.params): |
|---|
| 153 | params = self.params(context) |
|---|
| 154 | else: |
|---|
| 155 | params = self.params |
|---|
| 156 | |
|---|
| 157 | if not isinstance(params, list): |
|---|
| 158 | params = [params] |
|---|
| 159 | |
|---|
| 160 | # do a positive compare only |
|---|
| 161 | for param, received in zip(params, _received_parameters): |
|---|
| 162 | for k, v in param.iteritems(): |
|---|
| 163 | if k not in received or received[k] != v: |
|---|
| 164 | equivalent = False |
|---|
| 165 | break |
|---|
| 166 | else: |
|---|
| 167 | params = {} |
|---|
| 168 | |
|---|
| 169 | self._result = equivalent |
|---|
| 170 | if not self._result: |
|---|
| 171 | self._errmsg = "Testing for compiled statement %r partial params %r, " \ |
|---|
| 172 | "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) |
|---|
| 173 | |
|---|
| 174 | |
|---|
| 175 | class CountStatements(AssertRule): |
|---|
| 176 | def __init__(self, count): |
|---|
| 177 | self.count = count |
|---|
| 178 | self._statement_count = 0 |
|---|
| 179 | |
|---|
| 180 | def process_execute(self, clauseelement, *multiparams, **params): |
|---|
| 181 | self._statement_count += 1 |
|---|
| 182 | |
|---|
| 183 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 184 | pass |
|---|
| 185 | |
|---|
| 186 | def is_consumed(self): |
|---|
| 187 | return False |
|---|
| 188 | |
|---|
| 189 | def consume_final(self): |
|---|
| 190 | assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) |
|---|
| 191 | return True |
|---|
| 192 | |
|---|
| 193 | class AllOf(AssertRule): |
|---|
| 194 | def __init__(self, *rules): |
|---|
| 195 | self.rules = set(rules) |
|---|
| 196 | |
|---|
| 197 | def process_execute(self, clauseelement, *multiparams, **params): |
|---|
| 198 | for rule in self.rules: |
|---|
| 199 | rule.process_execute(clauseelement, *multiparams, **params) |
|---|
| 200 | |
|---|
| 201 | def process_cursor_execute(self, statement, parameters, context, executemany): |
|---|
| 202 | for rule in self.rules: |
|---|
| 203 | rule.process_cursor_execute(statement, parameters, context, executemany) |
|---|
| 204 | |
|---|
| 205 | def is_consumed(self): |
|---|
| 206 | if not self.rules: |
|---|
| 207 | return True |
|---|
| 208 | |
|---|
| 209 | for rule in list(self.rules): |
|---|
| 210 | if rule.rule_passed(): # a rule passed, move on |
|---|
| 211 | self.rules.remove(rule) |
|---|
| 212 | return len(self.rules) == 0 |
|---|
| 213 | |
|---|
| 214 | assert False, "No assertion rules were satisfied for statement" |
|---|
| 215 | |
|---|
| 216 | def consume_final(self): |
|---|
| 217 | return len(self.rules) == 0 |
|---|
| 218 | |
|---|
| 219 | def _process_engine_statement(query, context): |
|---|
| 220 | if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): |
|---|
| 221 | query = query[:-25] |
|---|
| 222 | |
|---|
| 223 | query = re.sub(r'\n', '', query) |
|---|
| 224 | |
|---|
| 225 | return query |
|---|
| 226 | |
|---|
| 227 | def _process_assertion_statement(query, context): |
|---|
| 228 | paramstyle = context.dialect.paramstyle |
|---|
| 229 | if paramstyle == 'named': |
|---|
| 230 | pass |
|---|
| 231 | elif paramstyle =='pyformat': |
|---|
| 232 | query = re.sub(r':([\w_]+)', r"%(\1)s", query) |
|---|
| 233 | else: |
|---|
| 234 | # positional params |
|---|
| 235 | repl = None |
|---|
| 236 | if paramstyle=='qmark': |
|---|
| 237 | repl = "?" |
|---|
| 238 | elif paramstyle=='format': |
|---|
| 239 | repl = r"%s" |
|---|
| 240 | elif paramstyle=='numeric': |
|---|
| 241 | repl = None |
|---|
| 242 | query = re.sub(r':([\w_]+)', repl, query) |
|---|
| 243 | |
|---|
| 244 | return query |
|---|
| 245 | |
|---|
| 246 | class SQLAssert(ConnectionProxy): |
|---|
| 247 | rules = None |
|---|
| 248 | |
|---|
| 249 | def add_rules(self, rules): |
|---|
| 250 | self.rules = list(rules) |
|---|
| 251 | |
|---|
| 252 | def statement_complete(self): |
|---|
| 253 | for rule in self.rules: |
|---|
| 254 | if not rule.consume_final(): |
|---|
| 255 | assert False, "All statements are complete, but pending assertion rules remain" |
|---|
| 256 | |
|---|
| 257 | def clear_rules(self): |
|---|
| 258 | del self.rules |
|---|
| 259 | |
|---|
| 260 | def execute(self, conn, execute, clauseelement, *multiparams, **params): |
|---|
| 261 | result = execute(clauseelement, *multiparams, **params) |
|---|
| 262 | |
|---|
| 263 | if self.rules is not None: |
|---|
| 264 | if not self.rules: |
|---|
| 265 | assert False, "All rules have been exhausted, but further statements remain" |
|---|
| 266 | rule = self.rules[0] |
|---|
| 267 | rule.process_execute(clauseelement, *multiparams, **params) |
|---|
| 268 | if rule.is_consumed(): |
|---|
| 269 | self.rules.pop(0) |
|---|
| 270 | |
|---|
| 271 | return result |
|---|
| 272 | |
|---|
| 273 | def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): |
|---|
| 274 | result = execute(cursor, statement, parameters, context) |
|---|
| 275 | |
|---|
| 276 | if self.rules: |
|---|
| 277 | rule = self.rules[0] |
|---|
| 278 | rule.process_cursor_execute(statement, parameters, context, executemany) |
|---|
| 279 | |
|---|
| 280 | return result |
|---|
| 281 | |
|---|
| 282 | asserter = SQLAssert() |
|---|
| 283 | |
|---|