코딩걸음마

[딥러닝] Pytorch 확률적 경사하강법 SGD(Stochastic Gradient Descent) 본문

딥러닝_Pytorch

[딥러닝] Pytorch 확률적 경사하강법 SGD(Stochastic Gradient Descent)

코딩걸음마 2022. 6. 30. 02:26
728x90

 

Loss Function을 계산할 때 전체 Train-Set을 사용하는 것을 Batch Gradient Descent라고 합니다. 

그러나 이렇게 계산하면 한번 계산할 때  전체 데이터에 대해 Loss Function을 계산해야 하므로 너무 많은 계산이 필요합니다.

이를 방지하기 위해 Stochastic Gradient Desenct(SGD)를 사용합니다. 이 방법에서는 Loss Function을 계산할 때, 전체 데이터(Batch) 대신 일부 데이터의 모음(Mini-Batch)를 사용하여 Loss Function을 계산합니다.

(Mini-Batch의 기대값은 전체 train-set을 계산한 값과 같다는 가정)

Batch Gradient Descent보다 다소 부정확할 수는 있지만, 계산 속도가 훨씬 빠르기 때문에 같은 시간에 더 많은 step을 갈 수 있으며, 여러 번 반복할 경우 Batch 처리한 결과로 수렴합니다.

또한 반복을 통해서 Batch Gradient Descent에서 빠질 Local Minima에 빠지지 않고 더 좋은 방향으로 수렴할 가능성도 높습니다.

 

캘리포니아 주택가격예측데이터를 활용해 SGD개념을 익혀봅시다.

1. 데이터 준비하기

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_california_housing


california = fetch_california_housing()
df = pd.DataFrame(california.data, columns=california.feature_names)
df["Target"] = california.target

 

StandardScaler를 사용해서 Scaling을 실행합니다.

scaler = StandardScaler()
scaler.fit(df.values[:, :-1])
df.values[:, :-1] = scaler.transform(df.values[:, :-1])

 

데이터를 시각화하여 확인해봅시다.

sns.pairplot(df.sample(1000))
plt.show()

 

2. 데이터 전처리 및 모델 params 설정

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

data = torch.from_numpy(df.values).float()

x = data[:, :-1]
y = data[:, -1:]

n_epochs = 4000
batch_size = 256  #512  768
print_interval = 200
learning_rate = 1e-2

 

3. 모델 생성

model = nn.Sequential(
    nn.Linear(x.size(-1), 6),
    nn.LeakyReLU(),
    nn.Linear(6, 5),
    nn.LeakyReLU(),
    nn.Linear(5, 4),
    nn.LeakyReLU(),
    nn.Linear(4, 3),
    nn.LeakyReLU(),
    nn.Linear(3, y.size(-1)),
)

model

최적화 함수를 적용해줍시다.

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

 

4. 모델 실행

# |x| = (total_size, input_dim)
# |y| = (total_size, output_dim)

for i in range(n_epochs):
    # Shuffle the index to feed-forward.
    indices = torch.randperm(x.size(0))
    x_ = torch.index_select(x, dim=0, index=indices)
    y_ = torch.index_select(y, dim=0, index=indices)
    
    x_ = x_.split(batch_size, dim=0)
    y_ = y_.split(batch_size, dim=0)
    # |x_[i]| = (batch_size, input_dim)
    # |y_[i]| = (batch_size, output_dim)
    
    y_hat = []
    total_loss = 0
    
    for x_i, y_i in zip(x_, y_):
        # |x_i| = |x_[i]|
        # |y_i| = |y_[i]|
        y_hat_i = model(x_i)
        loss = F.mse_loss(y_hat_i, y_i)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        
        total_loss += float(loss) # This is very important to prevent memory leak.
        y_hat += [y_hat_i]

    total_loss = total_loss / len(x_)
    if (i + 1) % print_interval == 0:
        print('Epoch %d: loss=%.4e' % (i + 1, total_loss))
    
y_hat = torch.cat(y_hat, dim=0)
y = torch.cat(y_, dim=0)
# |y_hat| = (total_size, output_dim)
# |y| = (total_size, output_dim)
Epoch 200: loss=3.3620e-01
Epoch 400: loss=3.1224e-01
Epoch 600: loss=3.1349e-01
......
Epoch 3800: loss=3.0736e-01
Epoch 4000: loss=3.0732e-01

 

5. 시각화

df = pd.DataFrame(torch.cat([y, y_hat], dim=1).detach().numpy(),
                  columns=["y", "y_hat"])

sns.pairplot(df, height=5)
plt.show()

728x90
Comments