-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathencoder.py
More file actions
1216 lines (1058 loc) · 42.4 KB
/
encoder.py
File metadata and controls
1216 lines (1058 loc) · 42.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Encoding categorical data as integers
Authors
* Samuele Cornell 2020
* Aku Rouhe 2020
"""
import ast
import collections
import itertools
import torch
import speechbrain as sb
from speechbrain.utils.checkpoints import (
mark_as_loader,
mark_as_saver,
register_checkpoint_hooks,
)
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
# NOTE: Changing these does NOT change the defaults in the classes.
# Consider these read-only.
DEFAULT_UNK = "<unk>"
DEFAULT_BOS = "<bos>"
DEFAULT_EOS = "<eos>"
DEFAULT_BLANK = "<blank>"
@register_checkpoint_hooks
class CategoricalEncoder:
"""Encode labels of a discrete set.
Used for encoding, e.g., speaker identities in speaker recognition.
Given a collection of hashables (e.g a strings) it encodes
every unique item to an integer value: ["spk0", "spk1"] --> [0, 1]
Internally the correspondence between each label to its index is handled by
two dictionaries: lab2ind and ind2lab.
The label integer encoding can be generated automatically from a SpeechBrain
DynamicItemDataset by specifying the desired entry (e.g., spkid) in the annotation
and calling update_from_didataset method:
>>> from speechbrain.dataio.encoder import CategoricalEncoder
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> dataset = {
... "ex_{}".format(x): {"spkid": "spk{}".format(x)} for x in range(20)
... }
>>> dataset = DynamicItemDataset(dataset)
>>> encoder = CategoricalEncoder()
>>> encoder.update_from_didataset(dataset, "spkid")
>>> assert len(encoder) == len(
... dataset
... ) # different speaker for each utterance
However can also be updated from an iterable:
>>> from speechbrain.dataio.encoder import CategoricalEncoder
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> dataset = ["spk{}".format(x) for x in range(20)]
>>> encoder = CategoricalEncoder()
>>> encoder.update_from_iterable(dataset)
>>> assert len(encoder) == len(dataset)
Note
----
In both methods it can be specified it the single element in the iterable
or in the dataset should be treated as a sequence or not (default False).
If it is a sequence each element in the sequence will be encoded.
>>> from speechbrain.dataio.encoder import CategoricalEncoder
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> dataset = [[x + 1, x + 2] for x in range(20)]
>>> encoder = CategoricalEncoder()
>>> encoder.ignore_len()
>>> encoder.update_from_iterable(dataset, sequence_input=True)
>>> assert len(encoder) == 21 # there are only 21 unique elements 1-21
This class offers 4 different methods to explicitly add a label in the internal
dicts: add_label, ensure_label, insert_label, enforce_label.
add_label and insert_label will raise an error if it is already present in the
internal dicts. insert_label, enforce_label allow also to specify the integer value
to which the desired label is encoded.
Encoding can be performed using 4 different methods:
encode_label, encode_sequence, encode_label_torch and encode_sequence_torch.
encode_label operate on single labels and simply returns the corresponding
integer encoding:
>>> from speechbrain.dataio.encoder import CategoricalEncoder
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> dataset = ["spk{}".format(x) for x in range(20)]
>>> encoder.update_from_iterable(dataset)
>>>
22
>>>
encode_sequence on sequences of labels:
>>> encoder.encode_sequence(["spk1", "spk19"])
[22, 40]
>>>
encode_label_torch and encode_sequence_torch return torch tensors
>>> encoder.encode_sequence_torch(["spk1", "spk19"])
tensor([22, 40])
>>>
Decoding can be performed using decode_torch and decode_ndim methods.
>>> encoded = encoder.encode_sequence_torch(["spk1", "spk19"])
>>> encoder.decode_torch(encoded)
['spk1', 'spk19']
>>>
decode_ndim is used for multidimensional list or pytorch tensors
>>> encoded = encoded.unsqueeze(0).repeat(3, 1)
>>> encoder.decode_torch(encoded)
[['spk1', 'spk19'], ['spk1', 'spk19'], ['spk1', 'spk19']]
>>>
In some applications, it can happen that during testing a label which has not
been encountered during training is encountered. To handle this out-of-vocabulary
problem add_unk can be used. Every out-of-vocab label is mapped to this special
<unk> label and its corresponding integer encoding.
>>> import torch
>>> try:
... encoder.encode_label("spk42")
... except KeyError:
... print("spk42 is not in the encoder this raises an error!")
spk42 is not in the encoder this raises an error!
>>> encoder.add_unk()
41
>>> encoder.encode_label("spk42")
41
>>>
returns the <unk> encoding
This class offers also methods to save and load the internal mappings between
labels and tokens using: save and load methods as well as load_or_create.
"""
VALUE_SEPARATOR = " => "
EXTRAS_SEPARATOR = "================\n"
def __init__(self, starting_index=0, **special_labels):
self.lab2ind = {}
self.ind2lab = {}
self.starting_index = starting_index
# NOTE: unk_label is not necessarily set at all!
# This is because None is a suitable value for unk.
# So the test is: hasattr(self, "unk_label")
# rather than self.unk_label is not None
self.handle_special_labels(special_labels)
def handle_special_labels(self, special_labels):
"""Handles special labels such as unk_label."""
if "unk_label" in special_labels:
self.add_unk(special_labels["unk_label"])
def __len__(self):
return len(self.lab2ind)
@classmethod
def from_saved(cls, path):
"""Recreate a previously saved encoder directly"""
obj = cls()
obj.load(path)
return obj
def update_from_iterable(self, iterable, sequence_input=False):
"""Update from iterator
Arguments
---------
iterable : iterable
Input sequence on which to operate.
sequence_input : bool
Whether iterable yields sequences of labels or individual labels
directly. (default False)
"""
if sequence_input:
label_iterator = itertools.chain.from_iterable(iterable)
else:
label_iterator = iter(iterable)
for label in label_iterator:
self.ensure_label(label)
def update_from_didataset(
self, didataset, output_key, sequence_input=False
):
"""Update from DynamicItemDataset.
Arguments
---------
didataset : DynamicItemDataset
Dataset on which to operate.
output_key : str
Key in the dataset (in data or a dynamic item) to encode.
sequence_input : bool
Whether the data yielded with the specified key consists of
sequences of labels or individual labels directly.
"""
with didataset.output_keys_as([output_key]):
self.update_from_iterable(
(data_point[output_key] for data_point in didataset),
sequence_input=sequence_input,
)
def limited_labelset_from_iterable(
self, iterable, sequence_input=False, n_most_common=None, min_count=1
):
"""Produce label mapping from iterable based on label counts
Used to limit label set size.
Arguments
---------
iterable : iterable
Input sequence on which to operate.
sequence_input : bool
Whether iterable yields sequences of labels or individual labels
directly. False by default.
n_most_common : int, None
Take at most this many labels as the label set, keeping the most
common ones. If None (as by default), take all.
min_count : int
Don't take labels if they appear less than this many times.
Returns
-------
collections.Counter
The counts of the different labels (unfiltered).
"""
if self.lab2ind:
clsname = self.__class__.__name__
logger.info(
f"Limited_labelset_from_iterable called, "
f"but {clsname} is not empty. "
"The new labels will be added, i.e. won't overwrite. "
"This is normal if there is e.g. an unk label already."
)
if sequence_input:
label_iterator = itertools.chain.from_iterable(iterable)
else:
label_iterator = iter(iterable)
counts = collections.Counter(label_iterator)
for label, count in counts.most_common(n_most_common):
if count < min_count:
# .most_common() produces counts in descending order,
# so no more labels can be found
break
self.add_label(label)
return counts
def load_or_create(
self,
path,
from_iterables=[],
from_didatasets=[],
sequence_input=False,
output_key=None,
special_labels={},
):
"""Convenient syntax for creating the encoder conditionally
This pattern would be repeated in so many experiments that
we decided to add a convenient shortcut for it here. The
current version is multi-gpu (DDP) safe.
"""
try:
if sb.utils.distributed.if_main_process():
if not self.load_if_possible(path):
for iterable in from_iterables:
self.update_from_iterable(iterable, sequence_input)
for didataset in from_didatasets:
if output_key is None:
raise ValueError(
"Provide an output_key for DynamicItemDataset"
)
self.update_from_didataset(
didataset, output_key, sequence_input
)
self.handle_special_labels(special_labels)
self.save(path)
finally:
sb.utils.distributed.ddp_barrier()
self.load(path)
def add_label(self, label):
"""Add new label to the encoder, at the next free position.
Arguments
---------
label : hashable
Most often labels are str, but anything that can act as dict key is
supported. Note that default save/load only supports Python
literals.
Returns
-------
int
The index that was used to encode this label.
"""
if label in self.lab2ind:
clsname = self.__class__.__name__
raise KeyError(f"Label already present in {clsname}")
index = self._next_index()
self.lab2ind[label] = index
self.ind2lab[index] = label
return index
def ensure_label(self, label):
"""Add a label if it is not already present.
Arguments
---------
label : hashable
Most often labels are str, but anything that can act as dict key is
supported. Note that default save/load only supports Python
literals.
Returns
-------
int
The index that was used to encode this label.
"""
if label in self.lab2ind:
return self.lab2ind[label]
else:
return self.add_label(label)
def insert_label(self, label, index):
"""Add a new label, forcing its index to a specific value.
If a label already has the specified index, it is moved to the end
of the mapping.
Arguments
---------
label : hashable
Most often labels are str, but anything that can act as dict key is
supported. Note that default save/load only supports Python
literals.
index : int
The specific index to use.
"""
if label in self.lab2ind:
clsname = self.__class__.__name__
raise KeyError(f"Label already present in {clsname}")
else:
self.enforce_label(label, index)
def enforce_label(self, label, index):
"""Make sure label is present and encoded to a particular index.
If the label is present but encoded to some other index, it is
moved to the given index.
If there is already another label at the
given index, that label is moved to the next free position.
"""
index = int(index)
if label in self.lab2ind:
if index == self.lab2ind[label]:
return
else:
# Delete old index mapping. Everything else gets overwritten.
del self.ind2lab[self.lab2ind[label]]
# Move other label out of the way:
if index in self.ind2lab:
saved_label = self.ind2lab[index]
moving_other = True
else:
moving_other = False
# Ready to push the new index.
self.lab2ind[label] = index
self.ind2lab[index] = label
# And finally put the moved index in new spot.
if moving_other:
logger.info(
f"Moving label {repr(saved_label)} from index "
f"{index}, because {repr(label)} was put at its place."
)
new_index = self._next_index()
self.lab2ind[saved_label] = new_index
self.ind2lab[new_index] = saved_label
def add_unk(self, unk_label=DEFAULT_UNK):
"""Add label for unknown tokens (out-of-vocab).
When asked to encode unknown labels, they can be mapped to this.
Arguments
---------
unk_label : hashable, optional
Most often labels are str, but anything that can act as dict key is
supported. Note that default save/load only supports Python
literals. Default: <unk>. This can be None, as well!
Returns
-------
int
The index that was used to encode this.
"""
self.unk_label = unk_label
return self.add_label(unk_label)
def _next_index(self):
"""The index to use for the next new label"""
index = self.starting_index
while index in self.ind2lab:
index += 1
return index
def is_continuous(self):
"""Check that the set of indices doesn't have gaps
For example:
If starting index = 1
Continuous: [1,2,3,4]
Continuous: [0,1,2]
Non-continuous: [2,3,4]
Non-continuous: [1,2,4]
Returns
-------
bool
True if continuous.
"""
# Because of Python indexing this also handles the special cases
# of 0 or 1 labels.
indices = sorted(self.ind2lab.keys())
return self.starting_index in indices and all(
j - i == 1 for i, j in zip(indices[:-1], indices[1:])
)
def encode_label(self, label, allow_unk=True):
"""Encode label to int
Arguments
---------
label : hashable
Label to encode, must exist in the mapping.
allow_unk : bool
If given, that label is not in the label set
AND unk_label has been added with add_unk(),
allows encoding to unk_label's index.
Returns
-------
int
Corresponding encoded int value.
"""
self._assert_len()
try:
return self.lab2ind[label]
except KeyError:
if hasattr(self, "unk_label") and allow_unk:
return self.lab2ind[self.unk_label]
elif hasattr(self, "unk_label") and not allow_unk:
raise KeyError(
f"Unknown label {label}, and explicitly "
"disallowed the use of the existing unk-label"
)
elif not hasattr(self, "unk_label") and allow_unk:
raise KeyError(
f"Cannot encode unknown label {label}. "
"You have not called add_unk() to add a special "
"unk-label for unknown labels."
)
else:
raise KeyError(
f"Couldn't and wouldn't encode unknown label {label}."
)
def encode_label_torch(self, label, allow_unk=True):
"""Encode label to torch.LongTensor.
Arguments
---------
label : hashable
Label to encode, must exist in the mapping.
allow_unk : bool
If given, that label is not in the label set
AND unk_label has been added with add_unk(),
allows encoding to unk_label's index.
Returns
-------
torch.LongTensor
Corresponding encoded int value.
Tensor shape [1].
"""
return torch.LongTensor([self.encode_label(label, allow_unk)])
def encode_sequence(self, sequence, allow_unk=True):
"""Encode a sequence of labels to list
Arguments
---------
sequence : iterable
Labels to encode, must exist in the mapping.
allow_unk : bool
If given, that label is not in the label set
AND unk_label has been added with add_unk(),
allows encoding to unk_label's index.
Returns
-------
list
Corresponding integer labels.
"""
self._assert_len()
return [self.encode_label(label, allow_unk) for label in sequence]
def encode_sequence_torch(self, sequence, allow_unk=True):
"""Encode a sequence of labels to torch.LongTensor
Arguments
---------
sequence : iterable
Labels to encode, must exist in the mapping.
allow_unk : bool
If given, that label is not in the label set
AND unk_label has been added with add_unk(),
allows encoding to unk_label's index.
Returns
-------
torch.LongTensor
Corresponding integer labels.
Tensor shape [len(sequence)].
"""
return torch.LongTensor(
[self.encode_label(label, allow_unk) for label in sequence]
)
def decode_torch(self, x):
"""Decodes an arbitrarily nested torch.Tensor to a list of labels.
Provided separately because Torch provides clearer introspection,
and so doesn't require try-except.
Arguments
---------
x : torch.Tensor
Torch tensor of some integer dtype (Long, int) and any shape to
decode.
Returns
-------
list
list of original labels
"""
self._assert_len()
decoded = []
# Recursively operates on the different dimensions.
if x.ndim == 1: # Last dimension!
for element in x:
decoded.append(self.ind2lab[int(element)])
else:
for subtensor in x:
decoded.append(self.decode_torch(subtensor))
return decoded
def decode_ndim(self, x):
"""Decodes an arbitrarily nested iterable to a list of labels.
This works for essentially any pythonic iterable (including torch), and
also single elements.
Arguments
---------
x : Any
Python list or other iterable or torch.Tensor or a single integer element
Returns
-------
list, Any
ndim list of original labels, or if input was single element,
output will be, too.
"""
self._assert_len()
# Recursively operates on the different dimensions.
try:
decoded = []
for subtensor in x:
decoded.append(self.decode_ndim(subtensor))
return decoded
except TypeError: # Not an iterable, bottom level!
return self.ind2lab[int(x)]
@mark_as_saver
def save(self, path):
"""Save the categorical encoding for later use and recovery
Saving uses a Python literal format, which supports things like
tuple labels, but is considered safe to load (unlike e.g. pickle).
Arguments
---------
path : str, Path
Where to save. Will overwrite.
"""
extras = self._get_extras()
self._save_literal(path, self.lab2ind, extras)
def load(self, path):
"""Loads from the given path.
CategoricalEncoder uses a Python literal format, which supports things
like tuple labels, but is considered safe to load (unlike e.g. pickle).
Arguments
---------
path : str, Path
Where to load from.
"""
if self.lab2ind:
clsname = self.__class__.__name__
logger.info(
f"Load called, but {clsname} is not empty. "
"Loaded data will overwrite everything. "
"This is normal if there is e.g. an unk label defined at init."
)
lab2ind, ind2lab, extras = self._load_literal(path)
self.lab2ind = lab2ind
self.ind2lab = ind2lab
self._set_extras(extras)
# If we're here, load was a success!
logger.debug(f"Loaded categorical encoding from {path}")
@mark_as_loader
def load_if_possible(self, path, end_of_epoch=False):
"""Loads if possible, returns a bool indicating if loaded or not.
Arguments
---------
path : str, Path
Where to load from.
end_of_epoch : bool
Whether the checkpoint was end-of-epoch or not.
Returns
-------
bool :
If load was successful.
Example
-------
>>> encoding_file = getfixture("tmpdir") / "encoding.txt"
>>> encoder = CategoricalEncoder()
>>> # The idea is in an experiment script to have something like this:
>>> if not encoder.load_if_possible(encoding_file):
... encoder.update_from_iterable("abcd")
... encoder.save(encoding_file)
>>> # So the first time you run the experiment, the encoding is created.
>>> # However, later, the encoding exists:
>>> encoder = CategoricalEncoder()
>>> encoder.expect_len(4)
>>> if not encoder.load_if_possible(encoding_file):
... assert False # We won't get here!
>>> encoder.decode_ndim(range(4))
['a', 'b', 'c', 'd']
"""
del end_of_epoch # Unused here.
try:
self.load(path)
except FileNotFoundError:
logger.debug(
f"Would load categorical encoding from {path}, "
"but file doesn't exist yet."
)
return False
except (ValueError, SyntaxError):
logger.debug(
f"Would load categorical encoding from {path}, "
"and file existed but seems to be corrupted or otherwise couldn't load."
)
return False
return True # If here, all good
def expect_len(self, expected_len):
"""Specify the expected category count. If the category count observed
during encoding/decoding does NOT match this, an error will be raised.
This can prove useful to detect bugs in scenarios where the encoder is
dynamically built using a dataset, but downstream code expects a
specific category count (and may silently break otherwise).
This can be called anytime and the category count check will only be
performed during an actual encoding/decoding task.
Arguments
---------
expected_len : int
The expected final category count, i.e. `len(encoder)`.
Example
-------
>>> encoder = CategoricalEncoder()
>>> encoder.update_from_iterable("abcd")
>>> encoder.expect_len(3)
>>> encoder.encode_label("a")
Traceback (most recent call last):
...
RuntimeError: .expect_len(3) was called, but 4 categories found
>>> encoder.expect_len(4)
>>> encoder.encode_label("a")
0
"""
self.expected_len = expected_len
def ignore_len(self):
"""Specifies that category count shall be ignored at encoding/decoding
time.
Effectively inhibits the ".expect_len was never called" warning.
Prefer :py:meth:`~CategoricalEncoder.expect_len` when the category count
is known."""
self.expected_len = None
def _assert_len(self):
"""If `expect_len` was called, then check if len(self) matches the
expected value. If it does not, raise a RuntimeError.
If neither `expect_len` or `ignore_len` were ever called, warn once."""
if hasattr(self, "expected_len"):
# skip when ignore_len() was called
if self.expected_len is None:
return
real_len = len(self)
if real_len != self.expected_len:
raise RuntimeError(
f".expect_len({self.expected_len}) was called, "
f"but {real_len} categories found"
)
else:
logger.warning_once(
f"{self.__class__.__name__}.expect_len was never called: "
f"assuming category count of {len(self)} to be correct! "
"Sanity check your encoder using `.expect_len`. "
"Ensure that downstream code also uses the correct size. "
"If you are sure this does not apply to you, use `.ignore_len`."
)
self.ignore_len()
return
def _get_extras(self):
"""Override this to provide any additional things to save
Call super()._get_extras() to get the base extras
"""
extras = {"starting_index": self.starting_index}
if hasattr(self, "unk_label"):
extras["unk_label"] = self.unk_label
return extras
def _set_extras(self, extras):
"""Override this to e.g. load any extras needed
Call super()._set_extras(extras) to set the base extras
"""
if "unk_label" in extras:
self.unk_label = extras["unk_label"]
self.starting_index = extras["starting_index"]
@staticmethod
def _save_literal(path, lab2ind, extras):
"""Save which is compatible with _load_literal"""
with open(path, "w", encoding="utf-8") as f:
for label, ind in lab2ind.items():
f.write(
repr(label)
+ CategoricalEncoder.VALUE_SEPARATOR
+ str(ind)
+ "\n"
)
f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
for key, value in extras.items():
f.write(
repr(key)
+ CategoricalEncoder.VALUE_SEPARATOR
+ repr(value)
+ "\n"
)
f.flush()
@staticmethod
def _load_literal(path):
"""Load which supports Python literals as keys.
This is considered safe for user input, as well (unlike e.g. pickle).
"""
lab2ind = {}
ind2lab = {}
extras = {}
with open(path, encoding="utf-8") as f:
# Load the label to index mapping (until EXTRAS_SEPARATOR)
for line in f:
if line == CategoricalEncoder.EXTRAS_SEPARATOR:
break
literal, ind = line.strip().split(
CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1
)
ind = int(ind)
label = ast.literal_eval(literal)
lab2ind[label] = ind
ind2lab[ind] = label
# Load the extras:
for line in f:
literal_key, literal_value = line.strip().split(
CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1
)
key = ast.literal_eval(literal_key)
value = ast.literal_eval(literal_value)
extras[key] = value
return lab2ind, ind2lab, extras
class TextEncoder(CategoricalEncoder):
"""CategoricalEncoder subclass which offers specific methods for encoding text and handle
special tokens for training of sequence to sequence models.
In detail, aside special <unk> token already present in CategoricalEncoder
for handling out-of-vocab tokens here special methods to handle
<bos> beginning of sequence and <eos> tokens are defined.
Note: update_from_iterable and update_from_didataset here have as default
sequence_input=True because it is assumed that this encoder is used on
iterables of strings: e.g.
>>> from speechbrain.dataio.encoder import TextEncoder
>>> dataset = [["encode", "this", "textencoder"], ["foo", "bar"]]
>>> encoder = TextEncoder()
>>> encoder.update_from_iterable(dataset)
>>> encoder.expect_len(5)
>>> encoder.encode_label("this")
1
>>> encoder.add_unk()
5
>>> encoder.expect_len(6)
>>> encoder.encode_sequence(["this", "out-of-vocab"])
[1, 5]
>>>
Two methods can be used to add <bos> and <eos> to the internal dicts:
insert_bos_eos, add_bos_eos.
>>> encoder.add_bos_eos()
>>> encoder.expect_len(8)
>>> encoder.lab2ind[encoder.eos_label]
7
>>>
add_bos_eos adds the special tokens at the end of the dict indexes
>>> encoder = TextEncoder()
>>> encoder.update_from_iterable(dataset)
>>> encoder.insert_bos_eos(bos_index=0, eos_index=1)
>>> encoder.expect_len(7)
>>> encoder.lab2ind[encoder.eos_label]
1
>>>
insert_bos_eos allows to specify whose index will correspond to each of them.
Note that you can also specify the same integer encoding for both.
Four methods can be used to prepend <bos> and append <eos>.
prepend_bos_label and append_eos_label add respectively the <bos> and <eos>
string tokens to the input sequence
>>> words = ["foo", "bar"]
>>> encoder.prepend_bos_label(words)
['<bos>', 'foo', 'bar']
>>> encoder.append_eos_label(words)
['foo', 'bar', '<eos>']
prepend_bos_index and append_eos_index add respectively the <bos> and <eos>
indexes to the input encoded sequence.
>>> words = ["foo", "bar"]
>>> encoded = encoder.encode_sequence(words)
>>> encoder.prepend_bos_index(encoded)
[0, 3, 4]
>>> encoder.append_eos_index(encoded)
[3, 4, 1]
"""
def handle_special_labels(self, special_labels):
"""Handles special labels such as bos and eos."""
super().handle_special_labels(special_labels)
# NOTE: bos_label and eos_label are not necessarily set at all!
# This is because None is a suitable value.
# So the test is: hasattr(self, "bos_label")
# rather than self.bos_label is not None
# Same thing with unk, see base class.
if "bos_label" in special_labels and "eos_label" in special_labels:
self.insert_bos_eos(
bos_label="<bos>",
eos_label="<eos>",
bos_index=special_labels["bos_label"],
eos_index=special_labels["eos_label"],
)
elif "bos_label" in special_labels or "eos_label" in special_labels:
raise TypeError("Only BOS or EOS specified. Need both for init.")
def update_from_iterable(self, iterable, sequence_input=True):
"""Change default for sequence_input to True."""
return super().update_from_iterable(iterable, sequence_input)
def update_from_didataset(self, didataset, output_key, sequence_input=True):
"""Change default for sequence_input to True."""
return super().update_from_didataset(
didataset, output_key, sequence_input
)
def limited_labelset_from_iterable(
self, iterable, sequence_input=True, n_most_common=None, min_count=1
):
"""Change default for sequence_input to True."""
return super().limited_labelset_from_iterable(
iterable, sequence_input=True, n_most_common=None, min_count=1
)
def add_bos_eos(
self,
bos_label=DEFAULT_BOS,
eos_label=DEFAULT_EOS,
):
"""Add sentence boundary markers in the label set.
If the beginning-of-sentence and end-of-sentence markers
are the same, will just use one sentence-boundary label.
This method adds to the end of the index, rather than at the beginning,
like insert_bos_eos.
Arguments
---------
bos_label : hashable
Beginning-of-sentence label, any label.
eos_label : hashable
End-of-sentence label, any label. If set to the same label as
bos_label, will just use one sentence-boundary label.
"""
if bos_label == eos_label:
logger.debug(
"BOS and EOS labels are the same so using just one sentence "
"boundary label"
)
self.add_label(bos_label)
else:
self.add_label(bos_label)
self.add_label(eos_label)
self.bos_label = bos_label
self.eos_label = eos_label
def insert_bos_eos(
self,
bos_label=DEFAULT_BOS,
eos_label=DEFAULT_EOS,
bos_index=0,
eos_index=None,
):
"""Insert sentence boundary markers in the label set.
If the beginning-of-sentence and end-of-sentence markers
are the same, will just use one sentence-boundary label.
Arguments
---------
bos_label : hashable
Beginning-of-sentence label, any label
eos_label : hashable
End-of-sentence label, any label. If set to the same label as
bos_label, will just use one sentence-boundary label.
bos_index : int
Where to insert bos_label. eos_index = bos_index + 1
eos_index : optional, int
Where to insert eos_label. Default: eos_index = bos_index + 1
"""
if bos_label == eos_label:
logger.debug(
"BOS and EOS labels are the same so using just one sentence "
"boundary label"
)
self.insert_label(bos_label, bos_index)
else:
self.insert_label(bos_label, bos_index)
if eos_index is None:
logger.debug("EOS label not specified, using BOS label + 1")
self.insert_label(eos_label, bos_index + 1)
else:
self.insert_label(eos_label, eos_index)
self.bos_label = bos_label
self.eos_label = eos_label
def get_bos_index(self):
"""Returns the index to which blank encodes"""
if not hasattr(self, "bos_label"):
raise RuntimeError("BOS label is not set!")