Bài 6: Lưu và load model trong Pytorch | Deep Learning cơ bản
 

Bài 6: Lưu và load model trong Pytorch

| Posted in Pytorch

Những bài trước mình đã học cách xây dựng và train deep learning model bằng Pytorch. Tuy nhiên, khi train xong model mình cần lưu được model đã train, để sau có thể dùng để dự đoán hoặc tiếp tục train mà không cần train lại từ đầu. Bài này mình sẽ hướng dẫn lưu và load model trong Pytorch.

State_dict là gì?

State_dict của model là một Python dict, với key là tên của layer và value là parameter của layer đó, bao gồm weight và bias. Bên cạnh model, optimizer (torch.optim) cũng có state_dict, có chứa những thông tin về optimizer’s state, cũng như các hyperparameter đi cùng.

Vì state_dict là Python dict, nên có thể dễ dàng cập nhật, thay đổi, lưu và load lên. Ví dụ như model CNN mình build ở bài Neural Network

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8 * 8 * 8, 32)
        # bài toán phân loại 10 lớp nên output ra 10 nodes
        self.fc2 = nn.Linear(32, 10)
        
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        # flatten về dạng vector để cho vào neural network
        out = out.view(-1, 8 * 8 * 8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

# Initialize model
model = Net()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Mình có thể in state_dict của model và optimizer.

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
'''
Output:
Model's state_dict:
conv1.weight 	 torch.Size([16, 3, 3, 3])
conv1.bias 	 torch.Size([16])
conv2.weight 	 torch.Size([8, 16, 3, 3])
conv2.bias 	 torch.Size([8])
fc1.weight 	 torch.Size([32, 512])
fc1.bias 	 torch.Size([32])
fc2.weight 	 torch.Size([10, 32])
fc2.bias 	 torch.Size([10])
'''

Mọi người thấy là state_dict của model có key là tên layer như conv1.weight, conv1.bias,…,fc2.weight, fc2.bias còn value là các tensor hệ số tương ứng với các layer đó.

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
'''
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.03, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': True, 'params': [0, 1, 2, 3, 4, 5, 6, 7]}]
'''

Tại mình chưa train nên state vẫn chưa có gì. Còn param_groups mọi người thấy các hyperparameter của optimizer như learning rate, momentum,…. Tuy nhiên sau khi mình train 1 vài epoch:

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
'''
state 	 {..., 7: {'momentum_buffer': tensor([-0.1845,  0.0265, -0.1256,  0.0954,  0.1391,  0.0451, -0.1127,  0.0178,
         0.0500,  0.0489])}}
param_groups 	 [{'lr': 0.03, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': True, 'params': [0, 1, 2, 3, 4, 5, 6, 7]}]
'''

Mọi người thấy là do optimizer mình có dùng momentum nên giờ state sẽ lưu lại giá trị cộng dồn của các tham số tương ứng, state giờ sẽ là 1 dict, với key là index của layer (0-7) còn value là giá trị cộng dồn của các hệ số (weight, bias).

Lưu và load model cho inference

Lưu state_dict của model

torch.save(model.state_dict(), PATH)

trong đó PATH là đường dẫn đến file lưu model, thông thường pytorch model lưu dưới dạng .pt hoặc .pth

Khi load model thì mình cần dựng lại kiến trúc của model trước, sau đó sẽ gọi hàm để load state_dict vào model.

model = Net()
model.load_state_dict(torch.load(PATH))

*lưu ý: hàm load_sate_dict nhận input là 1 dict nên mình cần load state_dict của model nên bằng hàm torch.load trước. Gọi thẳng trực tiếp model.load_state_dict(PATH) sẽ lỗi.

Lưu cả model

Thông thường Pytorch sẽ lưu model dưới dạng .pt hoặc .pth

torch.save(model, PATH)

Vì mình lưu cả model nên khi load mình không cần dựng lại kiến trúc của model trước mà có thể load thẳng lên

model = torch.load(PATH)

Mọi người sẽ thấy lưu cả model thì tiện hơn rất nhiều, không cần định nghĩa lại model và không cần quan tâm đến state_dict là gì. Khi mọi người lưu cả model thì Pytorch sẽ dùng pickle module của Python để lưu. Tuy nhiên, pickle thì không lưu trực tiếp model class (class định nghĩa model, dùng để định nghĩa kiến trúc model) mà lưu đường dẫn tới file chứa model class. Thế nên khi load model, nếu mình refactor code và đường dẫn đến file chứa model class thay đổi thì code sẽ lỗi và không load model lên được.

Lưu và load checkpoint để tiếp tục training

Khi mình lưu model để tiếp tục training thì bên cạnh model state_dict, mình còn phải lưu thêm các thông tin như optimizer state_dict, số epoch, loss hiện tại.

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, PATH)

Ở lần sau mình train tiếp, thì mình load các thông số đã lưu lên để cho tiếp tục training

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# train model
model.train()
# tiếp tục train model

Load_state_dict argument

Strict

Bình thường khi mọi người load state_dict của model thì tất cả các layer của model phải khớp với các key trong state_dict mà đã lưu trước đó.

model.load_state_dict(checkpoint['model_state_dict'])

Khi state_dict của model đã lưu thiếu một vài key hoặc có nhiều hơn một vài key với model hiện tại của mọi người thì hàm load_state_dict sẽ báo lỗi. Để giải quyết trường hợp này thì hàm load_state_dict hỗ trợ thuộc tính strict (mặc định là True), nếu strict là False thì model sẽ chỉ load những key-value mà 2 bên match với nhau.

model.load_state_dict(checkpoint['model_state_dict'], strict=False)

Map_location

Khi mọi người load lưu và load trên device khác nhau, ví dụ như save model trên gpu và load model trên cpu hoặc save model trên cpu và load model trên gpu, thì khi load model mọi người cần truyền map_location với device tương ứng.

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Net()
model.load_state_dict(torch.load(PATH, map_location=device))
model.to(device)


Deep Learning cơ bản ©2025. All Rights Reserved.
Powered by WordPress. Theme by Phoenix Web Solutions