lixi042 commited on
Commit
7e31006
·
1 Parent(s): bc6ea96
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ weights/ filter=lfs diff=lfs merge=lfs -text
37
+ assets/ filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,8 +1,140 @@
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
  import gradio as gr
4
 
5
+ import torch
6
+ import matplotlib.cm as cm
7
+ import sys
8
 
9
+ sys.path.append("src")
 
10
 
11
+ from src.utils.plotting import make_matching_figure
12
+ from src.edm import EDM
13
+ from src.config.default import get_cfg_defaults
14
+ from src.utils.misc import lower_config
15
+
16
+
17
+ HEADER = """
18
+ <div align="center">
19
+ <p>
20
+ <span style="font-size: 30px; vertical-align: bottom;"> 🎶 EDM: Efficient Deep Feature Matching</span>
21
+ </p>
22
+ <p style="margin-top: -15px;">
23
+ <a href="https://arxiv.org/pdf/2503.05122" target="_blank" style="color: grey;">ArXiv Paper</a>
24
+ &nbsp;
25
+ <a href="https://github.com/chicleee/EDM" target="_blank" style="color: grey;">GitHub Repository</a>
26
+ </p>
27
+
28
+ </div>
29
+ """
30
+
31
+ ABSTRACT = """
32
+ Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications."""
33
+
34
+ def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000):
35
+ config = get_cfg_defaults()
36
+ data_cfg_path = "configs/data/megadepth_test_1500.py"
37
+ main_cfg_path = "configs/edm/outdoor/edm_base.py"
38
+ config.merge_from_file(main_cfg_path)
39
+ config.merge_from_file(data_cfg_path)
40
+
41
+ W, H = 832, 832
42
+ config.EDM.COARSE.MCONF_THR = conf_thres
43
+ config.EDM.COARSE.BORDER_RM = border_rm
44
+ config.EDM.COARSE.TOPK = topk
45
+
46
+ _config = lower_config(config)
47
+ matcher = EDM(config=_config["edm"])
48
+ state_dict = torch.load("weights/edm_outdoor.ckpt")["state_dict"]
49
+ matcher.load_state_dict(state_dict)
50
+ matcher = matcher.eval()
51
+
52
+ # Load example images
53
+ img0_bgr = image_0
54
+ img1_bgr = image_1
55
+
56
+
57
+ h0, w0 = img0_bgr.shape[:2]
58
+ h1, w1 = img1_bgr.shape[:2]
59
+
60
+ h0_scale = h0 / H
61
+ w0_scale = w0 / W
62
+ h1_scale = h1 / H
63
+ w1_scale = w1 / W
64
+
65
+ # For inference
66
+ img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY)
67
+ img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY)
68
+ img0_raw = cv2.resize(img0_raw, (W, H))
69
+ img1_raw = cv2.resize(img1_raw, (W, H))
70
+
71
+ img0 = torch.from_numpy(img0_raw)[None][None] / 255.
72
+ img1 = torch.from_numpy(img1_raw)[None][None]/ 255.
73
+ batch = {'image0': img0, 'image1': img1}
74
+
75
+ # Inference with EDM and get prediction
76
+ with torch.no_grad():
77
+ matcher(batch)
78
+
79
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
80
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
81
+ mconf = batch['mconf'].cpu().numpy()
82
+
83
+ mkpts0[:, 0] *= w0_scale
84
+ mkpts0[:, 1] *= h0_scale
85
+ mkpts1[:, 0] *= w1_scale
86
+ mkpts1[:, 1] *= h1_scale
87
+
88
+ color = cm.jet(mconf)
89
+ # Draw
90
+ text = [
91
+ 'EDM',
92
+ 'Matches: {}'.format(len(mkpts0)),
93
+ ]
94
+ fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text)
95
+
96
+ return fig
97
+
98
+
99
+ with gr.Blocks() as demo:
100
+
101
+ gr.Markdown(HEADER)
102
+ with gr.Accordion("Abstract (click to open)", open=False):
103
+ gr.Image("assets/teaser.jpg")
104
+ gr.Markdown(ABSTRACT)
105
+
106
+ with gr.Row():
107
+ image_1 = gr.Image()
108
+ image_2 = gr.Image()
109
+ with gr.Row():
110
+ conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold")
111
+ topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK")
112
+ border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)")
113
+ gr.HTML(
114
+ """
115
+ Note: images are actually resized to 832 x 832 for matching.
116
+ """
117
+ )
118
+ with gr.Row():
119
+ button = gr.Button(value="Find Matches")
120
+ clear = gr.ClearButton(value="Clear")
121
+
122
+ output = gr.Image()
123
+ button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output)
124
+ clear.add([image_1, image_2, output])
125
+
126
+ gr.Examples(
127
+ examples=[
128
+ ["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"],
129
+ ["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"],
130
+ ["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"],
131
+ ["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"],
132
+ ],
133
+ inputs=[image_1, image_2],
134
+ outputs=[output],
135
+ fn=find_matches,
136
+ cache_examples=None,
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch(debug=True)
configs/data/__init__.py ADDED
File without changes
configs/data/base.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The data config will be the last one merged into the main config.
3
+ Setups in data configs will override all existed setups!
4
+ """
5
+
6
+ from yacs.config import CfgNode as CN
7
+ _CN = CN()
8
+ _CN.DATASET = CN()
9
+ _CN.TRAINER = CN()
10
+ _CN.EDM = CN()
11
+ _CN.EDM.NECK = CN()
12
+
13
+ # training data config
14
+ _CN.DATASET.TRAIN_DATA_ROOT = None
15
+ _CN.DATASET.TRAIN_POSE_ROOT = None
16
+ _CN.DATASET.TRAIN_NPZ_ROOT = None
17
+ _CN.DATASET.TRAIN_LIST_PATH = None
18
+ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
19
+ # validation set config
20
+ _CN.DATASET.VAL_DATA_ROOT = None
21
+ _CN.DATASET.VAL_POSE_ROOT = None
22
+ _CN.DATASET.VAL_NPZ_ROOT = None
23
+ _CN.DATASET.VAL_LIST_PATH = None
24
+ _CN.DATASET.VAL_INTRINSIC_PATH = None
25
+
26
+ # testing data config
27
+ _CN.DATASET.TEST_DATA_ROOT = None
28
+ _CN.DATASET.TEST_POSE_ROOT = None
29
+ _CN.DATASET.TEST_NPZ_ROOT = None
30
+ _CN.DATASET.TEST_LIST_PATH = None
31
+ _CN.DATASET.TEST_INTRINSIC_PATH = None
32
+
33
+ # dataset config
34
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
35
+ _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
36
+
37
+ cfg = _CN
configs/data/megadepth_test_1500.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs.data.base import cfg
2
+
3
+ TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info"
4
+
5
+ cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
6
+ cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
7
+ cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
8
+ cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
9
+
10
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
11
+
12
+ cfg.EDM.TRAIN_RES_H = 832
13
+ cfg.EDM.TRAIN_RES_W = 832
14
+ cfg.EDM.TEST_RES_H = 832
15
+ cfg.EDM.TEST_RES_W = 832
16
+
17
+ cfg.EDM.NECK.NPE = [
18
+ cfg.EDM.TRAIN_RES_H,
19
+ cfg.EDM.TRAIN_RES_W,
20
+ cfg.EDM.TEST_RES_H,
21
+ cfg.EDM.TEST_RES_W,
22
+ ]
23
+
configs/edm/outdoor/edm_base.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config.default import _CN as cfg
2
+
3
+ cfg.TRAINER.CANONICAL_BS = 4 * 8
4
+ cfg.TRAINER.CANONICAL_LR = 2e-3
5
+ cfg.TRAINER.WARMUP_STEP = int(36800 / cfg.TRAINER.CANONICAL_BS * 3) # 3 epochs
6
+ cfg.TRAINER.WARMUP_RATIO = 0.1
7
+ cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
8
+ cfg.TRAINER.EPI_ERR_THR = 1e-4
9
+
10
+ cfg.EDM.COARSE.MCONF_THR = 0.05
11
+ cfg.EDM.FINE.SIGMA_THR = 1e-6
12
+ cfg.EDM.COARSE.BORDER_RM = 0
13
+
14
+ # Top-K should not exceed grid_size = TEST_RES_H / 8 * TEST_RES_W / 8
15
+ # The recommended value is approximately grid_size * 0.35 for Megadepth
16
+ # cfg.EDM.COARSE.TOPK = int(832 / 8 * 832 / 8 * 0.35) # 3786 for train & LO-RANSAC test
17
+ cfg.EDM.COARSE.TOPK = int(1152 / 8 * 1152 / 8 * 0.35) # 7258 for test
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch==2.0.0
2
+ torchvision==0.15.0
3
+ torchaudio==2.0.0
4
+ opencv-python
5
+ albumentations==0.5.1 --no-binary=imgaug,albumentations
6
+ ray>=1.0.1
7
+ einops==0.3.0
8
+ kornia==0.4.1
9
+ loguru==0.5.3
10
+ yacs>=0.1.8
11
+ tqdm
12
+ autopep8
13
+ pylint
14
+ ipython
15
+ jupyterlab
16
+ matplotlib
17
+ h5py==3.1.0
18
+ pytorch-lightning==1.3.5
19
+ torchmetrics==0.7.0
20
+ joblib>=1.0.1
21
+ pillow==9.5.0
22
+ poselib
src/__init__.py ADDED
File without changes
src/config/default.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _CN = CN()
4
+
5
+ ############## ↓ EDM Pipeline ↓ ##############
6
+ _CN.EDM = CN()
7
+ _CN.EDM.TRAIN_RES_H = 832
8
+ _CN.EDM.TRAIN_RES_W = 832
9
+ _CN.EDM.TEST_RES_H = 832
10
+ _CN.EDM.TEST_RES_W = 832
11
+ _CN.EDM.LOCAL_RESOLUTION = 8 # coarse matching at 1/8, window_size 8x8 in fine_level
12
+ _CN.EDM.MP = False # Not use mixed precision (ELoFTR defualt used mixed precision)
13
+ _CN.EDM.HALF = False # Not FP16
14
+ _CN.EDM.DEPLOY = False # export onnx model
15
+ _CN.EDM.EVAL_TIMES = 1
16
+
17
+ # 3. Coarse-Matching config
18
+ _CN.EDM.BACKBONE = CN()
19
+ _CN.EDM.BACKBONE.BLOCK_DIMS = [32, 64, 128, 256, 256] # 1/2 -> 1/32
20
+
21
+ # 3. Coarse-Matching config
22
+ _CN.EDM.NECK = CN()
23
+ _CN.EDM.NECK.D_MODEL = 256
24
+ _CN.EDM.NECK.NHEAD = 8
25
+ _CN.EDM.NECK.LAYER_NAMES = ["self", "cross"] * 2
26
+ _CN.EDM.NECK.AGG_SIZE0 = 1
27
+ _CN.EDM.NECK.AGG_SIZE1 = 1
28
+ _CN.EDM.NECK.ROPE = True
29
+ _CN.EDM.NECK.NPE = None
30
+
31
+ # 3. Coarse-Matching config
32
+ _CN.EDM.COARSE = CN()
33
+ _CN.EDM.COARSE.MCONF_THR = 0.2
34
+ _CN.EDM.COARSE.BORDER_RM = 0
35
+ _CN.EDM.COARSE.DSMAX_TEMPERATURE = 0.1
36
+ _CN.EDM.COARSE.TRAIN_PAD_NUM = 32 # training tricks: avoid DDP deadlock
37
+ _CN.EDM.COARSE.TOPK = 2048
38
+ _CN.EDM.COARSE.DS_OPT = True
39
+
40
+ # 4. EDM-fine module config
41
+ _CN.EDM.FINE = CN()
42
+ _CN.EDM.FINE.DROPRATE = None
43
+ _CN.EDM.FINE.COORD_LENGTH = 16
44
+ _CN.EDM.FINE.BI_DIRECTIONAL_REFINE = True
45
+ _CN.EDM.FINE.SIGMA_THR = 0.0
46
+ _CN.EDM.FINE.SIGMA_SELECTION = True
47
+
48
+ # 5. EDM Losses
49
+ # -- # coarse-level
50
+ _CN.EDM.LOSS = CN()
51
+ _CN.EDM.LOSS.COARSE_TYPE = "focal"
52
+ _CN.EDM.LOSS.COARSE_WEIGHT = 1.0
53
+ _CN.EDM.LOSS.SPARSE_SPVS = True
54
+
55
+ # -- - -- # focal loss (coarse)
56
+ _CN.EDM.LOSS.FOCAL_ALPHA = 0.25
57
+ _CN.EDM.LOSS.FOCAL_GAMMA = 2.0
58
+ _CN.EDM.LOSS.POS_WEIGHT = 1.0
59
+ _CN.EDM.LOSS.NEG_WEIGHT = 1.0
60
+
61
+
62
+ # -- # fine-level
63
+ _CN.EDM.LOSS.FINE_TYPE = "rle"
64
+ _CN.EDM.LOSS.FINE_WEIGHT = 0.2
65
+ _CN.EDM.LOSS.Q_DISTRIBUTION = "laplace" # options: ['laplace', 'gaussian']
66
+
67
+
68
+ ############## Dataset ##############
69
+ _CN.DATASET = CN()
70
+ # 1. data config
71
+ # training and validating
72
+ _CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
73
+ _CN.DATASET.TRAIN_DATA_ROOT = None
74
+ _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
75
+ _CN.DATASET.TRAIN_NPZ_ROOT = None
76
+ _CN.DATASET.TRAIN_LIST_PATH = None
77
+ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
78
+ _CN.DATASET.VAL_DATA_ROOT = None
79
+ _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
80
+ _CN.DATASET.VAL_NPZ_ROOT = None
81
+ _CN.DATASET.VAL_LIST_PATH = (
82
+ None # None if val data from all scenes are bundled into a single npz file
83
+ )
84
+ _CN.DATASET.VAL_INTRINSIC_PATH = None
85
+ # testing
86
+ _CN.DATASET.TEST_DATA_SOURCE = None
87
+ _CN.DATASET.TEST_DATA_ROOT = None
88
+ _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
89
+ _CN.DATASET.TEST_NPZ_ROOT = None
90
+ _CN.DATASET.TEST_LIST_PATH = (
91
+ None # None if test data from all scenes are bundled into a single npz file
92
+ )
93
+ _CN.DATASET.TEST_INTRINSIC_PATH = None
94
+
95
+ # 2. dataset config
96
+ # general options
97
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
98
+ 0.4 # discard data with overlap_score < min_overlap_score
99
+ )
100
+ _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
101
+ _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
102
+
103
+ # MegaDepth options
104
+ _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
105
+ _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
106
+ _CN.DATASET.MGDPT_DF = 8
107
+
108
+
109
+ ############## Trainer ##############
110
+ _CN.TRAINER = CN()
111
+ _CN.TRAINER.WORLD_SIZE = 1
112
+ _CN.TRAINER.CANONICAL_BS = 4 * 8
113
+ _CN.TRAINER.CANONICAL_LR = 2e-3
114
+ _CN.TRAINER.SCALING = None # this will be calculated automatically
115
+ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
116
+
117
+ # optimizer
118
+ _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
119
+ _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
120
+ _CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam
121
+ _CN.TRAINER.ADAMW_DECAY = 0.1
122
+
123
+ # step-based warm-up
124
+ _CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant]
125
+ _CN.TRAINER.WARMUP_RATIO = 0.0
126
+ _CN.TRAINER.WARMUP_STEP = 4800
127
+
128
+ # learning rate scheduler
129
+ # [MultiStepLR, CosineAnnealing, ExponentialLR]
130
+ _CN.TRAINER.SCHEDULER = "MultiStepLR"
131
+ _CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step]
132
+ _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
133
+ _CN.TRAINER.MSLR_GAMMA = 0.5
134
+ _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
135
+ # ELR: ExponentialLR, this value for 'step' interval
136
+ _CN.TRAINER.ELR_GAMMA = 0.999992
137
+
138
+ # plotting related
139
+ _CN.TRAINER.ENABLE_PLOTTING = False
140
+ _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
141
+ _CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence']
142
+ _CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
143
+
144
+ # geometric metrics and pose solver
145
+ # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
146
+ _CN.TRAINER.EPI_ERR_THR = 5e-4
147
+ _CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H']
148
+ _CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, LO-RANSAC]
149
+ _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
150
+ _CN.TRAINER.RANSAC_CONF = 0.99999
151
+
152
+ # data sampler for train_dataloader
153
+ _CN.TRAINER.DATA_SAMPLER = (
154
+ "scene_balance" # options: ['scene_balance', 'random', 'normal']
155
+ )
156
+ # 'scene_balance' config
157
+ _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
158
+ # whether sample each scene with replacement or not
159
+ _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True
160
+ # after sampling from scenes, whether shuffle within the epoch or not
161
+ _CN.TRAINER.SB_SUBSET_SHUFFLE = True
162
+ _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
163
+ # 'random' config
164
+ _CN.TRAINER.RDM_REPLACEMENT = True
165
+ _CN.TRAINER.RDM_NUM_SAMPLES = None
166
+
167
+ # gradient clipping
168
+ _CN.TRAINER.GRADIENT_CLIPPING = 0.5
169
+
170
+ # reproducibility
171
+ # This seed affects the data sampling. With the same seed, the data sampling is promised
172
+ # to be the same. When resume training from a checkpoint, it's better to use a different
173
+ # seed, otherwise the sampled data will be exactly the same as before resuming, which will
174
+ # cause less unique data items sampled during the entire training.
175
+ # Use of different seed values might affect the final training result, since not all data items
176
+ # are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
177
+ _CN.TRAINER.SEED = 66
178
+
179
+
180
+ def get_cfg_defaults():
181
+ """Get a yacs CfgNode object with default values for my_project."""
182
+ # Return a clone so that the defaults will not be altered
183
+ # This is for the "local variable" use pattern
184
+ return _CN.clone()
src/edm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .edm import EDM
2
+
src/edm/backbone/resnet.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def conv1x1(in_planes, out_planes, stride=1, bias=False):
5
+ """1x1 convolution without padding"""
6
+ return nn.Conv2d(
7
+ in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias
8
+ )
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, bias=False):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(
14
+ in_planes,
15
+ out_planes,
16
+ kernel_size=3,
17
+ stride=stride,
18
+ padding=1,
19
+ groups=groups,
20
+ bias=bias,
21
+ )
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ def __init__(self, in_planes, planes, stride=1):
26
+ super().__init__()
27
+ self.conv1 = conv3x3(in_planes, planes, stride)
28
+ self.conv2 = conv3x3(planes, planes)
29
+ self.bn1 = nn.BatchNorm2d(planes)
30
+ self.bn2 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+
33
+ if stride == 1:
34
+ self.downsample = None
35
+ else:
36
+ self.downsample = nn.Sequential(
37
+ conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes)
38
+ )
39
+
40
+ def forward(self, x):
41
+ y = x
42
+ y = self.relu(self.bn1(self.conv1(y)))
43
+ y = self.bn2(self.conv2(y))
44
+
45
+ if self.downsample is not None:
46
+ x = self.downsample(x)
47
+
48
+ return self.relu(x + y)
49
+
50
+
51
+ class ResNet18(nn.Module):
52
+ """
53
+ Fewer channels
54
+ """
55
+
56
+ def __init__(self, config=None):
57
+ super().__init__()
58
+ # Config
59
+ block_dims = config["backbone"]["block_dims"]
60
+
61
+ # Networks
62
+ self.conv1 = nn.Conv2d(
63
+ 1, block_dims[0], kernel_size=7, stride=2, padding=3, bias=False
64
+ )
65
+ self.bn1 = nn.BatchNorm2d(block_dims[0])
66
+ self.relu = nn.ReLU(inplace=True)
67
+ self.layer1 = self._make_layer(
68
+ BasicBlock, block_dims[0], block_dims[0], stride=1
69
+ ) # 1/2
70
+ self.layer2 = self._make_layer(
71
+ BasicBlock, block_dims[0], block_dims[1], stride=2
72
+ ) # 1/4
73
+ self.layer3 = self._make_layer(
74
+ BasicBlock, block_dims[1], block_dims[2], stride=2
75
+ ) # 1/8
76
+ self.layer4 = self._make_layer(
77
+ BasicBlock, block_dims[2], block_dims[3], stride=2
78
+ ) # 1/16
79
+ self.layer5 = self._make_layer(
80
+ BasicBlock, block_dims[3], block_dims[4], stride=2
81
+ ) # 1/32
82
+
83
+ # For fine matching
84
+ self.fine_conv = nn.Sequential(
85
+ self._make_layer(
86
+ BasicBlock, block_dims[2], block_dims[2], stride=1),
87
+ conv1x1(block_dims[2], block_dims[4]),
88
+ nn.BatchNorm2d(block_dims[4]),
89
+ )
90
+
91
+ for m in self.modules():
92
+ if isinstance(m, nn.Conv2d):
93
+ nn.init.kaiming_normal_(
94
+ m.weight, mode="fan_out", nonlinearity="relu")
95
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
96
+ nn.init.constant_(m.weight, 1)
97
+ nn.init.constant_(m.bias, 0)
98
+
99
+ def _make_layer(self, block, in_dim, out_dim, stride=1):
100
+ layer1 = block(in_dim, out_dim, stride=stride)
101
+ layer2 = block(out_dim, out_dim, stride=1)
102
+ layers = (layer1, layer2)
103
+
104
+ return nn.Sequential(*layers)
105
+
106
+ def forward(self, x):
107
+ x0 = self.relu(self.bn1(self.conv1(x)))
108
+ x1 = self.layer1(x0) # 1/2
109
+ x2 = self.layer2(x1) # 1/4
110
+ x3 = self.layer3(x2) # 1/8
111
+ x4 = self.layer4(x3) # 1/16
112
+ x5 = self.layer5(x4) # 1/32
113
+
114
+ xf = self.fine_conv(x3) # 1/8
115
+
116
+ return [x3, x4, x5, xf]
src/edm/edm.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils.misc import detect_NaN
2
+ from .head.fine_matching import FineMatching
3
+ from .head.coarse_matching import CoarseMatching
4
+ from .neck.neck import CIM
5
+ from .backbone.resnet import ResNet18
6
+ from einops.einops import rearrange
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ import torch
10
+ torch.set_float32_matmul_precision("highest") # highest (defualt) high medium
11
+
12
+
13
+ class EDM(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ # Misc
17
+ self.config = config
18
+ self.local_resolution = self.config["local_resolution"]
19
+ self.bi_directional_refine = self.config["fine"]["bi_directional_refine"]
20
+ self.deploy = self.config["deploy"]
21
+ self.topk = config["coarse"]["topk"]
22
+
23
+ # Modules
24
+ self.backbone = ResNet18(config)
25
+ self.neck = CIM(config)
26
+ self.coarse_matching = CoarseMatching(config)
27
+ self.fine_matching = FineMatching(config)
28
+
29
+ def forward(self, data):
30
+ """
31
+ Update:
32
+ data (dict): {
33
+ 'image0': (torch.Tensor): (N, 1, H, W)
34
+ 'image1': (torch.Tensor): (N, 1, H, W)
35
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
36
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
37
+ }
38
+ """
39
+ if self.deploy:
40
+ image0, image1 = data.split(1, 1)
41
+ data = {"image0": image0, "image1": image1}
42
+
43
+ data.update(
44
+ {
45
+ "bs": data["image0"].size(0),
46
+ "hw0_i": data["image0"].shape[2:],
47
+ "hw1_i": data["image1"].shape[2:],
48
+ }
49
+ )
50
+
51
+ # 1. Feature Extraction
52
+ if data["hw0_i"] == data["hw1_i"]:
53
+ # faster & better BN convergence
54
+ feats = self.backbone(
55
+ torch.cat([data["image0"], data["image1"]], dim=0))
56
+ f8, f16, f32, f8_fine = feats
57
+ ms_feats = f8, f16, f32
58
+ feat_f0, feat_f1 = f8_fine.chunk(2)
59
+ else:
60
+ # handle different input shapes
61
+ # raise ValueError("image0 and image1 should have the same shape.")
62
+ feats0, feats1 = self.backbone(data["image0"]), self.backbone(
63
+ data["image1"]
64
+ )
65
+ f8_0, f16_0, f32_0, feat_f0 = feats0
66
+ f8_1, f16_1, f32_1, feat_f1 = feats1
67
+ ms_feats = f8_0, f16_0, f32_0, f8_1, f16_1, f32_1
68
+
69
+ mask_c0 = mask_c1 = None # mask is useful in training
70
+ if "mask0" in data:
71
+ mask_c0, mask_c1 = data["mask0"], data["mask1"]
72
+
73
+ # 2. Feature Interaction & Multi-Scale Fusion
74
+ feat_c0, feat_c1 = self.neck(ms_feats, mask_c0, mask_c1)
75
+
76
+ data.update(
77
+ {
78
+ "hw0_c": feat_c0.shape[2:],
79
+ "hw1_c": feat_c1.shape[2:],
80
+ "hw0_f": feat_c0.shape[2:] * self.config["local_resolution"],
81
+ "hw1_f": feat_c1.shape[2:] * self.config["local_resolution"],
82
+ }
83
+ )
84
+ feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
85
+ feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")
86
+ feat_f0 = rearrange(feat_f0, "n c h w -> n (h w) c")
87
+ feat_f1 = rearrange(feat_f1, "n c h w -> n (h w) c")
88
+
89
+ # detect NaN during mixed precision training
90
+ if self.config["mp"] and (
91
+ torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))
92
+ ):
93
+ detect_NaN(feat_c0, feat_c1)
94
+
95
+ # 3. Coarse-Level Matching
96
+ conf_matrix = self.coarse_matching(
97
+ feat_c0,
98
+ feat_c1,
99
+ data,
100
+ mask_c0=(
101
+ mask_c0.view(mask_c0.size(0), -
102
+ 1) if mask_c0 is not None else mask_c0
103
+ ),
104
+ mask_c1=(
105
+ mask_c1.view(mask_c1.size(0), -
106
+ 1) if mask_c1 is not None else mask_c1
107
+ ),
108
+ )
109
+
110
+ if self.deploy:
111
+ k = self.topk
112
+ row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)
113
+ topk_val, topk_idx = torch.topk(row_max_val, k, dim=1)
114
+
115
+ b_ids = (
116
+ torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
117
+ .unsqueeze(1)
118
+ .repeat(1, k)
119
+ .flatten()
120
+ )
121
+ i_ids = topk_idx.flatten()
122
+ j_ids = row_max_idx[b_ids, i_ids].flatten()
123
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
124
+
125
+ scale = data["hw0_i"][0] / data["hw0_c"][0]
126
+ scale0 = scale * \
127
+ data["scale0"][b_ids] if "scale0" in data else scale
128
+ scale1 = scale * \
129
+ data["scale1"][b_ids] if "scale1" in data else scale
130
+ mkpts0_c = (
131
+ torch.stack(
132
+ [
133
+ i_ids % data["hw0_c"][1],
134
+ torch.div(i_ids, data["hw0_c"][1],
135
+ rounding_mode="floor"),
136
+ ],
137
+ dim=1,
138
+ )
139
+ * scale0
140
+ )
141
+ mkpts1_c = (
142
+ torch.stack(
143
+ [
144
+ j_ids % data["hw1_c"][1],
145
+ torch.div(j_ids, data["hw1_c"][1],
146
+ rounding_mode="floor"),
147
+ ],
148
+ dim=1,
149
+ )
150
+ * scale1
151
+ )
152
+
153
+ data.update(
154
+ {
155
+ "mconf": mconf,
156
+ "mkpts0_c": mkpts0_c,
157
+ "mkpts1_c": mkpts1_c,
158
+ "b_ids": b_ids,
159
+ "i_ids": i_ids,
160
+ "j_ids": j_ids,
161
+ }
162
+ )
163
+
164
+ # 4. Fine-Level Matching
165
+ K0 = data["i_ids"].shape[0] // data["bs"]
166
+ K1 = data["j_ids"].shape[0] // data["bs"]
167
+ feat_f0 = feat_f0[data["b_ids"], data["i_ids"]
168
+ ].reshape(data["bs"], K0, -1)
169
+ feat_f1 = feat_f1[data["b_ids"], data["j_ids"]
170
+ ].reshape(data["bs"], K1, -1)
171
+ feat_c0 = feat_c0[data["b_ids"], data["i_ids"]
172
+ ].reshape(data["bs"], K0, -1)
173
+ feat_c1 = feat_c1[data["b_ids"], data["j_ids"]
174
+ ].reshape(data["bs"], K1, -1)
175
+
176
+ if self.bi_directional_refine:
177
+ # Bidirectional Refinement
178
+ offset, score = self.fine_matching(
179
+ torch.cat([feat_f0, feat_f1], dim=1),
180
+ torch.cat([feat_f1, feat_f0], dim=1),
181
+ torch.cat([feat_c0, feat_c1], dim=1),
182
+ torch.cat([feat_c1, feat_c0], dim=1),
183
+ data,
184
+ )
185
+ else:
186
+ offset, score = self.fine_matching(
187
+ feat_f0, feat_f1, feat_c0, feat_c1, data)
188
+
189
+ if self.deploy:
190
+ if self.bi_directional_refine:
191
+ fine_offset01, fine_offset10 = offset.chunk(2)
192
+ fine_score01, fine_score10 = score.unsqueeze(dim=1).chunk(2)
193
+ output = torch.cat(
194
+ [mkpts0_c, mkpts1_c, fine_offset01, fine_offset10, fine_score01, fine_score10, mconf.unsqueeze(dim=1)], 1) # [K, 11]
195
+ else:
196
+ output = torch.cat(
197
+ [mkpts0_c, mkpts1_c, offset, score, mconf.unsqueeze(dim=1)], 1)
198
+ return output
199
+
200
+ def load_state_dict(self, state_dict, *args, **kwargs):
201
+ for k in list(state_dict.keys()):
202
+ if k.startswith("matcher."):
203
+ state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
204
+ return super().load_state_dict(state_dict, *args, **kwargs)
src/edm/head/coarse_matching.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ INF = 1e9 # -1e4 for fp16matmul
7
+
8
+
9
+ class CoarseMatching(nn.Module):
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ self.window_size = config["local_resolution"]
13
+ self.thr = config["coarse"]["mconf_thr"]
14
+ self.temperature = config["coarse"]["dsmax_temperature"]
15
+ self.ds_opt = config["coarse"]["ds_opt"]
16
+ self.pad_num = config["coarse"]["train_pad_num"]
17
+ self.topk = config["coarse"]["topk"]
18
+ self.deploy = config["deploy"]
19
+
20
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
21
+ """
22
+ Args:
23
+ feat0 (torch.Tensor): [N, L, C]
24
+ feat1 (torch.Tensor): [N, S, C]
25
+ data (dict)
26
+ mask_c0 (torch.Tensor): [N, L] (optional)
27
+ mask_c1 (torch.Tensor): [N, S] (optional)
28
+ Update:
29
+ data (dict): {
30
+ 'b_ids' (torch.Tensor): [M'],
31
+ 'i_ids' (torch.Tensor): [M'],
32
+ 'j_ids' (torch.Tensor): [M'],
33
+ 'm_bids' (torch.Tensor): [M],
34
+ 'mkpts0_c' (torch.Tensor): [M, 2],
35
+ 'mkpts1_c' (torch.Tensor): [M, 2],
36
+ 'mconf' (torch.Tensor): [M]}
37
+ NOTE: M' != M during training.
38
+ """
39
+ # normalize
40
+ feat_c0, feat_c1 = map(
41
+ lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
42
+ )
43
+
44
+ with torch.autocast(enabled=False, device_type="cuda"):
45
+ sim_matrix = (
46
+ torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) /
47
+ self.temperature
48
+ )
49
+ del feat_c0, feat_c1
50
+ if mask_c0 is not None:
51
+ sim_matrix = sim_matrix.float().masked_fill(
52
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
53
+ -INF,
54
+ )
55
+
56
+ if not self.training and self.ds_opt:
57
+ # Alternative implementation of daul-softmax operator for efficient inference
58
+ sim_matrix = torch.exp(sim_matrix)
59
+ conf_matrix = F.normalize(sim_matrix, p=1, dim=1) * F.normalize(
60
+ sim_matrix, p=1, dim=2
61
+ )
62
+ else:
63
+ # Native daul-softmax operator
64
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
65
+
66
+ data.update(
67
+ {
68
+ "conf_matrix": conf_matrix,
69
+ }
70
+ )
71
+
72
+ if not self.deploy:
73
+ # predict coarse matches from conf_matrix
74
+ self.coarse_matching_selection(data)
75
+
76
+ return conf_matrix # Returning the sim_matrix can be faster, but it may reduce the accuracy.
77
+
78
+ # Static tensor shape for mini-batch inference.
79
+ @torch.no_grad()
80
+ def coarse_matching_selection(self, data):
81
+ """
82
+ Args:
83
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
84
+ Returns:
85
+ coarse_matches (dict): {
86
+ 'b_ids' (torch.Tensor): [M'],
87
+ 'i_ids' (torch.Tensor): [M'],
88
+ 'j_ids' (torch.Tensor): [M'],
89
+ 'm_bids' (torch.Tensor): [M],
90
+ 'mkpts0_c' (torch.Tensor): [M, 2],
91
+ 'mkpts1_c' (torch.Tensor): [M, 2],
92
+ 'mconf' (torch.Tensor): [M]}
93
+ """
94
+ conf_matrix = data["conf_matrix"]
95
+
96
+ # mutual nearest
97
+ # mask = (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
98
+ # * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
99
+ # conf_matrix[~mask]=0
100
+
101
+ k = self.topk
102
+ row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)
103
+
104
+ # prevent out of range
105
+ if k == -1 or k > row_max_val.shape[-1]:
106
+ k = row_max_val.shape[-1]
107
+
108
+ topk_val, topk_idx = torch.topk(row_max_val, k)
109
+ b_ids = (
110
+ torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
111
+ .unsqueeze(1)
112
+ .repeat(1, k)
113
+ .flatten()
114
+ )
115
+ i_ids = topk_idx.flatten()
116
+ j_ids = row_max_idx[b_ids, i_ids].flatten()
117
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
118
+
119
+ scale = data["hw0_i"][0] / data["hw0_c"][0]
120
+ scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
121
+ scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
122
+ mkpts0_c = (
123
+ torch.stack(
124
+ [
125
+ i_ids % data["hw0_c"][1],
126
+ torch.div(i_ids, data["hw0_c"][1], rounding_mode="floor"),
127
+ ],
128
+ dim=1,
129
+ )
130
+ * scale0
131
+ )
132
+ mkpts1_c = (
133
+ torch.stack(
134
+ [
135
+ j_ids % data["hw1_c"][1],
136
+ torch.div(j_ids, data["hw1_c"][1], rounding_mode="floor"),
137
+ ],
138
+ dim=1,
139
+ )
140
+ * scale1
141
+ )
142
+
143
+ data.update(
144
+ {
145
+ "mconf": mconf,
146
+ "mkpts0_c": mkpts0_c,
147
+ "mkpts1_c": mkpts1_c,
148
+ "b_ids": b_ids,
149
+ "i_ids": i_ids,
150
+ "j_ids": j_ids,
151
+ }
152
+ )
src/edm/head/fine_matching.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import distributions
5
+
6
+
7
+ class Conv1d_BN_Act(nn.Sequential):
8
+ def __init__(
9
+ self,
10
+ a,
11
+ b,
12
+ ks=1,
13
+ stride=1,
14
+ pad=0,
15
+ dilation=1,
16
+ groups=1,
17
+ bn_weight_init=1,
18
+ act=None,
19
+ drop=None,
20
+ ):
21
+ super().__init__()
22
+ self.inp_channel = a
23
+ self.out_channel = b
24
+ self.ks = ks
25
+ self.pad = pad
26
+ self.stride = stride
27
+ self.dilation = dilation
28
+ self.groups = groups
29
+
30
+ self.add_module(
31
+ "c", nn.Conv1d(a, b, ks, stride, pad, dilation, groups, bias=False)
32
+ )
33
+ bn = nn.BatchNorm1d(b)
34
+ nn.init.constant_(bn.weight, bn_weight_init)
35
+ nn.init.constant_(bn.bias, 0)
36
+ self.add_module("bn", bn)
37
+ if act != None:
38
+ self.add_module("a", act)
39
+ if drop != None:
40
+ self.add_module("d", nn.Dropout(drop))
41
+
42
+
43
+ class RealNVP(nn.Module):
44
+ """RealNVP: a flow-based generative model
45
+
46
+ `Density estimation using Real NVP
47
+ arXiv: <https://arxiv.org/abs/1605.08803>`_.
48
+
49
+ Code is modified from `the mmpose implementation of RLE
50
+ <https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/utils/realnvp.py>`_.
51
+ """
52
+
53
+ @staticmethod
54
+ def get_scale_net(channel):
55
+ """Get the scale model in a single invertable mapping."""
56
+ return nn.Sequential(
57
+ nn.Linear(2, channel),
58
+ nn.GELU(),
59
+ nn.Linear(channel, channel),
60
+ nn.GELU(),
61
+ nn.Linear(channel, 2),
62
+ nn.Tanh(),
63
+ )
64
+
65
+ @staticmethod
66
+ def get_trans_net(channel):
67
+ """Get the translation model in a single invertable mapping."""
68
+ return nn.Sequential(
69
+ nn.Linear(2, channel),
70
+ nn.GELU(),
71
+ nn.Linear(channel, channel),
72
+ nn.GELU(),
73
+ nn.Linear(channel, 2),
74
+ )
75
+
76
+ @property
77
+ def prior(self):
78
+ """The prior distribution."""
79
+ return distributions.MultivariateNormal(self.loc, self.cov)
80
+
81
+ def __init__(self, channel=64):
82
+ super(RealNVP, self).__init__()
83
+ self.channel = channel
84
+ self.register_buffer("loc", torch.zeros(2))
85
+ self.register_buffer("cov", torch.eye(2))
86
+ self.register_buffer(
87
+ "mask", torch.tensor([[0, 1], [1, 0]] * 3, dtype=torch.float32)
88
+ )
89
+
90
+ self.s = torch.nn.ModuleList(
91
+ [self.get_scale_net(self.channel) for _ in range(len(self.mask))]
92
+ )
93
+ self.t = torch.nn.ModuleList(
94
+ [self.get_trans_net(self.channel) for _ in range(len(self.mask))]
95
+ )
96
+ self.init_weights()
97
+
98
+ def init_weights(self):
99
+ """Initialization model weights."""
100
+ for m in self.modules():
101
+ if isinstance(m, nn.Linear):
102
+ nn.init.xavier_uniform_(m.weight, gain=0.01)
103
+
104
+ def backward_p(self, x):
105
+ """Apply mapping form the data space to the latent space and calculate
106
+ the log determinant of the Jacobian matrix."""
107
+
108
+ log_det_jacob, z = x.new_zeros(x.shape[0]), x
109
+ for i in reversed(range(len(self.t))):
110
+ z_ = self.mask[i] * z
111
+ s = self.s[i](z_) * (1 - self.mask[i]) # torch.exp(s): betas
112
+ t = self.t[i](z_) * (1 - self.mask[i]) # gammas
113
+ z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
114
+ log_det_jacob -= s.sum(dim=1)
115
+ return z, log_det_jacob
116
+
117
+ def log_prob(self, x):
118
+ """Calculate the log probability of given sample in data space."""
119
+
120
+ z, log_det = self.backward_p(x)
121
+ return self.prior.log_prob(z) + log_det
122
+
123
+
124
+ def soft_argmax(x, temperature=1.0):
125
+ L = x.shape[1]
126
+ assert L % 2 # L is odd to ensure symmetry
127
+ idx = torch.arange(0, L, 1, device=x.device).repeat(x.shape[0], 1)
128
+ scale_x = x / temperature
129
+ out = F.softmax(scale_x, dim=1) * idx
130
+ out = out.sum(dim=1, keepdim=True)
131
+
132
+ return out
133
+
134
+
135
+ class FineMatching(nn.Module):
136
+ def __init__(self, config, act_layer=nn.GELU):
137
+ super(FineMatching, self).__init__()
138
+ self.config = config
139
+
140
+ self.block_dims = self.config["backbone"]["block_dims"]
141
+ self.local_resolution = self.config["local_resolution"]
142
+ self.drop = self.config["fine"]["droprate"]
143
+ self.coord_length = self.config["fine"]["coord_length"]
144
+ self.bi_directional_refine = self.config["fine"]["bi_directional_refine"]
145
+ self.sigma_selection = self.config["fine"]["sigma_selection"]
146
+ self.mconf_thr = self.config["coarse"]["mconf_thr"]
147
+ self.sigma_thr = self.config["fine"]["sigma_thr"]
148
+ self.border_rm = self.config["coarse"]["border_rm"] * \
149
+ self.local_resolution
150
+ self.deploy = self.config["deploy"]
151
+
152
+ # network
153
+ self.query_encoder = nn.Sequential(
154
+ Conv1d_BN_Act(
155
+ self.block_dims[-1],
156
+ self.block_dims[-1],
157
+ act=act_layer(),
158
+ drop=self.drop,
159
+ ),
160
+ Conv1d_BN_Act(
161
+ self.block_dims[-1],
162
+ self.block_dims[-1],
163
+ act=act_layer(),
164
+ drop=self.drop,
165
+ ),
166
+ )
167
+
168
+ self.reference_encoder = nn.Sequential(
169
+ Conv1d_BN_Act(
170
+ self.block_dims[-1],
171
+ self.block_dims[-1],
172
+ act=act_layer(),
173
+ drop=self.drop,
174
+ ),
175
+ Conv1d_BN_Act(
176
+ self.block_dims[-1],
177
+ self.block_dims[-1],
178
+ act=act_layer(),
179
+ drop=self.drop,
180
+ ),
181
+ )
182
+
183
+ self.merge_qr = nn.Sequential(
184
+ Conv1d_BN_Act(
185
+ self.block_dims[-1] * 2,
186
+ self.block_dims[-1] * 2,
187
+ act=act_layer(),
188
+ drop=self.drop,
189
+ ),
190
+ Conv1d_BN_Act(
191
+ self.block_dims[-1] * 2,
192
+ self.block_dims[-1] * 2,
193
+ act=act_layer(),
194
+ drop=self.drop,
195
+ ),
196
+ )
197
+
198
+ self.x_head = nn.Sequential(
199
+ Conv1d_BN_Act(
200
+ self.block_dims[-1] * 2,
201
+ self.block_dims[-1] * 2,
202
+ act=act_layer(),
203
+ drop=self.drop,
204
+ ),
205
+ nn.Conv1d(self.block_dims[-1] * 2,
206
+ self.coord_length + 2, kernel_size=1),
207
+ )
208
+
209
+ self.y_head = nn.Sequential(
210
+ Conv1d_BN_Act(
211
+ self.block_dims[-1] * 2,
212
+ self.block_dims[-1] * 2,
213
+ act=act_layer(),
214
+ drop=self.drop,
215
+ ),
216
+ nn.Conv1d(self.block_dims[-1] * 2,
217
+ self.coord_length + 2, kernel_size=1),
218
+ )
219
+
220
+ self.flow = RealNVP()
221
+ self.init_params()
222
+
223
+ def init_params(self):
224
+ for m in self.modules():
225
+ if isinstance(m, nn.Conv1d):
226
+ nn.init.kaiming_normal_(m.weight)
227
+ if m.bias is not None:
228
+ nn.init.constant_(m.bias, 0)
229
+
230
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data={}):
231
+ q = self.query_encoder(
232
+ feat_f0.permute(0, 2, 1).contiguous()
233
+ + feat_c0.permute(0, 2, 1).contiguous()
234
+ )
235
+ r = self.reference_encoder(
236
+ feat_f1.permute(0, 2, 1).contiguous()
237
+ + feat_c1.permute(0, 2, 1).contiguous()
238
+ )
239
+ out = self.merge_qr(torch.cat([q, r], dim=1))
240
+
241
+ x = self.x_head(out).permute(0, 2, 1).contiguous()
242
+ y = self.y_head(out).permute(0, 2, 1).contiguous()
243
+
244
+ if self.bi_directional_refine:
245
+ x01, x10 = x.chunk(2, dim=1)
246
+ x01 = x01.reshape(-1, self.coord_length + 2)
247
+ x10 = x10.reshape(-1, self.coord_length + 2)
248
+ x_out = torch.cat([x01, x10])
249
+
250
+ y01, y10 = y.chunk(2, dim=1)
251
+ y01 = y01.reshape(-1, self.coord_length + 2)
252
+ y10 = y10.reshape(-1, self.coord_length + 2)
253
+ y_out = torch.cat([y01, y10])
254
+ else:
255
+ x_out = x.reshape(-1, self.coord_length + 2)
256
+ y_out = y.reshape(-1, self.coord_length + 2)
257
+
258
+ x_cls = x_out[:, : self.coord_length + 1]
259
+ coord_x = soft_argmax(x_cls) / self.coord_length - \
260
+ 0.5 # range [-0.5, +0.5]
261
+ x_sigma = x_out[:, -1:].sigmoid()
262
+
263
+ y_cls = y_out[:, : self.coord_length + 1]
264
+ coord_y = soft_argmax(y_cls) / self.coord_length - 0.5
265
+ y_sigma = y_out[:, -1:].sigmoid()
266
+
267
+ coord = torch.cat([coord_x, coord_y], dim=1)
268
+ sigma = torch.cat([x_sigma, y_sigma], dim=1)
269
+
270
+ if data.get("target_uv", None) is not None:
271
+ gt_uv = data["target_uv"]
272
+ mask = data["target_uv_weight"].clone()
273
+
274
+ if mask.sum() == 0:
275
+ mask[0] = True
276
+ mask_coord = coord[mask]
277
+ mask_gt_uv = gt_uv[mask]
278
+ mask_sigma = sigma[mask]
279
+
280
+ mask_sigma = torch.clamp(mask_sigma, 1e-6, 1 - 1e-6)
281
+ bar_mu = (mask_coord - mask_gt_uv) / mask_sigma
282
+
283
+ log_phi = self.flow.log_prob(bar_mu).unsqueeze(-1)
284
+ nf_loss = torch.log(mask_sigma) - log_phi
285
+
286
+ data.update(
287
+ {
288
+ "pred_coord": coord,
289
+ "pred_score": 1.0 - torch.mean(sigma, dim=-1).flatten(),
290
+ "mask_coord": mask_coord,
291
+ "mask_sigma": mask_sigma,
292
+ "nf_loss": nf_loss,
293
+ }
294
+ )
295
+ else:
296
+ data.update(
297
+ {
298
+ "pred_coord": coord,
299
+ "pred_score": 1.0 - torch.mean(sigma, dim=-1).flatten(),
300
+ }
301
+ )
302
+
303
+ if not self.deploy:
304
+ self.final_matching_selection(data)
305
+
306
+ return data["pred_coord"], data["pred_score"]
307
+
308
+ @torch.no_grad()
309
+ def final_matching_selection(self, data):
310
+ offset = data["pred_coord"] * self.local_resolution
311
+
312
+ if self.bi_directional_refine:
313
+ fine_offset01, fine_offset10 = torch.clamp(
314
+ offset, -self.local_resolution / 2, self.local_resolution / 2
315
+ ).chunk(2)
316
+ else:
317
+ fine_offset01 = torch.clamp(
318
+ offset, -self.local_resolution / 2, self.local_resolution / 2
319
+ )
320
+
321
+ h0, w0 = data["hw0_i"]
322
+ h1, w1 = data["hw1_i"]
323
+ scale0 = data["scale0"][data["b_ids"]] if "scale0" in data else 1.0
324
+ scale1 = data["scale1"][data["b_ids"]] if "scale1" in data else 1.0
325
+ scale0_w = scale0[:, 0] if "scale0" in data else 1.0
326
+ scale0_h = scale0[:, 1] if "scale0" in data else 1.0
327
+ scale1_w = scale1[:, 0] if "scale1" in data else 1.0
328
+ scale1_h = scale1[:, 1] if "scale1" in data else 1.0
329
+
330
+ # Filter by mconf and border
331
+ mkpts0_f = data["mkpts0_c"]
332
+ mkpts1_f = data["mkpts1_c"] + fine_offset01 * scale1
333
+ mask = (
334
+ (data["mconf"] > self.mconf_thr)
335
+ & (mkpts0_f[:, 0] >= self.border_rm)
336
+ & (mkpts0_f[:, 0] <= w0 * scale0_w - self.border_rm)
337
+ & (mkpts0_f[:, 1] >= self.border_rm)
338
+ & (mkpts0_f[:, 1] <= h0 * scale0_h - self.border_rm)
339
+ & (mkpts1_f[:, 0] >= self.border_rm)
340
+ & (mkpts1_f[:, 0] <= w1 * scale1_w - self.border_rm)
341
+ & (mkpts1_f[:, 1] >= self.border_rm)
342
+ & (mkpts1_f[:, 1] <= h1 * scale1_h - self.border_rm)
343
+ )
344
+ if self.bi_directional_refine:
345
+ mkpts0_f_ = data["mkpts0_c"] + fine_offset10 * scale0
346
+ mkpts1_f_ = data["mkpts1_c"]
347
+ mask_ = (
348
+ (data["mconf"] > self.mconf_thr)
349
+ & (mkpts0_f_[:, 0] >= self.border_rm)
350
+ & (mkpts0_f_[:, 0] <= w0 * scale0_w - self.border_rm)
351
+ & (mkpts0_f_[:, 1] >= self.border_rm)
352
+ & (mkpts0_f_[:, 1] <= h0 * scale0_h - self.border_rm)
353
+ & (mkpts1_f_[:, 0] >= self.border_rm)
354
+ & (mkpts1_f_[:, 0] <= w1 * scale1_w - self.border_rm)
355
+ & (mkpts1_f_[:, 1] >= self.border_rm)
356
+ & (mkpts1_f_[:, 1] <= h1 * scale1_h - self.border_rm)
357
+ )
358
+
359
+ if self.bi_directional_refine:
360
+ mkpts0_f = torch.cat([mkpts0_f, mkpts0_f_])
361
+ mkpts1_f = torch.cat([mkpts1_f, mkpts1_f_])
362
+ mask = torch.cat([mask, mask_])
363
+ data["mconf"] = torch.cat([data["mconf"], data["mconf"]])
364
+ data["b_ids"] = torch.cat([data["b_ids"], data["b_ids"]])
365
+
366
+ # Filter by sigma
367
+ if self.bi_directional_refine and self.sigma_selection:
368
+ # Retain the more confident matching pair with a smaller sigma (more significant) in the bi-directional matching pairs
369
+ pred_score01, pred_score10 = data["pred_score"].chunk(2)
370
+ pred_score_mask = pred_score01 > pred_score10
371
+ pred_score_mask = torch.cat([pred_score_mask, ~pred_score_mask])
372
+ pred_score_mask &= data["pred_score"] > self.sigma_thr
373
+ mask &= pred_score_mask
374
+
375
+ data.update(
376
+ {
377
+ # "gt_mask": data["mconf"] == 0,
378
+ "m_bids": data["b_ids"][mask],
379
+ "mkpts0_f": mkpts0_f[mask],
380
+ "mkpts1_f": mkpts1_f[mask],
381
+ "mconf": data["mconf"][mask],
382
+ }
383
+ )
src/edm/neck/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .loftr_module import LocalFeatureTransformer
src/edm/neck/loftr_module/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .transformer import LocalFeatureTransformer
src/edm/neck/loftr_module/transformer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Module
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops.einops import rearrange
7
+ import math
8
+ import os
9
+ import sys
10
+
11
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RoPEPositionEncodingSine(nn.Module):
15
+ """
16
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
17
+ """
18
+
19
+ def __init__(self, d_model, max_shape=(128, 128), npe=None, ropefp16=True):
20
+ """
21
+ Args:
22
+ max_shape (tuple): for 1/32 featmap, the max length of 128 corresponds to 4096 pixels
23
+ """
24
+ super().__init__()
25
+
26
+ i_position = torch.ones(max_shape).cumsum(
27
+ 0).float().unsqueeze(-1) # [H, 1]
28
+ j_position = torch.ones(max_shape).cumsum(
29
+ 1).float().unsqueeze(-1) # [W, 1]
30
+
31
+ assert npe is not None
32
+ train_res_H, train_res_W, test_res_H, test_res_W = (
33
+ npe[0],
34
+ npe[1],
35
+ npe[2],
36
+ npe[3],
37
+ )
38
+ i_position, j_position = (
39
+ i_position * train_res_H / test_res_H,
40
+ j_position * train_res_W / test_res_W,
41
+ )
42
+
43
+ div_term = torch.exp(
44
+ torch.arange(0, d_model // 4, 1).float()
45
+ * (-math.log(10000.0) / (d_model // 4))
46
+ )
47
+ div_term = div_term[None, None, :] # [1, 1, C//4]
48
+
49
+ sin = torch.zeros(
50
+ *max_shape, d_model // 2, dtype=torch.float16 if ropefp16 else torch.float32
51
+ )
52
+ cos = torch.zeros(
53
+ *max_shape, d_model // 2, dtype=torch.float16 if ropefp16 else torch.float32
54
+ )
55
+
56
+ sin[:, :, 0::2] = (
57
+ torch.sin(i_position * div_term).half()
58
+ if ropefp16
59
+ else torch.sin(i_position * div_term)
60
+ )
61
+ sin[:, :, 1::2] = (
62
+ torch.sin(j_position * div_term).half()
63
+ if ropefp16
64
+ else torch.sin(j_position * div_term)
65
+ )
66
+ cos[:, :, 0::2] = (
67
+ torch.cos(i_position * div_term).half()
68
+ if ropefp16
69
+ else torch.cos(i_position * div_term)
70
+ )
71
+ cos[:, :, 1::2] = (
72
+ torch.cos(j_position * div_term).half()
73
+ if ropefp16
74
+ else torch.cos(j_position * div_term)
75
+ )
76
+
77
+ sin = sin.repeat_interleave(2, dim=-1)
78
+ cos = cos.repeat_interleave(2, dim=-1)
79
+
80
+ self.register_buffer(
81
+ "sin", sin.unsqueeze(0), persistent=False
82
+ ) # [1, H, W, C//2]
83
+ self.register_buffer(
84
+ "cos", cos.unsqueeze(0), persistent=False
85
+ ) # [1, H, W, C//2]
86
+
87
+ def forward(self, x, ratio=1):
88
+ """
89
+ Args:
90
+ x: [N, H, W, C]
91
+ """
92
+ return (x * self.cos[:, : x.size(1), : x.size(2), :]) + (
93
+ self.rotate_half(x) * self.sin[:, : x.size(1), : x.size(2), :]
94
+ )
95
+
96
+ def rotate_half(self, x):
97
+ # x = x.unflatten(-1, (-1, 2))
98
+ a, b, c, d = x.shape
99
+ x = x.reshape(a, b, c, d // 2, 2)
100
+
101
+ x1, x2 = x.unbind(dim=-1)
102
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
103
+
104
+
105
+ """
106
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
107
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
108
+ """
109
+
110
+
111
+ def crop_feature(query, key, value, x_mask, source_mask):
112
+ mask_h0, mask_w0, mask_h1, mask_w1 = (
113
+ x_mask[0].sum(-2)[0],
114
+ x_mask[0].sum(-1)[0],
115
+ source_mask[0].sum(-2)[0],
116
+ source_mask[0].sum(-1)[0],
117
+ )
118
+ query = query[:, :mask_h0, :mask_w0, :]
119
+ key = key[:, :mask_h1, :mask_w1, :]
120
+ value = value[:, :mask_h1, :mask_w1, :]
121
+ return query, key, value, mask_h0, mask_w0
122
+
123
+
124
+ def pad_feature(m, mask_h0, mask_w0, x_mask):
125
+ bs, hw, nhead, dim = m.shape
126
+ m = m.view(bs, mask_h0, mask_w0, nhead, dim)
127
+ if mask_h0 != x_mask.size(-2):
128
+ m = torch.cat(
129
+ [
130
+ m,
131
+ torch.zeros(
132
+ m.size(0),
133
+ x_mask.size(-2) - mask_h0,
134
+ x_mask.size(-1),
135
+ nhead,
136
+ dim,
137
+ device=m.device,
138
+ dtype=m.dtype,
139
+ ),
140
+ ],
141
+ dim=1,
142
+ )
143
+ elif mask_w0 != x_mask.size(-1):
144
+ m = torch.cat(
145
+ [
146
+ m,
147
+ torch.zeros(
148
+ m.size(0),
149
+ x_mask.size(-2),
150
+ x_mask.size(-1) - mask_w0,
151
+ nhead,
152
+ dim,
153
+ device=m.device,
154
+ dtype=m.dtype,
155
+ ),
156
+ ],
157
+ dim=2,
158
+ )
159
+ return m
160
+
161
+
162
+ class Attention(Module):
163
+ def __init__(self, nhead=8, dim=256, re=False):
164
+ super().__init__()
165
+
166
+ self.nhead = nhead
167
+ self.dim = dim
168
+
169
+ def attention(self, query, key, value, q_mask=None, kv_mask=None):
170
+ assert (
171
+ q_mask is None and kv_mask is None
172
+ ), "Not support generalized attention mask yet."
173
+ # Scaled Cosine Attention
174
+ # Refer to "Query-key normalization for transformers" and "https://kexue.fm/archives/9859"
175
+ query = F.normalize(query, p=2, dim=3)
176
+ key = F.normalize(key, p=2, dim=3)
177
+ QK = torch.einsum("nlhd,nshd->nlsh", query, key)
178
+ s = 20.0
179
+ A = torch.softmax(s * QK, dim=2)
180
+
181
+ out = torch.einsum("nlsh,nshd->nlhd", A, value)
182
+ return out
183
+
184
+ def _forward(self, query, key, value, q_mask=None, kv_mask=None):
185
+ if q_mask is not None:
186
+ query, key, value, mask_h0, mask_w0 = crop_feature(
187
+ query, key, value, q_mask, kv_mask
188
+ )
189
+
190
+ query, key, value = map(
191
+ lambda x: rearrange(
192
+ x,
193
+ "n h w (nhead d) -> n (h w) nhead d",
194
+ nhead=self.nhead,
195
+ d=self.dim,
196
+ ),
197
+ [query, key, value],
198
+ )
199
+
200
+ m = self.attention(query, key, value, q_mask=None, kv_mask=None)
201
+
202
+ if q_mask is not None:
203
+ m = pad_feature(m, mask_h0, mask_w0, q_mask)
204
+
205
+ return m
206
+
207
+ def forward(self, query, key, value, q_mask=None, kv_mask=None):
208
+ """
209
+ Args:
210
+ queries: [N, L, H, D]
211
+ keys: [N, S, H, D]
212
+ values: [N, S, H, D]
213
+ q_mask: [N, L]
214
+ kv_mask: [N, S]
215
+ Returns:
216
+ queried_values: (N, L, H, D)
217
+ """
218
+ bs = query.size(0)
219
+ if bs == 1 or q_mask is None:
220
+ m = self._forward(query, key, value,
221
+ q_mask=q_mask, kv_mask=kv_mask)
222
+ else: # for faster trainning with padding mask while batch size > 1
223
+ m_list = []
224
+ for i in range(bs):
225
+ m_list.append(
226
+ self._forward(
227
+ query[i: i + 1],
228
+ key[i: i + 1],
229
+ value[i: i + 1],
230
+ q_mask=q_mask[i: i + 1],
231
+ kv_mask=kv_mask[i: i + 1],
232
+ )
233
+ )
234
+ m = torch.cat(m_list, dim=0)
235
+ return m
236
+
237
+
238
+ class AG_RoPE_EncoderLayer(nn.Module):
239
+ def __init__(
240
+ self,
241
+ d_model,
242
+ nhead,
243
+ agg_size0=2,
244
+ agg_size1=2,
245
+ rope=False,
246
+ npe=None,
247
+ ):
248
+ super(AG_RoPE_EncoderLayer, self).__init__()
249
+ self.dim = d_model // nhead
250
+ self.nhead = nhead
251
+ self.agg_size0, self.agg_size1 = agg_size0, agg_size1
252
+ self.rope = rope
253
+
254
+ # aggregate and position encoding
255
+ self.aggregate = (
256
+ nn.Conv2d(
257
+ d_model,
258
+ d_model,
259
+ kernel_size=agg_size0,
260
+ padding=0,
261
+ stride=agg_size0,
262
+ bias=False,
263
+ groups=d_model,
264
+ )
265
+ if self.agg_size0 != 1
266
+ else nn.Identity()
267
+ )
268
+ self.max_pool = (
269
+ torch.nn.MaxPool2d(kernel_size=self.agg_size1,
270
+ stride=self.agg_size1)
271
+ if self.agg_size1 != 1
272
+ else nn.Identity()
273
+ )
274
+ self.mask_max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
275
+ self.rope_pos_enc = RoPEPositionEncodingSine(
276
+ d_model, max_shape=(128, 128), npe=npe, ropefp16=True
277
+ )
278
+
279
+ # multi-head attention
280
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
281
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
282
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
283
+ self.attention = Attention(self.nhead, self.dim)
284
+ self.merge = nn.Linear(d_model, d_model, bias=False)
285
+
286
+ # feed-forward network
287
+ self.mlp = nn.Sequential(
288
+ nn.Linear(d_model * 2, d_model * 2, bias=False),
289
+ nn.LeakyReLU(inplace=True),
290
+ nn.Linear(d_model * 2, d_model, bias=False),
291
+ )
292
+
293
+ # norm
294
+ self.norm1 = nn.LayerNorm(d_model)
295
+ self.norm2 = nn.LayerNorm(d_model)
296
+
297
+ def forward(self, x, source, x_mask=None, source_mask=None):
298
+ """
299
+ Args:
300
+ x (torch.Tensor): [N, C, H0, W0]
301
+ source (torch.Tensor): [N, C, H1, W1]
302
+ x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0)
303
+ source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1)
304
+ """
305
+ bs, C, H0, W0 = x.size()
306
+ H1, W1 = source.size(-2), source.size(-1)
307
+
308
+ # Aggragate feature
309
+ # assert x_mask is None and source_mask is None
310
+
311
+ query, source = self.norm1(self.aggregate(x).permute(0, 2, 3, 1)), self.norm1(
312
+ self.max_pool(source).permute(0, 2, 3, 1)
313
+ ) # [N, H, W, C]
314
+ if x_mask is not None:
315
+ # mask 1/8 to 1/32
316
+ x_mask, source_mask = map(
317
+ lambda x: self.mask_max_pool(
318
+ self.mask_max_pool(x.float())).bool(),
319
+ [x_mask, source_mask],
320
+ )
321
+ query, key, value = self.q_proj(
322
+ query), self.k_proj(source), self.v_proj(source)
323
+
324
+ # Positional encoding
325
+ if self.rope:
326
+ query = self.rope_pos_enc(query)
327
+ key = self.rope_pos_enc(key)
328
+
329
+ # multi-head attention handle padding mask
330
+ m = self.attention(query, key, value, q_mask=x_mask,
331
+ kv_mask=source_mask)
332
+ m = self.merge(m.reshape(bs, -1, self.nhead * self.dim)) # [N, L, C]
333
+
334
+ # Upsample feature
335
+ m = rearrange(
336
+ m, "b (h w) c -> b c h w", h=H0 // self.agg_size0, w=W0 // self.agg_size0
337
+ ) # [N, C, H0, W0]
338
+
339
+ if self.agg_size0 != 1:
340
+ m = torch.nn.functional.interpolate(
341
+ m, size=(H0, W0), mode="bilinear", align_corners=False
342
+ ) # [N, C, H0, W0]
343
+
344
+ # feed-forward network
345
+ m = self.mlp(torch.cat([x, m], dim=1).permute(
346
+ 0, 2, 3, 1)) # [N, H0, W0, C]
347
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0]
348
+
349
+ return x + m
350
+
351
+
352
+ '''
353
+ Modified from EfficientLoFTR
354
+ '''
355
+ class LocalFeatureTransformer(nn.Module):
356
+ """A Local Feature Transformer (LoFTR) module."""
357
+
358
+ def __init__(self, config):
359
+ super(LocalFeatureTransformer, self).__init__()
360
+ self.d_model = config["d_model"]
361
+ self.nhead = config["nhead"]
362
+ self.layer_names = config["layer_names"]
363
+ self.agg_size0, self.agg_size1 = config["agg_size0"], config["agg_size1"]
364
+ self.rope = config["rope"]
365
+
366
+ self_layer = AG_RoPE_EncoderLayer(
367
+ config["d_model"],
368
+ config["nhead"],
369
+ config["agg_size0"],
370
+ config["agg_size1"],
371
+ config["rope"],
372
+ config["npe"],
373
+ )
374
+ cross_layer = AG_RoPE_EncoderLayer(
375
+ config["d_model"],
376
+ config["nhead"],
377
+ config["agg_size0"],
378
+ config["agg_size1"],
379
+ False,
380
+ config["npe"],
381
+ )
382
+
383
+ self.layers = nn.ModuleList(
384
+ [
385
+ (
386
+ copy.deepcopy(self_layer)
387
+ if _ == "self"
388
+ else copy.deepcopy(cross_layer)
389
+ )
390
+ for _ in self.layer_names
391
+ ]
392
+ )
393
+ self._reset_parameters()
394
+
395
+ def _reset_parameters(self):
396
+ for p in self.parameters():
397
+ if p.dim() > 1:
398
+ nn.init.xavier_uniform_(p)
399
+
400
+ def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
401
+ """
402
+ Args:
403
+ feat0 (torch.Tensor): [N, C, H, W]
404
+ feat1 (torch.Tensor): [N, C, H, W]
405
+ mask0 (torch.Tensor): [N, L] (optional)
406
+ mask1 (torch.Tensor): [N, S] (optional)
407
+ """
408
+ for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
409
+ if name == "self":
410
+ feat0 = layer(feat0, feat0, mask0, mask0)
411
+ feat1 = layer(feat1, feat1, mask1, mask1)
412
+ elif name == "cross":
413
+ feat0 = layer(feat0, feat1, mask0, mask1)
414
+ feat1 = layer(feat1, feat0, mask1, mask0)
415
+ else:
416
+ raise KeyError
417
+
418
+ return feat0, feat1
src/edm/neck/neck.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .loftr_module.transformer import LocalFeatureTransformer
6
+
7
+
8
+ class Conv2d_BN_Act(nn.Sequential):
9
+ def __init__(
10
+ self,
11
+ a,
12
+ b,
13
+ ks=1,
14
+ stride=1,
15
+ pad=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bn_weight_init=1,
19
+ act=None,
20
+ drop=None,
21
+ ):
22
+ super().__init__()
23
+ self.inp_channel = a
24
+ self.out_channel = b
25
+ self.ks = ks
26
+ self.pad = pad
27
+ self.stride = stride
28
+ self.dilation = dilation
29
+ self.groups = groups
30
+
31
+ self.add_module(
32
+ "c", nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
33
+ )
34
+ bn = nn.BatchNorm2d(b)
35
+ nn.init.constant_(bn.weight, bn_weight_init)
36
+ nn.init.constant_(bn.bias, 0)
37
+ self.add_module("bn", bn)
38
+ if act != None:
39
+ self.add_module("a", act)
40
+ if drop != None:
41
+ self.add_module("d", nn.Dropout(drop))
42
+
43
+
44
+ class CIM(nn.Module):
45
+ """Feature Aggregation, Correlation Injection Module"""
46
+
47
+ def __init__(self, config):
48
+ super(CIM, self).__init__()
49
+
50
+ self.block_dims = config["backbone"]["block_dims"]
51
+ self.drop = config["fine"]["droprate"]
52
+
53
+ self.fc32 = Conv2d_BN_Act(
54
+ self.block_dims[-1], self.block_dims[-1], 1, drop=self.drop
55
+ )
56
+ self.fc16 = Conv2d_BN_Act(
57
+ self.block_dims[-2], self.block_dims[-1], 1, drop=self.drop
58
+ )
59
+ self.fc8 = Conv2d_BN_Act(
60
+ self.block_dims[-3], self.block_dims[-1], 1, drop=self.drop
61
+ )
62
+ self.att32 = Conv2d_BN_Act(
63
+ self.block_dims[-1],
64
+ self.block_dims[-1],
65
+ 1,
66
+ act=nn.Sigmoid(),
67
+ drop=self.drop,
68
+ )
69
+ self.att16 = Conv2d_BN_Act(
70
+ self.block_dims[-1],
71
+ self.block_dims[-1],
72
+ 1,
73
+ act=nn.Sigmoid(),
74
+ drop=self.drop,
75
+ )
76
+ self.dwconv16 = nn.Sequential(
77
+ Conv2d_BN_Act(
78
+ self.block_dims[-1],
79
+ self.block_dims[-1],
80
+ ks=3,
81
+ pad=1,
82
+ groups=self.block_dims[-1],
83
+ act=nn.GELU(),
84
+ ),
85
+ Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
86
+ )
87
+ self.dwconv8 = nn.Sequential(
88
+ Conv2d_BN_Act(
89
+ self.block_dims[-1],
90
+ self.block_dims[-1],
91
+ ks=3,
92
+ pad=1,
93
+ groups=self.block_dims[-1],
94
+ act=nn.GELU(),
95
+ ),
96
+ Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
97
+ )
98
+
99
+ self.loftr_32 = LocalFeatureTransformer(config["neck"])
100
+
101
+ def forward(self, ms_feats, mask_c0=None, mask_c1=None):
102
+ if len(ms_feats) == 3: # same image shape
103
+ f8, f16, f32 = ms_feats
104
+ f32 = self.fc32(f32)
105
+
106
+ f32_0, f32_1 = f32.chunk(2, dim=0)
107
+ f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
108
+ f32 = torch.cat([f32_0, f32_1], dim=0)
109
+
110
+ f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
111
+ att32_up = F.interpolate(self.att32(
112
+ f32), scale_factor=2.0, mode="bilinear")
113
+ f16 = self.fc16(f16)
114
+ f16 = self.dwconv16(f16 * att32_up + f32_up)
115
+ f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
116
+ att16_up = F.interpolate(self.att16(
117
+ f16), scale_factor=2.0, mode="bilinear")
118
+ f8 = self.fc8(f8)
119
+ f8 = self.dwconv8(f8 * att16_up + f16_up)
120
+
121
+ feat_c0, feat_c1 = f8.chunk(2)
122
+
123
+ elif len(ms_feats) == 6: # diffirent image shape
124
+ f8_0, f16_0, f32_0, f8_1, f16_1, f32_1 = ms_feats
125
+ f32_0 = self.fc32(f32_0)
126
+ f32_1 = self.fc32(f32_1)
127
+
128
+ f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
129
+
130
+ f8, f16, f32 = f8_0, f16_0, f32_0
131
+ f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
132
+ att32_up = F.interpolate(self.att32(
133
+ f32), scale_factor=2.0, mode="bilinear")
134
+ f16 = self.fc16(f16)
135
+ f16 = self.dwconv16(f16 * att32_up + f32_up)
136
+ f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
137
+ att16_up = F.interpolate(self.att16(
138
+ f16), scale_factor=2.0, mode="bilinear")
139
+ f8 = self.fc8(f8)
140
+ f8 = self.dwconv8(f8 * att16_up + f16_up)
141
+ feat_c0 = f8
142
+
143
+ f8, f16, f32 = f8_1, f16_1, f32_1
144
+ f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
145
+ att32_up = F.interpolate(self.att32(
146
+ f32), scale_factor=2.0, mode="bilinear")
147
+ f16 = self.fc16(f16)
148
+ f16 = self.dwconv16(f16 * att32_up + f32_up)
149
+ f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
150
+ att16_up = F.interpolate(self.att16(
151
+ f16), scale_factor=2.0, mode="bilinear")
152
+ f8 = self.fc8(f8)
153
+ f8 = self.dwconv8(f8 * att16_up + f16_up)
154
+ feat_c1 = f8
155
+
156
+ return feat_c0, feat_c1
src/utils/misc.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import contextlib
3
+ import joblib
4
+ from typing import Union
5
+ from loguru import _Logger, logger
6
+ from itertools import chain
7
+
8
+ import torch
9
+ from yacs.config import CfgNode as CN
10
+
11
+
12
+
13
+ def lower_config(yacs_cfg):
14
+ if not isinstance(yacs_cfg, CN):
15
+ return yacs_cfg
16
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
17
+
18
+
19
+ def upper_config(dict_cfg):
20
+ if not isinstance(dict_cfg, dict):
21
+ return dict_cfg
22
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
23
+
24
+
25
+ def log_on(condition, message, level):
26
+ if condition:
27
+ assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
28
+ logger.log(level, message)
29
+
30
+
31
+
32
+ def setup_gpus(gpus: Union[str, int]) -> int:
33
+ """A temporary fix for pytorch-lighting 1.3.x"""
34
+ gpus = str(gpus)
35
+ gpu_ids = []
36
+
37
+ if "," not in gpus:
38
+ n_gpus = int(gpus)
39
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
40
+ else:
41
+ gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
42
+
43
+ # setup environment variables
44
+ visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
45
+ if visible_devices is None:
46
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
47
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
48
+ visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
49
+ logger.warning(
50
+ f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
51
+ )
52
+ else:
53
+ logger.warning(
54
+ "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
55
+ )
56
+ return len(gpu_ids)
57
+
58
+
59
+ def flattenList(x):
60
+ return list(chain(*x))
61
+
62
+
63
+ @contextlib.contextmanager
64
+ def tqdm_joblib(tqdm_object):
65
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
66
+
67
+ Usage:
68
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
69
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
70
+
71
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
72
+ ret_vals = Parallel(n_jobs=args.world_size)(
73
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
74
+ for param in tqdm(combinations(image_ids, 2),
75
+ desc=f'Computing cov_score of [{pid}]',
76
+ total=len(image_ids)*(len(image_ids)-1)/2))
77
+ Src: https://stackoverflow.com/a/58936697
78
+ """
79
+
80
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
81
+ def __init__(self, *args, **kwargs):
82
+ super().__init__(*args, **kwargs)
83
+
84
+ def __call__(self, *args, **kwargs):
85
+ tqdm_object.update(n=self.batch_size)
86
+ return super().__call__(*args, **kwargs)
87
+
88
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
89
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
90
+ try:
91
+ yield tqdm_object
92
+ finally:
93
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
94
+ tqdm_object.close()
95
+
96
+
97
+ def detect_NaN(feat_0, feat_1):
98
+ logger.info(f"NaN detected in feature")
99
+ logger.info(
100
+ f"#NaN in feat_0: {torch.isnan(feat_0).int().sum()}, #NaN in feat_1: {torch.isnan(feat_1).int().sum()}"
101
+ )
102
+ feat_0[torch.isnan(feat_0)] = 0
103
+ feat_1[torch.isnan(feat_1)] = 0
src/utils/plotting.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import cv2
6
+
7
+ def _compute_conf_thresh(data):
8
+ dataset_name = data["dataset_name"][0].lower()
9
+ if dataset_name == "scannet":
10
+ thr = 5e-4
11
+ elif dataset_name == "megadepth":
12
+ thr = 1e-4
13
+ else:
14
+ raise ValueError(f"Unknown dataset: {dataset_name}")
15
+ return thr
16
+
17
+
18
+ # --- VISUALIZATION --- #
19
+
20
+
21
+ def make_matching_figure(
22
+ img0,
23
+ img1,
24
+ mkpts0,
25
+ mkpts1,
26
+ color,
27
+ kpts0=None,
28
+ kpts1=None,
29
+ text=[],
30
+ path=None,
31
+ ):
32
+ """
33
+ 使用OpenCV绘制匹配点可视化图像
34
+
35
+ 参数:
36
+ img0: 第一张图像 (BGR格式)
37
+ img1: 第二张图像 (BGR格式)
38
+ mkpts0: 第一张图像中的匹配点 (Nx2数组)
39
+ mkpts1: 第二张图像中的匹配点 (Nx2数组)
40
+ color: 每个匹配点的颜色
41
+ kpts0: 第一张图像中的所有关键点 (可选)
42
+ kpts1: 第二张图像中的所有关键点 (可选)
43
+ text: 要添加的文本 (可选)
44
+ path: 保存图像的路径 (可选)
45
+
46
+ 返回:
47
+ 绘制好的OpenCV图像 (BGR格式)
48
+ """
49
+ # 确保匹配点数量一致
50
+ assert mkpts0.shape[0] == mkpts1.shape[0], \
51
+ f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
52
+
53
+ # 确保图像有相同的高度,如果不同则调整
54
+ h0, w0 = img0.shape[:2]
55
+ h1, w1 = img1.shape[:2]
56
+ max_height = max(h0, h1)
57
+
58
+ # 创建画布,两张图像并排显示
59
+ canvas = np.ones((max_height, w0 + w1, 3), dtype=np.uint8) * 255
60
+
61
+ # 将图像放置到画布上
62
+ canvas[:h0, :w0] = img0
63
+ canvas[:h1, w0:w0+w1] = img1
64
+
65
+ # 绘制所有关键点(如果提供)
66
+ if kpts0 is not None and kpts1 is not None:
67
+ for (x, y) in kpts0.astype(np.int32):
68
+ cv2.circle(canvas, (x, y), 1, (255, 255, 255), -1)
69
+
70
+ for (x, y) in kpts1.astype(np.int32):
71
+ cv2.circle(canvas, (x + w0, y), 1, (255, 255, 255), -1)
72
+
73
+ # 绘制匹配点和连接线
74
+ if mkpts0.shape[0] > 0 and mkpts1.shape[0] > 0:
75
+ # 转换为整数坐标
76
+ mkpts0_int = mkpts0.astype(np.int32)
77
+ mkpts1_int = mkpts1.astype(np.int32)
78
+
79
+ # 绘制连接线
80
+ for i in range(len(mkpts0_int)):
81
+ x0, y0 = mkpts0_int[i]
82
+ x1, y1 = mkpts1_int[i]
83
+ # 第二张图的x坐标需要加上第一张图的宽度
84
+ x1 += w0
85
+
86
+ # 将颜色从0-1范围转换为0-255
87
+ line_color = tuple(int(c * 255) for c in color[i][:3])
88
+ # 转换为BGR格式(因为OpenCV使用BGR)
89
+ # line_color = (line_color[2], line_color[1], line_color[0])
90
+
91
+ cv2.line(canvas, (x0, y0), (x1, y1), line_color, 1)
92
+
93
+ # 绘制匹配点
94
+ for i in range(len(mkpts0_int)):
95
+ x0, y0 = mkpts0_int[i]
96
+ x1, y1 = mkpts1_int[i]
97
+ x1 += w0
98
+
99
+ pt_color = tuple(int(c * 255) for c in color[i][:3])
100
+ # pt_color = (pt_color[2], pt_color[1], pt_color[0])
101
+
102
+ cv2.circle(canvas, (x0, y0), 2, pt_color, -1)
103
+ cv2.circle(canvas, (x1, y1), 2, pt_color, -1)
104
+
105
+ # 添加文本
106
+ if text:
107
+ # 确定文本颜色(基于图像亮度)
108
+ roi = img0[:100, :200] if h0 > 100 and w0 > 200 else img0
109
+ brightness = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY).mean()
110
+ text_color = (0, 0, 0) if brightness > 200 else (255, 255, 255)
111
+
112
+ # 绘制文本
113
+ y_pos = 30
114
+ for i, line in enumerate(text):
115
+ cv2.putText(
116
+ canvas, line, (10, y_pos + i * 30),
117
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, text_color, 2
118
+ )
119
+
120
+ # 保存图像(如果指定了路径)
121
+ if path:
122
+ cv2.imwrite(path, canvas)
123
+
124
+ return canvas
125
+
126
+ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
127
+ b_mask = data["m_bids"] == b_id
128
+ conf_thr = _compute_conf_thresh(data)
129
+
130
+ img0 = (data["image0"][b_id][0].cpu().numpy()
131
+ * 255).round().astype(np.int32)
132
+ img1 = (data["image1"][b_id][0].cpu().numpy()
133
+ * 255).round().astype(np.int32)
134
+ kpts0 = data["mkpts0_f"][b_mask].clone().detach().cpu().numpy()
135
+ kpts1 = data["mkpts1_f"][b_mask].clone().detach().cpu().numpy()
136
+
137
+ # for megadepth, we visualize matches on the resized image
138
+ if "scale0" in data:
139
+ kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
140
+ kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
141
+
142
+ epi_errs = data["epi_errs"][b_mask].cpu().numpy()
143
+ correct_mask = epi_errs < conf_thr
144
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
145
+ n_correct = np.sum(correct_mask)
146
+ n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
147
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
148
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
149
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
150
+
151
+ # matching info
152
+ if alpha == "dynamic":
153
+ alpha = dynamic_alpha(len(correct_mask))
154
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
155
+
156
+ text = [
157
+ f"#Matches {len(kpts0)}",
158
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
159
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
160
+ ]
161
+
162
+ # make the figure
163
+ figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
164
+ return figure
165
+
166
+
167
+ def _make_confidence_figure(data, b_id):
168
+ # TODO: Implement confidence figure
169
+ raise NotImplementedError()
170
+
171
+
172
+ def make_matching_figures(data, config, mode="evaluation"):
173
+ """Make matching figures for a batch.
174
+
175
+ Args:
176
+ data (Dict): a batch updated by PL_LoFTR.
177
+ config (Dict): matcher config
178
+ Returns:
179
+ figures (Dict[str, List[plt.figure]]
180
+ """
181
+ assert mode in ["evaluation", "confidence", "gt"] # 'confidence'
182
+ figures = {mode: []}
183
+ for b_id in range(data["image0"].size(0)):
184
+ if mode == "evaluation":
185
+ fig = _make_evaluation_figure(
186
+ data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
187
+ )
188
+ elif mode == "confidence":
189
+ fig = _make_confidence_figure(data, b_id)
190
+ else:
191
+ raise ValueError(f"Unknown plot mode: {mode}")
192
+ figures[mode].append(fig)
193
+ return figures
194
+
195
+
196
+ def dynamic_alpha(
197
+ n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
198
+ ):
199
+ if n_matches == 0:
200
+ return 1.0
201
+ ranges = list(zip(alphas, alphas[1:] + [None]))
202
+ loc = bisect.bisect_right(milestones, n_matches) - 1
203
+ _range = ranges[loc]
204
+ if _range[1] is None:
205
+ return _range[0]
206
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
207
+ milestones[loc + 1] - milestones[loc]
208
+ ) * (_range[0] - _range[1])
209
+
210
+
211
+ def error_colormap(err, thr, alpha=1.0):
212
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
213
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
214
+ return np.clip(
215
+ np.stack([2 - x * 2, x * 2, np.zeros_like(x),
216
+ np.ones_like(x) * alpha], -1),
217
+ 0,
218
+ 1,
219
+ )