1 from itertools import izip
2 from atoms import *
3 from textual import WordNumberRegexTokenizer
4 from io import SMARTParser
5
6
7
8
9
10
11
12
13
14 -class Dataset:
15 """
16 A Dataset with 'documents' and 'labels'
17
18 Documents should be L{AtomVector}-like i.e. they should be
19 iterable, yielding C{(a,v)} pairs.
20
21 >>> ds = Dataset("reuters")
22 >>> ds.add(doc, labels)
23 >>> ds.digest()
24
25 @todo: Fix semantics of labels.
26 """
27
29
30 self.name = name
31
32
33 self.labelset = set()
34
35 self.docs = []
36
37 self.labels = []
38 self.digested = True
39
40
41 self.cs = None
42 self.catfactory = None
43 self.tokenfactory = None
44
45
47 """Iterate over (doc, labels) tuples.
48
49 """
50 return izip(self.docs, self.labels)
51
52
53 - def add(self, doc, labels):
54 """
55 Add a (doc, labels) pair to the dataset
56
57 'labels' can be either a sequence (e.g. [1,2,5],
58 or a single value (e.g. True or False)
59 """
60 self.docs.append(doc)
61 self.labels.append(labels)
62 self.digested = False
63
64
65 - def digest(self, force=False):
66 """Analyze the data and generate an internal list of labels.
67
68 Useful for binarizing etc.
69 """
70 if self.digested and force==False: return
71
72 self.labelset = set()
73 for labels in self.labels:
74 if hasattr(labels, "__iter__"):
75 for label in labels:
76 self.labelset.add(label)
77 elif labels:
78 self.labelset.add(labels)
79 self.digested = True
80
81
83 self.digest()
84 return len(self.labelset) == 1
85
86
88 """Get a dictionary of labels and respective document counts.
89
90 This is an O(n) operation!
91 """
92 self.digest()
93 counts = dict([(l,0) for l in self.labelset])
94 for labels in self.labels:
95 for label in labels:
96 counts[label] += 1
97
98 return counts
99
100
102 self.digest()
103 return "<Dataset '%s', %d docs, %d labels>" % (self.name, len(self.docs), len(self.labelset))
104
106 """Write a multi-class dataset to fout in SVM format.
107
108 This can be directly consumed by LIBSVM.
109 """
110 for doc, labels in self:
111 svm_label = labels[0]
112 fout.write("%s %s\n" % (svm_label,
113 " ".join(["%d:%-7.4f" % (a,v) for a,v in sorted(doc.iteritems())])))
114
116 """Write a binary dataset to fout in SVM format.
117
118 Returns the byte positions of the labels, which can be used
119 by L{toSVMSubsequent}() to overwrite the labels with something
120 else.
121 """
122 assert(self.isBinary())
123 positions = []
124 for doc, label in self:
125 if label: svm_label = "+1"
126 else: svm_label = "-1"
127 positions.append(fout.tell())
128 fout.write("%s %s\n" % (svm_label,
129 " ".join(["%d:%-7.4f" % (a,v) for a,v in sorted(doc.iteritems())])))
130 return positions
131
133 assert(self.isBinary())
134 i = 0
135 for doc, label in self:
136 position = positions[i]
137 i += 1
138 if label: svm_label = "+1"
139 else: svm_label = "-1"
140 fout.seek(position)
141 fout.write(svm_label)
142
144 if self.catfactory is None or self.tokenfactory is None:
145 raise Exception("Dataset must have catfactory and tokenfactory")
146
147 for doc, labels in self:
148 fout.write(".I %s\n" % doc.name)
149 fout.write(".C\n")
150 fout.write("; ".join(["%s 1" % self.catfactory.get_object(a) for a in labels]))
151 fout.write("\n")
152 fout.write(".T\n\n")
153 fout.write(".W\n")
154 fout.write(" ".join([" ".join([self.tokenfactory.get_object(a)] * int(v)) for a,v in doc.iteritems()]))
155 fout.write("\n")
156
158 """Create and return binary datasets.
159
160 @return: A C{{k:v}} dictionary where k is a category name, and v is a binary dataset.
161 """
162
163 self.digest()
164
165 assert not self.isBinary(), "Dataset is already binary"
166
167 name = self.name
168 all_labels = self.labelset
169
170
171 ret = dict([(l,Dataset("%s.%s" % (name, str(l)))) for l in all_labels])
172
173 for doc, doclabels in self:
174 doclabels = set(doclabels)
175 for label in all_labels:
176 if label in doclabels:
177 ret[label].labels.append(True)
178 else:
179 ret[label].labels.append(False)
180
181 for ds in ret.values():
182
183 ds.docs = self.docs
184 ds.digest(force=True)
185 ds.catfactory = self.catfactory
186 ds.tokenfactory = self.tokenfactory
187 ds.cs = self.cs
188
189 return ret
190
192 """Convert to a weighted (e.g. LTC) dataset
193
194 @param cs : An optional L{CorpusStats} object, otherwise it will be created
195 an associated with the dataset.
196
197 """
198
199 if cs is None:
200 cs = CorpusStats()
201
202 for doc, doclabels in self:
203 cs.add(doc)
204
205 wvc = WeightVectors(cs)
206 for i in range(len(self.docs)):
207 self.docs[i] = wvc[self.docs[i]]
208
209 self.cs = cs
210
212 """Creates count subsets of the dataset.
213
214 Subsetting is performed using round-robin.
215
216 @param count : Number of subsets to create
217 @return : A list of datasets
218 """
219 n = len(self.docs)
220 docs_per_set = int(n/count)
221 if docs_per_set < 1:
222 raise Exception, "#subsets > #docs"
223 subsets = [Dataset("%s-%d" % (self.name, i+1)) for i in range(count)]
224 j = 0
225 for i in range(n):
226 j = i % count
227 subsets[j].add(self.docs[i], self.labels[i])
228
229 for ds in subsets:
230 ds.digest()
231 ds.catfactory = self.catfactory
232 ds.tokenfactory = self.tokenfactory
233 ds.cs = self.cs
234
235 return subsets
236
238 """Create cross-validation folds.
239
240 The dataset is broken into `count` pieces, each fold (i.e. train-test pair)
241 is created by assigning 1 piece to `train`, and `count-1` pieces to `test`.
242
243 @param count : Number of folds
244 @return : A list of [train,test] datasets
245 """
246 subsets = self.subset(count)
247 folds = [[Dataset(), Dataset()] for i in range(count)]
248 for i in range(count):
249 for j in range(count):
250 if i == j:
251 folds[i][1] = subsets[j]
252 else:
253 folds[i][0] += subsets[j]
254 return folds
255
257 """Add two datasets.
258
259 If both datasets are non-empty, then they must be 'compatible',
260 i.e., share the same factories and corpus stats.
261
262 The resulting dataset combines the docs and labels, and inherits
263 the factories and corpus stats of the non-empty parent dataset.
264
265 If both parents were L{digest}ed, the resulting dataset is also digested.
266 """
267 result = Dataset()
268
269 if len(self.docs) > 0 and len(other.docs) > 0:
270 if self.catfactory != other.catfactory or self.tokenfactory != other.tokenfactory or self.cs != other.cs:
271 raise Exception("Incompatible datasets")
272
273 if len(self.docs) > 0:
274 reference_ds = self
275 else:
276 reference_ds = other
277
278 result.docs = self.docs + other.docs
279 result.labels = self.labels + other.labels
280 if self.digested and other.digested:
281 result.labelset = self.labelset | other.labelset
282 result.digested = True
283 else:
284 result.digested = False
285
286 result.catfactory = reference_ds.catfactory
287 result.tokenfactory = reference_ds.tokenfactory
288 result.cs = reference_ds.cs
289
290 result.name = self.name + "+" + other.name
291
292 return result
293
294 @staticmethod
296 ds = Dataset(filename)
297 if linkto is None:
298 catfactory = AtomFactory("cats")
299 tokenfactory = AtomFactory("tokens")
300 else:
301 catfactory = linkto.catfactory
302 tokenfactory = linkto.tokenfactory
303
304 def handler(docid, cats, text):
305 catatoms = [catfactory[c] for c in cats]
306 av = AtomVector()
307 for token in WordNumberRegexTokenizer(text):
308 tokenatom = tokenfactory[token]
309 av[tokenatom] += 1
310 ds.add(av, catatoms)
311
312 with open(filename) as fin:
313 sp = SMARTParser(fin, handler, ["T", "W"])
314 sp.parse()
315
316 ds.digest()
317 ds.catfactory = catfactory
318 ds.tokenfactory = tokenfactory
319 return ds
320
321 @staticmethod
323 """Create a dataset from rainbow's output.
324
325 $ rainbow -d model --index 20news/train/*
326 $ rainbow -d model --print-matrix=siw > train.txt
327
328 >>> ds = from_rainbow("train.txt")
329
330 C{ds.catfactory} holds the L{AtomFactory} for category names.
331 C{ds.tokenfactory} holds the L{AtomFactory} for the tokens.
332
333 A test set should share its factories with a training set.
334 Therefore, read is like so:
335
336 >>> ds2 = from_rainbow("testfile.txt", linkto = ds)
337
338 @param filename : File containing rainbow's output
339 @param linkto : Another dataset whose L{AtomFactory} we should borrow.
340 @return : A brand new dataset.
341 """
342
343
344
345
346
347 ds = Dataset(filename)
348 if linkto is None:
349 catfactory = AtomFactory("cats")
350 tokenfactory = AtomFactory("tokens")
351 else:
352 catfactory = linkto.catfactory
353 tokenfactory = linkto.tokenfactory
354 fin = open(filename, "r")
355 for line in fin:
356 a = line.split(None, 2)
357 catatom = catfactory[a[1]]
358 a0 = a[0]
359 p = a0.rfind("/")
360 if p != -1:
361 docname = a0[p+1:]
362 else:
363 docname = a0
364 a = a[2].split()
365 l = len(a)
366 av = AtomVector(name=docname)
367
368 for i in range(0,l,2):
369 atom = tokenfactory[a[i]]
370 count = float(a[i+1])
371 av.set(atom, count)
372 ds.add(av,[catatom])
373 ds.digest()
374 ds.catfactory = catfactory
375 ds.tokenfactory = tokenfactory
376 return ds
377