3D U-Net: Key Concepts in 5 Minutes
A concise visual reference for 3D U-Net — architecture, skip connections, loss functions, and how it applies to sparse-view CT reconstruction.
A quick visual reference covering everything you need to understand 3D U-Net — from the core architecture to how it applies to real-world sparse-view CT reconstruction.
🧬 What Is 3D U-Net?
3D U-Net is a fully convolutional encoder-decoder network that operates directly on volumetric data (H × W × D). It extends the original 2D U-Net (Ronneberger et al., 2015) by replacing every 2D operation with its 3D counterpart.
Why 3D? Medical volumes like CT and MRI are inherently 3D. Processing slice-by-slice (2D) ignores inter-slice continuity. 3D U-Net captures spatial context across all three axes simultaneously.
🏗️ Architecture
The network has three parts: an encoder that compresses the volume, a bottleneck that captures high-level semantics, and a decoder that reconstructs the volume at full resolution. Skip connections bridge the encoder and decoder at each resolution level.
flowchart TB
Input(["🔵 Input Volume\nH × W × D"]):::io --> E1
subgraph ENC["⬇️ Encoder — Contracting Path"]
direction TB
E1["Conv3D × 2 | 64 ch"]:::enc -->|"MaxPool3D ÷2"| E2
E2["Conv3D × 2 | 128 ch"]:::enc -->|"MaxPool3D ÷2"| E3
E3["Conv3D × 2 | 256 ch"]:::enc -->|"MaxPool3D ÷2"| BN
end
BN(["🔶 Bottleneck | 512 ch"]):::bn
subgraph DEC["⬆️ Decoder — Expanding Path"]
direction TB
BN -->|"ConvTranspose3D ×2"| D1
D1["Skip Concat → Conv3D × 2 | 256 ch"]:::dec -->|"ConvTranspose3D ×2"| D2
D2["Skip Concat → Conv3D × 2 | 128 ch"]:::dec -->|"ConvTranspose3D ×2"| D3
D3["Skip Concat → Conv3D × 2 | 64 ch"]:::dec
end
D3 --> Output(["🟢 Output Volume\nH × W × D"]):::io
E3 -. "skip" .-> D1
E2 -. "skip" .-> D2
E1 -. "skip" .-> D3
classDef io fill:#4A90D9,stroke:#2c5f8a,color:#fff,rx:8
classDef enc fill:#5BA85A,stroke:#3a6e39,color:#fff
classDef dec fill:#D97B4A,stroke:#9e5430,color:#fff
classDef bn fill:#9B6EBD,stroke:#6b4785,color:#fff
Skip connections copy feature maps directly from each encoder level into the corresponding decoder level. This lets the decoder recover fine spatial details that would otherwise be lost during downsampling.
🔑 Core Building Blocks
1 · Conv3D Block (repeated at every level)
1
2
3
4
5
6
7
Input
│
├─► Conv3D (3×3×3) → BatchNorm → ReLU
│
└─► Conv3D (3×3×3) → BatchNorm → ReLU
│
Output
Each block applies two convolutions with a 3×3×3 kernel — capturing local 3D context in all directions.
2 · MaxPool3D (Encoder Downsampling)
- Kernel: 2×2×2, Stride: 2
- Halves spatial resolution in all three dimensions
- Side effect: doubles receptive field at next level
3 · ConvTranspose3D (Decoder Upsampling)
- Kernel: 2×2×2, Stride: 2
- Doubles spatial resolution — the learnable inverse of pooling
- Alternatively:
Upsample(mode='trilinear') + Conv3Dfor smoother outputs
⚖️ 2D vs 3D Comparison
| Feature | 2D U-Net | 3D U-Net |
|---|---|---|
| Input shape | H × W | H × W × D |
| Conv kernel | 3×3 | 3×3×3 |
| Inter-slice context | ❌ None | ✅ Full volume |
| GPU memory | Low | High ⚠️ |
| Best for | 2D images | CT / MRI volumes |
Training 3D U-Net on full volumes is memory-intensive. In practice, use patch-based training (e.g. 64×64×64 or 96×96×96 patches) to fit within GPU memory.
📉 Loss Functions
For reconstruction tasks (e.g. sparse-view CT), a composite loss works best:
\[\mathcal{L} = \lambda_1 \underbrace{\|{\hat{x} - x}\|^2}_{\text{MSE}} + \lambda_2 \underbrace{\|A\hat{x} - y\|^2}_{\text{Data Consistency}} + \lambda_3 \underbrace{\|\phi(\hat{x}) - \phi(x)\|^2}_{\text{Perceptual}}\]| Term | What It Enforces |
|---|---|
| MSE | Pixel-wise fidelity to ground truth |
| Data Consistency | Output must match the measured sinogram projections (physics) |
| Perceptual | High-level feature similarity via a pretrained VGG network |
Where $A$ is the Radon forward projection operator, $y$ is the measured sparse projections, and $\phi(\cdot)$ extracts VGG feature maps.
Tuning the $\lambda$ weights is where the real engineering happens. A common starting point: $\lambda_1 = 1.0$, $\lambda_2 = 0.1$, $\lambda_3 = 0.01$ — then sweep from there.
🩻 Application: Sparse-View CT Reconstruction
In your FYP context, 3D U-Net is used as an image-to-image regression network, not a segmentation model:
flowchart LR
S(["Sparse-View\nSinogram"]):::input
F["FBP\n(Fast Backprojection)"]:::proc
U["3D U-Net\n(Artifact Removal)"]:::model
DC["Data-Consistency\nLayer"]:::physics
O(["Clean CT\nVolume"]):::output
S --> F --> U --> DC --> O
classDef input fill:#4A90D9,stroke:#2c5f8a,color:#fff
classDef proc fill:#888,stroke:#555,color:#fff
classDef model fill:#5BA85A,stroke:#3a6e39,color:#fff
classDef physics fill:#D97B4A,stroke:#9e5430,color:#fff
classDef output fill:#9B6EBD,stroke:#6b4785,color:#fff
- FBP quickly reconstructs a noisy, artifact-laden volume from sparse projections
- 3D U-Net learns to suppress streak artifacts and recover anatomical detail
- Data-Consistency Layer enforces that the output is physically consistent with the measured projections — the network can’t hallucinate tissue that the scanner didn’t see
This residual approach (predict correction on top of FBP baseline, not raw image) typically converges faster and produces sharper results than learning the full mapping from scratch.
▶️ Video Resources
U-Net Explained — Computerphile (great 2D foundation, 10 min)
3D U-Net Paper Walkthrough — architecture intuition
📄 Further Reading
- Original paper: Çiçek et al. 2016 — 3D U-Net (arXiv:1606.06650) — read Abstract + Fig. 1 for a 5-min pass
- 2D U-Net foundation: Ronneberger et al. 2015 (arXiv:1505.04597)
- CT Reconstruction review: Willemink & Noël 2019, European Radiology
💡 One-Sentence Intuition
The encoder compresses the volume to extract abstract features; the decoder reconstructs it to restore spatial detail; skip connections bridge the two so nothing important is lost along the way.
Part of my FYP notes series on deep learning for medical imaging. Next up: physics-guided constraints and the data-consistency module.