Post

3D U-Net: Key Concepts in 5 Minutes

A concise visual reference for 3D U-Net — architecture, skip connections, loss functions, and how it applies to sparse-view CT reconstruction.

3D U-Net: Key Concepts in 5 Minutes

A quick visual reference covering everything you need to understand 3D U-Net — from the core architecture to how it applies to real-world sparse-view CT reconstruction.


🧬 What Is 3D U-Net?

3D U-Net is a fully convolutional encoder-decoder network that operates directly on volumetric data (H × W × D). It extends the original 2D U-Net (Ronneberger et al., 2015) by replacing every 2D operation with its 3D counterpart.

Why 3D? Medical volumes like CT and MRI are inherently 3D. Processing slice-by-slice (2D) ignores inter-slice continuity. 3D U-Net captures spatial context across all three axes simultaneously.


🏗️ Architecture

The network has three parts: an encoder that compresses the volume, a bottleneck that captures high-level semantics, and a decoder that reconstructs the volume at full resolution. Skip connections bridge the encoder and decoder at each resolution level.

flowchart TB
    Input(["🔵 Input Volume\nH × W × D"]):::io --> E1

    subgraph ENC["⬇️  Encoder  — Contracting Path"]
        direction TB
        E1["Conv3D × 2  |  64 ch"]:::enc -->|"MaxPool3D ÷2"| E2
        E2["Conv3D × 2  |  128 ch"]:::enc -->|"MaxPool3D ÷2"| E3
        E3["Conv3D × 2  |  256 ch"]:::enc -->|"MaxPool3D ÷2"| BN
    end

    BN(["🔶 Bottleneck  |  512 ch"]):::bn

    subgraph DEC["⬆️  Decoder  — Expanding Path"]
        direction TB
        BN -->|"ConvTranspose3D ×2"| D1
        D1["Skip Concat → Conv3D × 2  |  256 ch"]:::dec -->|"ConvTranspose3D ×2"| D2
        D2["Skip Concat → Conv3D × 2  |  128 ch"]:::dec -->|"ConvTranspose3D ×2"| D3
        D3["Skip Concat → Conv3D × 2  |  64 ch"]:::dec
    end

    D3 --> Output(["🟢 Output Volume\nH × W × D"]):::io

    E3 -. "skip" .-> D1
    E2 -. "skip" .-> D2
    E1 -. "skip" .-> D3

    classDef io fill:#4A90D9,stroke:#2c5f8a,color:#fff,rx:8
    classDef enc fill:#5BA85A,stroke:#3a6e39,color:#fff
    classDef dec fill:#D97B4A,stroke:#9e5430,color:#fff
    classDef bn fill:#9B6EBD,stroke:#6b4785,color:#fff

Skip connections copy feature maps directly from each encoder level into the corresponding decoder level. This lets the decoder recover fine spatial details that would otherwise be lost during downsampling.


🔑 Core Building Blocks

1 · Conv3D Block (repeated at every level)

1
2
3
4
5
6
7
Input
  │
  ├─► Conv3D (3×3×3) → BatchNorm → ReLU
  │
  └─► Conv3D (3×3×3) → BatchNorm → ReLU
                                      │
                                   Output

Each block applies two convolutions with a 3×3×3 kernel — capturing local 3D context in all directions.

2 · MaxPool3D (Encoder Downsampling)

  • Kernel: 2×2×2, Stride: 2
  • Halves spatial resolution in all three dimensions
  • Side effect: doubles receptive field at next level

3 · ConvTranspose3D (Decoder Upsampling)

  • Kernel: 2×2×2, Stride: 2
  • Doubles spatial resolution — the learnable inverse of pooling
  • Alternatively: Upsample(mode='trilinear') + Conv3D for smoother outputs

⚖️ 2D vs 3D Comparison

Feature2D U-Net3D U-Net
Input shapeH × WH × W × D
Conv kernel3×33×3×3
Inter-slice context❌ None✅ Full volume
GPU memoryLowHigh ⚠️
Best for2D imagesCT / MRI volumes

Training 3D U-Net on full volumes is memory-intensive. In practice, use patch-based training (e.g. 64×64×64 or 96×96×96 patches) to fit within GPU memory.


📉 Loss Functions

For reconstruction tasks (e.g. sparse-view CT), a composite loss works best:

\[\mathcal{L} = \lambda_1 \underbrace{\|{\hat{x} - x}\|^2}_{\text{MSE}} + \lambda_2 \underbrace{\|A\hat{x} - y\|^2}_{\text{Data Consistency}} + \lambda_3 \underbrace{\|\phi(\hat{x}) - \phi(x)\|^2}_{\text{Perceptual}}\]
TermWhat It Enforces
MSEPixel-wise fidelity to ground truth
Data ConsistencyOutput must match the measured sinogram projections (physics)
PerceptualHigh-level feature similarity via a pretrained VGG network

Where $A$ is the Radon forward projection operator, $y$ is the measured sparse projections, and $\phi(\cdot)$ extracts VGG feature maps.

Tuning the $\lambda$ weights is where the real engineering happens. A common starting point: $\lambda_1 = 1.0$, $\lambda_2 = 0.1$, $\lambda_3 = 0.01$ — then sweep from there.


🩻 Application: Sparse-View CT Reconstruction

In your FYP context, 3D U-Net is used as an image-to-image regression network, not a segmentation model:

flowchart LR
    S(["Sparse-View\nSinogram"]):::input
    F["FBP\n(Fast Backprojection)"]:::proc
    U["3D U-Net\n(Artifact Removal)"]:::model
    DC["Data-Consistency\nLayer"]:::physics
    O(["Clean CT\nVolume"]):::output

    S --> F --> U --> DC --> O

    classDef input  fill:#4A90D9,stroke:#2c5f8a,color:#fff
    classDef proc   fill:#888,stroke:#555,color:#fff
    classDef model  fill:#5BA85A,stroke:#3a6e39,color:#fff
    classDef physics fill:#D97B4A,stroke:#9e5430,color:#fff
    classDef output fill:#9B6EBD,stroke:#6b4785,color:#fff
  1. FBP quickly reconstructs a noisy, artifact-laden volume from sparse projections
  2. 3D U-Net learns to suppress streak artifacts and recover anatomical detail
  3. Data-Consistency Layer enforces that the output is physically consistent with the measured projections — the network can’t hallucinate tissue that the scanner didn’t see

This residual approach (predict correction on top of FBP baseline, not raw image) typically converges faster and produces sharper results than learning the full mapping from scratch.


▶️ Video Resources

U-Net Explained — Computerphile (great 2D foundation, 10 min)

3D U-Net Paper Walkthrough — architecture intuition


📄 Further Reading


💡 One-Sentence Intuition

The encoder compresses the volume to extract abstract features; the decoder reconstructs it to restore spatial detail; skip connections bridge the two so nothing important is lost along the way.


Part of my FYP notes series on deep learning for medical imaging. Next up: physics-guided constraints and the data-consistency module.

This post is licensed under CC BY 4.0 by the author.