IT/Machine Learning

pytorch EarlyStopping 적용하기

엘티엘 2023. 6. 15. 19:50

pytorch 에서 EarlyStopping 을 적용해 보자

EarlyStopping 이란?

아래 설명은 Keras에서 제공하는 EarlyStopping Class에 대한 설명이다. 즉, 모델의 학습중에 더이상의 성능향상이 없을 경우 중단하는 기능을 의미한다.

Stop training when a monitored metric has stopped improving.

pytorch 는 keras와 달리 기본적으로 제공하는 EarlyStopping 클래스가 없다. pytorch-light, pytorch-ignite 등에서 제공하는 패키지가 있긴하나, 기본 pytorch 와 간단하게 연동이 되지는 않는것 같다. 대체적으로 직접 EarlyStopping 클래스를 만들어서 사용하는 경우가 대다수인듯 하다.

EarlyStopping 클래스 만들기

아래 코드를 참고해서 EarlyStopping 클래스를 정의한다. (복붙도 상관없다) 성능 metric 이 좋아지지 않을경우 중단하는 기본적인 기능(patience, delta 등) 만을 사용하고자 한다면 이정도면 충분하다.

데이터 준비, train 함수 정의

자세한 내용은 이전글을 참고한다.

  • train 함수에서 EarlyStopping 객체를 파라미터로 받을수 있도록 정의한다.
  • EarlyStopping 이 결과가 좋아지지 않았다고 판단할 경우 break 하는 방식이다.
  • 링크에서 제공하는 EarlyStopping 은 결과값이 감소해야 좋아지는 것으로 판단하도록 구현되어 있다(loss 측면). 다만 예시에서는 평가 지표로 accurary 를 사용하고 있기 때문에 결과값이 클수록 성능이 좋음을 의미한다. 따라서 EarlyStopping에 -accuracy 로 전달한다.
from sklearn import datasets
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import torchmetrics


# 데이터 준비
data = datasets.load_digits()
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=0)

train_set = TensorDataset(torch.tensor(x_train).float(), torch.tensor(y_train).long())
train_loader = DataLoader(train_set, batch_size=128)

x_test = torch.tensor(x_test).float()
y_test = torch.tensor(y_test).long()


# 학습함수 정의
def train(loader, model, loss_fn, optimizer, es):
  epoch=100000
  check_num=500

  for e in range(epoch):
    for i, (x, y) in enumerate(loader):
      pred = model(x)
      loss = loss_fn(pred, y)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    if e % check_num == 0:
      pred = model(x_test)
      acc = accuracy(pred, y_test)

      es(-acc, model)
      if es.early_stop:
        print(e, loss.item())
        break

  print(e, loss.item())

학습 + EarlyStopping

EarlyStopping 객체를 생성하고, train 함수로 전달한다. EarlyStopping 파라미터 의미는 다음과 같다

  • patience: 평가지표가 좋아지지 않았을때 참는(?) 횟수
  • delta: 최소로 좋아져야 하는 결과지표 크기 (해당값 미만일 경우 좋아지지 않았다고 판단)
  • verbose: 중간 결과 출력여부
# EarlyStopping 클래스는 사전에 정의되어 있어야 한다.
es = EarlyStopping(patience = 3, verbose = True, delta=0.001)

model = nn.Sequential(
  nn.Linear(data.data.shape[-1], len(set(data.target)))
)
loss_fn = nn.functional.cross_entropy
optimizer= torch.optim.SGD(model.parameters(), lr=0.001)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(set(data.target)))


# 학습시작
train(train_loader, model, loss_fn, optimizer, es)

결과는 다음과 같다

  • epoch=100000 으로 적용하였지만 3500번에 성능향상 효과가 없다고 판단되어 중단되었다.
  • 최종 정확도는 0.963%
Validation loss decreased (inf --> -0.150000).  Saving model ...
Validation loss decreased (-0.150000 --> -0.955556).  Saving model ...
Validation loss decreased (-0.955556 --> -0.961111).  Saving model ...
EarlyStopping counter: 1 out of 3
Validation loss decreased (-0.961111 --> -0.963889).  Saving model ...
EarlyStopping counter: 1 out of 3
EarlyStopping counter: 2 out of 3
EarlyStopping counter: 3 out of 3
3500 0.008139846846461296
3500 0.008139846846461296

 

 

반응형