Skip to content

Pytorch Lightning

Integrating Hydra with PyTorch Lightning can significantly enhance the flexibility and scalability of your machine learning projects. Hydra simplifies configuration management, allowing you to easily modify hyperparameters and settings without altering your codebase. PyTorch Lightning streamlines the training process by providing a structured framework for PyTorch code. Additionally, PyTorch Lightning facilitates advanced features like mixed precision training, Fully Sharded Data Parallel (FSDP), Distributed Data Parallel (DDP) across multiple nodes, and multi-GPU training, making it a powerful choice for scaling and optimizing your deep learning workflows.

Example: Integrating Hydra with PyTorch Lightning

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import lightning as L
from omegaconf import DictConfig
import hydra
from lightning.pytorch.loggers import WandbLogger

torch.set_float32_matmul_precision('high')

class LitModel(L.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim, learning_rate):
        super(LitModel, self).__init__()
        self.save_hyperparameters()
        self.layer_1 = nn.Linear(input_dim * input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
    # Data
    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])
    train_loader = DataLoader(mnist_train, batch_size=32, num_workers=15)
    val_loader = DataLoader(mnist_val, batch_size=32, num_workers=15)

    # Model
    model = LitModel(
        input_dim=cfg.model.input_dim,
        hidden_dim=cfg.model.hidden_dim,
        output_dim=cfg.model.output_dim,
        learning_rate=cfg.model.learning_rate
    )

    # Initialize W&B logger
    wandb_logger = WandbLogger(project='my-awesome-project')

    # Trainer
    trainer = L.Trainer(
        accelerator='gpu',
        devices=cfg.trainer.gpus,
        max_epochs=cfg.trainer.max_epochs,
        precision='16-mixed',
        logger=wandb_logger,
    )

    # Training
    trainer.fit(model, train_loader, val_loader)

if __name__ == "__main__":
    main()

Create a Configuration File:

# config.yaml
model:
  input_dim: 28
  hidden_dim: 64
  output_dim: 10
  learning_rate: 0.001

trainer:
  max_epochs: 10
  gpus: 1

Install the Required Libraries:

pip install lightning hydra-core wandb

Run the Training Script:

python train.py

To override specific parameters without modifying the config.yaml file, use command-line arguments:

python train.py model.learning_rate=0.01 trainer.max_epochs=20

Benefits of Using Hydra with PyTorch Lightning:

  • Flexible Configuration Management: Hydra allows you to maintain a clean separation between code and configuration, facilitating easy experimentation with different settings.
  • Command-Line Overrides: Easily adjust parameters via command-line arguments, enabling rapid testing of various configurations.
  • Scalability: PyTorch Lightning's structured approach, combined with Hydra's configuration management, supports scaling from simple experiments to complex training pipelines.