Issue #71 – Knowledge Distillation for Neural Machine Translation
Author: Dr. Chao-Hong Liu, Machine Translation Scientist @ Iconic
While Neural Machine Translation (NMT) has shown its ability to translate sentences from one language into another, it requires very high capacity to perform well. A typical NMT network usually contains thousands of neurons (e.g. LSTM) in multiple layers (e.g. four to sixteen), which makes it very hard to train good models, and use them. In this post we review the work done by Kim and Rush (2016), who talked about a word-level knowledge distillation approach, and the proposed two novel methods, i.e. sequence-level knowledge distillation and sequence-level interpolation. Thanks to Dr Marcin Junczys-Dowmunt; I discovered this paper from his tweet.
Knowledge distillation, in the context of neural machine learning (ML), is the training of a smaller network (student) who learns from an already trained network (teacher). The teacher itself was trained directly from training set with a large network as it is necessary for the trained models to perform well. The idea is that we could train a smaller network (student) from the teacher, rather than directly from the training set, and the student will perform much faster and approximately as well as the teacher, with the hope that it might even outperform the teacher.
Figure 1. Overview of three knowledge distillation approaches for NMT, excerpted from Kim and Rush (2016).
Knowledge Distillation for NMT
The application of knowledge distillation in NMT is very desirable in this case, because NMT models are normally huge with thousands of neurons to perform well, while trained directly from millions of sentence pairs in the training set. Fig. 1 shows three different approaches to knowledge distillation for NMT. On the left-hand side is the word-level knowledge distillation. The teacher itself follows the traditional ML approach, to be trained via minimizing cross-entropy over the training set. The student, on the other hand, is trained by minimizing the cross-entropy with the teacher’s probability distribution (shown in yellow in the diagram).
Kim and Rush (2016) proposed two sequence-level knowledge distillation methods, which are shown on the middle and right-hand side of Fig. 1. The idea is that, instead of learning directly from the training set as the teacher did, the student learns from the output of the teacher. The diagram in the middle of Fig. 1 shows that the student is trained using the output with the highest score from the teacher’s beam search. While on the right-hand side of the diagram in Fig. 1, the student learns from the output which is most similar to the target sequence.
Experiments and Results
Datasets Two setups of experiments are conducted. For the high resource scenario, English-to-German models are trained using WMT 2014 dataset (4 million sentence pairs), with newstest2012/newstest2013 and newstest2014 as dev set and test set, respectively. For the low resource scenario, Thai-to-English models are trained with IWSLT 2015 dataset (90K sentence pairs), using 2010/2011/2012 data as dev set and 2012/2013 as test set.
Results In English-to-German experiments, the vocabulary size is set as 50K, the teacher model is 4×1000 LSTM neural network, and the student models are 2×300 and 2×500 networks. In Thai-to-English experiments, the vocabulary size is 25K, the teacher model is 2×500 network, and the student model is a 2×100 network. The results show that in general, combining the three knowledge distillation methods will improve MT performance by four points in terms of BLEU. And the two sequence-level methods contribute much larger gain compared to the word-level knowledge distillation approach.
Kim and Rush (2016) instigated the knowledge distillation methods for NMT. The results show that MT performance can be improved by 4 points in BLEU, and that the decoding speed of student models could be ten times faster than teacher models. These knowledge distillation methods can also be applied to all sequence-to-sequence tasks, e.g. parsing and POS-tagging.