diff --git a/README.md b/README.md index ea16dd3..65c688a 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,8 @@ Note you might want to change the `batch_size` in the config file if you meet OU # caching the generated motions (seed included) to `./outputs` python evaluator.py --config_path ./configs/fact_v5_deeper_t10_cm12.config --model_dir ./checkpoints # calculate FIDs -python tools/calculate_scores.py +python tools/extract_aist_features.py +python tools/calculate_fid_scores.py ``` diff --git a/tools/calculate_beat_scores.py b/tools/calculate_beat_scores.py new file mode 100644 index 0000000..b0e55d1 --- /dev/null +++ b/tools/calculate_beat_scores.py @@ -0,0 +1,214 @@ +from absl import app +from absl import flags +from absl import logging + +import os +from librosa import beat +import torch +import numpy as np +import pickle +from scipy.spatial.transform import Rotation as R +import scipy.signal as scisignal +from aist_plusplus.loader import AISTDataset + + +FLAGS = flags.FLAGS +flags.DEFINE_string( + 'anno_dir', '/mnt/data/aist_plusplus_final/', + 'Path to the AIST++ annotation files.') +flags.DEFINE_string( + 'audio_dir', '/mnt/data/AIST/music/', + 'Path to the AIST wav files.') +flags.DEFINE_string( + 'audio_cache_dir', './data/aist_audio_feats/', + 'Path to cache dictionary for audio features.') +flags.DEFINE_enum( + 'split', 'testval', ['train', 'testval'], + 'Whether do training set or testval set.') +flags.DEFINE_string( + 'result_files', '/mnt/data/aist_paper_results/*.pkl', + 'The path pattern of the result files.') +flags.DEFINE_bool( + 'legacy', True, + 'Whether the result files are the legacy version.') + + +def eye(n, batch_shape): + iden = np.zeros(np.concatenate([batch_shape, [n, n]])) + iden[..., 0, 0] = 1.0 + iden[..., 1, 1] = 1.0 + iden[..., 2, 2] = 1.0 + return iden + + +def get_closest_rotmat(rotmats): + """ + Finds the rotation matrix that is closest to the inputs in terms of the Frobenius norm. For each input matrix + it computes the SVD as R = USV' and sets R_closest = UV'. Additionally, it is made sure that det(R_closest) == 1. + Args: + rotmats: np array of shape (..., 3, 3). + Returns: + A numpy array of the same shape as the inputs. + """ + u, s, vh = np.linalg.svd(rotmats) + r_closest = np.matmul(u, vh) + + # if the determinant of UV' is -1, we must flip the sign of the last column of u + det = np.linalg.det(r_closest) # (..., ) + iden = eye(3, det.shape) + iden[..., 2, 2] = np.sign(det) + r_closest = np.matmul(np.matmul(u, iden), vh) + return r_closest + + +def recover_to_axis_angles(motion): + batch_size, seq_len, dim = motion.shape + assert dim == 225 + transl = motion[:, :, 6:9] + rotmats = get_closest_rotmat( + np.reshape(motion[:, :, 9:], (batch_size, seq_len, 24, 3, 3)) + ) + axis_angles = R.from_matrix( + rotmats.reshape(-1, 3, 3) + ).as_rotvec().reshape(batch_size, seq_len, 24, 3) + return axis_angles, transl + + +def recover_motion_to_keypoints(motion, smpl_model): + smpl_poses, smpl_trans = recover_to_axis_angles(motion) + smpl_poses = np.squeeze(smpl_poses, axis=0) # (seq_len, 24, 3) + smpl_trans = np.squeeze(smpl_trans, axis=0) # (seq_len, 3) + keypoints3d = smpl_model.forward( + global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(), + body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), + transl=torch.from_numpy(smpl_trans).float(), + ).joints.detach().numpy()[:, :24, :] # (seq_len, 24, 3) + return keypoints3d + + +def motion_peak_onehot(joints): + """Calculate motion beats. + Kwargs: + joints: [nframes, njoints, 3] + Returns: + - peak_onhot: motion beats. + """ + # Calculate velocity. + velocity = np.zeros_like(joints, dtype=np.float32) + velocity[1:] = joints[1:] - joints[:-1] + velocity_norms = np.linalg.norm(velocity, axis=2) + envelope = np.sum(velocity_norms, axis=1) # (seq_len,) + + # Find local minima in velocity -- beats + peak_idxs = scisignal.argrelextrema(envelope, np.less, axis=0, order=10) # 10 for 60FPS + peak_onehot = np.zeros_like(envelope, dtype=bool) + peak_onehot[peak_idxs] = 1 + + # # Second-derivative of the velocity shows the energy of the beats + # peak_energy = np.gradient(np.gradient(envelope)) # (seq_len,) + # # optimize peaks + # peak_onehot[peak_energy<0.001] = 0 + return peak_onehot + + +def alignment_score(music_beats, motion_beats, sigma=3): + """Calculate alignment score between music and motion.""" + if motion_beats.sum() == 0: + return 0.0 + music_beat_idxs = np.where(music_beats)[0] + motion_beat_idxs = np.where(motion_beats)[0] + score_all = [] + for motion_beat_idx in motion_beat_idxs: + dists = np.abs(music_beat_idxs - motion_beat_idx).astype(np.float32) + ind = np.argmin(dists) + score = np.exp(- dists[ind]**2 / 2 / sigma**2) + score_all.append(score) + return sum(score_all) / len(score_all) + + +def main(_): + import glob + import tqdm + from smplx import SMPL + + # set smpl + smpl = SMPL(model_path="/mnt/data/smpl/", gender='MALE', batch_size=1) + + # create list + seq_names = [] + if "train" in FLAGS.split: + seq_names += np.loadtxt( + os.path.join(FLAGS.anno_dir, "splits/crossmodal_train.txt"), dtype=str + ).tolist() + if "val" in FLAGS.split: + seq_names += np.loadtxt( + os.path.join(FLAGS.anno_dir, "splits/crossmodal_val.txt"), dtype=str + ).tolist() + if "test" in FLAGS.split: + seq_names += np.loadtxt( + os.path.join(FLAGS.anno_dir, "splits/crossmodal_test.txt"), dtype=str + ).tolist() + ignore_list = np.loadtxt( + os.path.join(FLAGS.anno_dir, "ignore_list.txt"), dtype=str + ).tolist() + seq_names = [name for name in seq_names if name not in ignore_list] + + # calculate score on real data + dataset = AISTDataset(FLAGS.anno_dir) + n_samples = len(seq_names) + beat_scores = [] + for i, seq_name in enumerate(seq_names): + logging.info("processing %d / %d" % (i + 1, n_samples)) + # get real data motion beats + smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion( + dataset.motion_dir, seq_name) + smpl_trans /= smpl_scaling + keypoints3d = smpl.forward( + global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(), + body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), + transl=torch.from_numpy(smpl_trans).float(), + ).joints.detach().numpy()[:, :24, :] # (seq_len, 24, 3) + motion_beats = motion_peak_onehot(keypoints3d) + # get real data music beats + audio_name = seq_name.split("_")[4] + audio_feature = np.load(os.path.join(FLAGS.audio_cache_dir, f"{audio_name}.npy")) + audio_beats = audio_feature[:keypoints3d.shape[0], -1] # last dim is the music beats + # get beat alignment scores + beat_score = alignment_score(audio_beats, motion_beats, sigma=3) + beat_scores.append(beat_score) + print ("\nBeat score on real data: %.3f\n" % (sum(beat_scores) / n_samples)) + + # calculate score on generated motion data + result_files = sorted(glob.glob(FLAGS.result_files)) + result_files = [f for f in result_files if f[-8:-4] in f[:-8]] + if FLAGS.legacy: + # for some reason there are repetitive results. Skip them + result_files = {f[-34:]: f for f in result_files} + result_files = result_files.values() + n_samples = len(result_files) + beat_scores = [] + for result_file in tqdm.tqdm(result_files): + if FLAGS.legacy: + with open(result_file, "rb") as f: + data = pickle.load(f) + result_motion = np.concatenate([ + np.pad(data["pred_trans"], ((0, 0), (0, 0), (6, 0))), + data["pred_motion"].reshape(1, -1, 24 * 9) + ], axis=-1) # [1, 120 + 1200, 225] + else: + result_motion = np.load(result_file)[None, ...] # [1, 120 + 1200, 225] + keypoints3d = recover_motion_to_keypoints(result_motion, smpl) + motion_beats = motion_peak_onehot(keypoints3d) + if FLAGS.legacy: + audio_beats = data["audio_beats"][0] > 0.5 + else: + audio_name = result_file[-8:-4] + audio_feature = np.load(os.path.join(FLAGS.audio_cache_dir, f"{audio_name}.npy")) + audio_beats = audio_feature[:, -1] # last dim is the music beats + beat_score = alignment_score(audio_beats[120:], motion_beats[120:], sigma=3) + beat_scores.append(beat_score) + print ("\nBeat score on generated data: %.3f\n" % (sum(beat_scores) / n_samples)) + + +if __name__ == '__main__': + app.run(main) diff --git a/tools/calculate_scores.py b/tools/calculate_fid_scores.py similarity index 85% rename from tools/calculate_scores.py rename to tools/calculate_fid_scores.py index 81bb49c..d72b918 100644 --- a/tools/calculate_scores.py +++ b/tools/calculate_fid_scores.py @@ -151,7 +151,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + np.trace(sigma2) - 2 * tr_covmean) -def extract_feature(motion, smpl_model, mode="kinetic"): +def recover_motion_to_keypoints(motion, smpl_model): smpl_poses, smpl_trans = recover_to_axis_angles(motion) smpl_poses = np.squeeze(smpl_poses, axis=0) # (seq_len, 24, 3) smpl_trans = np.squeeze(smpl_trans, axis=0) # (seq_len, 3) @@ -160,7 +160,10 @@ def extract_feature(motion, smpl_model, mode="kinetic"): body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(), transl=torch.from_numpy(smpl_trans).float(), ).joints.detach().numpy()[:, :24, :] # (seq_len, 24, 3) + return keypoints3d + +def extract_feature(keypoints3d, mode="kinetic"): if mode == "kinetic": feature = extract_kinetic_features(keypoints3d) elif mode == "manual": @@ -170,6 +173,20 @@ def extract_feature(motion, smpl_model, mode="kinetic"): return feature # (f_dim,) +def calculate_avg_distance(feature_list, mean=None, std=None): + feature_list = np.stack(feature_list) + n = feature_list.shape[0] + # normalize the scale + if (mean is not None) and (std is not None): + feature_list = (feature_list - mean) / std + dist = 0 + for i in range(n): + for j in range(i + 1, n): + dist += np.linalg.norm(feature_list[i] - feature_list[j]) + dist /= (n * n - n) / 2 + return dist + + def calculate_frechet_feature_distance(feature_list1, feature_list2): feature_list1 = np.stack(feature_list1) feature_list2 = np.stack(feature_list2) @@ -180,13 +197,14 @@ def calculate_frechet_feature_distance(feature_list1, feature_list2): feature_list1 = (feature_list1 - mean) / std feature_list2 = (feature_list2 - mean) / std - dist = calculate_frechet_distance( + frechet_dist = calculate_frechet_distance( mu1=np.mean(feature_list1, axis=0), sigma1=np.cov(feature_list1, rowvar=False), mu2=np.mean(feature_list2, axis=0), sigma2=np.cov(feature_list2, rowvar=False), ) - return dist + avg_dist = calculate_avg_distance(feature_list2) + return frechet_dist, avg_dist if __name__ == "__main__": @@ -199,31 +217,34 @@ def calculate_frechet_feature_distance(feature_list1, feature_list2): "kinetic": [np.load(f) for f in glob.glob("./data/aist_features/*_kinetic.npy")], "manual": [np.load(f) for f in glob.glob("./data/aist_features/*_manual.npy")], } - + # set smpl smpl = SMPL(model_path="/mnt/data/smpl/", gender='MALE', batch_size=1) # get motion features for the results result_features = {"kinetic": [], "manual": []} result_files = glob.glob("outputs/*.npy") + # result_files = [f for f in result_files if f[-8:-4] in f[:-8]] + beat_alignment_score = 0 for result_file in tqdm.tqdm(result_files): result_motion = np.load(result_file)[None, ...] # [1, 120 + 1200, 225] # visualize(result_motion, smpl) - result_features["kinetic"].append( - extract_feature(result_motion[:, 120:], smpl, "kinetic")) - result_features["manual"].append( - extract_feature(result_motion[:, 120:], smpl, "manual")) - + keypoints3d = recover_motion_to_keypoints(result_motion[:, 120:], smpl) + result_features["kinetic"].append(extract_feature(keypoints3d, "kinetic")) + result_features["manual"].append(extract_feature(keypoints3d, "manual")) + # FID metrics - FID_k = calculate_frechet_feature_distance( + FID_k, Dist_k = calculate_frechet_feature_distance( real_features["kinetic"], result_features["kinetic"]) - FID_g = calculate_frechet_feature_distance( + FID_g, Dist_g = calculate_frechet_feature_distance( real_features["manual"], result_features["manual"]) - # Evaluation: FID_k: ~38, FID_g: ~27 + # Evaluation: FID_k: ~32, FID_g: ~17 + # Evaluation: Dist_k: ~6, Dist_g: ~6 # The AIChoreo paper used a bugged version of manual feature extractor from # fairmotion (see here: https://github.com/facebookresearch/fairmotion/issues/50) # So the FID_g here does not match with the paper. But this value should be correct. # In this aistplusplus_api repo the feature extractor bug has been fixed. # (see here: https://github.com/google/aistplusplus_api/blob/main/aist_plusplus/features/manual.py#L50) - print('\nEvaluation: FID_k: {:.4f}, FID_g: {:.4f}\n'.format(FID_k, FID_g)) + print('\nEvaluation: FID_k: {:.4f}, FID_g: {:.4f}'.format(FID_k, FID_g)) + print('Evaluation: Dist_k: {:.4f}, Dist_g: {:.4f}\n'.format(Dist_k, Dist_g)) diff --git a/tools/extract_aist_features.py b/tools/extract_aist_features.py index d85b406..29e8a9f 100644 --- a/tools/extract_aist_features.py +++ b/tools/extract_aist_features.py @@ -65,5 +65,5 @@ def main(seq_name, motion_dir): # processing process = functools.partial(main, motion_dir=aist_dataset.motion_dir) - pool = multiprocessing.Pool(12) + pool = multiprocessing.Pool(8) pool.map(process, seq_names) \ No newline at end of file diff --git a/tools/preprocessing.py b/tools/preprocessing.py index c5cae37..d141b37 100644 --- a/tools/preprocessing.py +++ b/tools/preprocessing.py @@ -20,7 +20,7 @@ 'audio_dir', '/mnt/data/AIST/music/', 'Path to the AIST wav files.') flags.DEFINE_string( - 'audio_cache_dir', '/tmp/aist_audio_feats/', + 'audio_cache_dir', './data/aist_audio_feats/', 'Path to cache dictionary for audio features.') flags.DEFINE_enum( 'split', 'train', ['train', 'testval'],