Bài 9: StarGAN
Nội dung
Giới thiệu về StarGAN
Bài trước mình đã giới thiệu mạng CycleGAN thuộc bài toán unsupervised uni-model image to image translation. Mình có thể chuyển thuộc tính từ ngựa thường sang ngựa vằn và ngược lại. Tuy nhiên cycleGAN chỉ học để chuyển từ 1 domain sang 1 domain khác. Giả sử bài toán của mình giờ là chuyển từ ảnh người tóc đen sang tóc vàng, nam sang nữ, già sang trẻ, mặt trắng nhợt sang bình thường thì mình cần build 4 model cycleGAN như vậy không hiệu quả đặc biệt khi số lượng domain mình tăng lên và giờ mình cần một GAN model có thể học và chuyển đổi nhiều domain khác nhau như tóc, giới tính, tuổi, da,… Và StarGAN sinh ra để giải quyết vấn đề đấy.
Tên paper: “StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation”, multi-domain như mình nói là có thể học và chuyển đổi được nhiều domain, unified là hợp nhất tức là chuyển đổi nhiều domain chỉ trong 1 mạng GAN. Trước khi vào bài toán mình cùng xem dataset CelebA
Dataset CelebA
CelebFaces Attributes Dataset (CelebA) bao gồm hơn 200.000 ảnh người nổi tiếng, với mỗi ảnh có 40 thuộc tính như tóc đen, tóc vàng, đeo kính, đeo mũ, … Thuộc tính mỗi ảnh sẽ được biểu diễn dưới dạng kiểu one-hot như sau [1 0 0 1 0 0 0 0 0…] (thuộc tính nào có trong ảnh sẽ được biểu diễn 1, không có được biểu diễn 0). Mỗi ảnh có nhiều giá trị 1 vì mỗi ảnh có thể có nhiều thuộc tính như tóc vàng, đeo kính,…)
Bài toán của mình hôm nay là sử dụng bộ dataset CelebA với 5 thuộc tính để học multi-domain image to image translation dùng StarGAN.
StarGAN model
Generator của StarGAN mang ý tưởng giống với conditional GAN (cGAN) tức mình sẽ generate ra ảnh với condition vào điều kiện nào đấy (ở đây là các thuộc tính mình muốn chuyển đổi sang). Input sẽ là ảnh gốc và target domain mà mình mong muốn chuyển thuộc tính sang, output sẽ là ảnh sinh ra với các target domain tương ứng (b). Ảnh sinh ra ở (b) sẽ được truyền đến 2 nhánh ở (c) và (d).
Sau đấy ảnh fake sinh ra sẽ kết hợp với các domain gốc ban đầu của ảnh được cho qua generator với mục đích học lại ảnh ban đầu. Ví dụ: ảnh gốc tóc đen + domain tóc vàng -> ảnh tóc vàng. Sau đó ảnh tóc vàng sinh ra + domain tóc đen -> ảnh tóc đen. Ở đây mình sẽ có construction loss, tức là ảnh sinh ra ở (c) sẽ giống với ảnh ban đầu input ở (b).
Discriminator thì input sẽ là 1 ảnh (ảnh fake do nhánh (b) sinh ra hoặc ảnh thật trong dataset) và sẽ phân biệt đấy là ảnh thật hay ảnh fake. Bên cạnh đó thì discriminator còn phân loại ảnh tới đúng domain của nó (bài toán multiple classification) ví dụ input là ảnh tóc vàng và đeo kính thì discriminator sẽ phân loại là tóc vàng và đeo kính.
Ta thấy StarGAN model gồm duy nhất 1 mạng Generator và 1 mạng Discriminator. Có 2 điểm khác so với GAN bình thường. Thứ nhất là reconstruction, phần này thì CycleGAN cũng có và thứ hai là classification các domain tương ứng của mỗi ảnh để giúp StarGAN có thể học và chuyển đổi được nhiều domain.
Generator
Input generator sẽ là ảnh gốc (512*512*3) và 1 vector (5*1) dạng one-hot thể hiện target domain mà mình mong muốn.
Vector 5*1 sẽ được lặp lại các giá trị thành tensor 3d kích thước 512*512*5 (cùng width và height với ảnh). Mọi người tưởng tượng như ở ma trận 512*512 thì mỗi ô sẽ chứa giá trị vector 5*1 thì sẽ thành tensor 3d 512*512*5.
Sau đó ảnh gốc sẽ được nối với tensor 3d sinh ra từ vector domain thành tensor 3d kích thước 512*512*(5+3) = 512*512*8
Sau đó tensor 3d này sẽ được cho vào mô hình generator và cho ra output ảnh màu kích thước 512 * 512. Mô hình giống với pix2pix dạng U-net và có dùng residual block.
Discriminator
Discriminator input sẽ là 1 ảnh và có 2 output:
- Phân biệt ảnh thật với ảnh fake (binary classification)
- Phân loại ảnh đúng với các domain của nó (multiple classification)
Cấu trúc của Discriminator trong StarGAN giống như 1 Discriminator bình thường, tuy nhiên ở layer gần cuối sẽ cho qua 2 convolutional layers riêng để cho ra 2 outputs.
Loss function
Các ảnh chụp loss function được lấy từ paper gốc.
x là dữ liệu trong dataset CelebA, c là target domain mà mình muốn chuyển sang, c’ là domain ban đầu ảnh.
GAN loss function (phân biệt ảnh thật/fake)
Classification loss, phân loại ảnh x trong dataset với các domain c’ tương ứng, phần -log(D) chính là entropy loss
Discrimination loss, phân loại ảnh sinh ra G(x, c) với các domain c tương ứng.
Reconstruction loss, muốn ảnh khôi phục lại khi qua generator 2 lần gần với ảnh ban đầu.
Discrimination loss bao gồm GAN loss và classification loss
Generator loss bao gồm GAN loss, classification loss và reconstruction loss
Các hệ số \lambda_{cls} và \lambda_{rec} là hyperparameter, trong paper gốc tác giả dùng \lambda_{cls} = 1 và \lambda_{rec}=10
Code
Code mọi người xem ở github theo paper. Code đơn giản, dễ đọc, có pre-trained model trên dữ liệu CelebA và RafD để mọi người thử.