1 | import operator |
---|
2 | from sqlalchemy.sql import operators, functions |
---|
3 | from sqlalchemy.sql import expression as sql |
---|
4 | |
---|
5 | |
---|
6 | class UnevaluatableError(Exception): |
---|
7 | pass |
---|
8 | |
---|
9 | _straight_ops = set(getattr(operators, op) |
---|
10 | for op in ('add', 'mul', 'sub', 'div', 'mod', 'truediv', |
---|
11 | 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) |
---|
12 | |
---|
13 | |
---|
14 | _notimplemented_ops = set(getattr(operators, op) |
---|
15 | for op in ('like_op', 'notlike_op', 'ilike_op', |
---|
16 | 'notilike_op', 'between_op', 'in_op', |
---|
17 | 'notin_op', 'endswith_op', 'concat_op')) |
---|
18 | |
---|
19 | class EvaluatorCompiler(object): |
---|
20 | def process(self, clause): |
---|
21 | meth = getattr(self, "visit_%s" % clause.__visit_name__, None) |
---|
22 | if not meth: |
---|
23 | raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) |
---|
24 | return meth(clause) |
---|
25 | |
---|
26 | def visit_grouping(self, clause): |
---|
27 | return self.process(clause.element) |
---|
28 | |
---|
29 | def visit_null(self, clause): |
---|
30 | return lambda obj: None |
---|
31 | |
---|
32 | def visit_column(self, clause): |
---|
33 | if 'parentmapper' in clause._annotations: |
---|
34 | key = clause._annotations['parentmapper']._get_col_to_prop(clause).key |
---|
35 | else: |
---|
36 | key = clause.key |
---|
37 | get_corresponding_attr = operator.attrgetter(key) |
---|
38 | return lambda obj: get_corresponding_attr(obj) |
---|
39 | |
---|
40 | def visit_clauselist(self, clause): |
---|
41 | evaluators = map(self.process, clause.clauses) |
---|
42 | if clause.operator is operators.or_: |
---|
43 | def evaluate(obj): |
---|
44 | has_null = False |
---|
45 | for sub_evaluate in evaluators: |
---|
46 | value = sub_evaluate(obj) |
---|
47 | if value: |
---|
48 | return True |
---|
49 | has_null = has_null or value is None |
---|
50 | if has_null: |
---|
51 | return None |
---|
52 | return False |
---|
53 | if clause.operator is operators.and_: |
---|
54 | def evaluate(obj): |
---|
55 | for sub_evaluate in evaluators: |
---|
56 | value = sub_evaluate(obj) |
---|
57 | if not value: |
---|
58 | if value is None: |
---|
59 | return None |
---|
60 | return False |
---|
61 | return True |
---|
62 | |
---|
63 | return evaluate |
---|
64 | |
---|
65 | def visit_binary(self, clause): |
---|
66 | eval_left,eval_right = map(self.process, [clause.left, clause.right]) |
---|
67 | operator = clause.operator |
---|
68 | if operator is operators.is_: |
---|
69 | def evaluate(obj): |
---|
70 | return eval_left(obj) == eval_right(obj) |
---|
71 | elif operator is operators.isnot: |
---|
72 | def evaluate(obj): |
---|
73 | return eval_left(obj) != eval_right(obj) |
---|
74 | elif operator in _straight_ops: |
---|
75 | def evaluate(obj): |
---|
76 | left_val = eval_left(obj) |
---|
77 | right_val = eval_right(obj) |
---|
78 | if left_val is None or right_val is None: |
---|
79 | return None |
---|
80 | return operator(eval_left(obj), eval_right(obj)) |
---|
81 | else: |
---|
82 | raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) |
---|
83 | return evaluate |
---|
84 | |
---|
85 | def visit_unary(self, clause): |
---|
86 | eval_inner = self.process(clause.element) |
---|
87 | if clause.operator is operators.inv: |
---|
88 | def evaluate(obj): |
---|
89 | value = eval_inner(obj) |
---|
90 | if value is None: |
---|
91 | return None |
---|
92 | return not value |
---|
93 | return evaluate |
---|
94 | raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) |
---|
95 | |
---|
96 | def visit_bindparam(self, clause): |
---|
97 | val = clause.value |
---|
98 | return lambda obj: val |
---|