Package mekano :: Package ml :: Module svm
[hide private]
[frames] | no frames]

Source Code for Module mekano.ml.svm

  1  from __future__ import with_statement 
  2   
  3  import tempfile 
  4  import os 
  5   
  6  from ..atoms import AtomVector 
  7  from multiclassifier import MultiClassifier 
  8   
9 -class SVMClassifier:
10 """SVM Wrapper 11 12 >>> svm = SVMClassifier(modelfile) 13 >>> svm.score(av) 14 """ 15
16 - def __init__(self, modelfile = None):
17 self.w = AtomVector() 18 self.b = 0.0 19 self.sv = 0 20 self.c = 1.0 21 self.j = None 22 self.tmp = "/tmp/" 23 self.binary = "svm_perf_learn" 24 if modelfile is not None: 25 self.readmodelfile(modelfile)
26
27 - def readmodelfile(self, modelfile):
28 with open(modelfile) as fin: 29 line = fin.readline() 30 assert "SVM" in line, "Not an SVM model file!" 31 [fin.readline() for i in xrange(8)] 32 line = fin.readline() 33 sv = int(line.split()[0]) 34 line = fin.readline() 35 b = float(line.split()[0]) 36 for i in xrange(sv-1): 37 a = fin.readline().rstrip().split() 38 alpha_y = float(a[0]) 39 av = AtomVector() 40 for pairs in a[1:]: 41 if "#" in pairs: break 42 a, v = map(float, pairs.split(":")) 43 av[a] = v*alpha_y 44 self.w.addvector(av) 45 46 self.sv = sv 47 self.b = b
48
49 - def train(self, ds):
50 assert(ds.isBinary()) 51 fout = tempfile.NamedTemporaryFile(suffix='svm', dir=self.tmp) 52 ds.toSVM(fout) 53 fout.file.flush() 54 modelfilename = fout.name + ".model" 55 _run("%s %s %s %s > /dev/null 2>&1" % (self.binary, _svm_params(self), fout.name, modelfilename)) 56 fout.close() 57 self.readmodelfile(modelfilename) 58 os.remove(modelfilename)
59
60 - def __repr__(self):
61 return "<SVMClassfier len(w)=%d b=%7.4f #sv=%d>" % (len(self.w), self.b, self.sv)
62
63 - def score(self, av):
64 # keep the shorter vector on the left side for faster dot products! 65 return (av * self.w) - self.b
66
67 -class SVMMultiClassifier:
68 - def __init__(self):
69 self.mc = MultiClassifier() 70 self.labelset = set() 71 self.c = 1.0 72 self.j = None 73 self.tmp = "/tmp/" 74 self.binary = "svm_perf_learn"
75
76 - def train(self, ds):
77 bds = ds.binarize() 78 positions = None 79 print "SVMMultiClassifier: Training with %d docs, %d labels" % (len(ds.docs), len(bds)) 80 for label in bds: 81 if positions is None: 82 fout = tempfile.NamedTemporaryFile(suffix='svm', dir=self.tmp) 83 positions = bds[label].toSVM(fout) 84 else: 85 bds[label].toSVMSubsequent(fout, positions) 86 fout.file.flush() 87 modelfilename = "%s-%s.model" % (fout.name, label) 88 _run("%s %s %s %s > /dev/null 2>&1" % (self.binary, _svm_params(self), fout.name, modelfilename)) 89 self.mc.add(label, SVMClassifier(modelfilename)) 90 os.remove(modelfilename) 91 fout.close() 92 self.labelset = set(bds)
93
94 - def __repr__(self):
95 return "<SVMMultiClassfier: %d labels>" % len(self.labelset)
96
97 - def score(self, av):
98 return self.mc.score(av)
99
100 -def _run(cmd):
101 print "running:", cmd 102 retcode = os.system(cmd) 103 assert retcode == 0
104 105
106 -def _svm_params(classifier):
107 """Extract .c and .j from classifier object and return cmd-line options for SVM. 108 109 For example: -c 1.0 -j 2.0""" 110 111 ret = "-c %f" % classifier.c 112 if classifier.j is not None: 113 ret += " -j %f" % classifier.j 114 return ret
115