| 1 | from sqlalchemy import types as sqltypes |
|---|
| 2 | from sqlalchemy.sql.expression import ( |
|---|
| 3 | ClauseList, Function, _literal_as_binds, text |
|---|
| 4 | ) |
|---|
| 5 | from sqlalchemy.sql import operators |
|---|
| 6 | from sqlalchemy.sql.visitors import VisitableType |
|---|
| 7 | |
|---|
| 8 | class _GenericMeta(VisitableType): |
|---|
| 9 | def __call__(self, *args, **kwargs): |
|---|
| 10 | args = [_literal_as_binds(c) for c in args] |
|---|
| 11 | return type.__call__(self, *args, **kwargs) |
|---|
| 12 | |
|---|
| 13 | class GenericFunction(Function): |
|---|
| 14 | __metaclass__ = _GenericMeta |
|---|
| 15 | |
|---|
| 16 | def __init__(self, type_=None, args=(), **kwargs): |
|---|
| 17 | self.packagenames = [] |
|---|
| 18 | self.name = self.__class__.__name__ |
|---|
| 19 | self._bind = kwargs.get('bind', None) |
|---|
| 20 | self.clause_expr = ClauseList( |
|---|
| 21 | operator=operators.comma_op, |
|---|
| 22 | group_contents=True, *args).self_group() |
|---|
| 23 | self.type = sqltypes.to_instance( |
|---|
| 24 | type_ or getattr(self, '__return_type__', None)) |
|---|
| 25 | |
|---|
| 26 | class AnsiFunction(GenericFunction): |
|---|
| 27 | def __init__(self, **kwargs): |
|---|
| 28 | GenericFunction.__init__(self, **kwargs) |
|---|
| 29 | |
|---|
| 30 | class ReturnTypeFromArgs(GenericFunction): |
|---|
| 31 | """Define a function whose return type is the same as its arguments.""" |
|---|
| 32 | |
|---|
| 33 | def __init__(self, *args, **kwargs): |
|---|
| 34 | kwargs.setdefault('type_', _type_from_args(args)) |
|---|
| 35 | GenericFunction.__init__(self, args=args, **kwargs) |
|---|
| 36 | |
|---|
| 37 | class coalesce(ReturnTypeFromArgs): |
|---|
| 38 | pass |
|---|
| 39 | |
|---|
| 40 | class max(ReturnTypeFromArgs): |
|---|
| 41 | pass |
|---|
| 42 | |
|---|
| 43 | class min(ReturnTypeFromArgs): |
|---|
| 44 | pass |
|---|
| 45 | |
|---|
| 46 | class sum(ReturnTypeFromArgs): |
|---|
| 47 | pass |
|---|
| 48 | |
|---|
| 49 | class now(GenericFunction): |
|---|
| 50 | __return_type__ = sqltypes.DateTime |
|---|
| 51 | |
|---|
| 52 | class concat(GenericFunction): |
|---|
| 53 | __return_type__ = sqltypes.String |
|---|
| 54 | def __init__(self, *args, **kwargs): |
|---|
| 55 | GenericFunction.__init__(self, args=args, **kwargs) |
|---|
| 56 | |
|---|
| 57 | class char_length(GenericFunction): |
|---|
| 58 | __return_type__ = sqltypes.Integer |
|---|
| 59 | |
|---|
| 60 | def __init__(self, arg, **kwargs): |
|---|
| 61 | GenericFunction.__init__(self, args=[arg], **kwargs) |
|---|
| 62 | |
|---|
| 63 | class random(GenericFunction): |
|---|
| 64 | def __init__(self, *args, **kwargs): |
|---|
| 65 | kwargs.setdefault('type_', None) |
|---|
| 66 | GenericFunction.__init__(self, args=args, **kwargs) |
|---|
| 67 | |
|---|
| 68 | class count(GenericFunction): |
|---|
| 69 | """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" |
|---|
| 70 | |
|---|
| 71 | __return_type__ = sqltypes.Integer |
|---|
| 72 | |
|---|
| 73 | def __init__(self, expression=None, **kwargs): |
|---|
| 74 | if expression is None: |
|---|
| 75 | expression = text('*') |
|---|
| 76 | GenericFunction.__init__(self, args=(expression,), **kwargs) |
|---|
| 77 | |
|---|
| 78 | class current_date(AnsiFunction): |
|---|
| 79 | __return_type__ = sqltypes.Date |
|---|
| 80 | |
|---|
| 81 | class current_time(AnsiFunction): |
|---|
| 82 | __return_type__ = sqltypes.Time |
|---|
| 83 | |
|---|
| 84 | class current_timestamp(AnsiFunction): |
|---|
| 85 | __return_type__ = sqltypes.DateTime |
|---|
| 86 | |
|---|
| 87 | class current_user(AnsiFunction): |
|---|
| 88 | __return_type__ = sqltypes.String |
|---|
| 89 | |
|---|
| 90 | class localtime(AnsiFunction): |
|---|
| 91 | __return_type__ = sqltypes.DateTime |
|---|
| 92 | |
|---|
| 93 | class localtimestamp(AnsiFunction): |
|---|
| 94 | __return_type__ = sqltypes.DateTime |
|---|
| 95 | |
|---|
| 96 | class session_user(AnsiFunction): |
|---|
| 97 | __return_type__ = sqltypes.String |
|---|
| 98 | |
|---|
| 99 | class sysdate(AnsiFunction): |
|---|
| 100 | __return_type__ = sqltypes.DateTime |
|---|
| 101 | |
|---|
| 102 | class user(AnsiFunction): |
|---|
| 103 | __return_type__ = sqltypes.String |
|---|
| 104 | |
|---|
| 105 | def _type_from_args(args): |
|---|
| 106 | for a in args: |
|---|
| 107 | if not isinstance(a.type, sqltypes.NullType): |
|---|
| 108 | return a.type |
|---|
| 109 | else: |
|---|
| 110 | return sqltypes.NullType |
|---|