Examples
Pure PyTorch
Shows how to embed scalers in an torch.nn.Module, fit them on
training data, run a mini training loop, and round-trip through a checkpoint:
1"""Pure PyTorch example: embedding scalers in a model with checkpointing.
2
3Shows how to:
4- Embed ZScoreScaler instances as nn.Module submodules.
5- Fit the scalers on training data before training begins.
6- Save and reload the full model (scaler statistics are included automatically
7 in the state_dict because they are registered as nn.Module buffers).
8
9Run with:
10 uv run examples/pytorch_example.py
11"""
12
13import os
14import tempfile
15
16import torch
17import torch.nn as nn
18
19from torchscalers import ZScoreScaler
20
21# ---------------------------------------------------------------------------
22# Synthetic dataset
23# ---------------------------------------------------------------------------
24torch.manual_seed(0)
25
26N_TRAIN, N_TEST = 400, 100
27N_FEATURES, N_TARGETS = 8, 1
28
29X_train = torch.randn(N_TRAIN, N_FEATURES) * 3 + 5 # deliberately off-centre
30y_train = torch.randn(N_TRAIN, N_TARGETS) * 10 - 2
31X_test = torch.randn(N_TEST, N_FEATURES) * 3 + 5
32y_test = torch.randn(N_TEST, N_TARGETS) * 10 - 2
33
34
35# ---------------------------------------------------------------------------
36# Model definition
37# ---------------------------------------------------------------------------
38class SimpleModel(nn.Module):
39 """Linear regression model with embedded input and target scalers.
40
41 The scalers are stored as child modules, so their fitted statistics are
42 automatically included in state_dict() and moved with .to(device).
43 """
44
45 def __init__(self, in_features: int, out_features: int) -> None:
46 super().__init__()
47 self.feature_scaler = ZScoreScaler()
48 self.target_scaler = ZScoreScaler()
49 self.linear = nn.Linear(in_features, out_features)
50
51 def forward(self, x: torch.Tensor) -> torch.Tensor:
52 # Calling the scaler directly (scaler(x)) is equivalent to
53 # scaler.transform(x) — forward() delegates to transform().
54 x = self.feature_scaler(x)
55 return self.linear(x)
56
57
58# ---------------------------------------------------------------------------
59# Step 1: fit scalers on training data only (before any training loop)
60# ---------------------------------------------------------------------------
61model = SimpleModel(N_FEATURES, N_TARGETS)
62
63model.feature_scaler.fit(X_train)
64model.target_scaler.fit(y_train)
65
66print("After fitting:")
67print(f" feature_scaler.mean = {model.feature_scaler.mean}")
68print(f" target_scaler.mean = {model.target_scaler.mean}")
69
70# ---------------------------------------------------------------------------
71# Step 2: minimal training loop
72# ---------------------------------------------------------------------------
73optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
74loss_fn = nn.MSELoss()
75
76model.train()
77for epoch in range(3):
78 optimiser.zero_grad()
79 pred = model(X_train)
80 # Scale targets into normalised space before computing the loss.
81 target = model.target_scaler(y_train)
82 loss = loss_fn(pred, target)
83 loss.backward()
84 optimiser.step()
85 print(f" epoch {epoch + 1}/3 loss={loss.item():.4f}")
86
87# ---------------------------------------------------------------------------
88# Step 3: save checkpoint — scaler statistics are included automatically
89# ---------------------------------------------------------------------------
90with tempfile.TemporaryDirectory() as tmpdir:
91 ckpt_path = os.path.join(tmpdir, "checkpoint.pt")
92 torch.save(model.state_dict(), ckpt_path)
93 print(f"\nCheckpoint saved to {ckpt_path!r}")
94
95 # Step 4: reload into a fresh (unfitted) model
96 fresh_model = SimpleModel(N_FEATURES, N_TARGETS)
97 state = torch.load(ckpt_path, weights_only=True)
98 fresh_model.load_state_dict(state)
99 print("Checkpoint reloaded successfully.")
100
101 # Step 5: verify the statistics survived the round-trip
102 assert torch.allclose(
103 model.feature_scaler.mean, fresh_model.feature_scaler.mean
104 ), "feature_scaler.mean mismatch after reload!"
105 assert torch.allclose(
106 model.target_scaler.mean, fresh_model.target_scaler.mean
107 ), "target_scaler.mean mismatch after reload!"
108 print("Scaler statistics verified — round-trip OK.")
109
110 # Step 6: inference with inverse transform
111 fresh_model.eval()
112 with torch.no_grad():
113 pred_scaled = fresh_model(X_test)
114 # Bring predictions back to the original target scale.
115 pred_orig = fresh_model.target_scaler.inverse_transform(pred_scaled)
116
117 print(f"\nFirst 5 predictions (original scale): {pred_orig[:5].squeeze().tolist()}")
uv run examples/pytorch_example.py
PyTorch Lightning
Shows how to fit scalers inside a LightningDataModule (on the train split
only, to prevent data leakage), pass them to a LightningModule so they are
checkpointed automatically, and restore a run from the best checkpoint:
1"""PyTorch Lightning example: scalers in a DataModule with checkpointing.
2
3Shows how to:
4- Fit scalers inside a LightningDataModule (on train split only, to avoid
5 data leakage).
6- Pass the fitted scaler instances to a LightningModule so they become part
7 of the model's state_dict and are saved automatically with every checkpoint.
8- Use DataModule.state_dict / load_state_dict hooks to also persist scaler
9 stats in the DataModule checkpoint (optional — documented here because the
10 pattern is useful when the DataModule lives independently of the model).
11- Restore a run from a checkpoint.
12
13Run with (requires the 'examples' optional dependency group):
14 uv sync --extra examples
15 uv run examples/lightning_example.py
16"""
17
18import lightning as L
19import torch
20import torch.nn as nn
21from torch import Tensor
22from torch.utils.data import DataLoader, TensorDataset
23
24from torchscalers import ZScoreScaler
25
26# ---------------------------------------------------------------------------
27# Synthetic dataset
28# ---------------------------------------------------------------------------
29torch.manual_seed(0)
30
31N_TOTAL, N_FEATURES, N_TARGETS = 500, 8, 1
32X_all = torch.randn(N_TOTAL, N_FEATURES) * 3 + 5
33y_all = torch.randn(N_TOTAL, N_TARGETS) * 10 - 2
34
35
36# ---------------------------------------------------------------------------
37# DataModule
38# ---------------------------------------------------------------------------
39class ExampleDataModule(L.LightningDataModule):
40 """DataModule that fits scalers on the training split only."""
41
42 def __init__(self, X: Tensor, y: Tensor, batch_size: int = 64) -> None:
43 super().__init__()
44 self.X = X
45 self.y = y
46 self.batch_size = batch_size
47
48 # Scalers are kept here so they can be passed to the model after
49 # setup() has been called manually (see usage pattern below).
50 self.feature_scaler = ZScoreScaler()
51 self.target_scaler = ZScoreScaler()
52
53 def setup(self, stage: str) -> None:
54 if stage == "fit":
55 n_train = int(len(self.X) * 0.8)
56 X_train, X_val = self.X[:n_train], self.X[n_train:]
57 y_train, y_val = self.y[:n_train], self.y[n_train:]
58
59 # Fit on training data only to prevent leakage into validation.
60 self.feature_scaler.fit(X_train)
61 self.target_scaler.fit(y_train)
62
63 self.train_dataset = TensorDataset(X_train, y_train)
64 self.val_dataset = TensorDataset(X_val, y_val)
65
66 def train_dataloader(self) -> DataLoader:
67 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
68
69 def val_dataloader(self) -> DataLoader:
70 return DataLoader(self.val_dataset, batch_size=self.batch_size)
71
72 # ------------------------------------------------------------------
73 # Optional DataModule checkpoint hooks.
74 #
75 # Because the scalers will also live inside the model as child modules
76 # (see ExampleModel below), their statistics are already saved in the
77 # model checkpoint automatically. These hooks are shown here for cases
78 # where the DataModule must be restored independently of the model.
79 # ------------------------------------------------------------------
80 def state_dict(self) -> dict:
81 return {
82 "feature_scaler": self.feature_scaler.state_dict(),
83 "target_scaler": self.target_scaler.state_dict(),
84 }
85
86 def load_state_dict(self, state_dict: dict) -> None:
87 self.feature_scaler.load_state_dict(state_dict["feature_scaler"])
88 self.target_scaler.load_state_dict(state_dict["target_scaler"])
89
90
91# ---------------------------------------------------------------------------
92# LightningModule
93# ---------------------------------------------------------------------------
94class ExampleModel(L.LightningModule):
95 """Linear regression model that owns the fitted scalers as submodules."""
96
97 def __init__(
98 self,
99 feature_scaler: ZScoreScaler,
100 target_scaler: ZScoreScaler,
101 in_features: int,
102 out_features: int,
103 ) -> None:
104 super().__init__()
105 # Storing the scalers as attributes registers them as child modules,
106 # so their buffers (fitted statistics) are saved in every checkpoint.
107 self.feature_scaler = feature_scaler
108 self.target_scaler = target_scaler
109 self.linear = nn.Linear(in_features, out_features)
110
111 def forward(self, x: Tensor) -> Tensor:
112 # scaler(x) calls scaler.transform(x) via forward().
113 return self.linear(self.feature_scaler(x))
114
115 def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
116 x, y = batch
117 pred = self(x)
118 target = self.target_scaler(y) # scale targets into normalised space
119 loss = nn.functional.mse_loss(pred, target)
120 self.log("train_loss", loss, prog_bar=True)
121 return loss
122
123 def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None:
124 x, y = batch
125 pred = self(x)
126 target = self.target_scaler(y)
127 loss = nn.functional.mse_loss(pred, target)
128 self.log("val_loss", loss, prog_bar=True)
129
130 def configure_optimizers(self):
131 return torch.optim.Adam(self.parameters(), lr=1e-3)
132
133
134# ---------------------------------------------------------------------------
135# Usage pattern
136# ---------------------------------------------------------------------------
137if __name__ == "__main__":
138 dm = ExampleDataModule(X_all, y_all, batch_size=64)
139
140 # Call setup() manually before constructing the model so the fitted
141 # scalers can be passed in as constructor arguments.
142 dm.setup(stage="fit")
143
144 model = ExampleModel(
145 feature_scaler=dm.feature_scaler,
146 target_scaler=dm.target_scaler,
147 in_features=N_FEATURES,
148 out_features=N_TARGETS,
149 )
150
151 # Scaler statistics are now part of model.state_dict() and will be saved
152 # automatically in every checkpoint written by the Trainer.
153 trainer = L.Trainer(
154 max_epochs=3,
155 enable_model_summary=False,
156 logger=False,
157 )
158 trainer.fit(model, dm)
159
160 # -----------------------------------------------------------------------
161 # Restore from the best checkpoint (Lightning writes one automatically)
162 # -----------------------------------------------------------------------
163 ckpt_path = trainer.checkpoint_callback.best_model_path
164 if ckpt_path:
165 # Reconstruct the DataModule and set up scalers before loading.
166 dm2 = ExampleDataModule(X_all, y_all)
167 dm2.setup(stage="fit")
168
169 restored_model = ExampleModel(
170 feature_scaler=dm2.feature_scaler,
171 target_scaler=dm2.target_scaler,
172 in_features=N_FEATURES,
173 out_features=N_TARGETS,
174 )
175 restored_model.load_state_dict(
176 torch.load(ckpt_path, weights_only=True)["state_dict"]
177 )
178 print(f"Model restored from {ckpt_path!r}")
179
180 restored_model.eval()
181 with torch.no_grad():
182 X_sample = X_all[:5]
183 pred_scaled = restored_model(X_sample)
184 pred_orig = restored_model.target_scaler.inverse_transform(pred_scaled)
185 print(f"Predictions (original scale): {pred_orig.squeeze().tolist()}")
Requires the optional examples dependency group:
uv sync --extra examples
uv run examples/lightning_example.py