Examples

Pure PyTorch

Shows how to embed scalers in an torch.nn.Module, fit them on training data, run a mini training loop, and round-trip through a checkpoint:

  1"""Pure PyTorch example: embedding scalers in a model with checkpointing.
  2
  3Shows how to:
  4- Embed ZScoreScaler instances as nn.Module submodules.
  5- Fit the scalers on training data before training begins.
  6- Save and reload the full model (scaler statistics are included automatically
  7  in the state_dict because they are registered as nn.Module buffers).
  8
  9Run with:
 10    uv run examples/pytorch_example.py
 11"""
 12
 13import os
 14import tempfile
 15
 16import torch
 17import torch.nn as nn
 18
 19from torchscalers import ZScoreScaler
 20
 21# ---------------------------------------------------------------------------
 22# Synthetic dataset
 23# ---------------------------------------------------------------------------
 24torch.manual_seed(0)
 25
 26N_TRAIN, N_TEST = 400, 100
 27N_FEATURES, N_TARGETS = 8, 1
 28
 29X_train = torch.randn(N_TRAIN, N_FEATURES) * 3 + 5  # deliberately off-centre
 30y_train = torch.randn(N_TRAIN, N_TARGETS) * 10 - 2
 31X_test = torch.randn(N_TEST, N_FEATURES) * 3 + 5
 32y_test = torch.randn(N_TEST, N_TARGETS) * 10 - 2
 33
 34
 35# ---------------------------------------------------------------------------
 36# Model definition
 37# ---------------------------------------------------------------------------
 38class SimpleModel(nn.Module):
 39    """Linear regression model with embedded input and target scalers.
 40
 41    The scalers are stored as child modules, so their fitted statistics are
 42    automatically included in state_dict() and moved with .to(device).
 43    """
 44
 45    def __init__(self, in_features: int, out_features: int) -> None:
 46        super().__init__()
 47        self.feature_scaler = ZScoreScaler()
 48        self.target_scaler = ZScoreScaler()
 49        self.linear = nn.Linear(in_features, out_features)
 50
 51    def forward(self, x: torch.Tensor) -> torch.Tensor:
 52        # Calling the scaler directly (scaler(x)) is equivalent to
 53        # scaler.transform(x) — forward() delegates to transform().
 54        x = self.feature_scaler(x)
 55        return self.linear(x)
 56
 57
 58# ---------------------------------------------------------------------------
 59# Step 1: fit scalers on training data only (before any training loop)
 60# ---------------------------------------------------------------------------
 61model = SimpleModel(N_FEATURES, N_TARGETS)
 62
 63model.feature_scaler.fit(X_train)
 64model.target_scaler.fit(y_train)
 65
 66print("After fitting:")
 67print(f"  feature_scaler.mean = {model.feature_scaler.mean}")
 68print(f"  target_scaler.mean  = {model.target_scaler.mean}")
 69
 70# ---------------------------------------------------------------------------
 71# Step 2: minimal training loop
 72# ---------------------------------------------------------------------------
 73optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
 74loss_fn = nn.MSELoss()
 75
 76model.train()
 77for epoch in range(3):
 78    optimiser.zero_grad()
 79    pred = model(X_train)
 80    # Scale targets into normalised space before computing the loss.
 81    target = model.target_scaler(y_train)
 82    loss = loss_fn(pred, target)
 83    loss.backward()
 84    optimiser.step()
 85    print(f"  epoch {epoch + 1}/3  loss={loss.item():.4f}")
 86
 87# ---------------------------------------------------------------------------
 88# Step 3: save checkpoint — scaler statistics are included automatically
 89# ---------------------------------------------------------------------------
 90with tempfile.TemporaryDirectory() as tmpdir:
 91    ckpt_path = os.path.join(tmpdir, "checkpoint.pt")
 92    torch.save(model.state_dict(), ckpt_path)
 93    print(f"\nCheckpoint saved to {ckpt_path!r}")
 94
 95    # Step 4: reload into a fresh (unfitted) model
 96    fresh_model = SimpleModel(N_FEATURES, N_TARGETS)
 97    state = torch.load(ckpt_path, weights_only=True)
 98    fresh_model.load_state_dict(state)
 99    print("Checkpoint reloaded successfully.")
100
101    # Step 5: verify the statistics survived the round-trip
102    assert torch.allclose(
103        model.feature_scaler.mean, fresh_model.feature_scaler.mean
104    ), "feature_scaler.mean mismatch after reload!"
105    assert torch.allclose(
106        model.target_scaler.mean, fresh_model.target_scaler.mean
107    ), "target_scaler.mean mismatch after reload!"
108    print("Scaler statistics verified — round-trip OK.")
109
110    # Step 6: inference with inverse transform
111    fresh_model.eval()
112    with torch.no_grad():
113        pred_scaled = fresh_model(X_test)
114        # Bring predictions back to the original target scale.
115        pred_orig = fresh_model.target_scaler.inverse_transform(pred_scaled)
116
117    print(f"\nFirst 5 predictions (original scale): {pred_orig[:5].squeeze().tolist()}")
uv run examples/pytorch_example.py

PyTorch Lightning

Shows how to fit scalers inside a LightningDataModule (on the train split only, to prevent data leakage), pass them to a LightningModule so they are checkpointed automatically, and restore a run from the best checkpoint:

  1"""PyTorch Lightning example: scalers in a DataModule with checkpointing.
  2
  3Shows how to:
  4- Fit scalers inside a LightningDataModule (on train split only, to avoid
  5  data leakage).
  6- Pass the fitted scaler instances to a LightningModule so they become part
  7  of the model's state_dict and are saved automatically with every checkpoint.
  8- Use DataModule.state_dict / load_state_dict hooks to also persist scaler
  9  stats in the DataModule checkpoint (optional — documented here because the
 10  pattern is useful when the DataModule lives independently of the model).
 11- Restore a run from a checkpoint.
 12
 13Run with (requires the 'examples' optional dependency group):
 14    uv sync --extra examples
 15    uv run examples/lightning_example.py
 16"""
 17
 18import lightning as L
 19import torch
 20import torch.nn as nn
 21from torch import Tensor
 22from torch.utils.data import DataLoader, TensorDataset
 23
 24from torchscalers import ZScoreScaler
 25
 26# ---------------------------------------------------------------------------
 27# Synthetic dataset
 28# ---------------------------------------------------------------------------
 29torch.manual_seed(0)
 30
 31N_TOTAL, N_FEATURES, N_TARGETS = 500, 8, 1
 32X_all = torch.randn(N_TOTAL, N_FEATURES) * 3 + 5
 33y_all = torch.randn(N_TOTAL, N_TARGETS) * 10 - 2
 34
 35
 36# ---------------------------------------------------------------------------
 37# DataModule
 38# ---------------------------------------------------------------------------
 39class ExampleDataModule(L.LightningDataModule):
 40    """DataModule that fits scalers on the training split only."""
 41
 42    def __init__(self, X: Tensor, y: Tensor, batch_size: int = 64) -> None:
 43        super().__init__()
 44        self.X = X
 45        self.y = y
 46        self.batch_size = batch_size
 47
 48        # Scalers are kept here so they can be passed to the model after
 49        # setup() has been called manually (see usage pattern below).
 50        self.feature_scaler = ZScoreScaler()
 51        self.target_scaler = ZScoreScaler()
 52
 53    def setup(self, stage: str) -> None:
 54        if stage == "fit":
 55            n_train = int(len(self.X) * 0.8)
 56            X_train, X_val = self.X[:n_train], self.X[n_train:]
 57            y_train, y_val = self.y[:n_train], self.y[n_train:]
 58
 59            # Fit on training data only to prevent leakage into validation.
 60            self.feature_scaler.fit(X_train)
 61            self.target_scaler.fit(y_train)
 62
 63            self.train_dataset = TensorDataset(X_train, y_train)
 64            self.val_dataset = TensorDataset(X_val, y_val)
 65
 66    def train_dataloader(self) -> DataLoader:
 67        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
 68
 69    def val_dataloader(self) -> DataLoader:
 70        return DataLoader(self.val_dataset, batch_size=self.batch_size)
 71
 72    # ------------------------------------------------------------------
 73    # Optional DataModule checkpoint hooks.
 74    #
 75    # Because the scalers will also live inside the model as child modules
 76    # (see ExampleModel below), their statistics are already saved in the
 77    # model checkpoint automatically.  These hooks are shown here for cases
 78    # where the DataModule must be restored independently of the model.
 79    # ------------------------------------------------------------------
 80    def state_dict(self) -> dict:
 81        return {
 82            "feature_scaler": self.feature_scaler.state_dict(),
 83            "target_scaler": self.target_scaler.state_dict(),
 84        }
 85
 86    def load_state_dict(self, state_dict: dict) -> None:
 87        self.feature_scaler.load_state_dict(state_dict["feature_scaler"])
 88        self.target_scaler.load_state_dict(state_dict["target_scaler"])
 89
 90
 91# ---------------------------------------------------------------------------
 92# LightningModule
 93# ---------------------------------------------------------------------------
 94class ExampleModel(L.LightningModule):
 95    """Linear regression model that owns the fitted scalers as submodules."""
 96
 97    def __init__(
 98        self,
 99        feature_scaler: ZScoreScaler,
100        target_scaler: ZScoreScaler,
101        in_features: int,
102        out_features: int,
103    ) -> None:
104        super().__init__()
105        # Storing the scalers as attributes registers them as child modules,
106        # so their buffers (fitted statistics) are saved in every checkpoint.
107        self.feature_scaler = feature_scaler
108        self.target_scaler = target_scaler
109        self.linear = nn.Linear(in_features, out_features)
110
111    def forward(self, x: Tensor) -> Tensor:
112        # scaler(x) calls scaler.transform(x) via forward().
113        return self.linear(self.feature_scaler(x))
114
115    def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
116        x, y = batch
117        pred = self(x)
118        target = self.target_scaler(y)  # scale targets into normalised space
119        loss = nn.functional.mse_loss(pred, target)
120        self.log("train_loss", loss, prog_bar=True)
121        return loss
122
123    def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None:
124        x, y = batch
125        pred = self(x)
126        target = self.target_scaler(y)
127        loss = nn.functional.mse_loss(pred, target)
128        self.log("val_loss", loss, prog_bar=True)
129
130    def configure_optimizers(self):
131        return torch.optim.Adam(self.parameters(), lr=1e-3)
132
133
134# ---------------------------------------------------------------------------
135# Usage pattern
136# ---------------------------------------------------------------------------
137if __name__ == "__main__":
138    dm = ExampleDataModule(X_all, y_all, batch_size=64)
139
140    # Call setup() manually before constructing the model so the fitted
141    # scalers can be passed in as constructor arguments.
142    dm.setup(stage="fit")
143
144    model = ExampleModel(
145        feature_scaler=dm.feature_scaler,
146        target_scaler=dm.target_scaler,
147        in_features=N_FEATURES,
148        out_features=N_TARGETS,
149    )
150
151    # Scaler statistics are now part of model.state_dict() and will be saved
152    # automatically in every checkpoint written by the Trainer.
153    trainer = L.Trainer(
154        max_epochs=3,
155        enable_model_summary=False,
156        logger=False,
157    )
158    trainer.fit(model, dm)
159
160    # -----------------------------------------------------------------------
161    # Restore from the best checkpoint (Lightning writes one automatically)
162    # -----------------------------------------------------------------------
163    ckpt_path = trainer.checkpoint_callback.best_model_path
164    if ckpt_path:
165        # Reconstruct the DataModule and set up scalers before loading.
166        dm2 = ExampleDataModule(X_all, y_all)
167        dm2.setup(stage="fit")
168
169        restored_model = ExampleModel(
170            feature_scaler=dm2.feature_scaler,
171            target_scaler=dm2.target_scaler,
172            in_features=N_FEATURES,
173            out_features=N_TARGETS,
174        )
175        restored_model.load_state_dict(
176            torch.load(ckpt_path, weights_only=True)["state_dict"]
177        )
178        print(f"Model restored from {ckpt_path!r}")
179
180        restored_model.eval()
181        with torch.no_grad():
182            X_sample = X_all[:5]
183            pred_scaled = restored_model(X_sample)
184            pred_orig = restored_model.target_scaler.inverse_transform(pred_scaled)
185        print(f"Predictions (original scale): {pred_orig.squeeze().tolist()}")

Requires the optional examples dependency group:

uv sync --extra examples
uv run examples/lightning_example.py