@@ -194,26 +194,38 @@ def test_mel_filter_bank_kaldi(self):
194
194
triangularize_in_mel_space = True ,
195
195
)
196
196
# fmt: off
197
+ # here the expected values from torchaudio.compliance.kaldi.get_mel_banks
198
+ # note that we compute values in float64 while they do it in float32
197
199
expected = np .array (
198
- [[0.0000 , 0.0000 , 0.0000 , 0.0000 ],
199
- [0.6086 , 0.0000 , 0.0000 , 0.0000 ],
200
- [0.8689 , 0.1311 , 0.0000 , 0.0000 ],
201
- [0.4110 , 0.5890 , 0.0000 , 0.0000 ],
202
- [0.0036 , 0.9964 , 0.0000 , 0.0000 ],
203
- [0.0000 , 0.6366 , 0.3634 , 0.0000 ],
204
- [0.0000 , 0.3027 , 0.6973 , 0.0000 ],
205
- [0.0000 , 0.0000 , 0.9964 , 0.0036 ],
206
- [0.0000 , 0.0000 , 0.7135 , 0.2865 ],
207
- [0.0000 , 0.0000 , 0.4507 , 0.5493 ],
208
- [0.0000 , 0.0000 , 0.2053 , 0.7947 ],
209
- [0.0000 , 0.0000 , 0.0000 , 0.9752 ],
210
- [0.0000 , 0.0000 , 0.0000 , 0.7585 ],
211
- [0.0000 , 0.0000 , 0.0000 , 0.5539 ],
212
- [0.0000 , 0.0000 , 0.0000 , 0.3599 ],
213
- [0.0000 , 0.0000 , 0.0000 , 0.1756 ]]
200
+ [
201
+ [0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 ],
202
+ [0.6457883715629578 , 0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 ],
203
+ [0.8044781088829041 , 0.1955219060182571 , 0.0000000000000000 , 0.0000000000000000 ],
204
+ [0.3258901536464691 , 0.6741098165512085 , 0.0000000000000000 , 0.0000000000000000 ],
205
+ [0.0000000000000000 , 0.9021250009536743 , 0.0978749766945839 , 0.0000000000000000 ],
206
+ [0.0000000000000000 , 0.5219038724899292 , 0.4780961275100708 , 0.0000000000000000 ],
207
+ [0.0000000000000000 , 0.1771058291196823 , 0.8228941559791565 , 0.0000000000000000 ],
208
+ [0.0000000000000000 , 0.0000000000000000 , 0.8616894483566284 , 0.1383105516433716 ],
209
+ [0.0000000000000000 , 0.0000000000000000 , 0.5710380673408508 , 0.4289619624614716 ],
210
+ [0.0000000000000000 , 0.0000000000000000 , 0.3015440106391907 , 0.6984559893608093 ],
211
+ [0.0000000000000000 , 0.0000000000000000 , 0.0503356307744980 , 0.9496643543243408 ],
212
+ [0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 , 0.8150880336761475 ],
213
+ [0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 , 0.5938932299613953 ],
214
+ [0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 , 0.3851676583290100 ],
215
+ [0.0000000000000000 , 0.0000000000000000 , 0.0000000000000000 , 0.1875794380903244 ],
216
+ ],
217
+ dtype = np .float64 ,
214
218
)
215
219
# fmt: on
216
- self .assertTrue (np .allclose (mel_filters , expected , atol = 5e-5 ))
220
+
221
+ # kaldi implementation does not compute values for last fft bin
222
+ # indeed, they enforce max_frequency <= sampling_rate / 2 and
223
+ # therefore they know that last fft bin filter bank values will be all 0
224
+ # and pad after with zeros
225
+ # to comply with our API for `mel_filter_bank`, we need to also pad here
226
+ expected = np .pad (expected , ((0 , 1 ), (0 , 0 )))
227
+
228
+ self .assertTrue (np .allclose (mel_filters , expected ))
217
229
218
230
def test_mel_filter_bank_slaney_norm (self ):
219
231
mel_filters = mel_filter_bank (
@@ -369,7 +381,7 @@ def test_spectrogram_integration_test(self):
369
381
self .assertTrue (np .allclose (spec [:64 , 400 ], expected ))
370
382
371
383
mel_filters = mel_filter_bank (
372
- num_frequency_bins = 256 ,
384
+ num_frequency_bins = 257 ,
373
385
num_mel_filters = 400 ,
374
386
min_frequency = 20 ,
375
387
max_frequency = 8000 ,
@@ -379,8 +391,6 @@ def test_spectrogram_integration_test(self):
379
391
triangularize_in_mel_space = True ,
380
392
)
381
393
382
- mel_filters = np .pad (mel_filters , ((0 , 1 ), (0 , 0 )))
383
-
384
394
spec = spectrogram (
385
395
waveform ,
386
396
window_function (400 , "povey" , periodic = False ),
@@ -510,7 +520,7 @@ def test_spectrogram_batch_integration_test(self):
510
520
self .assertTrue (np .allclose (spec_list [2 ][:64 , 400 ], expected3 ))
511
521
512
522
mel_filters = mel_filter_bank (
513
- num_frequency_bins = 256 ,
523
+ num_frequency_bins = 257 ,
514
524
num_mel_filters = 400 ,
515
525
min_frequency = 20 ,
516
526
max_frequency = 8000 ,
@@ -520,8 +530,6 @@ def test_spectrogram_batch_integration_test(self):
520
530
triangularize_in_mel_space = True ,
521
531
)
522
532
523
- mel_filters = np .pad (mel_filters , ((0 , 1 ), (0 , 0 )))
524
-
525
533
spec_list = spectrogram_batch (
526
534
waveform_list ,
527
535
window_function (400 , "povey" , periodic = False ),
0 commit comments