Matthew Wiesner commited on
Commit
33dbb99
·
1 Parent(s): 8c9cdc2

Initial commit (without model)

Browse files
Files changed (3) hide show
  1. app.py +194 -0
  2. packages.txt +1 -0
  3. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import lhotse
4
+ from lhotse import CutSet
5
+ import numpy as np
6
+ import os
7
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2ForPreTraining
8
+ import gradio as gr
9
+ import geoviews as gv
10
+ import geoviews.tile_sources as gts
11
+ import uuid
12
+ import math
13
+ import torch.nn as nn
14
+
15
+
16
+ device = torch.device("cpu")
17
+
18
+ class AttentionPool(nn.Module):
19
+ def __init__(self, att, query_embed):
20
+ super(AttentionPool, self).__init__()
21
+ self.query_embed = query_embed
22
+ self.att = att
23
+
24
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
25
+ # Create mask
26
+ max_seq_length = x_lens.max().item()
27
+
28
+ # Step 2: Create a binary mask
29
+ mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
30
+
31
+ # Step 3: Expand the mask to match the shape required by MultiheadAttention
32
+ # The mask should have shape (batch_size, 1, 1, max_seq_length)
33
+ x, w = self.att(
34
+ self.query_embed.unsqueeze(0).unsqueeze(1).repeat(x.size(0), 1, 1),
35
+ x,
36
+ x,
37
+ key_padding_mask=mask
38
+ )
39
+ x = x.squeeze(1)
40
+ return x, w
41
+
42
+
43
+ class AveragePool(nn.Module):
44
+ def __init__(self):
45
+ super(AveragePool, self).__init__()
46
+
47
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
48
+ # Create mask
49
+ max_seq_length = x_lens.max().item()
50
+ # Step 2: Create a binary mask
51
+ mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
52
+ x[mask] = torch.nan
53
+ return x.nanmean(dim=1), None
54
+
55
+
56
+ class Wav2Vec2Model(nn.Module):
57
+ def __init__(self,
58
+ modelpath='facebook/mms-300m',
59
+ freeze_feat_extractor=True,
60
+ pooling_loc=0,
61
+ pooling_type='att',
62
+ ):
63
+ super(Wav2Vec2Model, self).__init__()
64
+ try:
65
+ self.encoder = Wav2Vec2ForCTC.from_pretrained(modelpath).wav2vec2
66
+ except:
67
+ self.encoder = Wav2Vec2ForPreTraining.from_pretrained(modelpath).wav2vec2
68
+
69
+ if freeze_feat_extractor:
70
+ self.encoder.feature_extractor._freeze_parameters()
71
+ self.freeze_feat_extractor = freeze_feat_extractor
72
+ self.odim = self._get_output_dim()
73
+
74
+ self.frozen = False
75
+ if pooling_type == 'att':
76
+ assert pooling_loc == 0
77
+ self.att = nn.MultiheadAttention(self.odim, 1, batch_first=True)
78
+ self.loc_embed = nn.Parameter(
79
+ torch.FloatTensor(self.odim).uniform_(-1, 1)
80
+ )
81
+ self.pooling = AttentionPool(self.att, self.loc_embed)
82
+ elif pooling_type == 'avg':
83
+ self.pooling = AveragePool()
84
+ self.pooling_type = pooling_type
85
+ # pooling loc is on 0: embeddings 1: unnormalized coords, 2: normalized coords
86
+ self.pooling_loc = pooling_loc
87
+ self.linear_out = nn.Linear(self.odim, 3)
88
+
89
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
90
+ x = self.encoder(
91
+ x.squeeze(-1), output_hidden_states=False
92
+ )[0]
93
+
94
+ for width, stride in [(10, 5), (3, 2), (3, 2), (3, 2), (3, 2), (2, 2), (2, 2)]:
95
+ x_lens = torch.floor((x_lens - width) / stride + 1)
96
+ if self.pooling_loc == 0:
97
+ x, w = self.pooling(x, x_lens)
98
+ x = self.linear_out(x)
99
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
100
+ elif self.pooling_loc == 1:
101
+ x = self.linear_out(x)
102
+ x, w = self.pooling(x, x_lens)
103
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
104
+ elif self.pooling_loc == 2:
105
+ x = self.linear_out(x)
106
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
107
+ x = self.pooling(x, x_lens)
108
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
109
+ return x, w
110
+
111
+ def freeze_encoder(self):
112
+ for p in self.encoder.encoder.parameters():
113
+ if p.requires_grad:
114
+ p.requires_grad = False
115
+ self.frozen = True
116
+
117
+ def unfreeze_encoder(self):
118
+ for i, p in enumerate(self.encoder.encoder.parameters()):
119
+ p.requires_grad = True
120
+ if self.freeze_feat_extractor:
121
+ self.encoder.feature_extractor._freeze_parameters()
122
+ self.frozen = False
123
+
124
+ def _get_output_dim(self):
125
+ x = torch.rand(1, 400)
126
+ return self.encoder(x).last_hidden_state.size(-1)
127
+
128
+
129
+ #if 'checkpoint.pt' not in os.listdir():
130
+ # checkpoint_url = "https://drive.google.com/uc?id=162jJ_YC4MGEfXBWvAK-kXnZcXX3v1smr"
131
+ # output = "checkpoint.pt"
132
+ # gdown.download(checkpoint_url, output, quiet=False)
133
+
134
+
135
+ model = Wav2Vec2Model()
136
+ model.to(device)
137
+
138
+ # load model checkpoint
139
+ checkpoint = torch.load("mms-300m-checkpoint.pt", map_location=f'cpu')
140
+ model.load_state_dict(checkpoint)
141
+ model.eval()
142
+ print(f'Loaded state dict {f}')
143
+
144
+ def predict(audio_path):
145
+ # get raw audio data
146
+ try:
147
+ a = lhotse.Recording.from_file(audio_path)
148
+ except:
149
+ return (None, "Please wait a bit until the audio file has uploaded, then try again")
150
+ cuts = CutSet.from_recordings([a])
151
+ cuts = cuts.resample(16000)
152
+ cuts = cuts.cut_into_windows(10).to_mono(mono_downmix=True)
153
+ # The model wasn't trained on anything less than 2 seconds long
154
+ cuts = cuts.filter(lambda c: c.duration >= 2)
155
+ audio_data, audio_lens = lhotse.dataset.collation.collate_audio(cuts)
156
+
157
+ # pass through model
158
+ x, _ = model.forward(audio_data, audio_lens)
159
+ x = x.mean(dim=0)
160
+ print(x)
161
+
162
+ pred_lon = torch.atan2(x[:, 0], x[:, 1]).unsqueeze(-1)
163
+ pred_lat = torch.asin(x[:, 2]).unsqueeze(-1)
164
+ x_polar = torch.cat((pred_lat, pred_lon), dim=1).to(device)
165
+ coords = x_polar.mul(180. / math.pi).cpu().detach().numpy()
166
+ print(coords)
167
+
168
+
169
+ coords = [[-lon, math.degrees(math.asin(math.sin(math.radians(lat))))] if lat > 90 else [lon, lat] for lat, lon in coords][0] # wraparound fix (lat > 90)
170
+
171
+ # create plot
172
+ guesses = gv.Points([coords]).opts(
173
+ size=8, cmap='Spectral_r', color='blue', fill_alpha=1
174
+ )
175
+ plot = (gts.OSM * guesses).options(
176
+ gv.opts.Points(width=800, height=400, xlim=(-180*110000, 180*110000), ylim=(-90*140000, 90*140000), xaxis=None, yaxis=None)
177
+ )
178
+ filename = f"{str(uuid.uuid4())}.png"
179
+ gv.save(plot, filename=filename, fmt='png')
180
+ coords = [round(i, 2) for i in coords]
181
+ coords = [coords[1], coords[0]]
182
+ print(filename, coords)
183
+ return (filename, str(coords)[1:-1])
184
+
185
+
186
+ gradio_app = gr.Interface(
187
+ predict,
188
+ inputs=gr.Audio(label="Record Audio (5 seconds)", type="filepath", min_length=5.0),
189
+ outputs=[gr.Image(type="filepath", label="Map of Prediction"), gr.Textbox(placeholder="Latitude, Longitude", label="Prediction (Latitude, Longitude)")],
190
+ title="Speech Geolocation Demo",
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ gradio_app.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chromium-driver
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bokeh==3.3.4
2
+ Cartopy==0.22.0
3
+ geoviews==1.11.0
4
+ gradio==5.13.0
5
+ lhotse==1.19.2
6
+ pydub==0.25.1
7
+ torch==2.1.2
8
+ transformers==4.37.1
9
+ selenium==4.0.0
10
+ numpy<2
11
+ torchaudio