lixi042
commited on
Commit
·
7e31006
1
Parent(s):
bc6ea96
update
Browse files- .gitattributes +2 -0
- app.py +136 -4
- configs/data/__init__.py +0 -0
- configs/data/base.py +37 -0
- configs/data/megadepth_test_1500.py +23 -0
- configs/edm/outdoor/edm_base.py +17 -0
- requirements.txt +22 -0
- src/__init__.py +0 -0
- src/config/default.py +184 -0
- src/edm/__init__.py +2 -0
- src/edm/backbone/resnet.py +116 -0
- src/edm/edm.py +204 -0
- src/edm/head/coarse_matching.py +152 -0
- src/edm/head/fine_matching.py +383 -0
- src/edm/neck/__init__.py +1 -0
- src/edm/neck/loftr_module/__init__.py +1 -0
- src/edm/neck/loftr_module/transformer.py +418 -0
- src/edm/neck/neck.py +156 -0
- src/utils/misc.py +103 -0
- src/utils/plotting.py +219 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
weights/ filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/ filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,8 +1,140 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
import gradio as gr
|
4 |
|
5 |
+
import torch
|
6 |
+
import matplotlib.cm as cm
|
7 |
+
import sys
|
8 |
|
9 |
+
sys.path.append("src")
|
|
|
10 |
|
11 |
+
from src.utils.plotting import make_matching_figure
|
12 |
+
from src.edm import EDM
|
13 |
+
from src.config.default import get_cfg_defaults
|
14 |
+
from src.utils.misc import lower_config
|
15 |
+
|
16 |
+
|
17 |
+
HEADER = """
|
18 |
+
<div align="center">
|
19 |
+
<p>
|
20 |
+
<span style="font-size: 30px; vertical-align: bottom;"> 🎶 EDM: Efficient Deep Feature Matching</span>
|
21 |
+
</p>
|
22 |
+
<p style="margin-top: -15px;">
|
23 |
+
<a href="https://arxiv.org/pdf/2503.05122" target="_blank" style="color: grey;">ArXiv Paper</a>
|
24 |
+
|
25 |
+
<a href="https://github.com/chicleee/EDM" target="_blank" style="color: grey;">GitHub Repository</a>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
</div>
|
29 |
+
"""
|
30 |
+
|
31 |
+
ABSTRACT = """
|
32 |
+
Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications."""
|
33 |
+
|
34 |
+
def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000):
|
35 |
+
config = get_cfg_defaults()
|
36 |
+
data_cfg_path = "configs/data/megadepth_test_1500.py"
|
37 |
+
main_cfg_path = "configs/edm/outdoor/edm_base.py"
|
38 |
+
config.merge_from_file(main_cfg_path)
|
39 |
+
config.merge_from_file(data_cfg_path)
|
40 |
+
|
41 |
+
W, H = 832, 832
|
42 |
+
config.EDM.COARSE.MCONF_THR = conf_thres
|
43 |
+
config.EDM.COARSE.BORDER_RM = border_rm
|
44 |
+
config.EDM.COARSE.TOPK = topk
|
45 |
+
|
46 |
+
_config = lower_config(config)
|
47 |
+
matcher = EDM(config=_config["edm"])
|
48 |
+
state_dict = torch.load("weights/edm_outdoor.ckpt")["state_dict"]
|
49 |
+
matcher.load_state_dict(state_dict)
|
50 |
+
matcher = matcher.eval()
|
51 |
+
|
52 |
+
# Load example images
|
53 |
+
img0_bgr = image_0
|
54 |
+
img1_bgr = image_1
|
55 |
+
|
56 |
+
|
57 |
+
h0, w0 = img0_bgr.shape[:2]
|
58 |
+
h1, w1 = img1_bgr.shape[:2]
|
59 |
+
|
60 |
+
h0_scale = h0 / H
|
61 |
+
w0_scale = w0 / W
|
62 |
+
h1_scale = h1 / H
|
63 |
+
w1_scale = w1 / W
|
64 |
+
|
65 |
+
# For inference
|
66 |
+
img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY)
|
67 |
+
img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY)
|
68 |
+
img0_raw = cv2.resize(img0_raw, (W, H))
|
69 |
+
img1_raw = cv2.resize(img1_raw, (W, H))
|
70 |
+
|
71 |
+
img0 = torch.from_numpy(img0_raw)[None][None] / 255.
|
72 |
+
img1 = torch.from_numpy(img1_raw)[None][None]/ 255.
|
73 |
+
batch = {'image0': img0, 'image1': img1}
|
74 |
+
|
75 |
+
# Inference with EDM and get prediction
|
76 |
+
with torch.no_grad():
|
77 |
+
matcher(batch)
|
78 |
+
|
79 |
+
mkpts0 = batch['mkpts0_f'].cpu().numpy()
|
80 |
+
mkpts1 = batch['mkpts1_f'].cpu().numpy()
|
81 |
+
mconf = batch['mconf'].cpu().numpy()
|
82 |
+
|
83 |
+
mkpts0[:, 0] *= w0_scale
|
84 |
+
mkpts0[:, 1] *= h0_scale
|
85 |
+
mkpts1[:, 0] *= w1_scale
|
86 |
+
mkpts1[:, 1] *= h1_scale
|
87 |
+
|
88 |
+
color = cm.jet(mconf)
|
89 |
+
# Draw
|
90 |
+
text = [
|
91 |
+
'EDM',
|
92 |
+
'Matches: {}'.format(len(mkpts0)),
|
93 |
+
]
|
94 |
+
fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text)
|
95 |
+
|
96 |
+
return fig
|
97 |
+
|
98 |
+
|
99 |
+
with gr.Blocks() as demo:
|
100 |
+
|
101 |
+
gr.Markdown(HEADER)
|
102 |
+
with gr.Accordion("Abstract (click to open)", open=False):
|
103 |
+
gr.Image("assets/teaser.jpg")
|
104 |
+
gr.Markdown(ABSTRACT)
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
image_1 = gr.Image()
|
108 |
+
image_2 = gr.Image()
|
109 |
+
with gr.Row():
|
110 |
+
conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold")
|
111 |
+
topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK")
|
112 |
+
border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)")
|
113 |
+
gr.HTML(
|
114 |
+
"""
|
115 |
+
Note: images are actually resized to 832 x 832 for matching.
|
116 |
+
"""
|
117 |
+
)
|
118 |
+
with gr.Row():
|
119 |
+
button = gr.Button(value="Find Matches")
|
120 |
+
clear = gr.ClearButton(value="Clear")
|
121 |
+
|
122 |
+
output = gr.Image()
|
123 |
+
button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output)
|
124 |
+
clear.add([image_1, image_2, output])
|
125 |
+
|
126 |
+
gr.Examples(
|
127 |
+
examples=[
|
128 |
+
["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"],
|
129 |
+
["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"],
|
130 |
+
["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"],
|
131 |
+
["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"],
|
132 |
+
],
|
133 |
+
inputs=[image_1, image_2],
|
134 |
+
outputs=[output],
|
135 |
+
fn=find_matches,
|
136 |
+
cache_examples=None,
|
137 |
+
)
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
demo.launch(debug=True)
|
configs/data/__init__.py
ADDED
File without changes
|
configs/data/base.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The data config will be the last one merged into the main config.
|
3 |
+
Setups in data configs will override all existed setups!
|
4 |
+
"""
|
5 |
+
|
6 |
+
from yacs.config import CfgNode as CN
|
7 |
+
_CN = CN()
|
8 |
+
_CN.DATASET = CN()
|
9 |
+
_CN.TRAINER = CN()
|
10 |
+
_CN.EDM = CN()
|
11 |
+
_CN.EDM.NECK = CN()
|
12 |
+
|
13 |
+
# training data config
|
14 |
+
_CN.DATASET.TRAIN_DATA_ROOT = None
|
15 |
+
_CN.DATASET.TRAIN_POSE_ROOT = None
|
16 |
+
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
17 |
+
_CN.DATASET.TRAIN_LIST_PATH = None
|
18 |
+
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
19 |
+
# validation set config
|
20 |
+
_CN.DATASET.VAL_DATA_ROOT = None
|
21 |
+
_CN.DATASET.VAL_POSE_ROOT = None
|
22 |
+
_CN.DATASET.VAL_NPZ_ROOT = None
|
23 |
+
_CN.DATASET.VAL_LIST_PATH = None
|
24 |
+
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
25 |
+
|
26 |
+
# testing data config
|
27 |
+
_CN.DATASET.TEST_DATA_ROOT = None
|
28 |
+
_CN.DATASET.TEST_POSE_ROOT = None
|
29 |
+
_CN.DATASET.TEST_NPZ_ROOT = None
|
30 |
+
_CN.DATASET.TEST_LIST_PATH = None
|
31 |
+
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
32 |
+
|
33 |
+
# dataset config
|
34 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
|
35 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
|
36 |
+
|
37 |
+
cfg = _CN
|
configs/data/megadepth_test_1500.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.data.base import cfg
|
2 |
+
|
3 |
+
TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info"
|
4 |
+
|
5 |
+
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
|
6 |
+
cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
|
7 |
+
cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
|
8 |
+
cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
|
9 |
+
|
10 |
+
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
11 |
+
|
12 |
+
cfg.EDM.TRAIN_RES_H = 832
|
13 |
+
cfg.EDM.TRAIN_RES_W = 832
|
14 |
+
cfg.EDM.TEST_RES_H = 832
|
15 |
+
cfg.EDM.TEST_RES_W = 832
|
16 |
+
|
17 |
+
cfg.EDM.NECK.NPE = [
|
18 |
+
cfg.EDM.TRAIN_RES_H,
|
19 |
+
cfg.EDM.TRAIN_RES_W,
|
20 |
+
cfg.EDM.TEST_RES_H,
|
21 |
+
cfg.EDM.TEST_RES_W,
|
22 |
+
]
|
23 |
+
|
configs/edm/outdoor/edm_base.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config.default import _CN as cfg
|
2 |
+
|
3 |
+
cfg.TRAINER.CANONICAL_BS = 4 * 8
|
4 |
+
cfg.TRAINER.CANONICAL_LR = 2e-3
|
5 |
+
cfg.TRAINER.WARMUP_STEP = int(36800 / cfg.TRAINER.CANONICAL_BS * 3) # 3 epochs
|
6 |
+
cfg.TRAINER.WARMUP_RATIO = 0.1
|
7 |
+
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
|
8 |
+
cfg.TRAINER.EPI_ERR_THR = 1e-4
|
9 |
+
|
10 |
+
cfg.EDM.COARSE.MCONF_THR = 0.05
|
11 |
+
cfg.EDM.FINE.SIGMA_THR = 1e-6
|
12 |
+
cfg.EDM.COARSE.BORDER_RM = 0
|
13 |
+
|
14 |
+
# Top-K should not exceed grid_size = TEST_RES_H / 8 * TEST_RES_W / 8
|
15 |
+
# The recommended value is approximately grid_size * 0.35 for Megadepth
|
16 |
+
# cfg.EDM.COARSE.TOPK = int(832 / 8 * 832 / 8 * 0.35) # 3786 for train & LO-RANSAC test
|
17 |
+
cfg.EDM.COARSE.TOPK = int(1152 / 8 * 1152 / 8 * 0.35) # 7258 for test
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch==2.0.0
|
2 |
+
torchvision==0.15.0
|
3 |
+
torchaudio==2.0.0
|
4 |
+
opencv-python
|
5 |
+
albumentations==0.5.1 --no-binary=imgaug,albumentations
|
6 |
+
ray>=1.0.1
|
7 |
+
einops==0.3.0
|
8 |
+
kornia==0.4.1
|
9 |
+
loguru==0.5.3
|
10 |
+
yacs>=0.1.8
|
11 |
+
tqdm
|
12 |
+
autopep8
|
13 |
+
pylint
|
14 |
+
ipython
|
15 |
+
jupyterlab
|
16 |
+
matplotlib
|
17 |
+
h5py==3.1.0
|
18 |
+
pytorch-lightning==1.3.5
|
19 |
+
torchmetrics==0.7.0
|
20 |
+
joblib>=1.0.1
|
21 |
+
pillow==9.5.0
|
22 |
+
poselib
|
src/__init__.py
ADDED
File without changes
|
src/config/default.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
_CN = CN()
|
4 |
+
|
5 |
+
############## ↓ EDM Pipeline ↓ ##############
|
6 |
+
_CN.EDM = CN()
|
7 |
+
_CN.EDM.TRAIN_RES_H = 832
|
8 |
+
_CN.EDM.TRAIN_RES_W = 832
|
9 |
+
_CN.EDM.TEST_RES_H = 832
|
10 |
+
_CN.EDM.TEST_RES_W = 832
|
11 |
+
_CN.EDM.LOCAL_RESOLUTION = 8 # coarse matching at 1/8, window_size 8x8 in fine_level
|
12 |
+
_CN.EDM.MP = False # Not use mixed precision (ELoFTR defualt used mixed precision)
|
13 |
+
_CN.EDM.HALF = False # Not FP16
|
14 |
+
_CN.EDM.DEPLOY = False # export onnx model
|
15 |
+
_CN.EDM.EVAL_TIMES = 1
|
16 |
+
|
17 |
+
# 3. Coarse-Matching config
|
18 |
+
_CN.EDM.BACKBONE = CN()
|
19 |
+
_CN.EDM.BACKBONE.BLOCK_DIMS = [32, 64, 128, 256, 256] # 1/2 -> 1/32
|
20 |
+
|
21 |
+
# 3. Coarse-Matching config
|
22 |
+
_CN.EDM.NECK = CN()
|
23 |
+
_CN.EDM.NECK.D_MODEL = 256
|
24 |
+
_CN.EDM.NECK.NHEAD = 8
|
25 |
+
_CN.EDM.NECK.LAYER_NAMES = ["self", "cross"] * 2
|
26 |
+
_CN.EDM.NECK.AGG_SIZE0 = 1
|
27 |
+
_CN.EDM.NECK.AGG_SIZE1 = 1
|
28 |
+
_CN.EDM.NECK.ROPE = True
|
29 |
+
_CN.EDM.NECK.NPE = None
|
30 |
+
|
31 |
+
# 3. Coarse-Matching config
|
32 |
+
_CN.EDM.COARSE = CN()
|
33 |
+
_CN.EDM.COARSE.MCONF_THR = 0.2
|
34 |
+
_CN.EDM.COARSE.BORDER_RM = 0
|
35 |
+
_CN.EDM.COARSE.DSMAX_TEMPERATURE = 0.1
|
36 |
+
_CN.EDM.COARSE.TRAIN_PAD_NUM = 32 # training tricks: avoid DDP deadlock
|
37 |
+
_CN.EDM.COARSE.TOPK = 2048
|
38 |
+
_CN.EDM.COARSE.DS_OPT = True
|
39 |
+
|
40 |
+
# 4. EDM-fine module config
|
41 |
+
_CN.EDM.FINE = CN()
|
42 |
+
_CN.EDM.FINE.DROPRATE = None
|
43 |
+
_CN.EDM.FINE.COORD_LENGTH = 16
|
44 |
+
_CN.EDM.FINE.BI_DIRECTIONAL_REFINE = True
|
45 |
+
_CN.EDM.FINE.SIGMA_THR = 0.0
|
46 |
+
_CN.EDM.FINE.SIGMA_SELECTION = True
|
47 |
+
|
48 |
+
# 5. EDM Losses
|
49 |
+
# -- # coarse-level
|
50 |
+
_CN.EDM.LOSS = CN()
|
51 |
+
_CN.EDM.LOSS.COARSE_TYPE = "focal"
|
52 |
+
_CN.EDM.LOSS.COARSE_WEIGHT = 1.0
|
53 |
+
_CN.EDM.LOSS.SPARSE_SPVS = True
|
54 |
+
|
55 |
+
# -- - -- # focal loss (coarse)
|
56 |
+
_CN.EDM.LOSS.FOCAL_ALPHA = 0.25
|
57 |
+
_CN.EDM.LOSS.FOCAL_GAMMA = 2.0
|
58 |
+
_CN.EDM.LOSS.POS_WEIGHT = 1.0
|
59 |
+
_CN.EDM.LOSS.NEG_WEIGHT = 1.0
|
60 |
+
|
61 |
+
|
62 |
+
# -- # fine-level
|
63 |
+
_CN.EDM.LOSS.FINE_TYPE = "rle"
|
64 |
+
_CN.EDM.LOSS.FINE_WEIGHT = 0.2
|
65 |
+
_CN.EDM.LOSS.Q_DISTRIBUTION = "laplace" # options: ['laplace', 'gaussian']
|
66 |
+
|
67 |
+
|
68 |
+
############## Dataset ##############
|
69 |
+
_CN.DATASET = CN()
|
70 |
+
# 1. data config
|
71 |
+
# training and validating
|
72 |
+
_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
|
73 |
+
_CN.DATASET.TRAIN_DATA_ROOT = None
|
74 |
+
_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
|
75 |
+
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
76 |
+
_CN.DATASET.TRAIN_LIST_PATH = None
|
77 |
+
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
78 |
+
_CN.DATASET.VAL_DATA_ROOT = None
|
79 |
+
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
80 |
+
_CN.DATASET.VAL_NPZ_ROOT = None
|
81 |
+
_CN.DATASET.VAL_LIST_PATH = (
|
82 |
+
None # None if val data from all scenes are bundled into a single npz file
|
83 |
+
)
|
84 |
+
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
85 |
+
# testing
|
86 |
+
_CN.DATASET.TEST_DATA_SOURCE = None
|
87 |
+
_CN.DATASET.TEST_DATA_ROOT = None
|
88 |
+
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
89 |
+
_CN.DATASET.TEST_NPZ_ROOT = None
|
90 |
+
_CN.DATASET.TEST_LIST_PATH = (
|
91 |
+
None # None if test data from all scenes are bundled into a single npz file
|
92 |
+
)
|
93 |
+
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
94 |
+
|
95 |
+
# 2. dataset config
|
96 |
+
# general options
|
97 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
|
98 |
+
0.4 # discard data with overlap_score < min_overlap_score
|
99 |
+
)
|
100 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
101 |
+
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
|
102 |
+
|
103 |
+
# MegaDepth options
|
104 |
+
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
105 |
+
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
106 |
+
_CN.DATASET.MGDPT_DF = 8
|
107 |
+
|
108 |
+
|
109 |
+
############## Trainer ##############
|
110 |
+
_CN.TRAINER = CN()
|
111 |
+
_CN.TRAINER.WORLD_SIZE = 1
|
112 |
+
_CN.TRAINER.CANONICAL_BS = 4 * 8
|
113 |
+
_CN.TRAINER.CANONICAL_LR = 2e-3
|
114 |
+
_CN.TRAINER.SCALING = None # this will be calculated automatically
|
115 |
+
_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
|
116 |
+
|
117 |
+
# optimizer
|
118 |
+
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
|
119 |
+
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
120 |
+
_CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam
|
121 |
+
_CN.TRAINER.ADAMW_DECAY = 0.1
|
122 |
+
|
123 |
+
# step-based warm-up
|
124 |
+
_CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant]
|
125 |
+
_CN.TRAINER.WARMUP_RATIO = 0.0
|
126 |
+
_CN.TRAINER.WARMUP_STEP = 4800
|
127 |
+
|
128 |
+
# learning rate scheduler
|
129 |
+
# [MultiStepLR, CosineAnnealing, ExponentialLR]
|
130 |
+
_CN.TRAINER.SCHEDULER = "MultiStepLR"
|
131 |
+
_CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step]
|
132 |
+
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
|
133 |
+
_CN.TRAINER.MSLR_GAMMA = 0.5
|
134 |
+
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
135 |
+
# ELR: ExponentialLR, this value for 'step' interval
|
136 |
+
_CN.TRAINER.ELR_GAMMA = 0.999992
|
137 |
+
|
138 |
+
# plotting related
|
139 |
+
_CN.TRAINER.ENABLE_PLOTTING = False
|
140 |
+
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
|
141 |
+
_CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence']
|
142 |
+
_CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
|
143 |
+
|
144 |
+
# geometric metrics and pose solver
|
145 |
+
# recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
146 |
+
_CN.TRAINER.EPI_ERR_THR = 5e-4
|
147 |
+
_CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H']
|
148 |
+
_CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, LO-RANSAC]
|
149 |
+
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
150 |
+
_CN.TRAINER.RANSAC_CONF = 0.99999
|
151 |
+
|
152 |
+
# data sampler for train_dataloader
|
153 |
+
_CN.TRAINER.DATA_SAMPLER = (
|
154 |
+
"scene_balance" # options: ['scene_balance', 'random', 'normal']
|
155 |
+
)
|
156 |
+
# 'scene_balance' config
|
157 |
+
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
158 |
+
# whether sample each scene with replacement or not
|
159 |
+
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True
|
160 |
+
# after sampling from scenes, whether shuffle within the epoch or not
|
161 |
+
_CN.TRAINER.SB_SUBSET_SHUFFLE = True
|
162 |
+
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
163 |
+
# 'random' config
|
164 |
+
_CN.TRAINER.RDM_REPLACEMENT = True
|
165 |
+
_CN.TRAINER.RDM_NUM_SAMPLES = None
|
166 |
+
|
167 |
+
# gradient clipping
|
168 |
+
_CN.TRAINER.GRADIENT_CLIPPING = 0.5
|
169 |
+
|
170 |
+
# reproducibility
|
171 |
+
# This seed affects the data sampling. With the same seed, the data sampling is promised
|
172 |
+
# to be the same. When resume training from a checkpoint, it's better to use a different
|
173 |
+
# seed, otherwise the sampled data will be exactly the same as before resuming, which will
|
174 |
+
# cause less unique data items sampled during the entire training.
|
175 |
+
# Use of different seed values might affect the final training result, since not all data items
|
176 |
+
# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
|
177 |
+
_CN.TRAINER.SEED = 66
|
178 |
+
|
179 |
+
|
180 |
+
def get_cfg_defaults():
|
181 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
182 |
+
# Return a clone so that the defaults will not be altered
|
183 |
+
# This is for the "local variable" use pattern
|
184 |
+
return _CN.clone()
|
src/edm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .edm import EDM
|
2 |
+
|
src/edm/backbone/resnet.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def conv1x1(in_planes, out_planes, stride=1, bias=False):
|
5 |
+
"""1x1 convolution without padding"""
|
6 |
+
return nn.Conv2d(
|
7 |
+
in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, bias=False):
|
12 |
+
"""3x3 convolution with padding"""
|
13 |
+
return nn.Conv2d(
|
14 |
+
in_planes,
|
15 |
+
out_planes,
|
16 |
+
kernel_size=3,
|
17 |
+
stride=stride,
|
18 |
+
padding=1,
|
19 |
+
groups=groups,
|
20 |
+
bias=bias,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
def __init__(self, in_planes, planes, stride=1):
|
26 |
+
super().__init__()
|
27 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
28 |
+
self.conv2 = conv3x3(planes, planes)
|
29 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
30 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
|
33 |
+
if stride == 1:
|
34 |
+
self.downsample = None
|
35 |
+
else:
|
36 |
+
self.downsample = nn.Sequential(
|
37 |
+
conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes)
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
y = x
|
42 |
+
y = self.relu(self.bn1(self.conv1(y)))
|
43 |
+
y = self.bn2(self.conv2(y))
|
44 |
+
|
45 |
+
if self.downsample is not None:
|
46 |
+
x = self.downsample(x)
|
47 |
+
|
48 |
+
return self.relu(x + y)
|
49 |
+
|
50 |
+
|
51 |
+
class ResNet18(nn.Module):
|
52 |
+
"""
|
53 |
+
Fewer channels
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, config=None):
|
57 |
+
super().__init__()
|
58 |
+
# Config
|
59 |
+
block_dims = config["backbone"]["block_dims"]
|
60 |
+
|
61 |
+
# Networks
|
62 |
+
self.conv1 = nn.Conv2d(
|
63 |
+
1, block_dims[0], kernel_size=7, stride=2, padding=3, bias=False
|
64 |
+
)
|
65 |
+
self.bn1 = nn.BatchNorm2d(block_dims[0])
|
66 |
+
self.relu = nn.ReLU(inplace=True)
|
67 |
+
self.layer1 = self._make_layer(
|
68 |
+
BasicBlock, block_dims[0], block_dims[0], stride=1
|
69 |
+
) # 1/2
|
70 |
+
self.layer2 = self._make_layer(
|
71 |
+
BasicBlock, block_dims[0], block_dims[1], stride=2
|
72 |
+
) # 1/4
|
73 |
+
self.layer3 = self._make_layer(
|
74 |
+
BasicBlock, block_dims[1], block_dims[2], stride=2
|
75 |
+
) # 1/8
|
76 |
+
self.layer4 = self._make_layer(
|
77 |
+
BasicBlock, block_dims[2], block_dims[3], stride=2
|
78 |
+
) # 1/16
|
79 |
+
self.layer5 = self._make_layer(
|
80 |
+
BasicBlock, block_dims[3], block_dims[4], stride=2
|
81 |
+
) # 1/32
|
82 |
+
|
83 |
+
# For fine matching
|
84 |
+
self.fine_conv = nn.Sequential(
|
85 |
+
self._make_layer(
|
86 |
+
BasicBlock, block_dims[2], block_dims[2], stride=1),
|
87 |
+
conv1x1(block_dims[2], block_dims[4]),
|
88 |
+
nn.BatchNorm2d(block_dims[4]),
|
89 |
+
)
|
90 |
+
|
91 |
+
for m in self.modules():
|
92 |
+
if isinstance(m, nn.Conv2d):
|
93 |
+
nn.init.kaiming_normal_(
|
94 |
+
m.weight, mode="fan_out", nonlinearity="relu")
|
95 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
96 |
+
nn.init.constant_(m.weight, 1)
|
97 |
+
nn.init.constant_(m.bias, 0)
|
98 |
+
|
99 |
+
def _make_layer(self, block, in_dim, out_dim, stride=1):
|
100 |
+
layer1 = block(in_dim, out_dim, stride=stride)
|
101 |
+
layer2 = block(out_dim, out_dim, stride=1)
|
102 |
+
layers = (layer1, layer2)
|
103 |
+
|
104 |
+
return nn.Sequential(*layers)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
108 |
+
x1 = self.layer1(x0) # 1/2
|
109 |
+
x2 = self.layer2(x1) # 1/4
|
110 |
+
x3 = self.layer3(x2) # 1/8
|
111 |
+
x4 = self.layer4(x3) # 1/16
|
112 |
+
x5 = self.layer5(x4) # 1/32
|
113 |
+
|
114 |
+
xf = self.fine_conv(x3) # 1/8
|
115 |
+
|
116 |
+
return [x3, x4, x5, xf]
|
src/edm/edm.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils.misc import detect_NaN
|
2 |
+
from .head.fine_matching import FineMatching
|
3 |
+
from .head.coarse_matching import CoarseMatching
|
4 |
+
from .neck.neck import CIM
|
5 |
+
from .backbone.resnet import ResNet18
|
6 |
+
from einops.einops import rearrange
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch
|
10 |
+
torch.set_float32_matmul_precision("highest") # highest (defualt) high medium
|
11 |
+
|
12 |
+
|
13 |
+
class EDM(nn.Module):
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__()
|
16 |
+
# Misc
|
17 |
+
self.config = config
|
18 |
+
self.local_resolution = self.config["local_resolution"]
|
19 |
+
self.bi_directional_refine = self.config["fine"]["bi_directional_refine"]
|
20 |
+
self.deploy = self.config["deploy"]
|
21 |
+
self.topk = config["coarse"]["topk"]
|
22 |
+
|
23 |
+
# Modules
|
24 |
+
self.backbone = ResNet18(config)
|
25 |
+
self.neck = CIM(config)
|
26 |
+
self.coarse_matching = CoarseMatching(config)
|
27 |
+
self.fine_matching = FineMatching(config)
|
28 |
+
|
29 |
+
def forward(self, data):
|
30 |
+
"""
|
31 |
+
Update:
|
32 |
+
data (dict): {
|
33 |
+
'image0': (torch.Tensor): (N, 1, H, W)
|
34 |
+
'image1': (torch.Tensor): (N, 1, H, W)
|
35 |
+
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
|
36 |
+
'mask1'(optional) : (torch.Tensor): (N, H, W)
|
37 |
+
}
|
38 |
+
"""
|
39 |
+
if self.deploy:
|
40 |
+
image0, image1 = data.split(1, 1)
|
41 |
+
data = {"image0": image0, "image1": image1}
|
42 |
+
|
43 |
+
data.update(
|
44 |
+
{
|
45 |
+
"bs": data["image0"].size(0),
|
46 |
+
"hw0_i": data["image0"].shape[2:],
|
47 |
+
"hw1_i": data["image1"].shape[2:],
|
48 |
+
}
|
49 |
+
)
|
50 |
+
|
51 |
+
# 1. Feature Extraction
|
52 |
+
if data["hw0_i"] == data["hw1_i"]:
|
53 |
+
# faster & better BN convergence
|
54 |
+
feats = self.backbone(
|
55 |
+
torch.cat([data["image0"], data["image1"]], dim=0))
|
56 |
+
f8, f16, f32, f8_fine = feats
|
57 |
+
ms_feats = f8, f16, f32
|
58 |
+
feat_f0, feat_f1 = f8_fine.chunk(2)
|
59 |
+
else:
|
60 |
+
# handle different input shapes
|
61 |
+
# raise ValueError("image0 and image1 should have the same shape.")
|
62 |
+
feats0, feats1 = self.backbone(data["image0"]), self.backbone(
|
63 |
+
data["image1"]
|
64 |
+
)
|
65 |
+
f8_0, f16_0, f32_0, feat_f0 = feats0
|
66 |
+
f8_1, f16_1, f32_1, feat_f1 = feats1
|
67 |
+
ms_feats = f8_0, f16_0, f32_0, f8_1, f16_1, f32_1
|
68 |
+
|
69 |
+
mask_c0 = mask_c1 = None # mask is useful in training
|
70 |
+
if "mask0" in data:
|
71 |
+
mask_c0, mask_c1 = data["mask0"], data["mask1"]
|
72 |
+
|
73 |
+
# 2. Feature Interaction & Multi-Scale Fusion
|
74 |
+
feat_c0, feat_c1 = self.neck(ms_feats, mask_c0, mask_c1)
|
75 |
+
|
76 |
+
data.update(
|
77 |
+
{
|
78 |
+
"hw0_c": feat_c0.shape[2:],
|
79 |
+
"hw1_c": feat_c1.shape[2:],
|
80 |
+
"hw0_f": feat_c0.shape[2:] * self.config["local_resolution"],
|
81 |
+
"hw1_f": feat_c1.shape[2:] * self.config["local_resolution"],
|
82 |
+
}
|
83 |
+
)
|
84 |
+
feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
|
85 |
+
feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")
|
86 |
+
feat_f0 = rearrange(feat_f0, "n c h w -> n (h w) c")
|
87 |
+
feat_f1 = rearrange(feat_f1, "n c h w -> n (h w) c")
|
88 |
+
|
89 |
+
# detect NaN during mixed precision training
|
90 |
+
if self.config["mp"] and (
|
91 |
+
torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))
|
92 |
+
):
|
93 |
+
detect_NaN(feat_c0, feat_c1)
|
94 |
+
|
95 |
+
# 3. Coarse-Level Matching
|
96 |
+
conf_matrix = self.coarse_matching(
|
97 |
+
feat_c0,
|
98 |
+
feat_c1,
|
99 |
+
data,
|
100 |
+
mask_c0=(
|
101 |
+
mask_c0.view(mask_c0.size(0), -
|
102 |
+
1) if mask_c0 is not None else mask_c0
|
103 |
+
),
|
104 |
+
mask_c1=(
|
105 |
+
mask_c1.view(mask_c1.size(0), -
|
106 |
+
1) if mask_c1 is not None else mask_c1
|
107 |
+
),
|
108 |
+
)
|
109 |
+
|
110 |
+
if self.deploy:
|
111 |
+
k = self.topk
|
112 |
+
row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)
|
113 |
+
topk_val, topk_idx = torch.topk(row_max_val, k, dim=1)
|
114 |
+
|
115 |
+
b_ids = (
|
116 |
+
torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
|
117 |
+
.unsqueeze(1)
|
118 |
+
.repeat(1, k)
|
119 |
+
.flatten()
|
120 |
+
)
|
121 |
+
i_ids = topk_idx.flatten()
|
122 |
+
j_ids = row_max_idx[b_ids, i_ids].flatten()
|
123 |
+
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
124 |
+
|
125 |
+
scale = data["hw0_i"][0] / data["hw0_c"][0]
|
126 |
+
scale0 = scale * \
|
127 |
+
data["scale0"][b_ids] if "scale0" in data else scale
|
128 |
+
scale1 = scale * \
|
129 |
+
data["scale1"][b_ids] if "scale1" in data else scale
|
130 |
+
mkpts0_c = (
|
131 |
+
torch.stack(
|
132 |
+
[
|
133 |
+
i_ids % data["hw0_c"][1],
|
134 |
+
torch.div(i_ids, data["hw0_c"][1],
|
135 |
+
rounding_mode="floor"),
|
136 |
+
],
|
137 |
+
dim=1,
|
138 |
+
)
|
139 |
+
* scale0
|
140 |
+
)
|
141 |
+
mkpts1_c = (
|
142 |
+
torch.stack(
|
143 |
+
[
|
144 |
+
j_ids % data["hw1_c"][1],
|
145 |
+
torch.div(j_ids, data["hw1_c"][1],
|
146 |
+
rounding_mode="floor"),
|
147 |
+
],
|
148 |
+
dim=1,
|
149 |
+
)
|
150 |
+
* scale1
|
151 |
+
)
|
152 |
+
|
153 |
+
data.update(
|
154 |
+
{
|
155 |
+
"mconf": mconf,
|
156 |
+
"mkpts0_c": mkpts0_c,
|
157 |
+
"mkpts1_c": mkpts1_c,
|
158 |
+
"b_ids": b_ids,
|
159 |
+
"i_ids": i_ids,
|
160 |
+
"j_ids": j_ids,
|
161 |
+
}
|
162 |
+
)
|
163 |
+
|
164 |
+
# 4. Fine-Level Matching
|
165 |
+
K0 = data["i_ids"].shape[0] // data["bs"]
|
166 |
+
K1 = data["j_ids"].shape[0] // data["bs"]
|
167 |
+
feat_f0 = feat_f0[data["b_ids"], data["i_ids"]
|
168 |
+
].reshape(data["bs"], K0, -1)
|
169 |
+
feat_f1 = feat_f1[data["b_ids"], data["j_ids"]
|
170 |
+
].reshape(data["bs"], K1, -1)
|
171 |
+
feat_c0 = feat_c0[data["b_ids"], data["i_ids"]
|
172 |
+
].reshape(data["bs"], K0, -1)
|
173 |
+
feat_c1 = feat_c1[data["b_ids"], data["j_ids"]
|
174 |
+
].reshape(data["bs"], K1, -1)
|
175 |
+
|
176 |
+
if self.bi_directional_refine:
|
177 |
+
# Bidirectional Refinement
|
178 |
+
offset, score = self.fine_matching(
|
179 |
+
torch.cat([feat_f0, feat_f1], dim=1),
|
180 |
+
torch.cat([feat_f1, feat_f0], dim=1),
|
181 |
+
torch.cat([feat_c0, feat_c1], dim=1),
|
182 |
+
torch.cat([feat_c1, feat_c0], dim=1),
|
183 |
+
data,
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
offset, score = self.fine_matching(
|
187 |
+
feat_f0, feat_f1, feat_c0, feat_c1, data)
|
188 |
+
|
189 |
+
if self.deploy:
|
190 |
+
if self.bi_directional_refine:
|
191 |
+
fine_offset01, fine_offset10 = offset.chunk(2)
|
192 |
+
fine_score01, fine_score10 = score.unsqueeze(dim=1).chunk(2)
|
193 |
+
output = torch.cat(
|
194 |
+
[mkpts0_c, mkpts1_c, fine_offset01, fine_offset10, fine_score01, fine_score10, mconf.unsqueeze(dim=1)], 1) # [K, 11]
|
195 |
+
else:
|
196 |
+
output = torch.cat(
|
197 |
+
[mkpts0_c, mkpts1_c, offset, score, mconf.unsqueeze(dim=1)], 1)
|
198 |
+
return output
|
199 |
+
|
200 |
+
def load_state_dict(self, state_dict, *args, **kwargs):
|
201 |
+
for k in list(state_dict.keys()):
|
202 |
+
if k.startswith("matcher."):
|
203 |
+
state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
|
204 |
+
return super().load_state_dict(state_dict, *args, **kwargs)
|
src/edm/head/coarse_matching.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
INF = 1e9 # -1e4 for fp16matmul
|
7 |
+
|
8 |
+
|
9 |
+
class CoarseMatching(nn.Module):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__()
|
12 |
+
self.window_size = config["local_resolution"]
|
13 |
+
self.thr = config["coarse"]["mconf_thr"]
|
14 |
+
self.temperature = config["coarse"]["dsmax_temperature"]
|
15 |
+
self.ds_opt = config["coarse"]["ds_opt"]
|
16 |
+
self.pad_num = config["coarse"]["train_pad_num"]
|
17 |
+
self.topk = config["coarse"]["topk"]
|
18 |
+
self.deploy = config["deploy"]
|
19 |
+
|
20 |
+
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
feat0 (torch.Tensor): [N, L, C]
|
24 |
+
feat1 (torch.Tensor): [N, S, C]
|
25 |
+
data (dict)
|
26 |
+
mask_c0 (torch.Tensor): [N, L] (optional)
|
27 |
+
mask_c1 (torch.Tensor): [N, S] (optional)
|
28 |
+
Update:
|
29 |
+
data (dict): {
|
30 |
+
'b_ids' (torch.Tensor): [M'],
|
31 |
+
'i_ids' (torch.Tensor): [M'],
|
32 |
+
'j_ids' (torch.Tensor): [M'],
|
33 |
+
'm_bids' (torch.Tensor): [M],
|
34 |
+
'mkpts0_c' (torch.Tensor): [M, 2],
|
35 |
+
'mkpts1_c' (torch.Tensor): [M, 2],
|
36 |
+
'mconf' (torch.Tensor): [M]}
|
37 |
+
NOTE: M' != M during training.
|
38 |
+
"""
|
39 |
+
# normalize
|
40 |
+
feat_c0, feat_c1 = map(
|
41 |
+
lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
|
42 |
+
)
|
43 |
+
|
44 |
+
with torch.autocast(enabled=False, device_type="cuda"):
|
45 |
+
sim_matrix = (
|
46 |
+
torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) /
|
47 |
+
self.temperature
|
48 |
+
)
|
49 |
+
del feat_c0, feat_c1
|
50 |
+
if mask_c0 is not None:
|
51 |
+
sim_matrix = sim_matrix.float().masked_fill(
|
52 |
+
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
53 |
+
-INF,
|
54 |
+
)
|
55 |
+
|
56 |
+
if not self.training and self.ds_opt:
|
57 |
+
# Alternative implementation of daul-softmax operator for efficient inference
|
58 |
+
sim_matrix = torch.exp(sim_matrix)
|
59 |
+
conf_matrix = F.normalize(sim_matrix, p=1, dim=1) * F.normalize(
|
60 |
+
sim_matrix, p=1, dim=2
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
# Native daul-softmax operator
|
64 |
+
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
65 |
+
|
66 |
+
data.update(
|
67 |
+
{
|
68 |
+
"conf_matrix": conf_matrix,
|
69 |
+
}
|
70 |
+
)
|
71 |
+
|
72 |
+
if not self.deploy:
|
73 |
+
# predict coarse matches from conf_matrix
|
74 |
+
self.coarse_matching_selection(data)
|
75 |
+
|
76 |
+
return conf_matrix # Returning the sim_matrix can be faster, but it may reduce the accuracy.
|
77 |
+
|
78 |
+
# Static tensor shape for mini-batch inference.
|
79 |
+
@torch.no_grad()
|
80 |
+
def coarse_matching_selection(self, data):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
|
84 |
+
Returns:
|
85 |
+
coarse_matches (dict): {
|
86 |
+
'b_ids' (torch.Tensor): [M'],
|
87 |
+
'i_ids' (torch.Tensor): [M'],
|
88 |
+
'j_ids' (torch.Tensor): [M'],
|
89 |
+
'm_bids' (torch.Tensor): [M],
|
90 |
+
'mkpts0_c' (torch.Tensor): [M, 2],
|
91 |
+
'mkpts1_c' (torch.Tensor): [M, 2],
|
92 |
+
'mconf' (torch.Tensor): [M]}
|
93 |
+
"""
|
94 |
+
conf_matrix = data["conf_matrix"]
|
95 |
+
|
96 |
+
# mutual nearest
|
97 |
+
# mask = (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
|
98 |
+
# * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
99 |
+
# conf_matrix[~mask]=0
|
100 |
+
|
101 |
+
k = self.topk
|
102 |
+
row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)
|
103 |
+
|
104 |
+
# prevent out of range
|
105 |
+
if k == -1 or k > row_max_val.shape[-1]:
|
106 |
+
k = row_max_val.shape[-1]
|
107 |
+
|
108 |
+
topk_val, topk_idx = torch.topk(row_max_val, k)
|
109 |
+
b_ids = (
|
110 |
+
torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
|
111 |
+
.unsqueeze(1)
|
112 |
+
.repeat(1, k)
|
113 |
+
.flatten()
|
114 |
+
)
|
115 |
+
i_ids = topk_idx.flatten()
|
116 |
+
j_ids = row_max_idx[b_ids, i_ids].flatten()
|
117 |
+
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
118 |
+
|
119 |
+
scale = data["hw0_i"][0] / data["hw0_c"][0]
|
120 |
+
scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
|
121 |
+
scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
|
122 |
+
mkpts0_c = (
|
123 |
+
torch.stack(
|
124 |
+
[
|
125 |
+
i_ids % data["hw0_c"][1],
|
126 |
+
torch.div(i_ids, data["hw0_c"][1], rounding_mode="floor"),
|
127 |
+
],
|
128 |
+
dim=1,
|
129 |
+
)
|
130 |
+
* scale0
|
131 |
+
)
|
132 |
+
mkpts1_c = (
|
133 |
+
torch.stack(
|
134 |
+
[
|
135 |
+
j_ids % data["hw1_c"][1],
|
136 |
+
torch.div(j_ids, data["hw1_c"][1], rounding_mode="floor"),
|
137 |
+
],
|
138 |
+
dim=1,
|
139 |
+
)
|
140 |
+
* scale1
|
141 |
+
)
|
142 |
+
|
143 |
+
data.update(
|
144 |
+
{
|
145 |
+
"mconf": mconf,
|
146 |
+
"mkpts0_c": mkpts0_c,
|
147 |
+
"mkpts1_c": mkpts1_c,
|
148 |
+
"b_ids": b_ids,
|
149 |
+
"i_ids": i_ids,
|
150 |
+
"j_ids": j_ids,
|
151 |
+
}
|
152 |
+
)
|
src/edm/head/fine_matching.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import distributions
|
5 |
+
|
6 |
+
|
7 |
+
class Conv1d_BN_Act(nn.Sequential):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
a,
|
11 |
+
b,
|
12 |
+
ks=1,
|
13 |
+
stride=1,
|
14 |
+
pad=0,
|
15 |
+
dilation=1,
|
16 |
+
groups=1,
|
17 |
+
bn_weight_init=1,
|
18 |
+
act=None,
|
19 |
+
drop=None,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.inp_channel = a
|
23 |
+
self.out_channel = b
|
24 |
+
self.ks = ks
|
25 |
+
self.pad = pad
|
26 |
+
self.stride = stride
|
27 |
+
self.dilation = dilation
|
28 |
+
self.groups = groups
|
29 |
+
|
30 |
+
self.add_module(
|
31 |
+
"c", nn.Conv1d(a, b, ks, stride, pad, dilation, groups, bias=False)
|
32 |
+
)
|
33 |
+
bn = nn.BatchNorm1d(b)
|
34 |
+
nn.init.constant_(bn.weight, bn_weight_init)
|
35 |
+
nn.init.constant_(bn.bias, 0)
|
36 |
+
self.add_module("bn", bn)
|
37 |
+
if act != None:
|
38 |
+
self.add_module("a", act)
|
39 |
+
if drop != None:
|
40 |
+
self.add_module("d", nn.Dropout(drop))
|
41 |
+
|
42 |
+
|
43 |
+
class RealNVP(nn.Module):
|
44 |
+
"""RealNVP: a flow-based generative model
|
45 |
+
|
46 |
+
`Density estimation using Real NVP
|
47 |
+
arXiv: <https://arxiv.org/abs/1605.08803>`_.
|
48 |
+
|
49 |
+
Code is modified from `the mmpose implementation of RLE
|
50 |
+
<https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/utils/realnvp.py>`_.
|
51 |
+
"""
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def get_scale_net(channel):
|
55 |
+
"""Get the scale model in a single invertable mapping."""
|
56 |
+
return nn.Sequential(
|
57 |
+
nn.Linear(2, channel),
|
58 |
+
nn.GELU(),
|
59 |
+
nn.Linear(channel, channel),
|
60 |
+
nn.GELU(),
|
61 |
+
nn.Linear(channel, 2),
|
62 |
+
nn.Tanh(),
|
63 |
+
)
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def get_trans_net(channel):
|
67 |
+
"""Get the translation model in a single invertable mapping."""
|
68 |
+
return nn.Sequential(
|
69 |
+
nn.Linear(2, channel),
|
70 |
+
nn.GELU(),
|
71 |
+
nn.Linear(channel, channel),
|
72 |
+
nn.GELU(),
|
73 |
+
nn.Linear(channel, 2),
|
74 |
+
)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def prior(self):
|
78 |
+
"""The prior distribution."""
|
79 |
+
return distributions.MultivariateNormal(self.loc, self.cov)
|
80 |
+
|
81 |
+
def __init__(self, channel=64):
|
82 |
+
super(RealNVP, self).__init__()
|
83 |
+
self.channel = channel
|
84 |
+
self.register_buffer("loc", torch.zeros(2))
|
85 |
+
self.register_buffer("cov", torch.eye(2))
|
86 |
+
self.register_buffer(
|
87 |
+
"mask", torch.tensor([[0, 1], [1, 0]] * 3, dtype=torch.float32)
|
88 |
+
)
|
89 |
+
|
90 |
+
self.s = torch.nn.ModuleList(
|
91 |
+
[self.get_scale_net(self.channel) for _ in range(len(self.mask))]
|
92 |
+
)
|
93 |
+
self.t = torch.nn.ModuleList(
|
94 |
+
[self.get_trans_net(self.channel) for _ in range(len(self.mask))]
|
95 |
+
)
|
96 |
+
self.init_weights()
|
97 |
+
|
98 |
+
def init_weights(self):
|
99 |
+
"""Initialization model weights."""
|
100 |
+
for m in self.modules():
|
101 |
+
if isinstance(m, nn.Linear):
|
102 |
+
nn.init.xavier_uniform_(m.weight, gain=0.01)
|
103 |
+
|
104 |
+
def backward_p(self, x):
|
105 |
+
"""Apply mapping form the data space to the latent space and calculate
|
106 |
+
the log determinant of the Jacobian matrix."""
|
107 |
+
|
108 |
+
log_det_jacob, z = x.new_zeros(x.shape[0]), x
|
109 |
+
for i in reversed(range(len(self.t))):
|
110 |
+
z_ = self.mask[i] * z
|
111 |
+
s = self.s[i](z_) * (1 - self.mask[i]) # torch.exp(s): betas
|
112 |
+
t = self.t[i](z_) * (1 - self.mask[i]) # gammas
|
113 |
+
z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
|
114 |
+
log_det_jacob -= s.sum(dim=1)
|
115 |
+
return z, log_det_jacob
|
116 |
+
|
117 |
+
def log_prob(self, x):
|
118 |
+
"""Calculate the log probability of given sample in data space."""
|
119 |
+
|
120 |
+
z, log_det = self.backward_p(x)
|
121 |
+
return self.prior.log_prob(z) + log_det
|
122 |
+
|
123 |
+
|
124 |
+
def soft_argmax(x, temperature=1.0):
|
125 |
+
L = x.shape[1]
|
126 |
+
assert L % 2 # L is odd to ensure symmetry
|
127 |
+
idx = torch.arange(0, L, 1, device=x.device).repeat(x.shape[0], 1)
|
128 |
+
scale_x = x / temperature
|
129 |
+
out = F.softmax(scale_x, dim=1) * idx
|
130 |
+
out = out.sum(dim=1, keepdim=True)
|
131 |
+
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
class FineMatching(nn.Module):
|
136 |
+
def __init__(self, config, act_layer=nn.GELU):
|
137 |
+
super(FineMatching, self).__init__()
|
138 |
+
self.config = config
|
139 |
+
|
140 |
+
self.block_dims = self.config["backbone"]["block_dims"]
|
141 |
+
self.local_resolution = self.config["local_resolution"]
|
142 |
+
self.drop = self.config["fine"]["droprate"]
|
143 |
+
self.coord_length = self.config["fine"]["coord_length"]
|
144 |
+
self.bi_directional_refine = self.config["fine"]["bi_directional_refine"]
|
145 |
+
self.sigma_selection = self.config["fine"]["sigma_selection"]
|
146 |
+
self.mconf_thr = self.config["coarse"]["mconf_thr"]
|
147 |
+
self.sigma_thr = self.config["fine"]["sigma_thr"]
|
148 |
+
self.border_rm = self.config["coarse"]["border_rm"] * \
|
149 |
+
self.local_resolution
|
150 |
+
self.deploy = self.config["deploy"]
|
151 |
+
|
152 |
+
# network
|
153 |
+
self.query_encoder = nn.Sequential(
|
154 |
+
Conv1d_BN_Act(
|
155 |
+
self.block_dims[-1],
|
156 |
+
self.block_dims[-1],
|
157 |
+
act=act_layer(),
|
158 |
+
drop=self.drop,
|
159 |
+
),
|
160 |
+
Conv1d_BN_Act(
|
161 |
+
self.block_dims[-1],
|
162 |
+
self.block_dims[-1],
|
163 |
+
act=act_layer(),
|
164 |
+
drop=self.drop,
|
165 |
+
),
|
166 |
+
)
|
167 |
+
|
168 |
+
self.reference_encoder = nn.Sequential(
|
169 |
+
Conv1d_BN_Act(
|
170 |
+
self.block_dims[-1],
|
171 |
+
self.block_dims[-1],
|
172 |
+
act=act_layer(),
|
173 |
+
drop=self.drop,
|
174 |
+
),
|
175 |
+
Conv1d_BN_Act(
|
176 |
+
self.block_dims[-1],
|
177 |
+
self.block_dims[-1],
|
178 |
+
act=act_layer(),
|
179 |
+
drop=self.drop,
|
180 |
+
),
|
181 |
+
)
|
182 |
+
|
183 |
+
self.merge_qr = nn.Sequential(
|
184 |
+
Conv1d_BN_Act(
|
185 |
+
self.block_dims[-1] * 2,
|
186 |
+
self.block_dims[-1] * 2,
|
187 |
+
act=act_layer(),
|
188 |
+
drop=self.drop,
|
189 |
+
),
|
190 |
+
Conv1d_BN_Act(
|
191 |
+
self.block_dims[-1] * 2,
|
192 |
+
self.block_dims[-1] * 2,
|
193 |
+
act=act_layer(),
|
194 |
+
drop=self.drop,
|
195 |
+
),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.x_head = nn.Sequential(
|
199 |
+
Conv1d_BN_Act(
|
200 |
+
self.block_dims[-1] * 2,
|
201 |
+
self.block_dims[-1] * 2,
|
202 |
+
act=act_layer(),
|
203 |
+
drop=self.drop,
|
204 |
+
),
|
205 |
+
nn.Conv1d(self.block_dims[-1] * 2,
|
206 |
+
self.coord_length + 2, kernel_size=1),
|
207 |
+
)
|
208 |
+
|
209 |
+
self.y_head = nn.Sequential(
|
210 |
+
Conv1d_BN_Act(
|
211 |
+
self.block_dims[-1] * 2,
|
212 |
+
self.block_dims[-1] * 2,
|
213 |
+
act=act_layer(),
|
214 |
+
drop=self.drop,
|
215 |
+
),
|
216 |
+
nn.Conv1d(self.block_dims[-1] * 2,
|
217 |
+
self.coord_length + 2, kernel_size=1),
|
218 |
+
)
|
219 |
+
|
220 |
+
self.flow = RealNVP()
|
221 |
+
self.init_params()
|
222 |
+
|
223 |
+
def init_params(self):
|
224 |
+
for m in self.modules():
|
225 |
+
if isinstance(m, nn.Conv1d):
|
226 |
+
nn.init.kaiming_normal_(m.weight)
|
227 |
+
if m.bias is not None:
|
228 |
+
nn.init.constant_(m.bias, 0)
|
229 |
+
|
230 |
+
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data={}):
|
231 |
+
q = self.query_encoder(
|
232 |
+
feat_f0.permute(0, 2, 1).contiguous()
|
233 |
+
+ feat_c0.permute(0, 2, 1).contiguous()
|
234 |
+
)
|
235 |
+
r = self.reference_encoder(
|
236 |
+
feat_f1.permute(0, 2, 1).contiguous()
|
237 |
+
+ feat_c1.permute(0, 2, 1).contiguous()
|
238 |
+
)
|
239 |
+
out = self.merge_qr(torch.cat([q, r], dim=1))
|
240 |
+
|
241 |
+
x = self.x_head(out).permute(0, 2, 1).contiguous()
|
242 |
+
y = self.y_head(out).permute(0, 2, 1).contiguous()
|
243 |
+
|
244 |
+
if self.bi_directional_refine:
|
245 |
+
x01, x10 = x.chunk(2, dim=1)
|
246 |
+
x01 = x01.reshape(-1, self.coord_length + 2)
|
247 |
+
x10 = x10.reshape(-1, self.coord_length + 2)
|
248 |
+
x_out = torch.cat([x01, x10])
|
249 |
+
|
250 |
+
y01, y10 = y.chunk(2, dim=1)
|
251 |
+
y01 = y01.reshape(-1, self.coord_length + 2)
|
252 |
+
y10 = y10.reshape(-1, self.coord_length + 2)
|
253 |
+
y_out = torch.cat([y01, y10])
|
254 |
+
else:
|
255 |
+
x_out = x.reshape(-1, self.coord_length + 2)
|
256 |
+
y_out = y.reshape(-1, self.coord_length + 2)
|
257 |
+
|
258 |
+
x_cls = x_out[:, : self.coord_length + 1]
|
259 |
+
coord_x = soft_argmax(x_cls) / self.coord_length - \
|
260 |
+
0.5 # range [-0.5, +0.5]
|
261 |
+
x_sigma = x_out[:, -1:].sigmoid()
|
262 |
+
|
263 |
+
y_cls = y_out[:, : self.coord_length + 1]
|
264 |
+
coord_y = soft_argmax(y_cls) / self.coord_length - 0.5
|
265 |
+
y_sigma = y_out[:, -1:].sigmoid()
|
266 |
+
|
267 |
+
coord = torch.cat([coord_x, coord_y], dim=1)
|
268 |
+
sigma = torch.cat([x_sigma, y_sigma], dim=1)
|
269 |
+
|
270 |
+
if data.get("target_uv", None) is not None:
|
271 |
+
gt_uv = data["target_uv"]
|
272 |
+
mask = data["target_uv_weight"].clone()
|
273 |
+
|
274 |
+
if mask.sum() == 0:
|
275 |
+
mask[0] = True
|
276 |
+
mask_coord = coord[mask]
|
277 |
+
mask_gt_uv = gt_uv[mask]
|
278 |
+
mask_sigma = sigma[mask]
|
279 |
+
|
280 |
+
mask_sigma = torch.clamp(mask_sigma, 1e-6, 1 - 1e-6)
|
281 |
+
bar_mu = (mask_coord - mask_gt_uv) / mask_sigma
|
282 |
+
|
283 |
+
log_phi = self.flow.log_prob(bar_mu).unsqueeze(-1)
|
284 |
+
nf_loss = torch.log(mask_sigma) - log_phi
|
285 |
+
|
286 |
+
data.update(
|
287 |
+
{
|
288 |
+
"pred_coord": coord,
|
289 |
+
"pred_score": 1.0 - torch.mean(sigma, dim=-1).flatten(),
|
290 |
+
"mask_coord": mask_coord,
|
291 |
+
"mask_sigma": mask_sigma,
|
292 |
+
"nf_loss": nf_loss,
|
293 |
+
}
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
data.update(
|
297 |
+
{
|
298 |
+
"pred_coord": coord,
|
299 |
+
"pred_score": 1.0 - torch.mean(sigma, dim=-1).flatten(),
|
300 |
+
}
|
301 |
+
)
|
302 |
+
|
303 |
+
if not self.deploy:
|
304 |
+
self.final_matching_selection(data)
|
305 |
+
|
306 |
+
return data["pred_coord"], data["pred_score"]
|
307 |
+
|
308 |
+
@torch.no_grad()
|
309 |
+
def final_matching_selection(self, data):
|
310 |
+
offset = data["pred_coord"] * self.local_resolution
|
311 |
+
|
312 |
+
if self.bi_directional_refine:
|
313 |
+
fine_offset01, fine_offset10 = torch.clamp(
|
314 |
+
offset, -self.local_resolution / 2, self.local_resolution / 2
|
315 |
+
).chunk(2)
|
316 |
+
else:
|
317 |
+
fine_offset01 = torch.clamp(
|
318 |
+
offset, -self.local_resolution / 2, self.local_resolution / 2
|
319 |
+
)
|
320 |
+
|
321 |
+
h0, w0 = data["hw0_i"]
|
322 |
+
h1, w1 = data["hw1_i"]
|
323 |
+
scale0 = data["scale0"][data["b_ids"]] if "scale0" in data else 1.0
|
324 |
+
scale1 = data["scale1"][data["b_ids"]] if "scale1" in data else 1.0
|
325 |
+
scale0_w = scale0[:, 0] if "scale0" in data else 1.0
|
326 |
+
scale0_h = scale0[:, 1] if "scale0" in data else 1.0
|
327 |
+
scale1_w = scale1[:, 0] if "scale1" in data else 1.0
|
328 |
+
scale1_h = scale1[:, 1] if "scale1" in data else 1.0
|
329 |
+
|
330 |
+
# Filter by mconf and border
|
331 |
+
mkpts0_f = data["mkpts0_c"]
|
332 |
+
mkpts1_f = data["mkpts1_c"] + fine_offset01 * scale1
|
333 |
+
mask = (
|
334 |
+
(data["mconf"] > self.mconf_thr)
|
335 |
+
& (mkpts0_f[:, 0] >= self.border_rm)
|
336 |
+
& (mkpts0_f[:, 0] <= w0 * scale0_w - self.border_rm)
|
337 |
+
& (mkpts0_f[:, 1] >= self.border_rm)
|
338 |
+
& (mkpts0_f[:, 1] <= h0 * scale0_h - self.border_rm)
|
339 |
+
& (mkpts1_f[:, 0] >= self.border_rm)
|
340 |
+
& (mkpts1_f[:, 0] <= w1 * scale1_w - self.border_rm)
|
341 |
+
& (mkpts1_f[:, 1] >= self.border_rm)
|
342 |
+
& (mkpts1_f[:, 1] <= h1 * scale1_h - self.border_rm)
|
343 |
+
)
|
344 |
+
if self.bi_directional_refine:
|
345 |
+
mkpts0_f_ = data["mkpts0_c"] + fine_offset10 * scale0
|
346 |
+
mkpts1_f_ = data["mkpts1_c"]
|
347 |
+
mask_ = (
|
348 |
+
(data["mconf"] > self.mconf_thr)
|
349 |
+
& (mkpts0_f_[:, 0] >= self.border_rm)
|
350 |
+
& (mkpts0_f_[:, 0] <= w0 * scale0_w - self.border_rm)
|
351 |
+
& (mkpts0_f_[:, 1] >= self.border_rm)
|
352 |
+
& (mkpts0_f_[:, 1] <= h0 * scale0_h - self.border_rm)
|
353 |
+
& (mkpts1_f_[:, 0] >= self.border_rm)
|
354 |
+
& (mkpts1_f_[:, 0] <= w1 * scale1_w - self.border_rm)
|
355 |
+
& (mkpts1_f_[:, 1] >= self.border_rm)
|
356 |
+
& (mkpts1_f_[:, 1] <= h1 * scale1_h - self.border_rm)
|
357 |
+
)
|
358 |
+
|
359 |
+
if self.bi_directional_refine:
|
360 |
+
mkpts0_f = torch.cat([mkpts0_f, mkpts0_f_])
|
361 |
+
mkpts1_f = torch.cat([mkpts1_f, mkpts1_f_])
|
362 |
+
mask = torch.cat([mask, mask_])
|
363 |
+
data["mconf"] = torch.cat([data["mconf"], data["mconf"]])
|
364 |
+
data["b_ids"] = torch.cat([data["b_ids"], data["b_ids"]])
|
365 |
+
|
366 |
+
# Filter by sigma
|
367 |
+
if self.bi_directional_refine and self.sigma_selection:
|
368 |
+
# Retain the more confident matching pair with a smaller sigma (more significant) in the bi-directional matching pairs
|
369 |
+
pred_score01, pred_score10 = data["pred_score"].chunk(2)
|
370 |
+
pred_score_mask = pred_score01 > pred_score10
|
371 |
+
pred_score_mask = torch.cat([pred_score_mask, ~pred_score_mask])
|
372 |
+
pred_score_mask &= data["pred_score"] > self.sigma_thr
|
373 |
+
mask &= pred_score_mask
|
374 |
+
|
375 |
+
data.update(
|
376 |
+
{
|
377 |
+
# "gt_mask": data["mconf"] == 0,
|
378 |
+
"m_bids": data["b_ids"][mask],
|
379 |
+
"mkpts0_f": mkpts0_f[mask],
|
380 |
+
"mkpts1_f": mkpts1_f[mask],
|
381 |
+
"mconf": data["mconf"][mask],
|
382 |
+
}
|
383 |
+
)
|
src/edm/neck/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .loftr_module import LocalFeatureTransformer
|
src/edm/neck/loftr_module/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .transformer import LocalFeatureTransformer
|
src/edm/neck/loftr_module/transformer.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Module
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops.einops import rearrange
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
|
13 |
+
|
14 |
+
class RoPEPositionEncodingSine(nn.Module):
|
15 |
+
"""
|
16 |
+
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, d_model, max_shape=(128, 128), npe=None, ropefp16=True):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
max_shape (tuple): for 1/32 featmap, the max length of 128 corresponds to 4096 pixels
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
i_position = torch.ones(max_shape).cumsum(
|
27 |
+
0).float().unsqueeze(-1) # [H, 1]
|
28 |
+
j_position = torch.ones(max_shape).cumsum(
|
29 |
+
1).float().unsqueeze(-1) # [W, 1]
|
30 |
+
|
31 |
+
assert npe is not None
|
32 |
+
train_res_H, train_res_W, test_res_H, test_res_W = (
|
33 |
+
npe[0],
|
34 |
+
npe[1],
|
35 |
+
npe[2],
|
36 |
+
npe[3],
|
37 |
+
)
|
38 |
+
i_position, j_position = (
|
39 |
+
i_position * train_res_H / test_res_H,
|
40 |
+
j_position * train_res_W / test_res_W,
|
41 |
+
)
|
42 |
+
|
43 |
+
div_term = torch.exp(
|
44 |
+
torch.arange(0, d_model // 4, 1).float()
|
45 |
+
* (-math.log(10000.0) / (d_model // 4))
|
46 |
+
)
|
47 |
+
div_term = div_term[None, None, :] # [1, 1, C//4]
|
48 |
+
|
49 |
+
sin = torch.zeros(
|
50 |
+
*max_shape, d_model // 2, dtype=torch.float16 if ropefp16 else torch.float32
|
51 |
+
)
|
52 |
+
cos = torch.zeros(
|
53 |
+
*max_shape, d_model // 2, dtype=torch.float16 if ropefp16 else torch.float32
|
54 |
+
)
|
55 |
+
|
56 |
+
sin[:, :, 0::2] = (
|
57 |
+
torch.sin(i_position * div_term).half()
|
58 |
+
if ropefp16
|
59 |
+
else torch.sin(i_position * div_term)
|
60 |
+
)
|
61 |
+
sin[:, :, 1::2] = (
|
62 |
+
torch.sin(j_position * div_term).half()
|
63 |
+
if ropefp16
|
64 |
+
else torch.sin(j_position * div_term)
|
65 |
+
)
|
66 |
+
cos[:, :, 0::2] = (
|
67 |
+
torch.cos(i_position * div_term).half()
|
68 |
+
if ropefp16
|
69 |
+
else torch.cos(i_position * div_term)
|
70 |
+
)
|
71 |
+
cos[:, :, 1::2] = (
|
72 |
+
torch.cos(j_position * div_term).half()
|
73 |
+
if ropefp16
|
74 |
+
else torch.cos(j_position * div_term)
|
75 |
+
)
|
76 |
+
|
77 |
+
sin = sin.repeat_interleave(2, dim=-1)
|
78 |
+
cos = cos.repeat_interleave(2, dim=-1)
|
79 |
+
|
80 |
+
self.register_buffer(
|
81 |
+
"sin", sin.unsqueeze(0), persistent=False
|
82 |
+
) # [1, H, W, C//2]
|
83 |
+
self.register_buffer(
|
84 |
+
"cos", cos.unsqueeze(0), persistent=False
|
85 |
+
) # [1, H, W, C//2]
|
86 |
+
|
87 |
+
def forward(self, x, ratio=1):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
x: [N, H, W, C]
|
91 |
+
"""
|
92 |
+
return (x * self.cos[:, : x.size(1), : x.size(2), :]) + (
|
93 |
+
self.rotate_half(x) * self.sin[:, : x.size(1), : x.size(2), :]
|
94 |
+
)
|
95 |
+
|
96 |
+
def rotate_half(self, x):
|
97 |
+
# x = x.unflatten(-1, (-1, 2))
|
98 |
+
a, b, c, d = x.shape
|
99 |
+
x = x.reshape(a, b, c, d // 2, 2)
|
100 |
+
|
101 |
+
x1, x2 = x.unbind(dim=-1)
|
102 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
103 |
+
|
104 |
+
|
105 |
+
"""
|
106 |
+
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
107 |
+
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
108 |
+
"""
|
109 |
+
|
110 |
+
|
111 |
+
def crop_feature(query, key, value, x_mask, source_mask):
|
112 |
+
mask_h0, mask_w0, mask_h1, mask_w1 = (
|
113 |
+
x_mask[0].sum(-2)[0],
|
114 |
+
x_mask[0].sum(-1)[0],
|
115 |
+
source_mask[0].sum(-2)[0],
|
116 |
+
source_mask[0].sum(-1)[0],
|
117 |
+
)
|
118 |
+
query = query[:, :mask_h0, :mask_w0, :]
|
119 |
+
key = key[:, :mask_h1, :mask_w1, :]
|
120 |
+
value = value[:, :mask_h1, :mask_w1, :]
|
121 |
+
return query, key, value, mask_h0, mask_w0
|
122 |
+
|
123 |
+
|
124 |
+
def pad_feature(m, mask_h0, mask_w0, x_mask):
|
125 |
+
bs, hw, nhead, dim = m.shape
|
126 |
+
m = m.view(bs, mask_h0, mask_w0, nhead, dim)
|
127 |
+
if mask_h0 != x_mask.size(-2):
|
128 |
+
m = torch.cat(
|
129 |
+
[
|
130 |
+
m,
|
131 |
+
torch.zeros(
|
132 |
+
m.size(0),
|
133 |
+
x_mask.size(-2) - mask_h0,
|
134 |
+
x_mask.size(-1),
|
135 |
+
nhead,
|
136 |
+
dim,
|
137 |
+
device=m.device,
|
138 |
+
dtype=m.dtype,
|
139 |
+
),
|
140 |
+
],
|
141 |
+
dim=1,
|
142 |
+
)
|
143 |
+
elif mask_w0 != x_mask.size(-1):
|
144 |
+
m = torch.cat(
|
145 |
+
[
|
146 |
+
m,
|
147 |
+
torch.zeros(
|
148 |
+
m.size(0),
|
149 |
+
x_mask.size(-2),
|
150 |
+
x_mask.size(-1) - mask_w0,
|
151 |
+
nhead,
|
152 |
+
dim,
|
153 |
+
device=m.device,
|
154 |
+
dtype=m.dtype,
|
155 |
+
),
|
156 |
+
],
|
157 |
+
dim=2,
|
158 |
+
)
|
159 |
+
return m
|
160 |
+
|
161 |
+
|
162 |
+
class Attention(Module):
|
163 |
+
def __init__(self, nhead=8, dim=256, re=False):
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
self.nhead = nhead
|
167 |
+
self.dim = dim
|
168 |
+
|
169 |
+
def attention(self, query, key, value, q_mask=None, kv_mask=None):
|
170 |
+
assert (
|
171 |
+
q_mask is None and kv_mask is None
|
172 |
+
), "Not support generalized attention mask yet."
|
173 |
+
# Scaled Cosine Attention
|
174 |
+
# Refer to "Query-key normalization for transformers" and "https://kexue.fm/archives/9859"
|
175 |
+
query = F.normalize(query, p=2, dim=3)
|
176 |
+
key = F.normalize(key, p=2, dim=3)
|
177 |
+
QK = torch.einsum("nlhd,nshd->nlsh", query, key)
|
178 |
+
s = 20.0
|
179 |
+
A = torch.softmax(s * QK, dim=2)
|
180 |
+
|
181 |
+
out = torch.einsum("nlsh,nshd->nlhd", A, value)
|
182 |
+
return out
|
183 |
+
|
184 |
+
def _forward(self, query, key, value, q_mask=None, kv_mask=None):
|
185 |
+
if q_mask is not None:
|
186 |
+
query, key, value, mask_h0, mask_w0 = crop_feature(
|
187 |
+
query, key, value, q_mask, kv_mask
|
188 |
+
)
|
189 |
+
|
190 |
+
query, key, value = map(
|
191 |
+
lambda x: rearrange(
|
192 |
+
x,
|
193 |
+
"n h w (nhead d) -> n (h w) nhead d",
|
194 |
+
nhead=self.nhead,
|
195 |
+
d=self.dim,
|
196 |
+
),
|
197 |
+
[query, key, value],
|
198 |
+
)
|
199 |
+
|
200 |
+
m = self.attention(query, key, value, q_mask=None, kv_mask=None)
|
201 |
+
|
202 |
+
if q_mask is not None:
|
203 |
+
m = pad_feature(m, mask_h0, mask_w0, q_mask)
|
204 |
+
|
205 |
+
return m
|
206 |
+
|
207 |
+
def forward(self, query, key, value, q_mask=None, kv_mask=None):
|
208 |
+
"""
|
209 |
+
Args:
|
210 |
+
queries: [N, L, H, D]
|
211 |
+
keys: [N, S, H, D]
|
212 |
+
values: [N, S, H, D]
|
213 |
+
q_mask: [N, L]
|
214 |
+
kv_mask: [N, S]
|
215 |
+
Returns:
|
216 |
+
queried_values: (N, L, H, D)
|
217 |
+
"""
|
218 |
+
bs = query.size(0)
|
219 |
+
if bs == 1 or q_mask is None:
|
220 |
+
m = self._forward(query, key, value,
|
221 |
+
q_mask=q_mask, kv_mask=kv_mask)
|
222 |
+
else: # for faster trainning with padding mask while batch size > 1
|
223 |
+
m_list = []
|
224 |
+
for i in range(bs):
|
225 |
+
m_list.append(
|
226 |
+
self._forward(
|
227 |
+
query[i: i + 1],
|
228 |
+
key[i: i + 1],
|
229 |
+
value[i: i + 1],
|
230 |
+
q_mask=q_mask[i: i + 1],
|
231 |
+
kv_mask=kv_mask[i: i + 1],
|
232 |
+
)
|
233 |
+
)
|
234 |
+
m = torch.cat(m_list, dim=0)
|
235 |
+
return m
|
236 |
+
|
237 |
+
|
238 |
+
class AG_RoPE_EncoderLayer(nn.Module):
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
d_model,
|
242 |
+
nhead,
|
243 |
+
agg_size0=2,
|
244 |
+
agg_size1=2,
|
245 |
+
rope=False,
|
246 |
+
npe=None,
|
247 |
+
):
|
248 |
+
super(AG_RoPE_EncoderLayer, self).__init__()
|
249 |
+
self.dim = d_model // nhead
|
250 |
+
self.nhead = nhead
|
251 |
+
self.agg_size0, self.agg_size1 = agg_size0, agg_size1
|
252 |
+
self.rope = rope
|
253 |
+
|
254 |
+
# aggregate and position encoding
|
255 |
+
self.aggregate = (
|
256 |
+
nn.Conv2d(
|
257 |
+
d_model,
|
258 |
+
d_model,
|
259 |
+
kernel_size=agg_size0,
|
260 |
+
padding=0,
|
261 |
+
stride=agg_size0,
|
262 |
+
bias=False,
|
263 |
+
groups=d_model,
|
264 |
+
)
|
265 |
+
if self.agg_size0 != 1
|
266 |
+
else nn.Identity()
|
267 |
+
)
|
268 |
+
self.max_pool = (
|
269 |
+
torch.nn.MaxPool2d(kernel_size=self.agg_size1,
|
270 |
+
stride=self.agg_size1)
|
271 |
+
if self.agg_size1 != 1
|
272 |
+
else nn.Identity()
|
273 |
+
)
|
274 |
+
self.mask_max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
275 |
+
self.rope_pos_enc = RoPEPositionEncodingSine(
|
276 |
+
d_model, max_shape=(128, 128), npe=npe, ropefp16=True
|
277 |
+
)
|
278 |
+
|
279 |
+
# multi-head attention
|
280 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
281 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
282 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
283 |
+
self.attention = Attention(self.nhead, self.dim)
|
284 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
285 |
+
|
286 |
+
# feed-forward network
|
287 |
+
self.mlp = nn.Sequential(
|
288 |
+
nn.Linear(d_model * 2, d_model * 2, bias=False),
|
289 |
+
nn.LeakyReLU(inplace=True),
|
290 |
+
nn.Linear(d_model * 2, d_model, bias=False),
|
291 |
+
)
|
292 |
+
|
293 |
+
# norm
|
294 |
+
self.norm1 = nn.LayerNorm(d_model)
|
295 |
+
self.norm2 = nn.LayerNorm(d_model)
|
296 |
+
|
297 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
298 |
+
"""
|
299 |
+
Args:
|
300 |
+
x (torch.Tensor): [N, C, H0, W0]
|
301 |
+
source (torch.Tensor): [N, C, H1, W1]
|
302 |
+
x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0)
|
303 |
+
source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1)
|
304 |
+
"""
|
305 |
+
bs, C, H0, W0 = x.size()
|
306 |
+
H1, W1 = source.size(-2), source.size(-1)
|
307 |
+
|
308 |
+
# Aggragate feature
|
309 |
+
# assert x_mask is None and source_mask is None
|
310 |
+
|
311 |
+
query, source = self.norm1(self.aggregate(x).permute(0, 2, 3, 1)), self.norm1(
|
312 |
+
self.max_pool(source).permute(0, 2, 3, 1)
|
313 |
+
) # [N, H, W, C]
|
314 |
+
if x_mask is not None:
|
315 |
+
# mask 1/8 to 1/32
|
316 |
+
x_mask, source_mask = map(
|
317 |
+
lambda x: self.mask_max_pool(
|
318 |
+
self.mask_max_pool(x.float())).bool(),
|
319 |
+
[x_mask, source_mask],
|
320 |
+
)
|
321 |
+
query, key, value = self.q_proj(
|
322 |
+
query), self.k_proj(source), self.v_proj(source)
|
323 |
+
|
324 |
+
# Positional encoding
|
325 |
+
if self.rope:
|
326 |
+
query = self.rope_pos_enc(query)
|
327 |
+
key = self.rope_pos_enc(key)
|
328 |
+
|
329 |
+
# multi-head attention handle padding mask
|
330 |
+
m = self.attention(query, key, value, q_mask=x_mask,
|
331 |
+
kv_mask=source_mask)
|
332 |
+
m = self.merge(m.reshape(bs, -1, self.nhead * self.dim)) # [N, L, C]
|
333 |
+
|
334 |
+
# Upsample feature
|
335 |
+
m = rearrange(
|
336 |
+
m, "b (h w) c -> b c h w", h=H0 // self.agg_size0, w=W0 // self.agg_size0
|
337 |
+
) # [N, C, H0, W0]
|
338 |
+
|
339 |
+
if self.agg_size0 != 1:
|
340 |
+
m = torch.nn.functional.interpolate(
|
341 |
+
m, size=(H0, W0), mode="bilinear", align_corners=False
|
342 |
+
) # [N, C, H0, W0]
|
343 |
+
|
344 |
+
# feed-forward network
|
345 |
+
m = self.mlp(torch.cat([x, m], dim=1).permute(
|
346 |
+
0, 2, 3, 1)) # [N, H0, W0, C]
|
347 |
+
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0]
|
348 |
+
|
349 |
+
return x + m
|
350 |
+
|
351 |
+
|
352 |
+
'''
|
353 |
+
Modified from EfficientLoFTR
|
354 |
+
'''
|
355 |
+
class LocalFeatureTransformer(nn.Module):
|
356 |
+
"""A Local Feature Transformer (LoFTR) module."""
|
357 |
+
|
358 |
+
def __init__(self, config):
|
359 |
+
super(LocalFeatureTransformer, self).__init__()
|
360 |
+
self.d_model = config["d_model"]
|
361 |
+
self.nhead = config["nhead"]
|
362 |
+
self.layer_names = config["layer_names"]
|
363 |
+
self.agg_size0, self.agg_size1 = config["agg_size0"], config["agg_size1"]
|
364 |
+
self.rope = config["rope"]
|
365 |
+
|
366 |
+
self_layer = AG_RoPE_EncoderLayer(
|
367 |
+
config["d_model"],
|
368 |
+
config["nhead"],
|
369 |
+
config["agg_size0"],
|
370 |
+
config["agg_size1"],
|
371 |
+
config["rope"],
|
372 |
+
config["npe"],
|
373 |
+
)
|
374 |
+
cross_layer = AG_RoPE_EncoderLayer(
|
375 |
+
config["d_model"],
|
376 |
+
config["nhead"],
|
377 |
+
config["agg_size0"],
|
378 |
+
config["agg_size1"],
|
379 |
+
False,
|
380 |
+
config["npe"],
|
381 |
+
)
|
382 |
+
|
383 |
+
self.layers = nn.ModuleList(
|
384 |
+
[
|
385 |
+
(
|
386 |
+
copy.deepcopy(self_layer)
|
387 |
+
if _ == "self"
|
388 |
+
else copy.deepcopy(cross_layer)
|
389 |
+
)
|
390 |
+
for _ in self.layer_names
|
391 |
+
]
|
392 |
+
)
|
393 |
+
self._reset_parameters()
|
394 |
+
|
395 |
+
def _reset_parameters(self):
|
396 |
+
for p in self.parameters():
|
397 |
+
if p.dim() > 1:
|
398 |
+
nn.init.xavier_uniform_(p)
|
399 |
+
|
400 |
+
def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
|
401 |
+
"""
|
402 |
+
Args:
|
403 |
+
feat0 (torch.Tensor): [N, C, H, W]
|
404 |
+
feat1 (torch.Tensor): [N, C, H, W]
|
405 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
406 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
407 |
+
"""
|
408 |
+
for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
|
409 |
+
if name == "self":
|
410 |
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
411 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
412 |
+
elif name == "cross":
|
413 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
414 |
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
415 |
+
else:
|
416 |
+
raise KeyError
|
417 |
+
|
418 |
+
return feat0, feat1
|
src/edm/neck/neck.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .loftr_module.transformer import LocalFeatureTransformer
|
6 |
+
|
7 |
+
|
8 |
+
class Conv2d_BN_Act(nn.Sequential):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
a,
|
12 |
+
b,
|
13 |
+
ks=1,
|
14 |
+
stride=1,
|
15 |
+
pad=0,
|
16 |
+
dilation=1,
|
17 |
+
groups=1,
|
18 |
+
bn_weight_init=1,
|
19 |
+
act=None,
|
20 |
+
drop=None,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.inp_channel = a
|
24 |
+
self.out_channel = b
|
25 |
+
self.ks = ks
|
26 |
+
self.pad = pad
|
27 |
+
self.stride = stride
|
28 |
+
self.dilation = dilation
|
29 |
+
self.groups = groups
|
30 |
+
|
31 |
+
self.add_module(
|
32 |
+
"c", nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
|
33 |
+
)
|
34 |
+
bn = nn.BatchNorm2d(b)
|
35 |
+
nn.init.constant_(bn.weight, bn_weight_init)
|
36 |
+
nn.init.constant_(bn.bias, 0)
|
37 |
+
self.add_module("bn", bn)
|
38 |
+
if act != None:
|
39 |
+
self.add_module("a", act)
|
40 |
+
if drop != None:
|
41 |
+
self.add_module("d", nn.Dropout(drop))
|
42 |
+
|
43 |
+
|
44 |
+
class CIM(nn.Module):
|
45 |
+
"""Feature Aggregation, Correlation Injection Module"""
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
super(CIM, self).__init__()
|
49 |
+
|
50 |
+
self.block_dims = config["backbone"]["block_dims"]
|
51 |
+
self.drop = config["fine"]["droprate"]
|
52 |
+
|
53 |
+
self.fc32 = Conv2d_BN_Act(
|
54 |
+
self.block_dims[-1], self.block_dims[-1], 1, drop=self.drop
|
55 |
+
)
|
56 |
+
self.fc16 = Conv2d_BN_Act(
|
57 |
+
self.block_dims[-2], self.block_dims[-1], 1, drop=self.drop
|
58 |
+
)
|
59 |
+
self.fc8 = Conv2d_BN_Act(
|
60 |
+
self.block_dims[-3], self.block_dims[-1], 1, drop=self.drop
|
61 |
+
)
|
62 |
+
self.att32 = Conv2d_BN_Act(
|
63 |
+
self.block_dims[-1],
|
64 |
+
self.block_dims[-1],
|
65 |
+
1,
|
66 |
+
act=nn.Sigmoid(),
|
67 |
+
drop=self.drop,
|
68 |
+
)
|
69 |
+
self.att16 = Conv2d_BN_Act(
|
70 |
+
self.block_dims[-1],
|
71 |
+
self.block_dims[-1],
|
72 |
+
1,
|
73 |
+
act=nn.Sigmoid(),
|
74 |
+
drop=self.drop,
|
75 |
+
)
|
76 |
+
self.dwconv16 = nn.Sequential(
|
77 |
+
Conv2d_BN_Act(
|
78 |
+
self.block_dims[-1],
|
79 |
+
self.block_dims[-1],
|
80 |
+
ks=3,
|
81 |
+
pad=1,
|
82 |
+
groups=self.block_dims[-1],
|
83 |
+
act=nn.GELU(),
|
84 |
+
),
|
85 |
+
Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
|
86 |
+
)
|
87 |
+
self.dwconv8 = nn.Sequential(
|
88 |
+
Conv2d_BN_Act(
|
89 |
+
self.block_dims[-1],
|
90 |
+
self.block_dims[-1],
|
91 |
+
ks=3,
|
92 |
+
pad=1,
|
93 |
+
groups=self.block_dims[-1],
|
94 |
+
act=nn.GELU(),
|
95 |
+
),
|
96 |
+
Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
|
97 |
+
)
|
98 |
+
|
99 |
+
self.loftr_32 = LocalFeatureTransformer(config["neck"])
|
100 |
+
|
101 |
+
def forward(self, ms_feats, mask_c0=None, mask_c1=None):
|
102 |
+
if len(ms_feats) == 3: # same image shape
|
103 |
+
f8, f16, f32 = ms_feats
|
104 |
+
f32 = self.fc32(f32)
|
105 |
+
|
106 |
+
f32_0, f32_1 = f32.chunk(2, dim=0)
|
107 |
+
f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
|
108 |
+
f32 = torch.cat([f32_0, f32_1], dim=0)
|
109 |
+
|
110 |
+
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
|
111 |
+
att32_up = F.interpolate(self.att32(
|
112 |
+
f32), scale_factor=2.0, mode="bilinear")
|
113 |
+
f16 = self.fc16(f16)
|
114 |
+
f16 = self.dwconv16(f16 * att32_up + f32_up)
|
115 |
+
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
|
116 |
+
att16_up = F.interpolate(self.att16(
|
117 |
+
f16), scale_factor=2.0, mode="bilinear")
|
118 |
+
f8 = self.fc8(f8)
|
119 |
+
f8 = self.dwconv8(f8 * att16_up + f16_up)
|
120 |
+
|
121 |
+
feat_c0, feat_c1 = f8.chunk(2)
|
122 |
+
|
123 |
+
elif len(ms_feats) == 6: # diffirent image shape
|
124 |
+
f8_0, f16_0, f32_0, f8_1, f16_1, f32_1 = ms_feats
|
125 |
+
f32_0 = self.fc32(f32_0)
|
126 |
+
f32_1 = self.fc32(f32_1)
|
127 |
+
|
128 |
+
f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
|
129 |
+
|
130 |
+
f8, f16, f32 = f8_0, f16_0, f32_0
|
131 |
+
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
|
132 |
+
att32_up = F.interpolate(self.att32(
|
133 |
+
f32), scale_factor=2.0, mode="bilinear")
|
134 |
+
f16 = self.fc16(f16)
|
135 |
+
f16 = self.dwconv16(f16 * att32_up + f32_up)
|
136 |
+
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
|
137 |
+
att16_up = F.interpolate(self.att16(
|
138 |
+
f16), scale_factor=2.0, mode="bilinear")
|
139 |
+
f8 = self.fc8(f8)
|
140 |
+
f8 = self.dwconv8(f8 * att16_up + f16_up)
|
141 |
+
feat_c0 = f8
|
142 |
+
|
143 |
+
f8, f16, f32 = f8_1, f16_1, f32_1
|
144 |
+
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
|
145 |
+
att32_up = F.interpolate(self.att32(
|
146 |
+
f32), scale_factor=2.0, mode="bilinear")
|
147 |
+
f16 = self.fc16(f16)
|
148 |
+
f16 = self.dwconv16(f16 * att32_up + f32_up)
|
149 |
+
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
|
150 |
+
att16_up = F.interpolate(self.att16(
|
151 |
+
f16), scale_factor=2.0, mode="bilinear")
|
152 |
+
f8 = self.fc8(f8)
|
153 |
+
f8 = self.dwconv8(f8 * att16_up + f16_up)
|
154 |
+
feat_c1 = f8
|
155 |
+
|
156 |
+
return feat_c0, feat_c1
|
src/utils/misc.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import contextlib
|
3 |
+
import joblib
|
4 |
+
from typing import Union
|
5 |
+
from loguru import _Logger, logger
|
6 |
+
from itertools import chain
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from yacs.config import CfgNode as CN
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def lower_config(yacs_cfg):
|
14 |
+
if not isinstance(yacs_cfg, CN):
|
15 |
+
return yacs_cfg
|
16 |
+
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
|
17 |
+
|
18 |
+
|
19 |
+
def upper_config(dict_cfg):
|
20 |
+
if not isinstance(dict_cfg, dict):
|
21 |
+
return dict_cfg
|
22 |
+
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
|
23 |
+
|
24 |
+
|
25 |
+
def log_on(condition, message, level):
|
26 |
+
if condition:
|
27 |
+
assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
|
28 |
+
logger.log(level, message)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def setup_gpus(gpus: Union[str, int]) -> int:
|
33 |
+
"""A temporary fix for pytorch-lighting 1.3.x"""
|
34 |
+
gpus = str(gpus)
|
35 |
+
gpu_ids = []
|
36 |
+
|
37 |
+
if "," not in gpus:
|
38 |
+
n_gpus = int(gpus)
|
39 |
+
return n_gpus if n_gpus != -1 else torch.cuda.device_count()
|
40 |
+
else:
|
41 |
+
gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
|
42 |
+
|
43 |
+
# setup environment variables
|
44 |
+
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
45 |
+
if visible_devices is None:
|
46 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
47 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
|
48 |
+
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
49 |
+
logger.warning(
|
50 |
+
f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
logger.warning(
|
54 |
+
"[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
|
55 |
+
)
|
56 |
+
return len(gpu_ids)
|
57 |
+
|
58 |
+
|
59 |
+
def flattenList(x):
|
60 |
+
return list(chain(*x))
|
61 |
+
|
62 |
+
|
63 |
+
@contextlib.contextmanager
|
64 |
+
def tqdm_joblib(tqdm_object):
|
65 |
+
"""Context manager to patch joblib to report into tqdm progress bar given as argument
|
66 |
+
|
67 |
+
Usage:
|
68 |
+
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
|
69 |
+
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
|
70 |
+
|
71 |
+
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
|
72 |
+
ret_vals = Parallel(n_jobs=args.world_size)(
|
73 |
+
delayed(lambda x: _compute_cov_score(pid, *x))(param)
|
74 |
+
for param in tqdm(combinations(image_ids, 2),
|
75 |
+
desc=f'Computing cov_score of [{pid}]',
|
76 |
+
total=len(image_ids)*(len(image_ids)-1)/2))
|
77 |
+
Src: https://stackoverflow.com/a/58936697
|
78 |
+
"""
|
79 |
+
|
80 |
+
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
|
81 |
+
def __init__(self, *args, **kwargs):
|
82 |
+
super().__init__(*args, **kwargs)
|
83 |
+
|
84 |
+
def __call__(self, *args, **kwargs):
|
85 |
+
tqdm_object.update(n=self.batch_size)
|
86 |
+
return super().__call__(*args, **kwargs)
|
87 |
+
|
88 |
+
old_batch_callback = joblib.parallel.BatchCompletionCallBack
|
89 |
+
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
|
90 |
+
try:
|
91 |
+
yield tqdm_object
|
92 |
+
finally:
|
93 |
+
joblib.parallel.BatchCompletionCallBack = old_batch_callback
|
94 |
+
tqdm_object.close()
|
95 |
+
|
96 |
+
|
97 |
+
def detect_NaN(feat_0, feat_1):
|
98 |
+
logger.info(f"NaN detected in feature")
|
99 |
+
logger.info(
|
100 |
+
f"#NaN in feat_0: {torch.isnan(feat_0).int().sum()}, #NaN in feat_1: {torch.isnan(feat_1).int().sum()}"
|
101 |
+
)
|
102 |
+
feat_0[torch.isnan(feat_0)] = 0
|
103 |
+
feat_1[torch.isnan(feat_1)] = 0
|
src/utils/plotting.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
def _compute_conf_thresh(data):
|
8 |
+
dataset_name = data["dataset_name"][0].lower()
|
9 |
+
if dataset_name == "scannet":
|
10 |
+
thr = 5e-4
|
11 |
+
elif dataset_name == "megadepth":
|
12 |
+
thr = 1e-4
|
13 |
+
else:
|
14 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
15 |
+
return thr
|
16 |
+
|
17 |
+
|
18 |
+
# --- VISUALIZATION --- #
|
19 |
+
|
20 |
+
|
21 |
+
def make_matching_figure(
|
22 |
+
img0,
|
23 |
+
img1,
|
24 |
+
mkpts0,
|
25 |
+
mkpts1,
|
26 |
+
color,
|
27 |
+
kpts0=None,
|
28 |
+
kpts1=None,
|
29 |
+
text=[],
|
30 |
+
path=None,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
使用OpenCV绘制匹配点可视化图像
|
34 |
+
|
35 |
+
参数:
|
36 |
+
img0: 第一张图像 (BGR格式)
|
37 |
+
img1: 第二张图像 (BGR格式)
|
38 |
+
mkpts0: 第一张图像中的匹配点 (Nx2数组)
|
39 |
+
mkpts1: 第二张图像中的匹配点 (Nx2数组)
|
40 |
+
color: 每个匹配点的颜色
|
41 |
+
kpts0: 第一张图像中的所有关键点 (可选)
|
42 |
+
kpts1: 第二张图像中的所有关键点 (可选)
|
43 |
+
text: 要添加的文本 (可选)
|
44 |
+
path: 保存图像的路径 (可选)
|
45 |
+
|
46 |
+
返回:
|
47 |
+
绘制好的OpenCV图像 (BGR格式)
|
48 |
+
"""
|
49 |
+
# 确保匹配点数量一致
|
50 |
+
assert mkpts0.shape[0] == mkpts1.shape[0], \
|
51 |
+
f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
|
52 |
+
|
53 |
+
# 确保图像有相同的高度,如果不同则调整
|
54 |
+
h0, w0 = img0.shape[:2]
|
55 |
+
h1, w1 = img1.shape[:2]
|
56 |
+
max_height = max(h0, h1)
|
57 |
+
|
58 |
+
# 创建画布,两张图像并排显示
|
59 |
+
canvas = np.ones((max_height, w0 + w1, 3), dtype=np.uint8) * 255
|
60 |
+
|
61 |
+
# 将图像放置到画布上
|
62 |
+
canvas[:h0, :w0] = img0
|
63 |
+
canvas[:h1, w0:w0+w1] = img1
|
64 |
+
|
65 |
+
# 绘制所有关键点(如果提供)
|
66 |
+
if kpts0 is not None and kpts1 is not None:
|
67 |
+
for (x, y) in kpts0.astype(np.int32):
|
68 |
+
cv2.circle(canvas, (x, y), 1, (255, 255, 255), -1)
|
69 |
+
|
70 |
+
for (x, y) in kpts1.astype(np.int32):
|
71 |
+
cv2.circle(canvas, (x + w0, y), 1, (255, 255, 255), -1)
|
72 |
+
|
73 |
+
# 绘制匹配点和连接线
|
74 |
+
if mkpts0.shape[0] > 0 and mkpts1.shape[0] > 0:
|
75 |
+
# 转换为整数坐标
|
76 |
+
mkpts0_int = mkpts0.astype(np.int32)
|
77 |
+
mkpts1_int = mkpts1.astype(np.int32)
|
78 |
+
|
79 |
+
# 绘制连接线
|
80 |
+
for i in range(len(mkpts0_int)):
|
81 |
+
x0, y0 = mkpts0_int[i]
|
82 |
+
x1, y1 = mkpts1_int[i]
|
83 |
+
# 第二张图的x坐标需要加上第一张图的宽度
|
84 |
+
x1 += w0
|
85 |
+
|
86 |
+
# 将颜色从0-1范围转换为0-255
|
87 |
+
line_color = tuple(int(c * 255) for c in color[i][:3])
|
88 |
+
# 转换为BGR格式(因为OpenCV使用BGR)
|
89 |
+
# line_color = (line_color[2], line_color[1], line_color[0])
|
90 |
+
|
91 |
+
cv2.line(canvas, (x0, y0), (x1, y1), line_color, 1)
|
92 |
+
|
93 |
+
# 绘制匹配点
|
94 |
+
for i in range(len(mkpts0_int)):
|
95 |
+
x0, y0 = mkpts0_int[i]
|
96 |
+
x1, y1 = mkpts1_int[i]
|
97 |
+
x1 += w0
|
98 |
+
|
99 |
+
pt_color = tuple(int(c * 255) for c in color[i][:3])
|
100 |
+
# pt_color = (pt_color[2], pt_color[1], pt_color[0])
|
101 |
+
|
102 |
+
cv2.circle(canvas, (x0, y0), 2, pt_color, -1)
|
103 |
+
cv2.circle(canvas, (x1, y1), 2, pt_color, -1)
|
104 |
+
|
105 |
+
# 添加文本
|
106 |
+
if text:
|
107 |
+
# 确定文本颜色(基于图像亮度)
|
108 |
+
roi = img0[:100, :200] if h0 > 100 and w0 > 200 else img0
|
109 |
+
brightness = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY).mean()
|
110 |
+
text_color = (0, 0, 0) if brightness > 200 else (255, 255, 255)
|
111 |
+
|
112 |
+
# 绘制文本
|
113 |
+
y_pos = 30
|
114 |
+
for i, line in enumerate(text):
|
115 |
+
cv2.putText(
|
116 |
+
canvas, line, (10, y_pos + i * 30),
|
117 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, text_color, 2
|
118 |
+
)
|
119 |
+
|
120 |
+
# 保存图像(如果指定了路径)
|
121 |
+
if path:
|
122 |
+
cv2.imwrite(path, canvas)
|
123 |
+
|
124 |
+
return canvas
|
125 |
+
|
126 |
+
def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
127 |
+
b_mask = data["m_bids"] == b_id
|
128 |
+
conf_thr = _compute_conf_thresh(data)
|
129 |
+
|
130 |
+
img0 = (data["image0"][b_id][0].cpu().numpy()
|
131 |
+
* 255).round().astype(np.int32)
|
132 |
+
img1 = (data["image1"][b_id][0].cpu().numpy()
|
133 |
+
* 255).round().astype(np.int32)
|
134 |
+
kpts0 = data["mkpts0_f"][b_mask].clone().detach().cpu().numpy()
|
135 |
+
kpts1 = data["mkpts1_f"][b_mask].clone().detach().cpu().numpy()
|
136 |
+
|
137 |
+
# for megadepth, we visualize matches on the resized image
|
138 |
+
if "scale0" in data:
|
139 |
+
kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
|
140 |
+
kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
|
141 |
+
|
142 |
+
epi_errs = data["epi_errs"][b_mask].cpu().numpy()
|
143 |
+
correct_mask = epi_errs < conf_thr
|
144 |
+
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
145 |
+
n_correct = np.sum(correct_mask)
|
146 |
+
n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
|
147 |
+
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
148 |
+
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
149 |
+
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
150 |
+
|
151 |
+
# matching info
|
152 |
+
if alpha == "dynamic":
|
153 |
+
alpha = dynamic_alpha(len(correct_mask))
|
154 |
+
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
155 |
+
|
156 |
+
text = [
|
157 |
+
f"#Matches {len(kpts0)}",
|
158 |
+
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
|
159 |
+
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
|
160 |
+
]
|
161 |
+
|
162 |
+
# make the figure
|
163 |
+
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
|
164 |
+
return figure
|
165 |
+
|
166 |
+
|
167 |
+
def _make_confidence_figure(data, b_id):
|
168 |
+
# TODO: Implement confidence figure
|
169 |
+
raise NotImplementedError()
|
170 |
+
|
171 |
+
|
172 |
+
def make_matching_figures(data, config, mode="evaluation"):
|
173 |
+
"""Make matching figures for a batch.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
data (Dict): a batch updated by PL_LoFTR.
|
177 |
+
config (Dict): matcher config
|
178 |
+
Returns:
|
179 |
+
figures (Dict[str, List[plt.figure]]
|
180 |
+
"""
|
181 |
+
assert mode in ["evaluation", "confidence", "gt"] # 'confidence'
|
182 |
+
figures = {mode: []}
|
183 |
+
for b_id in range(data["image0"].size(0)):
|
184 |
+
if mode == "evaluation":
|
185 |
+
fig = _make_evaluation_figure(
|
186 |
+
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
|
187 |
+
)
|
188 |
+
elif mode == "confidence":
|
189 |
+
fig = _make_confidence_figure(data, b_id)
|
190 |
+
else:
|
191 |
+
raise ValueError(f"Unknown plot mode: {mode}")
|
192 |
+
figures[mode].append(fig)
|
193 |
+
return figures
|
194 |
+
|
195 |
+
|
196 |
+
def dynamic_alpha(
|
197 |
+
n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
|
198 |
+
):
|
199 |
+
if n_matches == 0:
|
200 |
+
return 1.0
|
201 |
+
ranges = list(zip(alphas, alphas[1:] + [None]))
|
202 |
+
loc = bisect.bisect_right(milestones, n_matches) - 1
|
203 |
+
_range = ranges[loc]
|
204 |
+
if _range[1] is None:
|
205 |
+
return _range[0]
|
206 |
+
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
207 |
+
milestones[loc + 1] - milestones[loc]
|
208 |
+
) * (_range[0] - _range[1])
|
209 |
+
|
210 |
+
|
211 |
+
def error_colormap(err, thr, alpha=1.0):
|
212 |
+
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
213 |
+
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
214 |
+
return np.clip(
|
215 |
+
np.stack([2 - x * 2, x * 2, np.zeros_like(x),
|
216 |
+
np.ones_like(x) * alpha], -1),
|
217 |
+
0,
|
218 |
+
1,
|
219 |
+
)
|