root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/test/assertsql.py @ 3

リビジョン 3, 9.2 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1
2from sqlalchemy.interfaces import ConnectionProxy
3from sqlalchemy.engine.default import DefaultDialect
4from sqlalchemy.engine.base import Connection
5from sqlalchemy import util
6import testing
7import re
8
9class AssertRule(object):
10    def process_execute(self, clauseelement, *multiparams, **params):
11        pass
12
13    def process_cursor_execute(self, statement, parameters, context, executemany):
14        pass
15       
16    def is_consumed(self):
17        """Return True if this rule has been consumed, False if not.
18       
19        Should raise an AssertionError if this rule's condition has definitely failed.
20       
21        """
22        raise NotImplementedError()
23   
24    def rule_passed(self):
25        """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
26       
27        raise NotImplementedError()
28       
29    def consume_final(self):
30        """Return True if this rule has been consumed.
31       
32        Should raise an AssertionError if this rule's condition has not been consumed or has failed.
33       
34        """
35       
36        if self._result is None:
37            assert False, "Rule has not been consumed"
38           
39        return self.is_consumed()
40
41class SQLMatchRule(AssertRule):
42    def __init__(self):
43        self._result = None
44        self._errmsg = ""
45   
46    def rule_passed(self):
47        return self._result
48       
49    def is_consumed(self):
50        if self._result is None:
51            return False
52           
53        assert self._result, self._errmsg
54       
55        return True
56   
57class ExactSQL(SQLMatchRule):
58    def __init__(self, sql, params=None):
59        SQLMatchRule.__init__(self)
60        self.sql = sql
61        self.params = params
62   
63    def process_cursor_execute(self, statement, parameters, context, executemany):
64        if not context:
65            return
66           
67        _received_statement = _process_engine_statement(statement, context)
68        _received_parameters = context.compiled_parameters
69       
70        # TODO: remove this step once all unit tests
71        # are migrated, as ExactSQL should really be *exact* SQL
72        sql = _process_assertion_statement(self.sql, context)
73       
74        equivalent = _received_statement == sql
75        if self.params:
76            if util.callable(self.params):
77                params = self.params(context)
78            else:
79                params = self.params
80
81            if not isinstance(params, list):
82                params = [params]
83            equivalent = equivalent and params == context.compiled_parameters
84        else:
85            params = {}
86       
87       
88        self._result = equivalent
89        if not self._result:
90            self._errmsg = "Testing for exact statement %r exact params %r, " \
91                "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
92   
93
94class RegexSQL(SQLMatchRule):
95    def __init__(self, regex, params=None):
96        SQLMatchRule.__init__(self)
97        self.regex = re.compile(regex)
98        self.orig_regex = regex
99        self.params = params
100
101    def process_cursor_execute(self, statement, parameters, context, executemany):
102        if not context:
103            return
104
105        _received_statement = _process_engine_statement(statement, context)
106        _received_parameters = context.compiled_parameters
107
108        equivalent = bool(self.regex.match(_received_statement))
109        if self.params:
110            if util.callable(self.params):
111                params = self.params(context)
112            else:
113                params = self.params
114
115            if not isinstance(params, list):
116                params = [params]
117           
118            # do a positive compare only
119            for param, received in zip(params, _received_parameters):
120                for k, v in param.iteritems():
121                    if k not in received or received[k] != v:
122                        equivalent = False
123                        break
124        else:
125            params = {}
126
127        self._result = equivalent
128        if not self._result:
129            self._errmsg = "Testing for regex %r partial params %r, "\
130                "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
131
132class CompiledSQL(SQLMatchRule):
133    def __init__(self, statement, params):
134        SQLMatchRule.__init__(self)
135        self.statement = statement
136        self.params = params
137
138    def process_cursor_execute(self, statement, parameters, context, executemany):
139        if not context:
140            return
141
142        _received_parameters = context.compiled_parameters
143       
144        # recompile from the context, using the default dialect
145        compiled = context.compiled.statement.\
146                compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
147               
148        _received_statement = re.sub(r'\n', '', str(compiled))
149       
150        equivalent = self.statement == _received_statement
151        if self.params:
152            if util.callable(self.params):
153                params = self.params(context)
154            else:
155                params = self.params
156
157            if not isinstance(params, list):
158                params = [params]
159           
160            # do a positive compare only
161            for param, received in zip(params, _received_parameters):
162                for k, v in param.iteritems():
163                    if k not in received or received[k] != v:
164                        equivalent = False
165                        break
166        else:
167            params = {}
168
169        self._result = equivalent
170        if not self._result:
171            self._errmsg = "Testing for compiled statement %r partial params %r, " \
172                    "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
173   
174       
175class CountStatements(AssertRule):
176    def __init__(self, count):
177        self.count = count
178        self._statement_count = 0
179       
180    def process_execute(self, clauseelement, *multiparams, **params):
181        self._statement_count += 1
182
183    def process_cursor_execute(self, statement, parameters, context, executemany):
184        pass
185
186    def is_consumed(self):
187        return False
188   
189    def consume_final(self):
190        assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
191        return True
192       
193class AllOf(AssertRule):
194    def __init__(self, *rules):
195        self.rules = set(rules)
196       
197    def process_execute(self, clauseelement, *multiparams, **params):
198        for rule in self.rules:
199            rule.process_execute(clauseelement, *multiparams, **params)
200
201    def process_cursor_execute(self, statement, parameters, context, executemany):
202        for rule in self.rules:
203            rule.process_cursor_execute(statement, parameters, context, executemany)
204
205    def is_consumed(self):
206        if not self.rules:
207            return True
208       
209        for rule in list(self.rules):
210            if rule.rule_passed(): # a rule passed, move on
211                self.rules.remove(rule)
212                return len(self.rules) == 0
213
214        assert False, "No assertion rules were satisfied for statement"
215   
216    def consume_final(self):
217        return len(self.rules) == 0
218       
219def _process_engine_statement(query, context):
220    if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
221        query = query[:-25]
222   
223    query = re.sub(r'\n', '', query)
224   
225    return query
226   
227def _process_assertion_statement(query, context):
228    paramstyle = context.dialect.paramstyle
229    if paramstyle == 'named':
230        pass
231    elif paramstyle =='pyformat':
232        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
233    else:
234        # positional params
235        repl = None
236        if paramstyle=='qmark':
237            repl = "?"
238        elif paramstyle=='format':
239            repl = r"%s"
240        elif paramstyle=='numeric':
241            repl = None
242        query = re.sub(r':([\w_]+)', repl, query)
243
244    return query
245
246class SQLAssert(ConnectionProxy):
247    rules = None
248   
249    def add_rules(self, rules):
250        self.rules = list(rules)
251   
252    def statement_complete(self):
253        for rule in self.rules:
254            if not rule.consume_final():
255                assert False, "All statements are complete, but pending assertion rules remain"
256
257    def clear_rules(self):
258        del self.rules
259       
260    def execute(self, conn, execute, clauseelement, *multiparams, **params):
261        result = execute(clauseelement, *multiparams, **params)
262
263        if self.rules is not None:
264            if not self.rules:
265                assert False, "All rules have been exhausted, but further statements remain"
266            rule = self.rules[0]
267            rule.process_execute(clauseelement, *multiparams, **params)
268            if rule.is_consumed():
269                self.rules.pop(0)
270           
271        return result
272       
273    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
274        result = execute(cursor, statement, parameters, context)
275       
276        if self.rules:
277            rule = self.rules[0]
278            rule.process_cursor_execute(statement, parameters, context, executemany)
279
280        return result
281
282asserter = SQLAssert()
283   
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。