Post

Deep Learning for 3D CT Reconstruction: A Technical Guide

A comprehensive technical guide to 3D U-Net-based CT reconstruction, covering sinogram-to-image domain transfer, patch-based training strategies, and a full PyTorch implementation.

Deep Learning for 3D CT Reconstruction: A Technical Guide

A comprehensive guide to 3D U-Net-based CT reconstruction for newcomers to medical image processing. We cover core concepts, domain transfer strategies, the memory/data-consistency challenges, a complete PyTorch framework, and the clinical relevance of sparse-view CT reconstruction.


🩻 1. Introduction: A New Computational Paradigm for Medical Imaging

Computed Tomography (CT) has been a cornerstone of modern medical diagnostics since Godfrey Hounsfield built the first CT scanner in the 1970s. The technology has evolved through several generations — from translate-rotate systems to multi-slice spiral CT and dual-energy CT.

Traditional CT reconstruction, exemplified by Filtered Back Projection (FBP), has been the industrial standard for decades. FBP is derived from rigorous mathematics (the Fourier Slice Theorem), is computationally fast, and produces predictable results. However, FBP is highly sensitive to noise and assumes complete, ideal projection data. Under low-dose CT (LDCT), sparse-view CT, or metal implant interference, FBP reconstructions suffer from severe streak artifacts.

In recent years, deep learning (DL) — particularly convolutional neural networks (CNNs) — has brought a third paradigm shift to CT reconstruction. Unlike traditional algorithms, deep learning methods learn the non-linear mapping from sinogram to image directly from large datasets, producing high-quality reconstructions in a fraction of the time.

flowchart LR
    FBP["⚡ FBP\n(Fast, artifacts\nat low dose)"]:::era --> IR["🔁 Iterative\nReconstruction\n(Better, slow)"]:::era --> DL["🤖 Deep Learning\n(Fast + best\nquality)"]:::era

    classDef era fill:#4A90D9,stroke:#2c5f8a,color:#fff

📡 2. Core Concepts: Bridging the Domain Gap

Before diving into deep learning architectures, it is essential to build an intuitive understanding of CT imaging physics and data formats. The core task of CT reconstruction — solving the inverse Radon transform — is not merely a denoising problem; it is a physical inversion problem involving geometric transformation.

2.1 The Intuitive Duality of Sinograms and CT Images

The Shadow-and-Object Analogy

  • CT image (Object): A complex semi-transparent object (e.g., a cross-section of the human chest). We want to know the density (attenuation coefficient) at every interior point. This is the final result in the “image domain,” typically represented as $f(x, y)$ or $f(x, y, z)$.

  • Projection: You shine a powerful flashlight from one side of the object; a shadow appears on the opposite wall. This shadow records the total amount of X-ray absorbed as the beam passed through — the line integral of attenuation.

  • Sinogram: You walk around the object with the flashlight, recording the shadow at each angular step. Stacking all these shadow profiles sequentially into a 2D image produces the sinogram.

    • Horizontal axis: Detector channel position (s)
    • Vertical axis: Rotation angle (θ)

Why “Sinogram”?

Consider an infinitesimally small, opaque point source offset from the rotation center. Tracing this point’s projection position across all angles in $(s, θ)$ coordinates traces a perfect sinusoidal curve:

\[s = x\cos\theta + y\sin\theta\]

Every point (voxel) in the CT image corresponds to a unique sinusoidal curve in the sinogram. This point-to-line / line-to-point duality is the key to understanding why CT reconstruction is hard.

2.2 The Deep Learning Bottleneck: Non-Locality

In standard computer vision tasks, U-Net and CNN architectures excel because the problems involve local features. However, training a standard CNN to map sinograms directly to CT images typically fails to converge or produces severe artifacts.

The fundamental reason is non-locality:

  1. Local receptive fields of convolution: A CNN’s core operation is convolution, where a small kernel (e.g., 3×3) extracts local features.
  2. Global dependency of the Radon transform: Recovering a single pixel in the reconstructed image requires information from an entire sinusoidal curve spanning the full sinogram.
  3. Exponential learning difficulty: To cover the entire sinogram, a standard CNN would need an enormous number of layers.

2.3 The Solution: Domain Transform Strategies

The core idea is divide and conquer: use traditional algorithms for the deterministic geometric transformation, and use deep learning for complex statistical processing and image quality optimization.

Core Strategy: FBP-ConvNet (Hybrid Reconstruction)

This is the most popular, robust, and beginner-friendly strategy, formally proposed by Jin et al. in 2017 in IEEE Transactions on Image Processing.

Key insight: Do not ask the neural network to learn geometric rotation and back-projection. Since FBP already perfectly handles the “sinogram → image” geometric transformation, leverage it.

Workflow:

flowchart LR
    S(["📡 Sinogram\n(noisy/sparse)"]):::input --> FBP["⚡ FBP\n(domain transform)"]:::proc --> NET["🤖 3D U-Net\n(artifact removal)"]:::model --> O(["🩻 Clean CT\nVolume"]):::output

    classDef input fill:#4A90D9,stroke:#2c5f8a,color:#fff
    classDef proc fill:#888,stroke:#555,color:#fff
    classDef model fill:#5BA85A,stroke:#3a6e39,color:#fff
    classDef output fill:#9B6EBD,stroke:#6b4785,color:#fff
  1. Input (Sinogram): Raw sinogram, possibly containing noise, artifacts, or undersampling
  2. Domain transform (FBP): Converts the sinogram to the image domain. Due to input deficiencies (e.g., sparse views), the FBP image may contain severe streak artifacts — but anatomical structure’s position and geometry are correct
  3. Image refinement (3D U-Net): The network’s task is an “image-to-image” denoising and artifact removal task. It learns the residual mapping $I_\text{target} - I_\text{FBP}$

This residual approach (predict correction on top of FBP baseline) typically converges faster and produces sharper results than learning the full mapping from scratch.

Advanced Strategy: Dual Domain Learning

For cases requiring even higher quality — particularly metal artifact reduction (MAR) or extreme sparse-view reconstruction — dual-domain networks (e.g., DuDoNet) address this.

Workflow:

  1. Sinogram repair network: A U-Net repairs missing sinogram data
  2. Domain transform: Differentiable FBP converts the repaired sinogram to the image domain
  3. Image repair network: Denoising and artifact removal in the image domain
  4. Data consistency (DC): Re-project the image-domain output back to sinogram space and check consistency with original measurements

Strategy comparison:

StrategyInput → OutputAdvantagesDisadvantagesBest Use Case
Direct inversionSinogram → ImageEnd-to-end, no physics model neededExtremely hard to trainSimple small-image experiments
FBP-ConvNetSinogram → FBP → ImageFast training, good results, easy to implementCannot recover information lost in FBPLow-dose denoising, mild sparse reconstruction
Dual domainSino → SinoNet → ImgNet → ImageHighest accuracy, physical consistencyComplex architectureMetal artifact removal, extreme sparse reconstruction

🔧 3. Technical Challenges and Solutions

Scaling from 2D to 3D CT reconstruction brings exponential engineering challenges. The two most critical issues are GPU memory explosion and data inconsistency.

3.1 The 3D Memory Problem: Patch-Based Training

Why Is 3D So Expensive?

A standard medical 3D CT volume is typically 512×512×Z slices (Z ranging from 200 to 500). Consider an input of 512×512×256:

  • At float32 precision, the input data alone occupies ~256 MB
  • Training a full-volume 3D U-Net can require hundreds of GB of GPU memory
  • Even the top-tier NVIDIA A100 (80 GB) cannot handle full-resolution 3D U-Net training on complete volumes

Typical patch sizes: 64×64×64, 96×96×96, or 128×128×64, depending on available GPU memory.

Solution: 3D Patch-Based Strategy

A. Training Strategy

Under the FBP-ConvNet architecture:

  1. Full-volume FBP: First, perform FBP reconstruction on the complete volume (on CPU or GPU). FBP is a global operation and cannot be applied at the patch level.
  2. Random cropping: During training iterations, randomly crop spatially aligned small cubes from the FBP input volume and the corresponding ground truth volume.

B. Inference Strategy

Naively assembling patches produces visible “blocking artifacts” because CNN predictions are less accurate near patch boundaries.

  1. Sliding window: Use a window of the same size as the training patch, with 25–50% overlap
  2. Gaussian blending: In overlapping regions, use weighted averaging rather than simple averaging — eliminates seam artifacts and produces smooth results

3.2 Data Inconsistency and Dimension Matching

Physical Constraints and Data Consistency

Negative value handling: Neural networks may output negative values. Physically, CT values (Hounsfield Units) plus 1000 represent attenuation coefficients and cannot be negative (air approaches -1000 HU). Apply ReLU activation at the output, or clamp the output before loss computation to enforce non-negativity.

Data leakage between patients in train/val splits is the #1 silent error in medical imaging datasets. Always split by patient, not by slice.


💻 4. Core Code Implementation (PyTorch)

A complete code framework based on the FBP-ConvNet strategy, combining odl for geometric operations and FBP, and PyTorch for deep learning.

Dependencies:

1
pip install torch torchvision odl astra-toolbox scikit-image matplotlib

4.1 Dataset Class: Simulated Data Generation and ODL Integration

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from torch.utils.data import Dataset
import odl
import numpy as np

class CTReconstructionDataset(Dataset):
    def __init__(self, num_samples=100, size=128, mode='train', noise_level=0.5):
        self.num_samples = num_samples
        self.size = size
        self.mode = mode
        self.noise_level = noise_level

        # Define ODL geometry space
        self.space = odl.uniform_discr(
            min_pt=[-20, -20, -20], max_pt=[20, 20, 20], 
            shape=[size, size, size], dtype='float32'
        )

        geometry = odl.tomo.parallel_beam_geometry(
            self.space, num_angles=180
        )
        
        try:
            self.ray_transform = odl.tomo.RayTransform(self.space, geometry, impl='astra_cuda')
        except:
            print("Warning: ASTRA CUDA not found, falling back to CPU.")
            self.ray_transform = odl.tomo.RayTransform(self.space, geometry)
        
        self.fbp_operator = odl.tomo.fbp_op(self.ray_transform)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate Ground Truth (Phantom)
        phantom = odl.phantom.shepp_logan(self.space, modified=True)
        target_np = phantom.asarray()

        # Simulate sinogram
        clean_sinogram = self.ray_transform(phantom)
        
        # Add noise (simulate low-dose CT)
        rng = np.random.default_rng(seed=idx)
        noise = rng.normal(0, self.noise_level, clean_sinogram.shape)
        noisy_sinogram = clean_sinogram + noise

        # Domain transform: Sinogram -> FBP Image
        noisy_fbp_np = self.fbp_operator(noisy_sinogram).asarray()

        # Convert to Tensor
        input_tensor = torch.from_numpy(noisy_fbp_np).unsqueeze(0)
        target_tensor = torch.from_numpy(target_np).unsqueeze(0)

        return input_tensor, target_tensor

4.2 Model Architecture: 3D U-Net with Residual Learning

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(Convolution => BatchNorm => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet3D(nn.Module):
    def __init__(self, n_channels=1, n_classes=1):
        super(UNet3D, self).__init__()
        # Encoder
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool3d(2), DoubleConv(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool3d(2), DoubleConv(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool3d(2), DoubleConv(128, 256))  # Bottleneck

        # Decoder
        self.up1 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(128, 64)
        self.up3 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(64, 32)

        self.outc = nn.Conv3d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        # Decoder with skip connections
        u = self.up1(x4)
        u = torch.cat([x3, u], dim=1)
        u = self.conv1(u)

        u = self.up2(u)
        u = torch.cat([x2, u], dim=1)
        u = self.conv2(u)

        u = self.up3(u)
        u = torch.cat([x1, u], dim=1)
        u = self.conv3(u)
        
        # Predict residual (artifacts/noise)
        residual = self.outc(u)
        
        # Global residual connection: Output = Input + Residual
        return x + residual

4.3 Evaluation Metrics: PSNR and SSIM

Two standard metrics quantify reconstruction quality:

  • PSNR (Peak Signal-to-Noise Ratio): Measures pixel-level error
  • SSIM (Structural Similarity Index): Measures perceptual structural similarity
\[\text{PSNR} = 20 \cdot \log_{10}\left(\frac{d_{\max}}{\sqrt{\text{MSE}}}\right)\]
1
2
3
4
5
def calculate_psnr(img1, img2, data_range=1.0):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return 100
    return 20 * torch.log10(data_range / torch.sqrt(mse))

4.4 Training Loop: Automatic Mixed Precision (AMP)

To address the 3D memory bottleneck, we use PyTorch’s AMP, which dramatically reduces memory usage and accelerates training by using float16 for selected operations.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast

def train_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet3D(n_channels=1, n_classes=1).to(device)
    
    dataset = CTReconstructionDataset(num_samples=100, size=64, mode='train')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()
    scaler = GradScaler()

    model.train()
    for epoch in range(20):
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if i % 10 == 0:
                with torch.no_grad():
                    psnr = calculate_psnr(outputs.float(), targets.float())
                print(f"Epoch {epoch+1}, Step {i}, Loss: {loss.item():.6f}, PSNR: {psnr:.2f} dB")

Use torch.cuda.amp.autocast() + GradScaler for 3D medical imaging — it’s often the difference between fitting in GPU memory and not.


🏥 5. Application: Sparse-View CT Reconstruction

5.1 What Is Sparse-View CT?

Standard clinical CT scans acquire 1,000 to 2,000 projection angles (views) during one rotation. Sparse-view CT dramatically reduces the number of acquired angles (e.g., only 50 or 100 views). Reducing the view count by a factor of 10 theoretically reduces the radiation dose by the same factor.

5.2 The Challenge and the Deep Learning Solution

The challenge: Nyquist sampling theory dictates that insufficient sampling prevents perfect signal recovery. FBP applied to sparse-view data produces severe streak artifacts — sharp radial streaks that obscure tumors, vessels, and other fine structures.

Why deep learning helps:

  1. Artifact recognition and removal: Streak artifacts have distinctive global geometric signatures (straight radial lines), while anatomical structures are locally connected and organically shaped. A well-trained U-Net can precisely identify and remove streak artifacts while preserving true structural detail.

  2. Sinogram inpainting: Dual-domain networks can predict missing projection angles in the sinogram domain, transforming a sparse sinogram into a synthesized full-angle sinogram.

5.3 Clinical Significance

In lung cancer screening, dental CBCT, and cardiac imaging, sparse-view CT reconstruction has demonstrated enormous potential. It makes low-dose screening clinically viable, reduces the risk of radiation-induced cancer, and is particularly valuable for patients requiring frequent follow-up CT scans.

A 10× reduction in view count → ~10× dose reduction. For pediatric or frequent-imaging patients, this is clinically significant.


💡 6. Conclusion

3D U-Net-based CT reconstruction is about deep integration of traditional physics-based models with modern data-driven methods:

  1. FBP (traditional algorithm): Handles the deterministic physical geometry, providing the structural skeleton of the image.
  2. U-Net (deep learning): Handles complex statistical noise and non-linear artifacts, refining the image quality.

For beginners, mastering the FBP-ConvNet architecture is the best entry point. From there, patch-based training solves the memory bottleneck, and dual-domain learning ensures physical consistency for more complex artifacts.


Part of my FYP notes series on deep learning for medical imaging. Next: data-consistency layers and physics-guided constraints.

This post is licensed under CC BY 4.0 by the author.