Medical Imaging Fine-Tuning & Ablation Study
By Fouzil Ali Β· Published June 29, 2026
Systematic ablation study on BreastMNIST comparing augmentation strategies, LR schedules, and backbone architectures (ResNet, EfficientNet, ViT) to optimize anti-overfitting performance.
- medical-imaging
- fine-tuning
- ablation-study
- classification
- regularization
Inside this notebook
# Medical Imaging Fine-Tuning: Anti-Overfitting Pipeline + Ablations ### Dataset: BreastMNIST (MedMNIST v2) β 780 train / 106 val / 421 test β Binary breast ultrasound classification --- ## π¬ SOTA Context (2024-2026 Literature) | Finding | Evidence | |---|---| | **DINOv2 ViT-B/14** leads end-to-end fine-tuning on small medical sets | Self-supervised pretraining key for domain gap | | **CNNs more reliable in very low-resource settings** (<500 images) | ResNet18/EfficientNet-B0 preferred baselines | | **CutMix/CutOut best for ViTs**, simpler aug for CNNs | Augmentation policy should match architecture | | **Warmup + cosine decay** is the dominant LR schedule for fine-tuning | Linear warmup stabilises early gradient flow | | **BreastMNIST ViT-B/16 AUC β 0.94** is the published strong baseline | MedMNIST v2 leaderboard | | **PEFT / LoRA** is the top anti-overfitting strategy for transformer fine-tuning | Reduces trainable params from 86M β ~1M | --- ## π§° Anti-Overfitting Stack Used in This Pipeline | Technique | Where Applied | Why | |---|---|---| | **ImageNet pretraining** | All backbones | Warm start; reduces need for data | | **Label smoothing (Ξ΅=0.1)** | CrossEntropyLoss | Prevents over-confident predictions | | **Weight decay (1e-2)** | AdamW optimizer | L2 regularization | | **Dropout (0.3)** | Classification head | Stochastic regularization | | **Gradient clipping (norm=1.0)** | Every backward pass | Stability on small batches | | **Early stopping (patience=8)** |β¦
# Medical Imaging Fine-Tuning β BreastMNIST Ablation Study This notebook runs a systematic ablation study on the **BreastMNIST** dataset (binary classification: malignant vs normal/benign ultrasound images) using three axes of variation: 1. **Augmentation strategy** β none / basic / heavy (with Mixup) 2. **LR schedule** β constant / cosine / warmup_cosine 3. **Backbone architecture** β ResNet-18 / EfficientNet-B0 / ViT-Tiny **Key outputs:** per-experiment AUC-ROC + balanced accuracy, learning curves, overfitting-gap analysis, and a saved checkpoint for the best configuration.
!pip install -q medmnist timm torchmetrics scikit-learn matplotlib seaborn tqdmimport warnings; warnings.filterwarnings('ignore')
import os, random, json, time
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, roc_curve, brier_score_loss
import timm
β¦Device: cuda | IMG_SIZE: 224
Dataset: BreastMNIST | Task: binary-class | Classes: {'0': 'malignant', '1': 'normal, benign'} | n_channels: 1
Train/Val/Test split sizes: {'train': 546, 'val': 78, 'test': 156}# ββ Augmentation transforms ββββββββββββββββββββββββββββββββββββββββββββββββ
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_transforms(aug_mode: str, img_size: int = IMG_SIZE):
"""Return (train_tf, val_tf) for aug_mode in {none, basic, heavy}."""
to3ch = T.Grayscale(3)
resize = T.Resize((img_size, img_size))
totensor= T.ToTensor()
norm = T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
val_tf = T.Compose([to3ch, resize, totensor, norm])
if aug_mode == 'none':
train_tf = val_tf
elif aug_mode == 'basic':
train_tf = T.Compose([
β¦Downloading https://zenodo.org/records/10519652/files/breastmnist_224.npz?download=1 to /tmp/medmnist/breastmnist_224.npz Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz Batch shape: torch.Size([32, 3, 224, 224]) | Label shape: torch.Size([32, 1]) Train batches: 18 | Val batches: 3 | Test batches: 5 Class distribution in batch β 0 (malignant): 7 1 (normal/benign): 25
## Dataset Exploration β BreastMNIST
fig, axes = plt.subplots(3, 8, figsize=(16, 7))
class_names = ['Malignant', 'Normal/Benign']
colors = ['#e74c3c', '#2ecc71']
for row_idx, aug_mode in enumerate(['none', 'basic', 'heavy']):
train_tf, _ = get_transforms(aug_mode)
ds = BreastMNIST(split='train', transform=train_tf, download=True, root=DATA_ROOT, size=IMG_SIZE)
shown = 0
for img, label in ds:
if shown >= 8: break
ax = axes[row_idx, shown]
# Un-normalize for display
img_np = img.permute(1,2,0).numpy()
img_np = img_np * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
img_np = np.clip(img_np, 0, 1)
ax.imshow(img_np[:, :, 0], cmap='gray')
lbl_idx = int(label.item()) if hasattr(label, 'item') else int(label)
ax.set_title(class_names[lbl_idx], fontsize=7, color=colors[lbl_idx], fontweight='bold')
β¦Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz Dataset visualization saved.
## Model Factory β Three Backbones
BACKBONE_CONFIGS = {
'resnet18': dict(pretrained=True, drop_rate=DROPOUT),
'efficientnet_b0': dict(pretrained=True, drop_rate=DROPOUT),
'vit_tiny_patch16_224': dict(pretrained=True, drop_rate=DROPOUT),
}
def create_model(backbone: str = 'efficientnet_b0', num_classes: int = 2) -> nn.Module:
cfg = BACKBONE_CONFIGS.get(backbone, dict(pretrained=True, drop_rate=DROPOUT))
model = timm.create_model(backbone, num_classes=num_classes, **cfg)
# Handle non-square ViT at smaller IMG_SIZE via positional-embed interpolation
if 'vit' in backbone and IMG_SIZE != 224:
model = timm.create_model(backbone, num_classes=num_classes,
img_size=IMG_SIZE, **cfg)
return model.to(DEVICE)
# Print parameter counts
print(f"{'Backbone':<30} {'Params (M)':>12}")
β¦This is a preview. Open the live notebook to see all 24 cells with their charts and full outputs, or fork it into your own Clusy workspace.