[3] | 1 | from sqlalchemy import types as sqltypes |
---|
| 2 | from sqlalchemy.sql.expression import ( |
---|
| 3 | ClauseList, Function, _literal_as_binds, text |
---|
| 4 | ) |
---|
| 5 | from sqlalchemy.sql import operators |
---|
| 6 | from sqlalchemy.sql.visitors import VisitableType |
---|
| 7 | |
---|
| 8 | class _GenericMeta(VisitableType): |
---|
| 9 | def __call__(self, *args, **kwargs): |
---|
| 10 | args = [_literal_as_binds(c) for c in args] |
---|
| 11 | return type.__call__(self, *args, **kwargs) |
---|
| 12 | |
---|
| 13 | class GenericFunction(Function): |
---|
| 14 | __metaclass__ = _GenericMeta |
---|
| 15 | |
---|
| 16 | def __init__(self, type_=None, args=(), **kwargs): |
---|
| 17 | self.packagenames = [] |
---|
| 18 | self.name = self.__class__.__name__ |
---|
| 19 | self._bind = kwargs.get('bind', None) |
---|
| 20 | self.clause_expr = ClauseList( |
---|
| 21 | operator=operators.comma_op, |
---|
| 22 | group_contents=True, *args).self_group() |
---|
| 23 | self.type = sqltypes.to_instance( |
---|
| 24 | type_ or getattr(self, '__return_type__', None)) |
---|
| 25 | |
---|
| 26 | class AnsiFunction(GenericFunction): |
---|
| 27 | def __init__(self, **kwargs): |
---|
| 28 | GenericFunction.__init__(self, **kwargs) |
---|
| 29 | |
---|
| 30 | class ReturnTypeFromArgs(GenericFunction): |
---|
| 31 | """Define a function whose return type is the same as its arguments.""" |
---|
| 32 | |
---|
| 33 | def __init__(self, *args, **kwargs): |
---|
| 34 | kwargs.setdefault('type_', _type_from_args(args)) |
---|
| 35 | GenericFunction.__init__(self, args=args, **kwargs) |
---|
| 36 | |
---|
| 37 | class coalesce(ReturnTypeFromArgs): |
---|
| 38 | pass |
---|
| 39 | |
---|
| 40 | class max(ReturnTypeFromArgs): |
---|
| 41 | pass |
---|
| 42 | |
---|
| 43 | class min(ReturnTypeFromArgs): |
---|
| 44 | pass |
---|
| 45 | |
---|
| 46 | class sum(ReturnTypeFromArgs): |
---|
| 47 | pass |
---|
| 48 | |
---|
| 49 | class now(GenericFunction): |
---|
| 50 | __return_type__ = sqltypes.DateTime |
---|
| 51 | |
---|
| 52 | class concat(GenericFunction): |
---|
| 53 | __return_type__ = sqltypes.String |
---|
| 54 | def __init__(self, *args, **kwargs): |
---|
| 55 | GenericFunction.__init__(self, args=args, **kwargs) |
---|
| 56 | |
---|
| 57 | class char_length(GenericFunction): |
---|
| 58 | __return_type__ = sqltypes.Integer |
---|
| 59 | |
---|
| 60 | def __init__(self, arg, **kwargs): |
---|
| 61 | GenericFunction.__init__(self, args=[arg], **kwargs) |
---|
| 62 | |
---|
| 63 | class random(GenericFunction): |
---|
| 64 | def __init__(self, *args, **kwargs): |
---|
| 65 | kwargs.setdefault('type_', None) |
---|
| 66 | GenericFunction.__init__(self, args=args, **kwargs) |
---|
| 67 | |
---|
| 68 | class count(GenericFunction): |
---|
| 69 | """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" |
---|
| 70 | |
---|
| 71 | __return_type__ = sqltypes.Integer |
---|
| 72 | |
---|
| 73 | def __init__(self, expression=None, **kwargs): |
---|
| 74 | if expression is None: |
---|
| 75 | expression = text('*') |
---|
| 76 | GenericFunction.__init__(self, args=(expression,), **kwargs) |
---|
| 77 | |
---|
| 78 | class current_date(AnsiFunction): |
---|
| 79 | __return_type__ = sqltypes.Date |
---|
| 80 | |
---|
| 81 | class current_time(AnsiFunction): |
---|
| 82 | __return_type__ = sqltypes.Time |
---|
| 83 | |
---|
| 84 | class current_timestamp(AnsiFunction): |
---|
| 85 | __return_type__ = sqltypes.DateTime |
---|
| 86 | |
---|
| 87 | class current_user(AnsiFunction): |
---|
| 88 | __return_type__ = sqltypes.String |
---|
| 89 | |
---|
| 90 | class localtime(AnsiFunction): |
---|
| 91 | __return_type__ = sqltypes.DateTime |
---|
| 92 | |
---|
| 93 | class localtimestamp(AnsiFunction): |
---|
| 94 | __return_type__ = sqltypes.DateTime |
---|
| 95 | |
---|
| 96 | class session_user(AnsiFunction): |
---|
| 97 | __return_type__ = sqltypes.String |
---|
| 98 | |
---|
| 99 | class sysdate(AnsiFunction): |
---|
| 100 | __return_type__ = sqltypes.DateTime |
---|
| 101 | |
---|
| 102 | class user(AnsiFunction): |
---|
| 103 | __return_type__ = sqltypes.String |
---|
| 104 | |
---|
| 105 | def _type_from_args(args): |
---|
| 106 | for a in args: |
---|
| 107 | if not isinstance(a.type, sqltypes.NullType): |
---|
| 108 | return a.type |
---|
| 109 | else: |
---|
| 110 | return sqltypes.NullType |
---|