1 | """Provides a thread-local transactional wrapper around the root Engine class. |
---|
2 | |
---|
3 | The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag |
---|
4 | with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is |
---|
5 | invoked automatically when the threadlocal engine strategy is used. |
---|
6 | """ |
---|
7 | |
---|
8 | from sqlalchemy import util |
---|
9 | from sqlalchemy.engine import base |
---|
10 | |
---|
11 | class TLSession(object): |
---|
12 | def __init__(self, engine): |
---|
13 | self.engine = engine |
---|
14 | self.__tcount = 0 |
---|
15 | |
---|
16 | def get_connection(self, close_with_result=False): |
---|
17 | try: |
---|
18 | return self.__transaction._increment_connect() |
---|
19 | except AttributeError: |
---|
20 | return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) |
---|
21 | |
---|
22 | def reset(self): |
---|
23 | try: |
---|
24 | self.__transaction._force_close() |
---|
25 | del self.__transaction |
---|
26 | del self.__trans |
---|
27 | except AttributeError: |
---|
28 | pass |
---|
29 | self.__tcount = 0 |
---|
30 | |
---|
31 | def _conn_closed(self): |
---|
32 | if self.__tcount == 1: |
---|
33 | self.__trans._trans.rollback() |
---|
34 | self.reset() |
---|
35 | |
---|
36 | def in_transaction(self): |
---|
37 | return self.__tcount > 0 |
---|
38 | |
---|
39 | def prepare(self): |
---|
40 | if self.__tcount == 1: |
---|
41 | self.__trans._trans.prepare() |
---|
42 | |
---|
43 | def begin_twophase(self, xid=None): |
---|
44 | if self.__tcount == 0: |
---|
45 | self.__transaction = self.get_connection() |
---|
46 | self.__trans = self.__transaction._begin_twophase(xid=xid) |
---|
47 | self.__tcount += 1 |
---|
48 | return self.__trans |
---|
49 | |
---|
50 | def begin(self, **kwargs): |
---|
51 | if self.__tcount == 0: |
---|
52 | self.__transaction = self.get_connection() |
---|
53 | self.__trans = self.__transaction._begin(**kwargs) |
---|
54 | self.__tcount += 1 |
---|
55 | return self.__trans |
---|
56 | |
---|
57 | def rollback(self): |
---|
58 | if self.__tcount > 0: |
---|
59 | try: |
---|
60 | self.__trans._trans.rollback() |
---|
61 | finally: |
---|
62 | self.reset() |
---|
63 | |
---|
64 | def commit(self): |
---|
65 | if self.__tcount == 1: |
---|
66 | try: |
---|
67 | self.__trans._trans.commit() |
---|
68 | finally: |
---|
69 | self.reset() |
---|
70 | elif self.__tcount > 1: |
---|
71 | self.__tcount -= 1 |
---|
72 | |
---|
73 | def close(self): |
---|
74 | if self.__tcount == 1: |
---|
75 | self.rollback() |
---|
76 | elif self.__tcount > 1: |
---|
77 | self.__tcount -= 1 |
---|
78 | |
---|
79 | def is_begun(self): |
---|
80 | return self.__tcount > 0 |
---|
81 | |
---|
82 | |
---|
83 | class TLConnection(base.Connection): |
---|
84 | def __init__(self, session, connection, **kwargs): |
---|
85 | base.Connection.__init__(self, session.engine, connection, **kwargs) |
---|
86 | self.__session = session |
---|
87 | self.__opencount = 1 |
---|
88 | |
---|
89 | def _branch(self): |
---|
90 | return self.engine.Connection(self.engine, self.connection, _branch=True) |
---|
91 | |
---|
92 | def session(self): |
---|
93 | return self.__session |
---|
94 | session = property(session) |
---|
95 | |
---|
96 | def _increment_connect(self): |
---|
97 | self.__opencount += 1 |
---|
98 | return self |
---|
99 | |
---|
100 | def _begin(self, **kwargs): |
---|
101 | return TLTransaction( |
---|
102 | super(TLConnection, self).begin(**kwargs), self.__session) |
---|
103 | |
---|
104 | def _begin_twophase(self, xid=None): |
---|
105 | return TLTransaction( |
---|
106 | super(TLConnection, self).begin_twophase(xid=xid), self.__session) |
---|
107 | |
---|
108 | def in_transaction(self): |
---|
109 | return self.session.in_transaction() |
---|
110 | |
---|
111 | def begin(self, **kwargs): |
---|
112 | return self.session.begin(**kwargs) |
---|
113 | |
---|
114 | def begin_twophase(self, xid=None): |
---|
115 | return self.session.begin_twophase(xid=xid) |
---|
116 | |
---|
117 | def begin_nested(self): |
---|
118 | raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") |
---|
119 | |
---|
120 | def close(self): |
---|
121 | if self.__opencount == 1: |
---|
122 | base.Connection.close(self) |
---|
123 | self.__session._conn_closed() |
---|
124 | self.__opencount -= 1 |
---|
125 | |
---|
126 | def _force_close(self): |
---|
127 | self.__opencount = 0 |
---|
128 | base.Connection.close(self) |
---|
129 | |
---|
130 | |
---|
131 | class TLTransaction(base.Transaction): |
---|
132 | def __init__(self, trans, session): |
---|
133 | self._trans = trans |
---|
134 | self._session = session |
---|
135 | |
---|
136 | def connection(self): |
---|
137 | return self._trans.connection |
---|
138 | connection = property(connection) |
---|
139 | |
---|
140 | def is_active(self): |
---|
141 | return self._trans.is_active |
---|
142 | is_active = property(is_active) |
---|
143 | |
---|
144 | def rollback(self): |
---|
145 | self._session.rollback() |
---|
146 | |
---|
147 | def prepare(self): |
---|
148 | self._session.prepare() |
---|
149 | |
---|
150 | def commit(self): |
---|
151 | self._session.commit() |
---|
152 | |
---|
153 | def close(self): |
---|
154 | self._session.close() |
---|
155 | |
---|
156 | def __enter__(self): |
---|
157 | return self |
---|
158 | |
---|
159 | def __exit__(self, type, value, traceback): |
---|
160 | self._trans.__exit__(type, value, traceback) |
---|
161 | |
---|
162 | |
---|
163 | class TLEngine(base.Engine): |
---|
164 | """An Engine that includes support for thread-local managed transactions. |
---|
165 | |
---|
166 | The TLEngine relies upon its Pool having "threadlocal" behavior, |
---|
167 | so that once a connection is checked out for the current thread, |
---|
168 | you get that same connection repeatedly. |
---|
169 | """ |
---|
170 | |
---|
171 | def __init__(self, *args, **kwargs): |
---|
172 | """Construct a new TLEngine.""" |
---|
173 | |
---|
174 | super(TLEngine, self).__init__(*args, **kwargs) |
---|
175 | self.context = util.threading.local() |
---|
176 | |
---|
177 | proxy = kwargs.get('proxy') |
---|
178 | if proxy: |
---|
179 | self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) |
---|
180 | else: |
---|
181 | self.TLConnection = TLConnection |
---|
182 | |
---|
183 | def session(self): |
---|
184 | "Returns the current thread's TLSession" |
---|
185 | if not hasattr(self.context, 'session'): |
---|
186 | self.context.session = TLSession(self) |
---|
187 | return self.context.session |
---|
188 | |
---|
189 | session = property(session) |
---|
190 | |
---|
191 | def contextual_connect(self, **kwargs): |
---|
192 | """Return a TLConnection which is thread-locally scoped.""" |
---|
193 | |
---|
194 | return self.session.get_connection(**kwargs) |
---|
195 | |
---|
196 | def begin_twophase(self, **kwargs): |
---|
197 | return self.session.begin_twophase(**kwargs) |
---|
198 | |
---|
199 | def begin_nested(self): |
---|
200 | raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") |
---|
201 | |
---|
202 | def begin(self, **kwargs): |
---|
203 | return self.session.begin(**kwargs) |
---|
204 | |
---|
205 | def prepare(self): |
---|
206 | self.session.prepare() |
---|
207 | |
---|
208 | def commit(self): |
---|
209 | self.session.commit() |
---|
210 | |
---|
211 | def rollback(self): |
---|
212 | self.session.rollback() |
---|
213 | |
---|
214 | def __repr__(self): |
---|
215 | return 'TLEngine(%s)' % str(self.url) |
---|