1 from __future__ import with_statement
2
3 import tempfile
4 import os
5
6 from ..atoms import AtomVector
7 from multiclassifier import MultiClassifier
8
10 """SVM Wrapper
11
12 >>> svm = SVMClassifier(modelfile)
13 >>> svm.score(av)
14 """
15
17 self.w = AtomVector()
18 self.b = 0.0
19 self.sv = 0
20 self.c = 1.0
21 self.j = None
22 self.tmp = "/tmp/"
23 self.binary = "svm_perf_learn"
24 if modelfile is not None:
25 self.readmodelfile(modelfile)
26
28 with open(modelfile) as fin:
29 line = fin.readline()
30 assert "SVM" in line, "Not an SVM model file!"
31 [fin.readline() for i in xrange(8)]
32 line = fin.readline()
33 sv = int(line.split()[0])
34 line = fin.readline()
35 b = float(line.split()[0])
36 for i in xrange(sv-1):
37 a = fin.readline().rstrip().split()
38 alpha_y = float(a[0])
39 av = AtomVector()
40 for pairs in a[1:]:
41 if "#" in pairs: break
42 a, v = map(float, pairs.split(":"))
43 av[a] = v*alpha_y
44 self.w.addvector(av)
45
46 self.sv = sv
47 self.b = b
48
50 assert(ds.isBinary())
51 fout = tempfile.NamedTemporaryFile(suffix='svm', dir=self.tmp)
52 ds.toSVM(fout)
53 fout.file.flush()
54 modelfilename = fout.name + ".model"
55 _run("%s %s %s %s > /dev/null 2>&1" % (self.binary, _svm_params(self), fout.name, modelfilename))
56 fout.close()
57 self.readmodelfile(modelfilename)
58 os.remove(modelfilename)
59
61 return "<SVMClassfier len(w)=%d b=%7.4f #sv=%d>" % (len(self.w), self.b, self.sv)
62
64
65 return (av * self.w) - self.b
66
69 self.mc = MultiClassifier()
70 self.labelset = set()
71 self.c = 1.0
72 self.j = None
73 self.tmp = "/tmp/"
74 self.binary = "svm_perf_learn"
75
77 bds = ds.binarize()
78 positions = None
79 print "SVMMultiClassifier: Training with %d docs, %d labels" % (len(ds.docs), len(bds))
80 for label in bds:
81 if positions is None:
82 fout = tempfile.NamedTemporaryFile(suffix='svm', dir=self.tmp)
83 positions = bds[label].toSVM(fout)
84 else:
85 bds[label].toSVMSubsequent(fout, positions)
86 fout.file.flush()
87 modelfilename = "%s-%s.model" % (fout.name, label)
88 _run("%s %s %s %s > /dev/null 2>&1" % (self.binary, _svm_params(self), fout.name, modelfilename))
89 self.mc.add(label, SVMClassifier(modelfilename))
90 os.remove(modelfilename)
91 fout.close()
92 self.labelset = set(bds)
93
95 return "<SVMMultiClassfier: %d labels>" % len(self.labelset)
96
98 return self.mc.score(av)
99
101 print "running:", cmd
102 retcode = os.system(cmd)
103 assert retcode == 0
104
105
107 """Extract .c and .j from classifier object and return cmd-line options for SVM.
108
109 For example: -c 1.0 -j 2.0"""
110
111 ret = "-c %f" % classifier.c
112 if classifier.j is not None:
113 ret += " -j %f" % classifier.j
114 return ret
115