A ConvNet that works well with 20 samples: Wavelet Scattering

lihan yao
Towards Data Science
10 min readJul 25, 2019

--

Scatter Transform architecture, from J. Bruna and S. Mallat, Invariant Scattering Convolution Networks (2012)

Often in data-constrained scenarios, scene comprehension has to occur with few time series observations - whether that’s audio, visual, or even radar. We do this using a surprisingly underrated technique called wavelet scattering.

Wavelet scattering (or scatter transform) generates a representation that’s invariant to data rotation/translation and stable to deformations of your data. Uninformative variations in your data are discarded — e.g. an audio sample time-shifted by various amounts. Information for downstream tasks like classification is preserved. Wavelet scattering requires no training and works great with low data.

Its main computation is convolution, making it fast and applicable to images and 1D signals. We focus on signals in this article. We will retrace findings by the signal processing community and relate it to modern machine learning concepts. I show that, yes, we can do great without learning, using 20 samples. Recreate experiments and illustrations in this article with the colab notebook in this link.

Wavelets

A wavelet can be convolved with the signal in the same sense that filters can. I think of convolution as the continuous analog to inner products, where large activation (commonly said in ML) or wavelet coefficient is caused by similarity between the continuous objects. By convolving elements from a dictionary to the signal under inspection, we capture local, spatial dependencies.

Convolution is a pivotal computation in the emergence of deep learning — it is extremely fast. The wavelet scattering implementation used by this article calls a deep learning backend solely for the efficient convolution! Kymatio is a great Python package built by passionate researchers that implement wavelet scattering, leveraging the PyTorch framework.

A Morlet Wavelet with real and complex components, plotted in time domain
Real and imaginary components of the Morlet Wavelet from M. Adamczyk et al., Automatic Sleep Spindle Detection and Genetic Influence Estimation Using Continuous Wavelet Transform (2015)

The basic building block of wavelet scattering is the Morlet wavelet. It is a Gaussian windowed sinusoid with deep connections to mammal hearing and vision. By convolving wavelets ψᵥ indexed by different frequency locations v, the wavelet transform of x is the set of scatter coefficients

{ x ψᵥ }ᵥ

When the wavelet’s sine component has room to dilate (sine wave ‘slowing’ its oscillation), it decomposes the signal at decorrelated scales. That’s good for revealing the signal’s frequency structure, but doing so over the course of a longer time range. The consequence is that a wider Gaussian window trades temporal resolution for increased frequency resolution ( itself a consequence of Heisenberg’s Uncertainty Principle). In practice, the width of the Gaussian window that tapers the sine wave is an important parameter [M. Cohen 2018].

Wavelet Scattering

The historic context of wavelet scattering starts with Fourier transform, the canonical signal processing technique. The shortcoming of Fourier representation includes its instability to signal deformations at high frequency. For a signal x perturbed slightly by a high frequency deformation into x̃, their spectrogram representations look different ( large ‖ FFT(x) -FFT(x̃) ) even if they remain similar signals to the human eye. This instability is due to sine wave’s inability to localize frequency information, since sine itself has non-localized support.

November, Golden Gardens credit: u/purebredcrab at reddit.com/r/analog

Wavelet transform fixes this by decomposing the signal with a family of wavelets, with various dilation, where every wavelet has localized support (flattening out eventually like the Morlet wavelet). The resulting wavelet representation localizes high frequency components of the signal. Yet because the wavelet operator commutes with translations, the resulting representation becomes translation covariant — shifting a signal also shifts its wavelet coefficients. This makes comparison between translated signals difficult, and translation invariance is key to tasks like classification. How do we achieve a signal representation Φ(x) that is translation invariant, stable under deformations, and offers good structural information at all frequencies?

Wavelet scattering builds a signal representation Φ(x) with a redundant dictionary of Morlet wavelets. While the space of signals X can be really high dimensional, the transform forms a kernel metric over the space of signals, inducing a lower dimensional manifold. Watch Stéphane Mallat discuss the manifold interpretation with visualization.

Readers have probably trained a neural convolution network to encode an image into a latent manifold Z, whose code/latent representation is used for classification or structure discovery — and that’s what is happening in analogy. Wavelet scattering encodes the dataset X where uninformative variability in X: translation, rotation, and scaling — the action of groups — are discarded in the process.

The key benefits of transforming a signal by Φ [J. Bruna and S. Mallat, 2013] is that

Φ is invariant to signal translation.

Denote by xₜ a signal identical to x, except translated in time, then Φ(x) = Φ(xₜ).

Φ is stable under signal deformations.

I.e. Φ is Lipschitz continuous to deformations — difference between scatter representations of a signal with its deformed version is linear. A deformation can be some local displacement/distortion (or a ridiculous amount of distortion, as a later example shows). For Lipschitz constant C>0 and a displacement field τ(u) causing deformations to create x̃,

‖Φ( x ) - Φ( x̃ )‖ Cx supᵤ|∇τ(u)|

Where ‖x‖= ∫ ‖x(u)‖²du and supᵤ|∇τ(u)| is the global deformation amplitude.

Φ does not require learning.

The priors introduced by wavelet scattering are nice enough that its performance often makes learning redundant; plus it comes with interpretable features and outputs. In data-constrained scenarios, if comparable data is publicly available, a nice plan is to pipe your small dataset through a pretrained model. But in the difficult situation that your dataset is small and unique, consider wavelet scattering as an initialization for ConvNets and other models. I suspect the future of ‘data constrained learning’ will be in synergizing predefined filters alongside learned filters.

Figure below illustrates stability under deformations. Left we applied scatter transform to the voice of a speaker saying ‘zero’. The scatter representation consists of the coefficients derived from averaging/low pass filter, order 1 wavelets, and order 2 wavelets. Right After applying a displacement field that has mostly masked the structure of the original signal with a sine wave, Φ( x̃ ) is barely affected; the deformation’s effect has been linearized by Φ’s transformation.

Convolution Network A scatter representation consists of order 0, 1, and 2 coefficients, which are generated by composing wavelets in different sequences. Multiple wavelets composed together capture high frequency structure, e.g. the 2nd order coefficients display wave interference ( heard as dissonance in music) within the signal. We now give a brief tour of how computation is done (the complex modulus and averaging filter is more closely explained in the appendix).

At the m-th layer, the set of n pre-defined wavelets { ψᵥ₁ , … , ψᵥₙ} are convolved with coefficients from the (m-1)-th layer, belonging to a previous wavelet ψᵥᵐ⁻¹. So in the figure above, a row of order 2 coefficients results from convolving with ψᵥ¹ in layer 1, then ψᵥ² in layer 2.

Altogether we denote a sequence of wavelets as a length M path p = (ψᵥ¹ , … , ψᵥᴹ). By convolving then taking the complex modulus |⋅|, denote the resulting ordered product of operators U[ v₁, v₂, …, vₘ]. For a path of length 2,

U[ v₁, v₂ ]x =| |x ∗ ψᵥ₁| ∗ ψᵥ₂|

To finally extract coefficients like a row in the above figure, apply an averaging filter φ, and call it S[ v₁, v₂, …, vₘ]:

S[ v₁, v₂ ]x = U[ v₁, v₂ ]x ∗ φ(u)

Starting at data x as root, the set of rooted paths in this wavelet tree specify Φ(x). In practice, paths up to length 2 is enough to extract all relevant frequency information from naturally occurring data. So the distinguishing aspect of scatter transform over wavelet transform is exactly the order 2 coefficients.

A 3-layer scatter network. Notice to extract any coefficients, we must apply the averaging filter Φ. Diagram from J. Bruna and S. Mallat, Invariant Scattering Convolution Networks (2012)

Example Task Let’s apply scatter transform to a real dataset. The Free Spoken Digit Dataset (FSDD) has 2000 recordings from 4 speakers, each speaker saying a digit 50 times. After convolving with wavelets, I plotted a 2D t-SNE projection of resulting high dimensional coefficients. Each point is an audio sample’s scatter representation Φ(x), colored by digit class.

Blue points belong to class zero? Red is one? Maybe a legend would have been helpful Lihan…

In time series classification tasks, rivaling techniques include dynamic time warping, Hidden Markov Models, and neural models like LSTMs and RNNs. Let us see how wavelet scattering perform, even under extremely data-constrained situations. As preprocessing, we Z-normalize the audio samples. Then audio samples are 0-padded to uniform length and ~20 audio samples are dropped to ensure scatter representations have the same dimensions. Continuing in the practice of not learning, we use 3-nearest neighbors as the classifier. 3NN(Φ(x)) predicts the i-th digit if at least 2 neighbors closest to Φ(x) in the training set belong to digit i, or closest neighbor is i ( a possibility during three-way ties).

With 1981 samples, we find that 3-fold classification accuracy is 91.5%, and with a training set of 20 samples, with all other observations in the held-out set, we reach 49%. Not bad, considering I only tried different values for the Q hyperparameter (number of wavelets per octave), and left other dials alone. Try different values for the scale parameter J in the attached notebook.

I’m finished with motivating wavelet scattering. At this point, I would love it if you started exploring the technique on your own, maybe starting with my digit audio jupyter notebook. The appendix go into key properties surrounding wavelet scattering like how to leverage Φ(x) when x is an image, and practical considerations.

Credits They say it takes a dev team to raise a Medium article. This would not be possible without Kymatio researchers and developers, including Edouard, Vincent, and Michael. I’m grateful for the opportunity to explain and leverage this amazing technology at Geometric Data Analytics, under the guidance of Paul Bendich, in collaboration with government agencies.

Appendix

Actually, given a path p = ( ψᵥ₁ , ψᵥ₂ ), extracting the 2nd order coefficients also requires the complex modulus | ᐧ | and averaging filter φ:

S[ v₁, v₂ ]x(u) =||x ∗ ψᵥ₁| ∗ ψᵥ₂|∗ φ(u)

Stephane Mallat explains the predecessors to wavelet scattering, then the need for complex modulus and averaging filter.

The complex modulus is the secret sauce of wavelet scattering. | ᐧ | is a non-linearity applied to coefficients that makes them 1) stable to diffeomorphisms and 2) stable in the euclidean metric L². The complex modulus makes Φ(x) Lipschitz continuous with respect to deformations. Visually, the modulus forms an upper envelope over the coefficients.

Whenever φ is applied to coefficients |x ∗ ψᵥ|, we accept the removal of the high frequency component in order to localize the bin covering the lower spectrum ( discarding the phase of |x ∗ ψᵥ |) . To preserve high frequency information, scatter transform propagates high frequency information further down the wavelet path p before localizing its distinct frequency bin via φ.

Wavelet Scattering for Images The 2D Morlet wavelets may be convolved two-directionally with image data. In addition to the dilation parameter, the 2D wavelets can be rotated.

For example, researchers have found wavelet scattering to excel at classification of images of texture, where the major obstacles are finding a representation that is ‘zoom’, translation and rotation invariant — as soon as these group actions are factored out from the sample generation process, texture classification is straightforward. It is precisely these variabilities that wavelet scattering discount.

Parametrizing Wavelet Dictionary When the wavelet’s sine component has room to dilate (sine wave ‘slowing’ its oscillation), it decomposes the signal at decorrelated scales. That’s good for revealing the signal’s frequency structure, but doing so over the course of a longer time range. The consequence is that a wider Gaussian window trades temporal resolution for increased frequency resolution (itself a consequence of Heisenberg’s Uncertainty Principle). In practice, the width of the Gaussian window that tapers the sine wave is an important parameter [M. Cohen 2018].

A wavelet dictionary in this article is parametrized by J, the maximum log scale and Q, the number of wavelets per octave. Increasing J trades temporal resolution for frequency resolution by increasing wavelet width by a factor of 2^J. A large J means a ‘long’ wavelet that emits fewer wavelet coefficients, leading to shorter wavlet representations.

“For audio signals, it is often beneficial to have a large value for Q (between 4 and 16), since these signals are often highly oscillatory and are better localized in frequency than they are in time.” — kymatio tutorial

For signal processing practitioners, scatter transform’s wavelets form a non-orthogonal, redundant dictionary.

Signal energy is conserved Composition of wavelet convolutions preserves the signal norm, or essentially the signal energy. This follows from the fact that wavelet transform is a contracting and invertible operator. Energy within the original signal is converted to ‘coefficient energy’ by the averaging filter after every convolution along a path. For the Caltech-101 image database, [J. Bruna and S. Mallat, 2013] shows ~99% of the signal energy has been conserved by networks of depth 3, for any value of J.

--

--