1 from itertools import izip
2 from ..evaluator import ConfusionMatrix, DegenerateMetric
3 from utils import *
4
6 """
7 Threshold finder for multi-label predictions
8
9 t = Thresholder(name)
10
11 e.add(truth, prediction)
12 Really generic. 'truth' is a set, prediction is a list of (label, score)
13
14 """
15
16
17 fbr = 0.1
18
20 self.name = name
21 self.truths = []
22 self.predictions = []
23
24 - def add(self, truth, prediction):
25 self.truths.append(truth)
26 self.predictions.append(prediction)
27
28 - def addbatch(self, truths, predictions):
29 self.truths = truths
30 self.predictions = predictions
31
33 ret = {}
34 cm = ConfusionMatrix()
35
36 for label in labelset:
37 data = []
38 pos = 0
39 neg = 0
40
41 for truth, prediction in izip(self.truths, self.predictions):
42 binarylabel = label in truth
43 if binarylabel: pos += 1
44 else: neg += 1
45 try:
46 data.append((prediction.get(label, -100.0), binarylabel))
47 except KeyError:
48 raise Exception, "Threshold error!"
49
50 data = sorted(data, reverse=True)
51 if len(data) == 0:
52 print "Thresholding: Data length was zero. Continuing."
53 continue
54
55
56
57 cm.tp = cm.fp = 0
58 cm.fn = pos
59 cm.tn = neg
60
61
62 bestf1 = Thresholder.fbr
63 bestf1_thres = data[0][0]
64 acc = neg
65 bestacc = acc
66 bestthreshold = data[0][0]
67 for pair in data:
68 if pair[1]:
69 acc += 1
70 cm.tp += 1
71 cm.fn -= 1
72 else:
73 acc -= 1
74 cm.fp += 1
75 cm.tn -= 1
76 if acc > bestacc:
77 bestacc = acc
78 bestthreshold = pair[0]
79 try:
80 f1 = cm.f1()
81 except DegenerateMetric:
82 f1 = 0.0
83 if f1 > bestf1:
84 bestf1 = f1
85 bestf1_thres = pair[0]
86 ret[label] = bestf1_thres
87
88 return ret
89
97