-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
111 lines (87 loc) · 4.19 KB
/
Copy patheval.py
File metadata and controls
111 lines (87 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataloader.ST_dataset import STCacheDataset
from model.HEXST import HEXST
from utils.config import load_config
from utils.loss import STORLoss
from utils.metric import MetricCalculator
from utils.utils import set_seed, log
@torch.no_grad()
def eval_epoch(args, model, loader, criterion, device):
model.eval()
loss_sum = 0
all_preds, all_y = [], []
all_slide_ids = []
for batch_idx, batch in enumerate(loader):
slide_ids = batch["slide"]
feats_img = batch["feats_img"].to(device)
feats_gene = batch["feats_gene"].to(device)
gene_exp = batch["gene_exp"].to(device)
coords = batch["coords"]
coords = np.array(coords)[:, :, 0]
feats_img = feats_img[0] if feats_img.ndim == 3 else feats_img
feats_gene = feats_gene[0] if feats_gene.ndim == 3 else feats_gene
gene_exp = gene_exp[0] if gene_exp.ndim == 3 else gene_exp
outputs = model(feats_img ,coords)
outputs["feats_gene"] = feats_gene
loss = criterion(gene_exp, outputs)
all_preds.append(outputs["preds"].detach().cpu())
all_y.append(gene_exp.detach().cpu())
all_slide_ids.extend(slide_ids)
loss_sum += float(loss.item()) * gene_exp.shape[0]
for slide_id, pred in zip(all_slide_ids, all_preds):
torch.save(pred, os.path.join(args.data.pred_save_dir, slide_id + '.pt'))
all_preds = torch.cat(all_preds, 0) # [N_total, G]
all_y = torch.cat(all_y, 0) # [N_total, G]
total_loss = loss_sum / all_preds.shape[0]
eval_metrics = MetricCalculator.get_metrics(all_preds, all_y)
return total_loss, eval_metrics
def main(args):
set_seed(args.train.seed)
exp_tag = "{}_{}_{}_{}".format(args.model.name, args.loss.function.name, args.loss.mode.name, args.loss.type.name)
exp_tag = exp_tag + "_" + args.tag if args.tag is not None and len(args.tag) > 0 else exp_tag
args.data.base_path = os.path.join(args.data.base_path, args.data.project) + ".pt"
args.data.save_dir = os.path.join(args.data.save_dir, args.data.project, exp_tag)
args.data.pred_save_dir = os.path.join(args.data.pred_save_dir, args.data.project, exp_tag)
args.data.log_path = os.path.join(args.data.save_dir, 'log_eval.txt')
log(args.data.log_path, str(args))
os.makedirs(args.data.pred_save_dir, exist_ok=True)
test_ds = STCacheDataset(args.data.base_path, split="test")
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)
args.model.num_genes = test_ds.num_genes if args.model.num_genes is None else args.model.num_genes
log(args.data.log_path, f"[{args.data.project}] Slides: Test={len(test_ds)}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = HEXST(
input_dim=args.model.img_dim,
feat_dim=args.model.feat_dim,
hidden_dim=args.model.hidden_dim,
num_genes=args.model.num_genes,
dropout=args.model.dropout,
).to(device)
criterion = STORLoss(
args.loss.type.img, args.loss.type.feat,
args.loss.mode.regression, args.loss.mode.differential,
args.loss.function.MSE, args.loss.function.PL
)
best_model_path = os.path.join(args.data.save_dir, "model_best.pth")
model.load_state_dict(torch.load(best_model_path))
test_loss, test_metrics = eval_epoch(args, model, test_loader, criterion, device)
log(args.data.log_path, f"[TEST] Loss={test_loss:.4f}")
log(args.data.log_path, str(test_metrics))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base_config", type=str, default="./config/baseline.yaml")
parser.add_argument("--data_config", type=str, default=None)
parser.add_argument("--model_config", type=str, default=None)
parser.add_argument("--loss_function_config", type=str, default=None)
parser.add_argument("--loss_mode_config", type=str, default=None)
parser.add_argument("--loss_type_config", type=str, default=None)
cfg = parser.parse_args()
args = load_config([
cfg.base_config, cfg.data_config, cfg.model_config,
cfg.loss_function_config, cfg.loss_mode_config, cfg.loss_type_config,
])
main(args)