Segmentation Use Case — Per-instance & Per-sample Signals (PyTorch)¶
This page walks through the segmentation integration from:
weightslab/examples/PyTorch/ws-segmentation/main.py
It builds on the classification Use Case Example (PyTorch) page and focuses on what is specific to segmentation: a list of per-instance masks per sample, a custom collate, and custom user signals computed both per-sample and per-instance (Dice as a metric, BCE as a loss).
Goal¶
Use Weightslab to:
track a U-Net segmentation model with tracked loaders,
store per-instance targets (one mask per object/class) alongside the per-sample view,
compute and log custom Dice/BCE signals at both granularities,
drive the dashboard’s shape × ODD-slice analysis from per-instance signals.
The multi-index data model¶
Segmentation samples are expanded into a (sample_id, annotation_id)
multi-index:
annotation_id == 0is the canonical sample row — it holds per-sample predictions/targets/signals plus sample-level metadata, origin and tags.annotation_id >= 1are the instance rows — one per object/class mask, holding only that instance’s target and per-instance signals.
So a sample with N instance masks occupies N + 1 rows. The studio collapses
this back to one row per sample for display: it keeps the instance_id 0 row
as the main value and only falls back to aggregating the instance rows for cells
that are empty on the sample row.
1) A dataset that returns a list of instance masks¶
utils/data.py returns (image, uid, instances, metadata) where
instances is a list of per-instance mask tensors (pixel value = class id):
def get_items(self, idx, ...):
...
mask_t_instances = []
lbl_max = mask_t.max().item()
for i in range(1, lbl_max + 1):
m = torch.zeros_like(mask_t)
m[mask_t == i] = i # one mask per class id
mask_t_instances.append(m)
return img_t, uid, mask_t_instances, metadata
Why a list (not a single dense mask):
A list of array-like targets is what tells Weightslab there are multiple instances → it creates instance rows
1..N. A single array/scalar target is treated as the sample’s own target and stays oninstance_id 0only.
2) A custom collate to batch variable-length instance lists¶
The default PyTorch collate cannot batch variable-length lists, so the example
provides seg_collate and passes it to the loader:
from utils.data import seg_collate
train_loader = wl.watch_or_edit(
_train_dataset, flag="data", loader_name="train_loader",
batch_size=2, collate_fn=seg_collate, preload_labels=False,
)
seg_collate returns (images, ids, labels, metas) where labels is a
list[B] and labels[s] is that sample’s list of instance masks (empty
masks filtered out).
3) Custom Dice (metric) and BCE (loss) signals¶
utils/criterions.py defines four small nn.Module criterions. Each scores
every instance mask against the model’s per-class probability map, then exposes
the value either aggregated per sample or flat per instance:
from utils.criterions import (
PerSampleDice, PerInstanceDice, # metric
PerSampleBCE, PerInstanceBCE, # loss
)
def _make_seg_signals(split):
return {
"dice_sample": wl.watch_or_edit(PerSampleDice(), flag="metric",
name=f"{split}_dice/sample", per_sample=True, log=True),
"dice_instance": wl.watch_or_edit(PerInstanceDice(), flag="metric",
name=f"{split}_dice/instance", per_instance=True, log=True),
"bce_sample": wl.watch_or_edit(PerSampleBCE(), flag="loss",
name=f"{split}_bce/sample", per_sample=True, log=True),
"bce_instance": wl.watch_or_edit(PerInstanceBCE(), flag="loss",
name=f"{split}_bce/instance", per_instance=True, log=True),
}
train_sig = _make_seg_signals("train")
test_sig = _make_seg_signals("test")
Why two flavors:
per_sample=True→ the returned[B]vector is logged and written to the sample row (instance_id 0) via the per-sample path.per_instance=True→ the returned flat ``[total_instances]`` tensor is auto-saved at(sample_id, annotation_id)forannotation_id >= 1viawl.save_instance_signals(). Ordering is sample-major and must match thebatch_idxyou pass (see next step).
4) The training step: build batch_idx and route signals¶
The per-instance wrapper needs a batch_idx that maps each instance (in flat,
sample-major order) to its sample position; build it from the same instance
lists so ordering lines up:
def _instance_batch_idx(labels):
return torch.tensor(
[s for s, insts in enumerate(labels) for _ in insts],
dtype=torch.long,
)
with guard_training_context:
inputs, ids, labels, _ = next(loader)
outputs = model(inputs) # [B, C, H, W]
batch_idx = _instance_batch_idx(labels)
# Per-sample (→ IID 0) and per-instance (→ IID >= 1) signals.
bce_sample = sig["bce_sample"](outputs, labels, batch_ids=ids, preds=preds)
dice_sample = sig["dice_sample"](outputs, labels, batch_ids=ids)
sig["dice_instance"](outputs, labels, batch_ids=ids, batch_idx=batch_idx, targets=flat_masks)
sig["bce_instance"](outputs, labels, batch_ids=ids, batch_idx=batch_idx)
# Custom per-sample aggregate, saved on the sample row and used for backward.
loss_per_sample = 0.5 * dice_sample + 0.5 * bce_sample
wl.save_signals({"combined_bce_dice_per_sample": loss_per_sample}, ids)
loss_per_sample.mean().backward()
optimizer.step()
Important:
Pass
batch_ids=idsto every watched criterion so Weightslab can bind values to real samples (and apply discard masking).For per-instance criterions also pass
batch_idx=...; passtargets=(a flat, sample-major list of per-instance masks) to also persist the instance GT masks atannotation_id >= 1.per_instanceannotation ids are 1-based (instance_id 0is reserved for the sample row), assigned in the order instances appear per sample.
5) Custom static / dynamic signals (@wl.signal)¶
utils/criterions.py also registers free-form signals via custom_signals()
— a static signal computed from the image, and a dynamic signal that reacts to a
logged metric:
@wl.signal(name="blue_pixels") # STATIC: from ctx.image
def compute_blue_pixels(ctx: wl.SignalContext) -> int:
img = ctx.image
...
return int(blue_mask.sum())
@wl.signal(name="blue_weighted_loss", # DYNAMIC: subscribes to a metric
subscribe_to="train_mlt_loss/CE", compute_every_n_steps=1)
def compute_blue_weighted_loss(ctx: wl.SignalContext) -> float:
blue = ctx.dataframe.get_value(ctx.origin, ctx.sample_id, "signals_blue_pixels")
return ctx.subscribed_value * (float(blue or 0) / (128 * 128))
custom_signals() # register before wl.serve()
See User Functions Reference for the @wl.signal / SignalContext reference.
Where the arrays come from in the studio¶
When the UI requests a sample for a segmentation run:
Raw image — read directly from the dataset file each time (never stored in the dataframe).
Prediction mask — loaded lazily from the array store (
arrays.h5) via a proxy, from whatever the per-sample path saved oninstance_id 0.GT label — taken from the sample row’s
targetif present, otherwise reconstructed from the dataset file; the individual per-instance masks live oninstance_id >= 1.
Practical checklist¶
Return a list of instance masks per sample and wire
collate_fn=seg_collate.Wrap per-sample criterions with
per_sample=Trueand per-instance ones withper_instance=True.Build
batch_idxfrom the same instance lists; pass it (andbatch_ids) to the per-instance criterions, plus a flattargetslist to persist instance masks.Keep Dice as a
metricand BCE as aloss; aggregate them per sample for the backward pass withwl.save_signals.