본문 바로가기
Data Science/Deep Learning

빠르고 간편한 Pytorch Lightning 개념 정리

by Queen2 2023. 4. 4.
728x90
반응형

깃헙에서 코드를 구경하다가 Pytorch Lightning 을 발견했는데요. (공식 홈페이지 )

 

기존 Pytorch 보다 코드가 훨씬 단축되어서 딥러닝 작업을 훨씬 간단히 수행할 수 있습니다. Pytorch와 API도 유사해서 기존 Pytorch 코드도 쉽게 변환이 가능합니다. 

 

가장 좋은 점은 기존에 Pytorch에서 했던 데이터 로더, 모델 정의,손실함수, 최적화 및 학습 루프 등 딥러닝 작업을 위한 기본 모델 구성을 , Python Lightning에서는 모듈화한 간단한 명령으로 수행이 가능합니다.

 

Pytorch Lightning 설치

pip install pytorch-lightning

 

Pytorch와 Pytorh_lightning 비교

Pytorch

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

 

Pytorch_lightning

# Python Lightning
import pytorch_lightning as pl
import torch.nn.functional as F

class LitNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F

 

위 예제에서 본 것 처럼 마지막의 훈련 시, 기본 손실함수 정의, 최적화, 모델 정의를 훨씬 간단하게 수행할 수 있도록 하는 장점이 있습니다. 공부하는 입장에서는 사실 처음부터 Pytorch Lightning 을 쓰면 딥러닝 모델의 구조를 파악하기 힘들다는 단점이 있지만 코드의 길이 측면의 효율성에서는 충분히 사용할 만한 것 같습니다.

 

r공식홈페이지 보면서 조금씩 탐구해봐야 겠습니다 :)

 

 

728x90
반응형

댓글