[2] | 1 | #!/usr/bin/env python |
---|
| 2 | #Guruprasad Ananda |
---|
| 3 | """ |
---|
| 4 | This tool provides the SQL "group by" functionality. |
---|
| 5 | """ |
---|
| 6 | import sys, string, re, commands, tempfile, random |
---|
| 7 | from rpy import * |
---|
| 8 | |
---|
| 9 | def stop_err(msg): |
---|
| 10 | sys.stderr.write(msg) |
---|
| 11 | sys.exit() |
---|
| 12 | |
---|
| 13 | def main(): |
---|
| 14 | inputfile = sys.argv[2] |
---|
| 15 | ignorecase = int(sys.argv[4]) |
---|
| 16 | ops = [] |
---|
| 17 | cols = [] |
---|
| 18 | rounds = [] |
---|
| 19 | elems = [] |
---|
| 20 | |
---|
| 21 | for var in sys.argv[5:]: |
---|
| 22 | ops.append(var.split()[0]) |
---|
| 23 | cols.append(var.split()[1]) |
---|
| 24 | rounds.append(var.split()[2]) |
---|
| 25 | |
---|
| 26 | if 'Mode' in ops: |
---|
| 27 | try: |
---|
| 28 | r.library('prettyR') |
---|
| 29 | except: |
---|
| 30 | stop_err('R package prettyR could not be loaded. Please make sure it is installed.') |
---|
| 31 | |
---|
| 32 | """ |
---|
| 33 | At this point, ops, cols and rounds will look something like this: |
---|
| 34 | ops: ['mean', 'min', 'c'] |
---|
| 35 | cols: ['1', '3', '4'] |
---|
| 36 | rounds: ['no', 'yes' 'no'] |
---|
| 37 | """ |
---|
| 38 | |
---|
| 39 | for i, line in enumerate( file ( inputfile )): |
---|
| 40 | line = line.rstrip('\r\n') |
---|
| 41 | if len( line )>0 and not line.startswith( '#' ): |
---|
| 42 | elems = line.split( '\t' ) |
---|
| 43 | break |
---|
| 44 | if i == 30: |
---|
| 45 | break # Hopefully we'll never get here... |
---|
| 46 | |
---|
| 47 | if len( elems )<1: |
---|
| 48 | stop_err( "The data in your input dataset is either missing or not formatted properly." ) |
---|
| 49 | |
---|
| 50 | try: |
---|
| 51 | group_col = int( sys.argv[3] )-1 |
---|
| 52 | except: |
---|
| 53 | stop_err( "Group column not specified." ) |
---|
| 54 | |
---|
| 55 | str_ops = ['c', 'length', 'unique', 'random', 'cuniq', 'Mode'] #ops that can handle string/non-numeric inputs |
---|
| 56 | for k,col in enumerate(cols): |
---|
| 57 | col = int(col)-1 |
---|
| 58 | if ops[k] not in str_ops: |
---|
| 59 | # We'll get here only if the user didn't choose 'Concatenate' or 'Count' or 'Count Distinct' or 'pick randmly', which are the |
---|
| 60 | # only aggregation functions that can be used on columns containing strings. |
---|
| 61 | try: |
---|
| 62 | float( elems[col] ) |
---|
| 63 | except: |
---|
| 64 | try: |
---|
| 65 | msg = "Operation '%s' cannot be performed on non-numeric column %d containing value '%s'." %( ops[k], col+1, elems[col] ) |
---|
| 66 | except: |
---|
| 67 | msg = "Operation '%s' cannot be performed on non-numeric data." %ops[k] |
---|
| 68 | stop_err( msg ) |
---|
| 69 | |
---|
| 70 | tmpfile = tempfile.NamedTemporaryFile() |
---|
| 71 | |
---|
| 72 | try: |
---|
| 73 | """ |
---|
| 74 | The -k option for the Posix sort command is as follows: |
---|
| 75 | -k, --key=POS1[,POS2] |
---|
| 76 | start a key at POS1, end it at POS2 (origin 1) |
---|
| 77 | In other words, column positions start at 1 rather than 0, so |
---|
| 78 | we need to add 1 to group_col. |
---|
| 79 | if POS2 is not specified, the newer versions of sort will consider the entire line for sorting. To prevent this, we set POS2=POS1. |
---|
| 80 | """ |
---|
| 81 | case = '' |
---|
| 82 | if ignorecase == 1: |
---|
| 83 | case = '-f' |
---|
| 84 | command_line = "sort -t ' ' " + case + " -k" + str(group_col+1) +"," + str(group_col+1) + " -o " + tmpfile.name + " " + inputfile |
---|
| 85 | except Exception, exc: |
---|
| 86 | stop_err( 'Initialization error -> %s' %str(exc) ) |
---|
| 87 | |
---|
| 88 | error_code, stdout = commands.getstatusoutput(command_line) |
---|
| 89 | |
---|
| 90 | if error_code != 0: |
---|
| 91 | stop_err( "Sorting input dataset resulted in error: %s: %s" %( error_code, stdout )) |
---|
| 92 | |
---|
| 93 | prev_item = "" |
---|
| 94 | prev_vals = [] |
---|
| 95 | skipped_lines = 0 |
---|
| 96 | first_invalid_line = 0 |
---|
| 97 | invalid_line = '' |
---|
| 98 | invalid_value = '' |
---|
| 99 | invalid_column = 0 |
---|
| 100 | fout = open(sys.argv[1], "w") |
---|
| 101 | |
---|
| 102 | for ii, line in enumerate( file( tmpfile.name )): |
---|
| 103 | if line and not line.startswith( '#' ): |
---|
| 104 | line = line.rstrip( '\r\n' ) |
---|
| 105 | try: |
---|
| 106 | fields = line.split("\t") |
---|
| 107 | item = fields[group_col] |
---|
| 108 | if ignorecase == 1: |
---|
| 109 | item = item.lower() |
---|
| 110 | if prev_item != "": |
---|
| 111 | # At this level, we're grouping on values (item and prev_item) in group_col |
---|
| 112 | if item == prev_item: |
---|
| 113 | # Keep iterating and storing values until a new value is encountered. |
---|
| 114 | for i, col in enumerate(cols): |
---|
| 115 | col = int(col)-1 |
---|
| 116 | valid = True |
---|
| 117 | # Before appending the current value, make sure it is numeric if the |
---|
| 118 | # operation for the column requires it. |
---|
| 119 | if ops[i] not in str_ops: |
---|
| 120 | try: |
---|
| 121 | float( fields[col].strip()) |
---|
| 122 | except: |
---|
| 123 | valid = False |
---|
| 124 | skipped_lines += 1 |
---|
| 125 | if not first_invalid_line: |
---|
| 126 | first_invalid_line = ii+1 |
---|
| 127 | invalid_value = fields[col] |
---|
| 128 | invalid_column = col+1 |
---|
| 129 | if valid: |
---|
| 130 | prev_vals[i].append(fields[col].strip()) |
---|
| 131 | else: |
---|
| 132 | """ |
---|
| 133 | When a new value is encountered, write the previous value and the |
---|
| 134 | corresponding aggregate values into the output file. This works |
---|
| 135 | due to the sort on group_col we've applied to the data above. |
---|
| 136 | """ |
---|
| 137 | out_str = prev_item |
---|
| 138 | multiple_modes = False |
---|
| 139 | mode_index = None |
---|
| 140 | for i, op in enumerate( ops ): |
---|
| 141 | if op == 'cuniq': |
---|
| 142 | rfunc = "r.c" |
---|
| 143 | else: |
---|
| 144 | rfunc = "r." + op |
---|
| 145 | if op not in str_ops: |
---|
| 146 | for j, elem in enumerate( prev_vals[i] ): |
---|
| 147 | prev_vals[i][j] = float( elem ) |
---|
| 148 | rout = eval( rfunc )( prev_vals[i] ) |
---|
| 149 | if rounds[i] == 'yes': |
---|
| 150 | rout = int(round(float(rout))) |
---|
| 151 | else: |
---|
| 152 | rout = '%g' %(float(rout)) |
---|
| 153 | else: |
---|
| 154 | if op != 'random': |
---|
| 155 | rout = eval( rfunc )( prev_vals[i] ) |
---|
| 156 | else: |
---|
| 157 | rand_index = random.randint(0,len(prev_vals[i])-1) |
---|
| 158 | rout = prev_vals[i][rand_index] |
---|
| 159 | |
---|
| 160 | if op == 'Mode' and rout == '>1 mode': |
---|
| 161 | multiple_modes = True |
---|
| 162 | mode_index = i |
---|
| 163 | if op == 'unique': |
---|
| 164 | rfunc = "r.length" |
---|
| 165 | rout = eval( rfunc )( rout ) |
---|
| 166 | if op in ['c', 'cuniq']: |
---|
| 167 | if op == 'c': |
---|
| 168 | if type(rout) == type([]): |
---|
| 169 | out_str += "\t" + ','.join(rout) |
---|
| 170 | else: |
---|
| 171 | out_str += "\t" + str(rout) |
---|
| 172 | else: |
---|
| 173 | if type(rout) == type([]): |
---|
| 174 | out_str += "\t" + ','.join(list(set(rout))) |
---|
| 175 | else: |
---|
| 176 | out_str += "\t" + str(rout) |
---|
| 177 | else: |
---|
| 178 | out_str += "\t" + str(rout) |
---|
| 179 | if multiple_modes and mode_index != None: |
---|
| 180 | out_str_list = out_str.split('\t') |
---|
| 181 | for val in prev_vals[mode_index]: |
---|
| 182 | out_str = '\t'.join(out_str_list[:mode_index+1]) + '\t' + str(val) + '\t' + '\t'.join(out_str_list[mode_index+2:]) |
---|
| 183 | print >>fout, out_str.rstrip('\t') |
---|
| 184 | else: |
---|
| 185 | print >>fout, out_str |
---|
| 186 | |
---|
| 187 | prev_item = item |
---|
| 188 | prev_vals = [] |
---|
| 189 | for col in cols: |
---|
| 190 | col = int(col)-1 |
---|
| 191 | val_list = [] |
---|
| 192 | val_list.append(fields[col].strip()) |
---|
| 193 | prev_vals.append(val_list) |
---|
| 194 | else: |
---|
| 195 | # This only occurs once, right at the start of the iteration. |
---|
| 196 | prev_item = item |
---|
| 197 | for col in cols: |
---|
| 198 | col = int(col)-1 |
---|
| 199 | val_list = [] |
---|
| 200 | val_list.append(fields[col].strip()) |
---|
| 201 | prev_vals.append(val_list) |
---|
| 202 | |
---|
| 203 | except Exception, exc: |
---|
| 204 | skipped_lines += 1 |
---|
| 205 | if not first_invalid_line: |
---|
| 206 | first_invalid_line = ii+1 |
---|
| 207 | else: |
---|
| 208 | skipped_lines += 1 |
---|
| 209 | if not first_invalid_line: |
---|
| 210 | first_invalid_line = ii+1 |
---|
| 211 | |
---|
| 212 | # Handle the last grouped value |
---|
| 213 | out_str = prev_item |
---|
| 214 | multiple_modes = False |
---|
| 215 | mode_index = None |
---|
| 216 | for i, op in enumerate(ops): |
---|
| 217 | if op == 'cuniq': |
---|
| 218 | rfunc = "r.c" |
---|
| 219 | else: |
---|
| 220 | rfunc = "r." + op |
---|
| 221 | try: |
---|
| 222 | if op not in str_ops: |
---|
| 223 | for j, elem in enumerate( prev_vals[i] ): |
---|
| 224 | prev_vals[i][j] = float( elem ) |
---|
| 225 | rout = eval( rfunc )( prev_vals[i] ) |
---|
| 226 | if rounds[i] == 'yes': |
---|
| 227 | rout = int(round(float(rout))) |
---|
| 228 | else: |
---|
| 229 | rout = '%g' %(float(rout)) |
---|
| 230 | else: |
---|
| 231 | if op != 'random': |
---|
| 232 | rout = eval( rfunc )( prev_vals[i] ) |
---|
| 233 | else: |
---|
| 234 | rand_index = random.randint(0,len(prev_vals[i])-1) |
---|
| 235 | rout = prev_vals[i][rand_index] |
---|
| 236 | |
---|
| 237 | if op == 'Mode' and rout == '>1 mode': |
---|
| 238 | multiple_modes = True |
---|
| 239 | mode_index = i |
---|
| 240 | if op == 'unique': |
---|
| 241 | rfunc = "r.length" |
---|
| 242 | rout = eval( rfunc )( rout ) |
---|
| 243 | if op in ['c','cuniq']: |
---|
| 244 | if op == 'c': |
---|
| 245 | if type(rout) == type([]): |
---|
| 246 | out_str += "\t" + ','.join(rout) |
---|
| 247 | else: |
---|
| 248 | out_str += "\t" + str(rout) |
---|
| 249 | else: |
---|
| 250 | if type(rout) == type([]): |
---|
| 251 | out_str += "\t" + ','.join(list(set(rout))) |
---|
| 252 | else: |
---|
| 253 | out_str += "\t" + str(rout) |
---|
| 254 | else: |
---|
| 255 | out_str += "\t" + str( rout ) |
---|
| 256 | except: |
---|
| 257 | skipped_lines += 1 |
---|
| 258 | if not first_invalid_line: |
---|
| 259 | first_invalid_line = ii+1 |
---|
| 260 | |
---|
| 261 | if multiple_modes and mode_index != None: |
---|
| 262 | out_str_list = out_str.split('\t') |
---|
| 263 | for val in prev_vals[mode_index]: |
---|
| 264 | out_str = '\t'.join(out_str_list[:mode_index+1]) + '\t' + str(val) + '\t' + '\t'.join(out_str_list[mode_index+2:]) |
---|
| 265 | print >>fout, out_str.rstrip('\t') |
---|
| 266 | else: |
---|
| 267 | print >>fout, out_str |
---|
| 268 | |
---|
| 269 | # Generate a useful info message. |
---|
| 270 | msg = "--Group by c%d: " %(group_col+1) |
---|
| 271 | for i,op in enumerate(ops): |
---|
| 272 | if op == 'c': |
---|
| 273 | op = 'concat' |
---|
| 274 | elif op == 'length': |
---|
| 275 | op = 'count' |
---|
| 276 | elif op == 'unique': |
---|
| 277 | op = 'count_distinct' |
---|
| 278 | elif op == 'random': |
---|
| 279 | op = 'randomly_pick' |
---|
| 280 | elif op == 'cuniq': |
---|
| 281 | op = 'concat_distinct' |
---|
| 282 | msg += op + "[c" + cols[i] + "] " |
---|
| 283 | if skipped_lines > 0: |
---|
| 284 | msg+= "--skipped %d invalid lines starting with line %d. Value '%s' in column %d is not numeric." % ( skipped_lines, first_invalid_line, invalid_value, invalid_column ) |
---|
| 285 | |
---|
| 286 | print msg |
---|
| 287 | |
---|
| 288 | if __name__ == "__main__": |
---|
| 289 | main() |
---|