IT/Machine Learning
pytorch EarlyStopping 적용하기
엘티엘
2023. 6. 15. 19:50
pytorch 에서 EarlyStopping 을 적용해 보자
EarlyStopping 이란?
아래 설명은 Keras에서 제공하는 EarlyStopping Class에 대한 설명이다. 즉, 모델의 학습중에 더이상의 성능향상이 없을 경우 중단하는 기능을 의미한다.
Stop training when a monitored metric has stopped improving.
pytorch 는 keras와 달리 기본적으로 제공하는 EarlyStopping 클래스가 없다. pytorch-light, pytorch-ignite 등에서 제공하는 패키지가 있긴하나, 기본 pytorch 와 간단하게 연동이 되지는 않는것 같다. 대체적으로 직접 EarlyStopping 클래스를 만들어서 사용하는 경우가 대다수인듯 하다.
EarlyStopping 클래스 만들기
아래 코드를 참고해서 EarlyStopping 클래스를 정의한다. (복붙도 상관없다) 성능 metric 이 좋아지지 않을경우 중단하는 기본적인 기능(patience, delta 등) 만을 사용하고자 한다면 이정도면 충분하다.
데이터 준비, train 함수 정의
자세한 내용은 이전글을 참고한다.
- train 함수에서 EarlyStopping 객체를 파라미터로 받을수 있도록 정의한다.
- EarlyStopping 이 결과가 좋아지지 않았다고 판단할 경우 break 하는 방식이다.
- 링크에서 제공하는 EarlyStopping 은 결과값이 감소해야 좋아지는 것으로 판단하도록 구현되어 있다(loss 측면). 다만 예시에서는 평가 지표로 accurary 를 사용하고 있기 때문에 결과값이 클수록 성능이 좋음을 의미한다. 따라서 EarlyStopping에 -accuracy 로 전달한다.
from sklearn import datasets
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchmetrics
# 데이터 준비
data = datasets.load_digits()
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=0)
train_set = TensorDataset(torch.tensor(x_train).float(), torch.tensor(y_train).long())
train_loader = DataLoader(train_set, batch_size=128)
x_test = torch.tensor(x_test).float()
y_test = torch.tensor(y_test).long()
# 학습함수 정의
def train(loader, model, loss_fn, optimizer, es):
epoch=100000
check_num=500
for e in range(epoch):
for i, (x, y) in enumerate(loader):
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if e % check_num == 0:
pred = model(x_test)
acc = accuracy(pred, y_test)
es(-acc, model)
if es.early_stop:
print(e, loss.item())
break
print(e, loss.item())
학습 + EarlyStopping
EarlyStopping 객체를 생성하고, train 함수로 전달한다. EarlyStopping 파라미터 의미는 다음과 같다
- patience: 평가지표가 좋아지지 않았을때 참는(?) 횟수
- delta: 최소로 좋아져야 하는 결과지표 크기 (해당값 미만일 경우 좋아지지 않았다고 판단)
- verbose: 중간 결과 출력여부
# EarlyStopping 클래스는 사전에 정의되어 있어야 한다.
es = EarlyStopping(patience = 3, verbose = True, delta=0.001)
model = nn.Sequential(
nn.Linear(data.data.shape[-1], len(set(data.target)))
)
loss_fn = nn.functional.cross_entropy
optimizer= torch.optim.SGD(model.parameters(), lr=0.001)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(set(data.target)))
# 학습시작
train(train_loader, model, loss_fn, optimizer, es)
결과는 다음과 같다
- epoch=100000 으로 적용하였지만 3500번에 성능향상 효과가 없다고 판단되어 중단되었다.
- 최종 정확도는 0.963%
Validation loss decreased (inf --> -0.150000). Saving model ...
Validation loss decreased (-0.150000 --> -0.955556). Saving model ...
Validation loss decreased (-0.955556 --> -0.961111). Saving model ...
EarlyStopping counter: 1 out of 3
Validation loss decreased (-0.961111 --> -0.963889). Saving model ...
EarlyStopping counter: 1 out of 3
EarlyStopping counter: 2 out of 3
EarlyStopping counter: 3 out of 3
3500 0.008139846846461296
3500 0.008139846846461296
반응형