Skip to content

PIUnet Architecture

PIUnet is a 3D convolutional + temporal attention network for multi-frame super-resolution, designed to be permutation-invariant (output does not depend on input frame ordering) and to produce calibrated uncertainty estimates. This article covers the architectural details, the variants we built, and the fundamental weaknesses that motivated moving to modern architectures.

Data Flow

Input: (B, T, H, W) -- a stack of T low-resolution frames (T=9 for LWIR). Output: (B, 1, 3H, 3W) -- a single 3x super-resolved image, plus (B, 1, 3H, 3W) uncertainty map.

The data flows through five stages:

LR stack (B,T,H,W)
  |
  v
[Input embedding]  -- Conv3d(1 -> N_feat, k=[1,3,3]) + MHA + BN
  |
  v
[16x TEFA blocks]  -- Feature extraction with temporal attention + SE gating
  |
  v
[Mid processing]   -- Conv3d + MHA + BN + residual skip from input embedding
  |
  v
[TERN]             -- Implicit alignment via predicted 5x5 kernels
  |
  v
[Temporal mean]    -- Collapse T dimension: mean over frames
  |
  v
[Reconstruction]   -- Pixel shuffle 3x + global residual from bicubic(mean(LR))

Source: piunet/models/piunet.py lines 102-173.

TEFA: Temporal Enhancement with Feature Attention

TEFA is the core building block, repeated 16 times. Each block performs spatial feature extraction, temporal information exchange via self-attention, and channel-wise recalibration.

Structure of One TEFA Block

Input h: (B, N_feat, T, H, W)
  |
  +---> [Conv3d k=[1,3,3]] -> [BN3d] -> [LeakyReLU]
  |       Spatial feature extraction (no temporal mixing)
  |
  +---> [MultiheadAttention]  (self-attention across T dimension)
  |       Reshapes to (T, B*H*W, N_feat) for MHA
  |       Residual connection: h = h + MHA(h, h, h)
  |
  +---> [BN3d] -> [LeakyReLU]
  |
  +---> [Conv3d k=[1,3,3]] -> [BN3d] -> [LeakyReLU]
  |
  +---> [MultiheadAttention]  (second temporal attention layer)
  |       Residual connection
  |
  +---> [BN3d]
  |
  |  (Squeeze-Excitation gating)
  +---> Global average pool over (T, H, W) -> (B, N_feat)
  +---> Linear(N_feat -> N_feat/R_bneck) -> LeakyReLU
  +---> Linear(N_feat/R_bneck -> N_feat) -> Sigmoid
  +---> Broadcast and multiply: h = h * gate
  |
  +---> Residual: output = gated_h + input

Source: piunet/models/piunet.py lines 16-63.

Attention Mechanics

The temporal attention reshapes (B, F, T, H, W) to (T, B*H*W, F) using the to_mha() helper. This means: - Sequence length = T (9 frames) - Batch dimension = B * H * W (all spatial positions treated independently) - Feature dimension = N_feat (42)

Each spatial position independently attends across all T frames. This is how PIUnet achieves permutation invariance -- the attention mechanism discovers which frames are informative for each pixel position, regardless of input order.

With N_heads = 1, the full 42-dimensional feature vector participates in a single attention head. The bottleneck ratio R_bneck = 8 means the SE gating compresses to 42/8 = 5 channels before expanding back.

Complexity Note

For T=9 frames, the attention matrix is only 9x9 per spatial position, which is trivially small. The real memory cost comes from the fact that there are B * H * W independent attention computations (e.g., 8 * 64 * 64 = 32,768 at training resolution). Total tokens processed by MHA: 9 * 32,768 = 294,912.

The linear attention analysis (piunet/LINEAR_ATTENTION_ANALYSIS.md) concluded that linear attention is not worthwhile here because: (a) T=9 is too short for O(T^2) to matter, (b) the bottleneck is Conv3D activations, not attention matrices, and (c) linear attention degrades quality on short sequences.

TERN: Temporal Enhancement with Registration Network

TERN is PIUnet's implicit alignment module. It predicts a single 5x5 convolution kernel per frame and applies it to align feature maps before temporal fusion.

How TERN Works

Input h: (B, N_feat, T, H, W)
  |
  +---> [Conv3d k=[1,3,3]] -> [BN3d] -> [LeakyReLU]
  +---> [MultiheadAttention] (temporal, with residual)
  +---> [BN3d] -> [LeakyReLU]
  |
  +---> Global average pool over (H, W) -> (B, N_feat, T)
  |       Collapses spatial dims: each frame gets one feature vector
  |
  +---> [Conv1d(N_feat -> 25, k=1)] -> (B, 25, T)
  |       Predicts 5*5 = 25 kernel weights per frame
  |
  +---> Reshape to (B*T, 1, 1, 5, 5) convolution kernels
  |
  +---> Apply via grouped Conv3d:
  |       Input reshaped to (1, B*T, N_feat, H, W)
  |       F.conv3d with groups=B*T, padding=[0,2,2]
  |       Each frame gets its own 5x5 kernel applied uniformly
  |
  +---> Reshape back to (B, N_feat, T, H, W)

Source: piunet/models/piunet.py lines 66-98.

Critical Limitation: Spatially-Invariant Kernels

The key weakness of TERN is that after the global average pool on line 89 (h = torch.mean(h,[3,4])), all spatial information is discarded. The predicted 5x5 kernel is the same for every pixel in the frame. This is equivalent to assuming that every pixel in a given frame has the same sub-pixel shift.

This assumption is reasonable for satellite imagery (PROBA-V), where frames are related by near-rigid translations. It breaks down badly for LWIR aerial imagery, where:

  1. Altitude parallax causes spatially-varying displacement -- tall objects (trees, buildings) shift differently than flat ground at different altitudes.
  2. Lens distortion residuals create radially-varying shifts even after undistortion.
  3. Rolling shutter on the thermal sensor creates non-rigid per-row shifts.

A 5x5 kernel can represent shifts up to +/- 2 pixels, but only a single uniform shift per frame. Real LWIR displacements vary across the frame by several pixels.

Modern approaches address this differently: - RASD+Restormer uses deformable attention where each query can attend to offset positions, enabling per-pixel alignment - QMambaBSR uses state-space models with implicit spatial-frequency alignment - Deformable convolutions (DCNv2/v3) predict per-pixel offsets, which was identified as a potential upgrade in piunet/TODO_NEXT_IMPROVEMENTS.md but never implemented

Reconstruction Variants

Original: Direct Prediction + Global Residual

The original PIUNET (piunet.py) predicts the full HR image via pixel shuffle, then adds a bicubic-upsampled mean of all LR frames as a global residual:

x_mu = F.pixel_shuffle(self.conv_d2s_mu(x), 3)      # Network prediction
x_up = self.up(torch.mean(x_in, dim=1).unsqueeze(1))  # Bicubic(mean(all_LR))
mu_sr = x_mu + x_up                                    # Global residual

Source: piunet/models/piunet.py lines 160-166.

V1 Residual: Reference Frame + Learned Gain

PIUNETResidual (piunet_residual.py) uses a designated reference frame instead of the mean, and adds a learned gain parameter:

residual_centered = residual_raw - residual_raw.mean()    # Zero-mean residual
bicubic_hr = bicubic_upsample(denormalize(x_ref))          # Upsample reference
hr_pred = bicubic_hr + gain * residual_centered             # Scaled residual

Gain is a per-batch scalar predicted via global average pooling + FC layers, constrained to [20, 80] DN via sigmoid. The idea is that the network only needs to predict the high-frequency detail that bicubic misses, which is a much easier learning problem.

Source: piunet/models/piunet_residual.py lines 169-258.

V2 Residual: Statistics-Matched Bicubic

PIUNETResidual V2 (piunet_residual_v2.py) fixes a normalization mismatch where the bicubic baseline had different mean/std than the HR target:

# During training: match bicubic stats to HR target
bicubic_hr = (bicubic_raw - mu_bicubic) / sigma_bicubic * sigma_hr + mu_hr
# Residual scaled as fraction of sigma
hr_pred = bicubic_hr + gain * sigma_for_residual * residual_centered

Gain range changed to [0.05, 0.5] (fraction of HR std, not absolute DN). Also replaced BatchNorm3d with GroupNorm(6 groups) for stability at batch_size=4.

Source: piunet/models/piunet_residual_v2.py lines 171-291.

LR Fusion: No Upsampling

PIUNETLRFusion models (piunet_lr_fusion.py, piunet_lr_fusion_v2.py) do not perform super-resolution at all. They fuse 9 LR frames into a single clean LR image (matching what you'd get from downsampling the HR mosaic). Output stays at LR resolution. This was intended as Stage 1 of a two-stage pipeline, with a separate SR network as Stage 2.

The V2 fusion model added Flash Attention 2 and uncertainty-aware temporal pooling (frames weighted by attention_weights * confidence^2 instead of uniform mean).

Source: piunet/models/piunet_lr_fusion_v2.py lines 202-361.

Permutation Invariance

PIUnet achieves permutation invariance through two mechanisms:

  1. Self-attention over temporal dimension -- attention weights are computed from the content of each frame, not its position. Reordering frames produces the same attention weights.
  2. Temporal mean pooling -- after TERN alignment, the T dimension is collapsed by torch.mean(x, dim=2). Mean is order-invariant.

This design was important for PROBA-V where frame ordering was arbitrary. For LWIR it is arguably a weakness, since the reference frame (temporally closest to mosaic capture, best-registered) should receive higher weight. The V2 fusion model partially addresses this with learned temporal attention weights, but the SR models still use uniform pooling.

Uncertainty Estimation

All PIUnet variants output an uncertainty map alongside the SR prediction. The uncertainty branch takes the same pooled features, applies bicubic upsampling to HR resolution, then predicts a single-channel log-variance map:

x_sigma = self.conv_d2s_sigma(self.up(x))  # Upsample features
x_sigma = self.norm_sigma(x_sigma)
x_sigma = F.leaky_relu(x_sigma)
sigma_sr = self.conv_out_sigma(x_sigma)     # (B, 1, 3H, 3W)

This is trained via the Laplacian NLL loss: sigma + |y - mu| * exp(-sigma), which encourages the network to output higher uncertainty in regions where the prediction error is large.

Source: piunet/models/piunet.py lines 168-172, piunet/training/losses.py lines 57-111.

Registered Loss Functions

A distinctive feature of PIUnet's training is the shift-search loss. Since LR-to-HR registration is imperfect, the loss evaluates 49 candidate pixel shifts (7x7 grid, +/-3 pixels) and takes the minimum:

for i in range(7):
    for j in range(7):
        cropped_labels = y_true[:, i:i+size, j:j+size]
        # Compute per-shift brightness bias correction
        b = mean(labels - predictions)
        corrected = predictions + b
        l1_loss = mean(|labels - corrected|)
        X.append(l1_loss)
min_l1 = torch.min(torch.stack(X), dim=0)

This makes training robust to small alignment errors but has two downsides: 1. Computational cost -- 49x forward evaluation per loss computation 2. Leaky supervision -- the shift that minimizes loss may not be the correct alignment, allowing the model to "cheat" by matching at a wrong offset

Source: piunet/training/losses.py lines 11-54.

Weaknesses Identified from Paper Analysis

Our analysis of modern MFSR literature identified several fundamental limitations of PIUnet's design:

1. Spatially-Invariant Alignment (TERN)

As detailed above, TERN predicts one 5x5 kernel per frame applied uniformly. This cannot handle: - Per-pixel parallax from altitude differences - Spatially-varying lens distortion residuals - Non-rigid deformations (rolling shutter, atmospheric turbulence)

Modern solutions: deformable attention (RASD+Restormer), per-pixel flow estimation, or implicit alignment via cross-attention between frames.

2. No Base-Frame Priority

PIUnet treats all input frames symmetrically (permutation invariance). In our LWIR pipeline, the reference frame is much better-registered to the HR target than supporting frames. The network has no built-in mechanism to prioritize information from this frame.

Modern solutions: explicit reference frame conditioning, cross-attention from reference to supporting frames (RASD+Restormer reference-aware design).

3. End-to-End Gradient Competition

TEFA (feature extraction), TERN (alignment), and the reconstruction head are all trained jointly with a single loss. The alignment module (TERN) competes for gradient signal with the SR module. If alignment is poor, the SR module learns to compensate, which prevents TERN from learning good alignment.

Modern solutions: stage-wise training (align first, then SR), or explicit alignment loss with flow supervision. Our two-stage LR Fusion approach was an attempt to address this but was never completed.

4. Limited Receptive Field

TEFA uses 3x3 spatial convolutions (with [1,3,3] kernels, no temporal extent). After 16 blocks, the effective receptive field is approximately 33x33 pixels at LR resolution. TERN adds a 5x5 kernel. This may be insufficient to capture long-range spatial correlations in thermal imagery.

Modern solutions: RASD+Restormer uses global self-attention across spatial dimensions; QMambaBSR uses Mamba state-space scanning with effectively infinite receptive field along scan paths.

5. Fixed 3x Scale Factor

PIUnet hardcodes 3x upsampling via pixel shuffle (Conv2d outputting 9 channels, then shuffle to 3x3). Adapting to 2x (our actual LWIR use case, since HA/LA altitude ratio is approximately 1.5:1 to 2:1) requires architecture changes, not just retraining.

Comparison to Modern Approaches

Feature PIUnet RASD+Restormer QMambaBSR
Alignment Spatially-invariant 5x5 kernel (TERN) Deformable cross-attention Implicit via Mamba scanning
Attention Standard MHA, T=9 only Window + channel attention State-space model (linear)
Scale Fixed 3x Flexible Flexible
Base-frame priority None (permutation invariant) Reference-aware Configurable
Receptive field ~33px (stacked 3x3 conv) Global (self-attention) Effectively global (SSM)
Uncertainty Yes (Laplacian NLL) No (typically) No
Temporal invariance Yes (by design) No No

PIUnet's uncertainty estimation is a genuine advantage not found in most modern architectures, but the alignment and receptive field limitations likely explain why it could not beat bicubic on our LWIR data.