# # Copyright (c) Facebook, Inc. and its affiliates. | |
# # All rights reserved. | |
# # | |
# # This source code is licensed under the license found in the | |
# # LICENSE file in the root directory of this source tree. | |
# | |
# # Adapted from https://github.com/jik876/hifi-gan | |
# | |
import os | |
import torch | |
def init_weights(m, mean=0.0, std=0.01): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(mean, std) | |
def get_padding(kernel_size, dilation=1): | |
return int((kernel_size*dilation - dilation)/2) | |
def load_checkpoint(filepath, device): | |
assert os.path.isfile(filepath) | |
print("Loading '{}'".format(filepath)) | |
checkpoint_dict = torch.load(filepath, map_location=device) | |
print("Complete.") | |
return checkpoint_dict | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |