forked from facebookresearch/denoiser
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·131 lines (100 loc) · 5.18 KB
/
train.py
File metadata and controls
executable file
·131 lines (100 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# authors: adiyoss and adefossez
import logging
import os
import hydra
from denoiser.executor import start_ddp_workers
logger = logging.getLogger(__name__)
def run(args): # 参数为args
import torch
from denoiser import distrib # 分GPU训练-DDP
from denoiser.data import NoisyCleanSet
from denoiser.ConvTasnet1 import ConvTasNet
from denoiser.solver import Solver
# torch also initialize cuda seed if available
torch.manual_seed(2036) # 看要不要用Cuda吧
import os
clean_folder = '/home/heisnproph11/pyProject/denoiser/train/clean'
noise_folder = '/home/heisnproph11/pyProject/denoiser/train/noisy'
clean_files = [os.path.join(clean_folder, file) for file in os.listdir(clean_folder) if file.endswith('.wav')]
noise_files = [os.path.join(noise_folder, file) for file in os.listdir(noise_folder) if file.endswith('.wav')]
model = ConvTasNet(sources={'clean': clean_files, 'noisy': noise_files}, N=256, L=20, B=256, H=512, P=3, X=8, R=4,
audio_channels=2, norm_type="gLN", causal=False, mask_nonlinear='relu',
samplerate=44100, segment_length=44100 * 2 * 4, frame_length=400, frame_step=100) # 创建Demucs模型实例,并使用args参数初始化模型
# 开始调用的同时,就开始RUN了,所以这一步是在RUN,之后的是模型使用条件
if args.show: # 这个args.show是出现在什么地方的?????
logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
logger.info('Size: %.1f MB', mb)
if hasattr(model, 'valid_length'):
field = model.valid_length(1) # 如果模型有valid——length方法,计算模型有效长度
logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
return
assert args.batch_size % distrib.world_size == 0 # 确定参数的batch_size批量大小,是分布式环境distrib的整数倍
args.batch_size //= distrib.world_size
length = int(args.segment * args.sample_rate) # 计算一个音频的长度,对于Transformer来说很重要!!!
stride = int(args.stride * args.sample_rate) # 计算音频片段之间的跨度
# Demucs requires a specific number of samples to avoid 0 padding during training
if hasattr(model, 'valid_length'): # 检查模型是否有valid_length
length = model.valid_length(length) # 如果有valid_length,根据模型的有效长度valid_length调整音频片段的长度,
# transformer同样需要!!!
kwargs = {"matching": args.dset.matching, "sample_rate": args.sample_rate}
########################################################################################################
# Building datasets and loaders 从Data里面给定数据集来训练。接口1 !!!
tr_dataset = NoisyCleanSet(
args.dset.train, length=length, stride=stride, pad=args.pad, **kwargs)
tr_loader = distrib.loader(
tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
if args.dset.valid: # 验证集
cv_dataset = NoisyCleanSet(args.dset.valid, **kwargs)
cv_loader = distrib.loader(cv_dataset, batch_size=1, num_workers=args.num_workers)
else:
cv_loader = None
if args.dset.test: # 测试集
tt_dataset = NoisyCleanSet(args.dset.test, **kwargs)
tt_loader = distrib.loader(tt_dataset, batch_size=1, num_workers=args.num_workers)
else:
tt_loader = None
data = {"tr_loader": tr_loader, "cv_loader": cv_loader, "tt_loader": tt_loader}
if torch.cuda.is_available():
model.cuda() # 如果cuda可用
# optimizer
if args.optim == "adam": # 选择参数优化器? ?? 这个通过什么方式
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, args.beta2))
else:
logger.fatal('Invalid optimizer %s', args.optim)
os._exit(1)
# Construct Solver
solver = Solver(data, model, optimizer, args)
solver.train()
def _main(args):
global __file__
# Updating paths in config
for key, value in args.dset.items():
if isinstance(value, str) and key not in ["matching"]:
args.dset[key] = hydra.utils.to_absolute_path(value)
__file__ = hydra.utils.to_absolute_path(__file__)
if args.verbose:
logger.setLevel(logging.DEBUG)
logging.getLogger("denoise").setLevel(logging.DEBUG)
logger.info("For logs, checkpoints and samples check %s", os.getcwd())
logger.debug(args)
if args.ddp and args.rank is None:
start_ddp_workers(args)
else:
run(args)
@hydra.main(version_base=None,config_path="conf",config_name="config")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()