Bài 8: CycleGAN | Deep Learning cơ bản
 

Bài 8: CycleGAN

| Posted in GAN

Giới thiệu về CycleGAN

Bài trước mình đã nói về mạng Pix2pix thuộc bài toán supervised uni-model image to image translation. Dataset mình cần chuẩn bị từng pair (input, output) tốn rất nhiều thời gian, công sức để làm. Bài này mình sẽ giới thiệu mạng CycleGAN thuộc bài toán unsupervised uni-model image to image translation. Bài toán hôm nay là chuyển từ ảnh ngựa thường sang ngựa vằn và ngược lại. Mình thấy hình dạng con ngựa vẫn không thay đổi mà chỉ màu con ngựa thay đổi để từ ngựa thường hay ngựa vằn và ngược lại.

chuyển từ ngựa vằn sang ngựa thường

Do là bài toán unsupervised nên dataset của mình chỉ có 2 tập, 1 tập gồm ảnh ngựa vằn và 1 tập ảnh ngựa thường, việc của mình là chuyển từ ngựa thường sang ngựa vằn. Nếu như ở bài toán supervised trong dataset mình có từng pair (input, output) nên mình có thể xây dựng mạng GAN như pix2pix để học mapping giữa input và output thì ở bài toán unsupervised thì sẽ khó hơn không có từng pair ảnh, cùng xem CycleGAN làm thế nào để học được nhé.

So sánh bài toán supervised (trái) và unsupervised (phải)

CycleGAN model

Nếu như ở pix2pix mình có pair (input, output) để train thì khi generator sinh ảnh mình có thể dùng L1 loss hay GAN loss để train. Tuy nhiên ở bài toán unsupervised mình không có từng pair để train như thế.

Mọi người nhớ ở bài toán unsupervised trong autoencoder để tìm latent space, mình phải dựng thêm decoder để tạo loss function để train mạng. Ý tưởng là việc nén chỉ tốt khi giải nén lại được toàn vẹn dữ liệu. Gọi dataset A là ảnh ngựa vằn, dataset B là ảnh ngựa thường, CycleGAN cũng có ý tưởng tương tự, mạng sẽ có 2 generator gọi là G_{ab}G_{ba}. G_{ab} sẽ có input là ảnh ngựa vằn và output ảnh ngựa thường, còn G_{ba} sẽ input ảnh ngựa thường và output là ảnh ngựa vằn.

CycleGAN mode, nguồn.

Như vậy ta mong muốn là ảnh ngựa vằn qua G_{ab} sau đó ouput qua G_{ba} sẽ thành ảnh ban đầu, hay G_{ba}(G_{ab}(a)) = a (loss này khá giống auto-encoder)

Bên cạnh đó mình cũng có 2 Discriminator, gọi là D_{a}D_{b}. Trong khi D_{b} phân biệt ảnh ngựa thường trong dataset B và ngựa thường do G_{ab} sinh ra, D_{a} để phân biệt ảnh ngựa vằn trong dataset A và ngựa vằn do G_{ba} sinh ra.

Như vậy ngoài L2 loss để khôi phục ảnh ban đầu qua 2 generator thì mình còn có 2 GAN loss để học cho ảnh sinh ra bởi các generator giống với các ảnh trong dataset A và dataset B.

Tóm lại mình có 2 generators (G_{ab}, G_{ba}) và 2 discriminator (D_{a}, D_{b}):

  • dataset A: ảnh ngựa vằn, dataset B: ảnh ngựa thường
  • G_{ab}: chuyển từ ngựa vằn sang ngựa thường, G_{ba}: chuyển từ ngựa thường sang ngựa vằn
  • D_{a}: phân biệt ảnh ngựa vằn trong dataset A và ảnh do G_{ba} sinh ra. D_{b}: phân biệt ảnh ngựa thường trong dataset B và ảnh do G_{ab} sinh ra.

Generator

Mạng generator input là ảnh 1 domain output là ảnh domain khác, nên kiến trúc tương tự generator của mạng pix2pix.

Kiến trúc cycleGAN generator, nguồn.

Generator gồm 3 phần: encoder, transformer và decoder. Phần encoder giảm kích thước ảnh cũng như tăng depth bằng 3 conv liên tiếp. Sau đó output được cho qua phần transformer với 6 residual block và cuối cùng được cho qua phần decoder để về ảnh có kích thước giống ban đầu.

Nhận xét: mạng dùng conv layer với stride = 2 để thay pooling layer. Mạng toàn conv layer mà không có fully connected layer nên có thể nhận input ảnh kích thước tùy ý.

Discriminator

Discriminator dùng để phân biệt ảnh sinh ra bởi generator hay ảnh thật trong dataset.

Ý tưởng vẫn sử dụng PatchGAN để phân biệt từng vùng nhỏ trên ảnh để cho hiệu suất tốt hơn.

Kiến trúc CycleGAN discriminator, nguồn.

Loss function

Như nói ở trên ngoài 2 GAN loss function mình còn có reconstruction loss (l2 loss).

GAN loss

\displaystyle L_{GAN(A2B)} = \operatorname{\mathbb{E}}_{b\sim p_{data}(b)}[\log D_b(b)] + \operatorname{\mathbb{E}}_{a\sim p_{data}(a)}[\log (1 - D_b(G_{ab}((a))] \newline\newline \displaystyle L_{GAN(B2A)} = \operatorname{\mathbb{E}}_{a\sim p_{data}(a)}[\log D_a(a)] + \operatorname{\mathbb{E}}_{b\sim p_{data}(b)}[\log (1 - D_a(G_{ba}((b))] \newline\newline

Cycle-consistancy loss

\displaystyle L_{cyc} = \operatorname{\mathbb{E}}_{a\sim p_{data}(a)}[||G_{ba}(G_{ab}(a)) - a||_2] + \operatorname{\mathbb{E}}_{b\sim p_{data}(b)}[||G_{ab}(G_{ba}(b)) - b||_2] \newline\newline

Full loss

\displaystyle L = L_{GAN(A2B)} + L_{GAN(B2A)} + \lambda L_{cyc}

\lambda là hyper-paramerter, được khuyên chọn bằng 10.

Bản thân GAN loss có thể mapping phân phối giữa A và B, nhưng để mapping được như vậy thì cần rất nhiều ảnh nên L2 loss được thêm vào để model tạo ra ảnh thật hơn.

Code

Code về cycleGAN mọi người tham khảo ở đây.

Còn implement cycleGAN lại đơn giản hơn mọi người tham khảo ở đây.

Nhận xét: mặc dù bài toán là unsupervised nhưng cycleGAN cho ra chất lượng ảnh khá tốt. Tuy nhiên nó thường chỉ tốt với những task liên quan đến thay đổi màu sắc (style) như ngày sang đêm, còn với những task mà có sự thay đổi lớn về hình dạng của input và output như chó sang mèo thì kết quả ra không tốt.

CycleGAN chuyển mèo sang chó, nguồn.

Bênh cạnh đó nó cũng có thể detect nhầm đối tượng xung quanh thành ngựa rồi chuyển đặc điểm dẫn tới nhầm lẫn.

CycleGAN chuyển ngựa thường sang ngựa vằn, nguồn.


Deep Learning cơ bản ©2024. All Rights Reserved.
Powered by WordPress. Theme by Phoenix Web Solutions