root/galaxy-central/eggs/SQLAlchemy-0.5.6_dev_r6498-py2.6.egg/sqlalchemy/test/engines.py @ 3

リビジョン 3, 7.8 KB (コミッタ: kohda, 14 年 前)

Install Unix tools  http://hannonlab.cshl.edu/galaxy_unix_tools/galaxy.html

行番号 
1import sys, types, weakref
2from collections import deque
3import config
4from sqlalchemy.util import function_named, callable
5
6class ConnectionKiller(object):
7    def __init__(self):
8        self.proxy_refs = weakref.WeakKeyDictionary()
9
10    def checkout(self, dbapi_con, con_record, con_proxy):
11        self.proxy_refs[con_proxy] = True
12
13    def _apply_all(self, methods):
14        for rec in self.proxy_refs:
15            if rec is not None and rec.is_valid:
16                try:
17                    for name in methods:
18                        if callable(name):
19                            name(rec)
20                        else:
21                            getattr(rec, name)()
22                except (SystemExit, KeyboardInterrupt):
23                    raise
24                except Exception, e:
25                    # fixme
26                    sys.stderr.write("\n" + str(e) + "\n")
27
28    def rollback_all(self):
29        self._apply_all(('rollback',))
30
31    def close_all(self):
32        self._apply_all(('rollback', 'close'))
33
34    def assert_all_closed(self):
35        for rec in self.proxy_refs:
36            if rec.is_valid:
37                assert False
38
39testing_reaper = ConnectionKiller()
40
41def assert_conns_closed(fn):
42    def decorated(*args, **kw):
43        try:
44            fn(*args, **kw)
45        finally:
46            testing_reaper.assert_all_closed()
47    return function_named(decorated, fn.__name__)
48
49def rollback_open_connections(fn):
50    """Decorator that rolls back all open connections after fn execution."""
51
52    def decorated(*args, **kw):
53        try:
54            fn(*args, **kw)
55        finally:
56            testing_reaper.rollback_all()
57    return function_named(decorated, fn.__name__)
58
59def close_open_connections(fn):
60    """Decorator that closes all connections after fn execution."""
61
62    def decorated(*args, **kw):
63        try:
64            fn(*args, **kw)
65        finally:
66            testing_reaper.close_all()
67    return function_named(decorated, fn.__name__)
68
69def all_dialects():
70    import sqlalchemy.databases as d
71    for name in d.__all__:
72        mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
73        yield mod.dialect()
74       
75class ReconnectFixture(object):
76    def __init__(self, dbapi):
77        self.dbapi = dbapi
78        self.connections = []
79
80    def __getattr__(self, key):
81        return getattr(self.dbapi, key)
82
83    def connect(self, *args, **kwargs):
84        conn = self.dbapi.connect(*args, **kwargs)
85        self.connections.append(conn)
86        return conn
87
88    def shutdown(self):
89        for c in list(self.connections):
90            c.close()
91        self.connections = []
92
93def reconnecting_engine(url=None, options=None):
94    url = url or config.db_url
95    dbapi = config.db.dialect.dbapi
96    if not options:
97        options = {}
98    options['module'] = ReconnectFixture(dbapi)
99    engine = testing_engine(url, options)
100    engine.test_shutdown = engine.dialect.dbapi.shutdown
101    return engine
102
103def testing_engine(url=None, options=None):
104    """Produce an engine configured by --options with optional overrides."""
105
106    from sqlalchemy import create_engine
107    from sqlalchemy.test.assertsql import asserter
108
109    url = url or config.db_url
110    options = options or config.db_opts
111
112    options.setdefault('proxy', asserter)
113   
114    listeners = options.setdefault('listeners', [])
115    listeners.append(testing_reaper)
116
117    engine = create_engine(url, **options)
118
119    return engine
120
121def utf8_engine(url=None, options=None):
122    """Hook for dialects or drivers that don't handle utf8 by default."""
123
124    from sqlalchemy.engine import url as engine_url
125
126    if config.db.name == 'mysql':
127        dbapi_ver = config.db.dialect.dbapi.version_info
128        if (dbapi_ver < (1, 2, 1) or
129            dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
130                          (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
131            raise RuntimeError('Character set support unavailable with this '
132                               'driver version: %s' % repr(dbapi_ver))
133        else:
134            url = url or config.db_url
135            url = engine_url.make_url(url)
136            url.query['charset'] = 'utf8'
137            url.query['use_unicode'] = '0'
138            url = str(url)
139
140    return testing_engine(url, options)
141
142def mock_engine(db=None):
143    """Provides a mocking engine based on the current testing.db."""
144   
145    from sqlalchemy import create_engine
146   
147    dbi = db or config.db
148    buffer = []
149    def executor(sql, *a, **kw):
150        buffer.append(sql)
151    engine = create_engine(dbi.name + '://',
152                           strategy='mock', executor=executor)
153    assert not hasattr(engine, 'mock')
154    engine.mock = buffer
155    return engine
156
157class ReplayableSession(object):
158    """A simple record/playback tool.
159
160    This is *not* a mock testing class.  It only records a session for later
161    playback and makes no assertions on call consistency whatsoever.  It's
162    unlikely to be suitable for anything other than DB-API recording.
163
164    """
165
166    Callable = object()
167    NoAttribute = object()
168    Natives = set([getattr(types, t)
169                   for t in dir(types) if not t.startswith('_')]). \
170                   difference([getattr(types, t)
171                               for t in ('FunctionType', 'BuiltinFunctionType',
172                                         'MethodType', 'BuiltinMethodType',
173                                         'LambdaType', 'UnboundMethodType',)])
174    def __init__(self):
175        self.buffer = deque()
176
177    def recorder(self, base):
178        return self.Recorder(self.buffer, base)
179
180    def player(self):
181        return self.Player(self.buffer)
182
183    class Recorder(object):
184        def __init__(self, buffer, subject):
185            self._buffer = buffer
186            self._subject = subject
187
188        def __call__(self, *args, **kw):
189            subject, buffer = [object.__getattribute__(self, x)
190                               for x in ('_subject', '_buffer')]
191
192            result = subject(*args, **kw)
193            if type(result) not in ReplayableSession.Natives:
194                buffer.append(ReplayableSession.Callable)
195                return type(self)(buffer, result)
196            else:
197                buffer.append(result)
198                return result
199
200        def __getattribute__(self, key):
201            try:
202                return object.__getattribute__(self, key)
203            except AttributeError:
204                pass
205
206            subject, buffer = [object.__getattribute__(self, x)
207                               for x in ('_subject', '_buffer')]
208            try:
209                result = type(subject).__getattribute__(subject, key)
210            except AttributeError:
211                buffer.append(ReplayableSession.NoAttribute)
212                raise
213            else:
214                if type(result) not in ReplayableSession.Natives:
215                    buffer.append(ReplayableSession.Callable)
216                    return type(self)(buffer, result)
217                else:
218                    buffer.append(result)
219                    return result
220
221    class Player(object):
222        def __init__(self, buffer):
223            self._buffer = buffer
224
225        def __call__(self, *args, **kw):
226            buffer = object.__getattribute__(self, '_buffer')
227            result = buffer.popleft()
228            if result is ReplayableSession.Callable:
229                return self
230            else:
231                return result
232
233        def __getattribute__(self, key):
234            try:
235                return object.__getattribute__(self, key)
236            except AttributeError:
237                pass
238            buffer = object.__getattribute__(self, '_buffer')
239            result = buffer.popleft()
240            if result is ReplayableSession.Callable:
241                return self
242            elif result is ReplayableSession.NoAttribute:
243                raise AttributeError(key)
244            else:
245                return result
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。