Package mekano :: Package ml :: Module thresholder
[hide private]
[frames] | no frames]

Source Code for Module mekano.ml.thresholder

 1  from itertools import izip 
 2  from ..evaluator import ConfusionMatrix, DegenerateMetric 
 3  from utils import * 
 4   
5 -class Thresholder:
6 """ 7 Threshold finder for multi-label predictions 8 9 t = Thresholder(name) 10 11 e.add(truth, prediction) 12 Really generic. 'truth' is a set, prediction is a list of (label, score) 13 14 """ 15 16 # The lower threshold for SCut.FBR method 17 fbr = 0.1 18
19 - def __init__(self, name=""):
20 self.name = name 21 self.truths = [] 22 self.predictions = []
23
24 - def add(self, truth, prediction):
25 self.truths.append(truth) 26 self.predictions.append(prediction)
27
28 - def addbatch(self, truths, predictions):
29 self.truths = truths 30 self.predictions = predictions
31
32 - def findthresholds(self, labelset):
33 ret = {} 34 cm = ConfusionMatrix() 35 # do it independently for each label { 36 for label in labelset: 37 data = [] 38 pos = 0 39 neg = 0 40 # for each document { 41 for truth, prediction in izip(self.truths, self.predictions): 42 binarylabel = label in truth 43 if binarylabel: pos += 1 44 else: neg += 1 45 try: 46 data.append((prediction.get(label, -100.0), binarylabel)) 47 except KeyError: 48 raise Exception, "Threshold error!" 49 # } for each doc 50 data = sorted(data, reverse=True) 51 if len(data) == 0: 52 print "Thresholding: Data length was zero. Continuing." 53 continue 54 55 # now we have (prediction score, binary truth) tuples 56 # default accuracy if threshold > highest score 57 cm.tp = cm.fp = 0 58 cm.fn = pos 59 cm.tn = neg 60 # If we can't beat fbr F1, then the best threshold is 61 # equal to the score of the top-ranking document. 62 bestf1 = Thresholder.fbr 63 bestf1_thres = data[0][0] 64 acc = neg # no need for denominator 65 bestacc = acc 66 bestthreshold = data[0][0] 67 for pair in data: 68 if pair[1]: 69 acc += 1 70 cm.tp += 1 71 cm.fn -= 1 72 else: 73 acc -= 1 74 cm.fp += 1 75 cm.tn -= 1 76 if acc > bestacc: 77 bestacc = acc 78 bestthreshold = pair[0] 79 try: 80 f1 = cm.f1() 81 except DegenerateMetric: 82 f1 = 0.0 83 if f1 > bestf1: 84 bestf1 = f1 85 bestf1_thres = pair[0] 86 ret[label] = bestf1_thres 87 # } for each label 88 return ret
89
90 -def findThresholdsForDataset(classifier, ds):
91 ds.digest() 92 thres = Thresholder() 93 preds = scoreAll(classifier, ds.docs) 94 thres.addbatch(ds.labels, preds) 95 # We use classifier.labelset since it might be smaller than ds.labelset 96 return thres.findthresholds(classifier.labelset)
97