-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
100 lines (87 loc) · 3.95 KB
/
main.py
File metadata and controls
100 lines (87 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from argparse import ArgumentParser
from params import get_parser
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import random
from dataset import MixDataset, AdDataset, SingleDataset, collect_fn, collect_ad, load_data, build_dict
from module import ScT
import warnings
from torch.utils.data import DataLoader
import scanpy as sc
import numpy as np
import anndata as ad
from utils import count_parameters
warnings.filterwarnings('ignore')
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
# parser = ArgumentParser()
parser = get_parser()
parser.add_argument('--precision', default=16, type=int)
parser.add_argument('--pretrain', default=False, type=bool)
parser.add_argument('--shuffle', default=True, type=bool)
params = parser.parse_args()
# ------------
# data
# ------------
adata_m, adata_h, adata_ts, adata_ct = load_data(params)
if params.dataset == 'hcl':
adata_m.obs['species'] = 0
adata_h.obs['species'] = 1
adata_ct.obs['species'] = 0
adata_ts.obs['species'] = 0
elif params.dataset == 'brain':
adata_m.obs['species'] = 0
adata_h.obs['species'] = 1
adata_ct.obs['species'] = 1
adata_ts.obs['species'] = 1
gene_list, adata_m, adata_h, id2tissue, tissue2id, id2celltype, celltype2id = build_dict(adata_m, adata_h)
adata_ts = adata_ts[:,gene_list]
adata_ct = adata_ct[:,gene_list]
adata_ts.var_names = gene_list
adata_ct.var_names = gene_list
adata_test = ad.concat([adata_ts, adata_ct])
gene2id = None
params.n_genes = len(gene_list)+1
params.gene2id = gene2id
# params.n_val = 11 # max(adata_m.X.max(), adata_h.X.max()) + 1 = 579
params.n_tissue = len(tissue2id)
params.n_celltype = len(celltype2id)
print('Dictionary builded, gene size: {}, number of tissue: {}, number of celltype: {}'.format(params.n_genes,params.n_tissue,params.n_celltype))
dataset_train = MixDataset(adata_m, adata_h, gene2id, tissue2id, celltype2id)
tr_len = int(dataset_train.__len__()*0.9)
val_len = dataset_train.__len__() - tr_len
ds_train, ds_val = random_split(dataset_train, [tr_len, val_len])
dataset_test = MixDataset(adata_ts, adata_ct, gene2id, tissue2id,celltype2id)
# dataset_test = SingleDataset(adata_test, gene2id, tissue2id,celltype2id)
train_loader = DataLoader(ds_train, batch_size=params.batch_size, collate_fn=collect_fn, num_workers=params.n_workers,drop_last=True, shuffle=params.shuffle)
val_loader = DataLoader(ds_val, batch_size=params.batch_size, collate_fn=collect_fn, num_workers=params.n_workers,drop_last=True)
test_loader = DataLoader(dataset_test, batch_size=params.batch_size, collate_fn=collect_fn, num_workers=params.n_workers,drop_last=True)
# ------------
# model
# ------------
model = ScT(params.n_genes, params.n_val, params.n_celltype, n_layers=params.n_layers, embed_dim=params.embed_dim)
print('The number of parameters: {}'.format(count_parameters(model)))
# model = ScT.load_from_checkpoint("model/last_"+str(24)+".ckpt",params=params)
# ------------
# training
# ------------
# wandb_logger = WandbLogger(project='ScT')
checkpoint_callback = ModelCheckpoint(monitor='train_loss')
trainer = pl.Trainer(gpus=1,
accelerator='ddp', precision=params.precision,
# logger=wandb_logger,
max_epochs=params.n_epochs, gradient_clip_val=0.5, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint("model/last_"+params.experiment+".ckpt")
# # # ------------
# # # testing
# # # ------------
trainer.test(test_dataloaders=test_loader)
if __name__ == '__main__':
cli_main()