Skip to content

This repository contains scripts to prune Wav2vec2 using a neuroevolution-based method. More details about this method can be found in the paper Compressing Wav2vec2 for Embedded Applications.

License

Notifications You must be signed in to change notification settings

oswaldoludwig/Pruning-pre-trained-models-using-evolutionary-computation

Repository files navigation

Pruning pre-trained models using evolutionary computation

This repository contains scripts to prune Wav2vec2 using a neuroevolution-based method. More details about this method can be found in the paper Compressing Wav2vec2 for Embedded Applications. In case of publication using ideas or code snippets from this repository, please cite this article.

The algorithm applies a customized version of Genetic Algorithms (GA) specially designed to solve the combinatorial optimization problem associated with pruning, which means running many copies of the Wav2vec2 decoder in parallel using multiprocessing on a computer grid. This method can be applied to any pre-trained AI model to preserve as much information as possible from its pre-training. The scripts are fully commented, indicating how parallel processing is configured and how Wav2vec2 tensors are pruned.

The run_distributed_GA.sh script shows how to call the distributed_GA.py master script that runs GA and automatically generates and calls the population_eval.sh Shell script that contains the commands to parallelize GA population evaluation on a computer grid.

The pseudocode below explains the method. Algorithm 1 describes the main structure of the GA code, while Algorithm 2 details how a new chromosome is assembled given the constraint of not repeating line block indexes in a fully-connected layer, which may imply the application of the mutation operator; see Line 15 of Algorithm 2. The particularities of this algorithm are circled in red.

alt tag

This repository also provides a toy dataset generated by applying neural TTS to some sentences from the paper to exemplify the data format; see the paper_abstract_TTS.hrl file and respective wav files in the wav_files_paper_abstract folder. An example of how to build a Singularity container for this codebase is also provided in the container_for_w2v.def file. Wav2vec2 checkpoint can be downloaded from https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english/tree/main

This code was originally designed to run on a computer grid (CPUs only, no GPU required) with Simple Linux Utility for Resource Management (Slurm 22.05.3), Ubuntu 20.04.3 LTS, Torch 1.9.0 + cu111. Other configurations may require minor adaptations.