ms180 commited on
Commit
3dfdefc
·
verified ·
1 Parent(s): e323307

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -78,14 +78,44 @@ Please consider citing the following papers if you find our work helpful.
78
 
79
  device = "cuda"
80
 
81
- # device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- s2l = Speech2Language.from_pretrained(
85
- model_tag=f"espnet/owsm_v4_medium_1B",
86
- device=device,
87
- nbest=1,
88
- )
89
 
90
  s2t_ar = ARSpeech2Text.from_pretrained(
91
  model_tag=f"espnet/owsm_v4_medium_1B",
 
78
 
79
  device = "cuda"
80
 
81
+ try:
82
+ s2l = Speech2Language.from_pretrained(
83
+ model_tag=f"espnet/owsm_v4_medium_1B",
84
+ device="cpu",
85
+ nbest=1,
86
+ )
87
+ except Exception as e:
88
+ print("File downloaded")
89
 
90
+ # 2. Remove unrequired file
91
+ import yaml
92
+ from pathlib import Path
93
+ import espnet_model_zoo
94
+
95
+ d = "models--espnet--owsm_v4_medium_1B/snapshots/471418ddaf0b03c9ab1fd75f1f5d26fc3aea3aa9/exp/s2t_train_conv2d8_size1024_e18_d18_mel128_raw_bpe50000/config.yaml"
96
+ p = Path(espnet_model_zoo.__file__)
97
+ config_path = p.parent / d
98
+
99
+ def remove_key(obj, key="gradient_checkpoint_layers"):
100
+ if isinstance(obj, dict):
101
+ if key in obj:
102
+ del obj[key]
103
+ for k, v in list(obj.items()):
104
+ remove_key(v, key)
105
+ elif isinstance(obj, list):
106
+ for item in obj:
107
+ remove_key(item, key)
108
+
109
+ with open(config_path, "r") as f:
110
+ config = yaml.safe_load(f)
111
+
112
+ remove_key(config)
113
+
114
+ with open(config_path, "w") as f:
115
+ yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)
116
+
117
+ print("Done! All 'gradient_checkpoint_layers' keys removed.")
118
 
 
 
 
 
 
119
 
120
  s2t_ar = ARSpeech2Text.from_pretrained(
121
  model_tag=f"espnet/owsm_v4_medium_1B",