Bài 14: Long short term memory (LSTM)
Nội dung
Giới thiệu về LSTM
Bài trước mình đã giới thiệu về recurrent neural network (RNN). RNN có thể xử lý thông tin dạng chuỗi (sequence/ time-series). Như ở bài dự đoán hành động trong video ở bài trước, RNN có thể mang thông tin của frame (ảnh) từ state trước tới các state sau, rồi ở state cuối là sự kết hợp của tất cả các ảnh để dự đoán hành động trong video.
Đạo hàm của L với W ở state thứ i: \displaystyle \frac{\partial L}{\partial W} = \frac{\partial L}{\partial \hat{y}} * \frac{\partial \hat{y}}{\partial s_{30}} * \frac{\partial s_{30}}{\partial s_i} * \frac{\partial s'_i}{\partial W}, trong đó \displaystyle \frac{\partial s_{30}}{\partial s_i} = \prod_{j=i}^{29} \frac{\partial s_{j+1}}{\partial s_j}
Giả sử activation là tanh function, \displaystyle s_{t} = tanh(U*x_t + W*s_{t-1})
\displaystyle \frac{\partial s_{t}}{\partial s_{t-1}} = (1-s_{t}^2) * W => \frac{\partial s_{30}}{\partial s_i} = W^{30-i} * \prod_{j=i}^{29} (1-s_j^2).
Ta có \displaystyle s_j < 1, W < 1 => Ở những state xa thì \displaystyle \frac{ \partial s_{30}}{\partial s_i} \approx 0 hay \displaystyle \frac{\partial L}{\partial W} \approx 0 , hiện tượng vanishing gradient
Ta có thể thấy là các state càng xa ở trước đó thì càng bị vanishing gradient và các hệ số không được update với các frame ở xa. Hay nói cách khác là RNN không học được từ các thông tin ở trước đó xa do vanishing gradient.
Như vậy về lý thuyết là RNN có thể mang thông tin từ các layer trước đến các layer sau, nhưng thực tế là thông tin chỉ mang được qua một số lượng state nhất định, sau đó thì sẽ bị vanishing gradient, hay nói cách khác là model chỉ học được từ các state gần nó => short term memory.
Cùng thử lấy ví dụ về short term memory nhé. Bài toán là dự đoán từ tiếp theo trong đoạn văn. Đoạn đầu tiên “Mặt trời mọc ở hướng …”, ta có thể chỉ sử dụng các từ trước trong câu để đoán là đông. Tuy nhiên, với đoạn, “Tôi là người Việt Nam. Tôi đang sống ở nước ngoài. Tôi có thể nói trôi chảy tiếng …” thì rõ ràng là chỉ sử dụng từ trong câu đấy hoặc câu trước là không thể dự đoán được từ cần điền là Việt. Ta cần các thông tin từ state ở trước đó rất xa => cần long term memory điều mà RNN không làm được => Cần một mô hình mới để giải quyết vấn đề này => Long short term memory (LSTM) ra đời.
Mô hình LSTM
Ở state thứ t của mô hình LSTM:
- Output: c_t, h_t , ta gọi c là cell state, h là hidden state.
- Input: c_{t-1}, h_{t-1}, x_t. Trong đó x_t là input ở state thứ t của model. c_{t-1}, h_{t-1} là output của layer trước. h đóng vai trò khá giống như s ở RNN, trong khi c là điểm mới của LSTM.
Các đọc biểu đồ trên: bạn nhìn thấy kí hiệu \sigma, tanh ý là bước đấy dùng sigma, tanh activation function. Phép nhân ở đây là element-wise multiplication, phép cộng là cộng ma trận.
f_t, i_t, o_t tương ứng với forget gate, input gate và output gate.
- Forget gate: \displaystyle f_t = \sigma(U_f*x_t + W_f*h_{t-1} + b_f)
- Input gate: \displaystyle i_t = \sigma(U_i*x_t + W_i*h_{t-1} + b_i)
- Output gate: \displaystyle o_t = \sigma(U_o*x_t + W_o*h_{t-1} + b_o)
Nhận xét: 0 < f_t, i_t, o_t < 1; b_f, b_i, b_o là các hệ số bias; hệ số W, U giống như trong bài RNN.
\displaystyle \tilde{c_t} = \tanh(U_c*x_t + W_c*h_{t-1} + b_c) , bước này giống hệt như tính s_t trong RNN.
\displaystyle c_t = f_t * c_{t-1} + i_t * \tilde{c_t}, forget gate quyết định xem cần lấy bao nhiêu từ cell state trước và input gate sẽ quyết định lấy bao nhiêu từ input của state và hidden layer của layer trước.
\displaystyle h_t = o_t * tanh(c_{t}), output gate quyết định xem cần lấy bao nhiêu từ cell state để trở thành output của hidden state. Ngoài ra h_t cũng được dùng để tính ra output y_t cho state t.
Nhận xét: h_t, \tilde{c_t} khá giống với RNN, nên model có short term memory. Trong khi đó c_t giống như một băng chuyền ở trên mô hình RNN vậy, thông tin nào cần quan trọng và dùng ở sau sẽ được gửi vào và dùng khi cần => có thể mang thông tin từ đi xa=> long term memory. Do đó mô hình LSTM có cả short term memory và long term memory.
LSTM chống vanishing gradient
Ta cũng áp dụng thuật toán back propagation through time cho LSTM tương tự như RNN.
Thành phần chính gây là vanishing gradient trong RNN là \displaystyle \frac{\partial s_{t+1}}{\partial s_t} = (1-s_{t}^2) * W , trong đó s_t, W < 1.
Tương tự trong LSTM ta quan tâm đến \displaystyle \frac{\partial c_t}{\partial c_{t-1}} =f_t. Do 0 < f_t < 1 nên về cơ bản thì LSTM vẫn bị vanishing gradient nhưng bị ít hơn so với RNN. Hơn thế nữa, khi mang thông tin trên cell state thì ít khi cần phải quên giá trị cell cũ, nên f_t \approx 1 => Tránh được vanishing gradient.
Do đó LSTM được dùng phổ biến hơn RNN cho các toán thông tin dạng chuỗi. Bài sau mình sẽ giới thiệu về ứng dụng LSTM cho image captioning.