Medical Imaging Fine-Tuning & Ablation Study

By Fouzil Ali Β· Published June 29, 2026

Systematic ablation study on BreastMNIST comparing augmentation strategies, LR schedules, and backbone architectures (ResNet, EfficientNet, ViT) to optimize anti-overfitting performance.

  • medical-imaging
  • fine-tuning
  • ablation-study
  • classification
  • regularization
24 cells1 experiment25 views0 forks

Inside this notebook

# Medical Imaging Fine-Tuning: Anti-Overfitting Pipeline + Ablations ### Dataset: BreastMNIST (MedMNIST v2) β€” 780 train / 106 val / 421 test β€” Binary breast ultrasound classification --- ## πŸ”¬ SOTA Context (2024-2026 Literature) | Finding | Evidence | |---|---| | **DINOv2 ViT-B/14** leads end-to-end fine-tuning on small medical sets | Self-supervised pretraining key for domain gap | | **CNNs more reliable in very low-resource settings** (<500 images) | ResNet18/EfficientNet-B0 preferred baselines | | **CutMix/CutOut best for ViTs**, simpler aug for CNNs | Augmentation policy should match architecture | | **Warmup + cosine decay** is the dominant LR schedule for fine-tuning | Linear warmup stabilises early gradient flow | | **BreastMNIST ViT-B/16 AUC β‰ˆ 0.94** is the published strong baseline | MedMNIST v2 leaderboard | | **PEFT / LoRA** is the top anti-overfitting strategy for transformer fine-tuning | Reduces trainable params from 86M β†’ ~1M | --- ## 🧰 Anti-Overfitting Stack Used in This Pipeline | Technique | Where Applied | Why | |---|---|---| | **ImageNet pretraining** | All backbones | Warm start; reduces need for data | | **Label smoothing (Ξ΅=0.1)** | CrossEntropyLoss | Prevents over-confident predictions | | **Weight decay (1e-2)** | AdamW optimizer | L2 regularization | | **Dropout (0.3)** | Classification head | Stochastic regularization | | **Gradient clipping (norm=1.0)** | Every backward pass | Stability on small batches | | **Early stopping (patience=8)** |…

# Medical Imaging Fine-Tuning β€” BreastMNIST Ablation Study This notebook runs a systematic ablation study on the **BreastMNIST** dataset (binary classification: malignant vs normal/benign ultrasound images) using three axes of variation: 1. **Augmentation strategy** β€” none / basic / heavy (with Mixup) 2. **LR schedule** β€” constant / cosine / warmup_cosine 3. **Backbone architecture** β€” ResNet-18 / EfficientNet-B0 / ViT-Tiny **Key outputs:** per-experiment AUC-ROC + balanced accuracy, learning curves, overfitting-gap analysis, and a saved checkpoint for the best configuration.

!pip install -q medmnist timm torchmetrics scikit-learn matplotlib seaborn tqdm
import warnings; warnings.filterwarnings('ignore')
import os, random, json, time
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, roc_curve, brier_score_loss

import timm
…
Device: cuda | IMG_SIZE: 224
Dataset: BreastMNIST | Task: binary-class | Classes: {'0': 'malignant', '1': 'normal, benign'} | n_channels: 1
Train/Val/Test split sizes: {'train': 546, 'val': 78, 'test': 156}
# ── Augmentation transforms ────────────────────────────────────────────────
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def get_transforms(aug_mode: str, img_size: int = IMG_SIZE):
    """Return (train_tf, val_tf) for aug_mode in {none, basic, heavy}."""
    to3ch   = T.Grayscale(3)
    resize  = T.Resize((img_size, img_size))
    totensor= T.ToTensor()
    norm    = T.Normalize(IMAGENET_MEAN, IMAGENET_STD)

    val_tf = T.Compose([to3ch, resize, totensor, norm])

    if aug_mode == 'none':
        train_tf = val_tf

    elif aug_mode == 'basic':
        train_tf = T.Compose([
…
Downloading https://zenodo.org/records/10519652/files/breastmnist_224.npz?download=1 to /tmp/medmnist/breastmnist_224.npz
Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz
Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz
Batch shape: torch.Size([32, 3, 224, 224]) | Label shape: torch.Size([32, 1])
Train batches: 18 | Val batches: 3 | Test batches: 5
Class distribution in batch β€” 0 (malignant): 7  1 (normal/benign): 25

## Dataset Exploration β€” BreastMNIST

fig, axes = plt.subplots(3, 8, figsize=(16, 7))
class_names = ['Malignant', 'Normal/Benign']
colors = ['#e74c3c', '#2ecc71']

for row_idx, aug_mode in enumerate(['none', 'basic', 'heavy']):
    train_tf, _ = get_transforms(aug_mode)
    ds = BreastMNIST(split='train', transform=train_tf, download=True, root=DATA_ROOT, size=IMG_SIZE)
    shown = 0
    for img, label in ds:
        if shown >= 8: break
        ax = axes[row_idx, shown]
        # Un-normalize for display
        img_np = img.permute(1,2,0).numpy()
        img_np = img_np * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
        img_np = np.clip(img_np, 0, 1)
        ax.imshow(img_np[:, :, 0], cmap='gray')
        lbl_idx = int(label.item()) if hasattr(label, 'item') else int(label)
        ax.set_title(class_names[lbl_idx], fontsize=7, color=colors[lbl_idx], fontweight='bold')
…
Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz
Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz
Using downloaded and verified file: /tmp/medmnist/breastmnist_224.npz
Dataset visualization saved.

## Model Factory β€” Three Backbones

BACKBONE_CONFIGS = {
    'resnet18':               dict(pretrained=True, drop_rate=DROPOUT),
    'efficientnet_b0':        dict(pretrained=True, drop_rate=DROPOUT),
    'vit_tiny_patch16_224':   dict(pretrained=True, drop_rate=DROPOUT),
}

def create_model(backbone: str = 'efficientnet_b0', num_classes: int = 2) -> nn.Module:
    cfg = BACKBONE_CONFIGS.get(backbone, dict(pretrained=True, drop_rate=DROPOUT))
    model = timm.create_model(backbone, num_classes=num_classes, **cfg)
    # Handle non-square ViT at smaller IMG_SIZE via positional-embed interpolation
    if 'vit' in backbone and IMG_SIZE != 224:
        model = timm.create_model(backbone, num_classes=num_classes,
                                  img_size=IMG_SIZE, **cfg)
    return model.to(DEVICE)


# Print parameter counts
print(f"{'Backbone':<30} {'Params (M)':>12}")
…

This is a preview. Open the live notebook to see all 24 cells with their charts and full outputs, or fork it into your own Clusy workspace.

Medical Imaging Fine-Tuning & Ablation Study | Clusy