Infrared small and dim target detection benefits from exploring correlations among targets, neighboring regions, and the background. However, existing methods relying on convolutional neural networks and vision transformers often struggle to capture long-range information correlations within images. To address this limitation, we propose CS-ViG-UNet, a framework that introduces vision graph convolution for infrared small and dim target detection. Our framework leverages a cyclic shift sparse graph attention mechanism to overcome the challenge of reduced expressive power. Additionally, we design the CS-ViG module to construct an effective graph structure using image patches, capturing feature information relevant to target recognition. Experimental results on public datasets Sirst AUG and IRSTD-1K demonstrate a significant improvement with F1 scores increased by 3.15% and 4.1% respectively compared to state-of-the-art methods.
![]() |
|---|
We evaluated the effectiveness of our model in dim and small target detection tasks using two publicly available datasets: SIRST Aug and IRSTD-1K. CSViG-UNet was implemented using PyTorch 1.8. The initial learning rate was set to 0.01, the batch size was set to 8, and the seg head used was the FCN head. In the CSViG module, the parameter N was set to 6, and the graph node distance K was set to 4. The Soft IOU loss function was utilized, and the computations were performed on Nvidia RTX3090.
Quantitative Results
| Mehtod |
DATASET |
||||||||
|---|---|---|---|---|---|---|---|---|---|
| Sirst AUG |
IRSTD-1K |
||||||||
| Precision |
Recall |
IoU |
F1 |
Precision |
Recall |
IoU |
F1 |
||
| Model Drive Method |
TopHat |
0.7136 |
0.1825 |
0.17 |
0.2906 |
0.0719 |
0.2034 |
0.0561 |
0.1062 |
| TLLCM |
0.8091 |
0.076 |
0.0747 |
0.139 |
0.6898 |
0.098 |
0.0938 |
0.1716 |
|
| LIG |
0.8798 |
0.1587 |
0.1553 |
0.2689 |
0.3044 |
0.159 |
0.1166 |
0.2089 |
|
| RLCM |
0.698 |
0.1979 |
0.1823 |
0.3083 |
0.3731 |
0.2608 |
0.1813 |
0.307 |
|
| var_diff |
0.7235 |
0.0834 |
0.0808 |
0.1495 |
0.664 |
0.1182 |
0.1115 |
0.2007 |
|
| MSAAGD |
0.0716 |
0.2136 |
0.0567 |
0.1073 |
0.2061 |
0.2128 |
0.1169 |
0.2094 |
|
| MSLoG |
0.0179 |
0.8366 |
0.0179 |
0.0351 |
0.0021 |
0.7698 |
0.0021 |
0.0041 |
|
| LEF |
0.6523 |
0.1777 |
0.1623 |
0.2793 |
0.6778 |
0.1964 |
0.1796 |
0.3045 |
|
| PSTNN |
0.9441 |
0.1391 |
0.138 |
0.2425 |
0.4495 |
0.204 |
0.1632 |
0.2806 |
|
| ADDGD |
0.8876 |
0.0817 |
0.0809 |
0.1497 |
0.5659 |
0.0671 |
0.0638 |
0.12 |
|
| WSLCM |
0.8466 |
0.0538 |
0.0533 |
0.1012 |
0.6095 |
0.0971 |
0.0915 |
0.1676 |
|
| ADMD |
0.9343 |
0.1176 |
0.1167 |
0.209 |
0.5493 |
0.0936 |
0.0869 |
0.1599 |
|
| Data Drive Method |
ACM |
0.8205 |
0.7607 |
0.6521 |
0.7895 |
0.682 |
0.722 |
0.5402 |
0.7014 |
| AGPCNet |
0.831 |
0.7718 |
0.6671 |
0.8003 |
0.6898 |
0.6433 |
0.499 |
0.6657 |
|
| DNANet |
0.8577 |
0.8029 |
0.7085 |
0.8294 |
0.7085 |
0.7228 |
0.5572 |
0.7156 |
|
| UCF |
0.8772 |
0.7773 |
0.7011 |
0.8243 |
0.1679 |
0.7775 |
0.1602 |
0.2762 |
|
| Swin |
0.686 |
0.8814 |
0.628 |
0.7715 |
0.3375 |
0.3101 |
0.1927 |
0.3232 |
|
| HRNet |
0.8062 |
0.8551 |
0.7093 |
0.8299 |
0.0715 |
0.6137 |
0.0684 |
0.128 |
|
| Ours |
0.8495 |
0.8627 |
0.7483 |
0.8561 |
0.7892 |
0.7054 |
0.5936 |
0.745 |
|
Qualitative Results
class MRConv4d(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
K is the number of superpatches, therefore hops equals res // K.
"""
def __init__(self, in_channels, out_channels, K=2):
super(MRConv4d, self).__init__()
self.nn = nn.Sequential(
nn.Conv2d(in_channels * 2, out_channels, 1),
nn.BatchNorm2d(in_channels * 2),
nn.GELU()
)
self.K = K
def forward(self, x):
B, C, H, W = x.shape
x = torch.roll(x, shifts=(-(self.K // 2), -(self.K // 2)), dims=(2, 3))
x_j = torch.zeros_like(x).to(x.device)
for i in torch.arange(self.K, H, self.K):
x_c = x - torch.roll(x, shifts=(-i, 0), dims=(2, 3))
x_j = torch.max(x_j, x_c)
for i in torch.arange(self.K, W, self.K):
x_r = x - torch.roll(x, shifts=(0, -i), dims=(2, 3))
x_j = torch.max(x_j, x_r)
x = torch.cat([x, x_j], dim=1)
x = torch.roll(x, shifts=(self.K // 2, self.K // 2), dims=(2, 3))
return self.nn(x)
@incollection{CSViG,
title = {CS ViG Uet: Infrared Dim and Small Target Detection based on Cycle
Shift Vision Graph Convolution Network},
author = {Jian Lin},
year = {2023},
}