Pytorch Lightning
Integrating Hydra with PyTorch Lightning can significantly enhance the flexibility and scalability of your machine learning projects. Hydra simplifies configuration management, allowing you to easily modify hyperparameters and settings without altering your codebase. PyTorch Lightning streamlines the training process by providing a structured framework for PyTorch code. Additionally, PyTorch Lightning facilitates advanced features like mixed precision training, Fully Sharded Data Parallel (FSDP), Distributed Data Parallel (DDP) across multiple nodes, and multi-GPU training, making it a powerful choice for scaling and optimizing your deep learning workflows.
Example: Integrating Hydra with PyTorch Lightning
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import lightning as L
from omegaconf import DictConfig
import hydra
from lightning.pytorch.loggers import WandbLogger
torch.set_float32_matmul_precision('high')
class LitModel(L.LightningModule):
def __init__(self, input_dim, hidden_dim, output_dim, learning_rate):
super(LitModel, self).__init__()
self.save_hyperparameters()
self.layer_1 = nn.Linear(input_dim * input_dim, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('val_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
# Data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32, num_workers=15)
val_loader = DataLoader(mnist_val, batch_size=32, num_workers=15)
# Model
model = LitModel(
input_dim=cfg.model.input_dim,
hidden_dim=cfg.model.hidden_dim,
output_dim=cfg.model.output_dim,
learning_rate=cfg.model.learning_rate
)
# Initialize W&B logger
wandb_logger = WandbLogger(project='my-awesome-project')
# Trainer
trainer = L.Trainer(
accelerator='gpu',
devices=cfg.trainer.gpus,
max_epochs=cfg.trainer.max_epochs,
precision='16-mixed',
logger=wandb_logger,
)
# Training
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
main()
Create a Configuration File:
# config.yaml
model:
input_dim: 28
hidden_dim: 64
output_dim: 10
learning_rate: 0.001
trainer:
max_epochs: 10
gpus: 1
Install the Required Libraries:
pip install lightning hydra-core wandb
Run the Training Script:
python train.py
To override specific parameters without modifying the config.yaml file, use command-line arguments:
python train.py model.learning_rate=0.01 trainer.max_epochs=20
Benefits of Using Hydra with PyTorch Lightning:
- Flexible Configuration Management: Hydra allows you to maintain a clean separation between code and configuration, facilitating easy experimentation with different settings.
- Command-Line Overrides: Easily adjust parameters via command-line arguments, enabling rapid testing of various configurations.
- Scalability: PyTorch Lightning's structured approach, combined with Hydra's configuration management, supports scaling from simple experiments to complex training pipelines.