PyTorch Lightning Integration

Weightslab is compatible with PyTorch Lightning and already includes a full example:

weightslab/examples/PyTorch_Lightning/ws-classification/main.py

This page explains how to integrate Weightslab in a Lightning workflow and scale to multiple GPUs.

Minimal integration pattern

  1. Wrap components with Weightslab:

    • model with flag="model"

    • optimizer with flag="optimizer"

    • data loaders with flag="data"

    • loss and metric with flag="loss" / flag="metric"

  2. Build a LightningModule that uses Weightslab-wrapped objects.

  3. Use guard_training_context and guard_testing_context inside step methods.

  4. Start Weightslab services before trainer.fit(...).

LightningModule excerpt

class LitMNIST(pl.LightningModule):
    def __init__(self, model, optim, train_criterion_wl, val_criterion_wl, metric_wl):
        super().__init__()
        self.model = model
        self.optimizer = optim
        self.train_criterion_wl = train_criterion_wl
        self.val_criterion_wl = val_criterion_wl
        self.metric_wl = metric_wl

    def training_step(self, batch):
        with guard_training_context:
            x, ids, y = batch
            logits = self.model(x)
            preds = torch.argmax(logits, dim=1)
            loss_batch = self.train_criterion_wl(
                logits.float(),
                y.long(),
                batch_ids=ids,
                preds=preds,
            )
            return loss_batch.mean()

    def validation_step(self, batch):
        with guard_testing_context:
            x, ids, y = batch
            logits = self.model(x)
            preds = torch.argmax(logits, dim=1)
            self.val_criterion_wl(logits.float(), y.long(), batch_ids=ids, preds=preds)
            self.metric_wl.update(logits, y)

    def configure_optimizers(self):
        return self.optimizer

Single-GPU trainer setup

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    log_every_n_steps=0,
    enable_checkpointing=False,
    logger=False,
)

logger=False is intentional here because Weightslab manages training signals directly.

Multi-GPU (DDP) setup

For multiple GPUs on one node, use DDP:

use_gpu = torch.cuda.is_available()
gpu_count = torch.cuda.device_count() if use_gpu else 0
multi_gpu = gpu_count > 1

trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="gpu" if use_gpu else "cpu",
    devices=gpu_count if multi_gpu else 1,
    strategy="ddp" if multi_gpu else "auto",
    sync_batchnorm=multi_gpu,
    use_distributed_sampler=multi_gpu,
    log_every_n_steps=0,
    enable_checkpointing=False,
    logger=False,
)

Notes:

  • use_distributed_sampler=True helps ensure each rank sees a unique subset.

  • Keep batch_ids passed to losses/signals to preserve per-sample traceability.

  • If your total batch size changes with GPU count, retune LR and/or per-device batch size.

Optional YAML-driven trainer config

The Lightning example already includes a ready template at:

weightslab/examples/PyTorch_Lightning/ws-classification/config.yaml

lightning:
  max_epochs: 10
  accelerator: gpu
  devices: 2
  strategy: ddp
  sync_batchnorm: true

Then map it into pl.Trainer(...) in your script.

Quick switch examples:

  • Single GPU: devices: 1, strategy: auto

  • Multi GPU (single node): devices: 2 (or more), strategy: ddp

Preset blocks are also provided directly in the example config.yaml under lightning_presets. To switch mode, copy either single_gpu or multi_gpu_ddp into the top-level lightning block.

End-to-end sequence

# 1) Wrap hyperparameters/model/data/optimizer/loss/metric
# 2) Build LightningModule with wrapped objects
# 3) Start Weightslab services
wl.serve(serving_grpc=False, serving_cli=False)

# 4) Train with Lightning
trainer.fit(lightning_module, train_loader, val_loader)

# 5) Keep services alive if needed
wl.keep_serving()