Post

Tips for Training Better Neural Networks

A practical checklist for training better deep learning models — from diagnosing failures to regularization and 3D medical imaging tips.

Tips for Training Better Neural Networks

Training deep learning models is both an art and a science. Here’s a practical checklist for getting better results — especially relevant for medical imaging tasks like CT reconstruction.


🔬 Step 0: Diagnose Before You Tune

Before touching any hyperparameter, understand why results are bad:

  • Underfitting → model too simple, learning rate too low, not enough epochs
  • Overfitting → training loss drops but validation loss rises → need regularization, more data, or simpler architecture
flowchart LR
    A["📊 Measure\ntrain vs val loss"]:::step --> B{Gap?}
    B -->|"Val >> Train\n(Overfitting)"| C["🛡️ Regularize\nor simplify"]:::fix
    B -->|"Both high\n(Underfitting)"| D["🔧 Bigger model\nor more epochs"]:::fix
    B -->|"Both low\n✅"| E["🎯 Tune further"]:::good

    classDef step fill:#4A90D9,stroke:#2c5f8a,color:#fff
    classDef fix fill:#D97B4A,stroke:#9e5430,color:#fff
    classDef good fill:#5BA85A,stroke:#3a6e39,color:#fff

The gap between train and val loss is your most important signal. Don’t touch hyperparameters until you know which regime you’re in.


📊 Data Preparation

  • Normalize inputs (zero mean, unit std) — especially critical for CT/medical images
  • Augment: flips, rotations, elastic deformations work well for medical imaging
  • Check for data leakage between train/val splits — a common silent killer

Data leakage (patient IDs split across train/val) is the most common silent mistake in medical imaging. Always split by patient, not by slice.


⚙️ Architecture & Initialization

  • Use BatchNorm (or InstanceNorm for medical images — often better for small batch sizes)
  • Use ReLU or LeakyReLU in hidden layers; avoid sigmoid/tanh in deep networks (vanishing gradients)
  • Use proper weight initialization — He init for ReLU-based networks

For medical imaging with small batch sizes (1–4), InstanceNorm almost always outperforms BatchNorm. BatchNorm statistics become unreliable at small batch sizes.


📉 Training Dynamics

  • Learning rate is the most important hyperparameter — use a learning rate finder to pick a good starting point
  • Use LR schedulers: CosineAnnealingLR or ReduceLROnPlateau work well in practice
  • Apply gradient clipping if you see exploding gradients (common in RNNs, sometimes in deep CNNs)
  • Start with Adam; if you see overfitting, switch to AdamW (has weight decay built in)
OptimizerWhen to Use
AdamDefault choice, fast convergence
AdamWWhen overfitting is a problem (built-in weight decay)
SGD + momentumFine-tuning, when you want better generalization

🛡️ Regularization

  • L2 / weight decay via AdamW — usually the first thing to try
  • Dropout (0.2–0.5): effective on FC layers; use carefully in conv layers
  • Early stopping based on validation loss — saves time and prevents overfitting

✅ Sanity Checks (Do These Every Time)

  1. Overfit a single batch first — if the model can’t memorize one batch, the architecture or loss function is broken
  2. Log both training and validation loss every epoch — don’t fly blind
  3. Save checkpoints, not just the final model — you’ll want to roll back

Overfitting a single batch is the most important sanity check. If your model can’t memorize 4 examples, something is fundamentally wrong — check your loss function, forward pass, and data loading.


🩻 For 3D Medical Imaging (e.g., CT Reconstruction)

A few extra things that matter:

  • InstanceNorm often outperforms BatchNorm when batch size is small (common with 3D volumes)
  • Multi-component loss (e.g., MSE + perceptual + data-consistency) can help the model learn structure and physics simultaneously:
\[\mathcal{L} = \lambda_1 \|\hat{x} - x\|^2 + \lambda_2 \|A\hat{x} - y\|^2 + \lambda_3 \|\phi(\hat{x}) - \phi(x)\|^2\]
  • Patch-based training — full 3D volumes rarely fit in memory; train on crops (e.g. 64×64×64)
  • Use mixed precision (torch.cuda.amp) to fit larger batches and speed up training

The biggest mistake is tuning hyperparameters before fixing data issues or a broken training loop. Get the basics right first, then optimize.


Part of my deep learning engineering notes. Next: learning rate schedules and when to use each one.

This post is licensed under CC BY 4.0 by the author.