diff --git a/nodes.py b/nodes.py index 96fa2e0..d1d27f2 100644 --- a/nodes.py +++ b/nodes.py @@ -110,6 +110,9 @@ def INPUT_TYPES(cls): 'crop_size': ('INT', {'default': 512, 'min': 512, 'max': 1024, 'step': 128}), 'crop_factor': ('FLOAT', {'default': 1.5, 'min': 1.0, 'max': 3, 'step': 0.1}), 'mask_type': (mask_types,) + }, + 'optional': { + 'image_override': ('IMAGE',), } } @@ -118,7 +121,7 @@ def INPUT_TYPES(cls): FUNCTION = 'run' CATEGORY = 'facetools' - def run(self, faces, crop_size, crop_factor, mask_type): + def run(self, faces, crop_size, crop_factor, mask_type, image_override=None): if len(faces) == 0: empty_crop = torch.zeros((1,512,512,3)) empty_mask = torch.zeros((1,512,512)) @@ -132,7 +135,7 @@ def run(self, faces, crop_size, crop_factor, mask_type): masks = [] warps = [] for face in faces: - M, crop = face.crop(crop_size, crop_factor) + M, crop = face.crop(crop_size, crop_factor, image_override) mask = mask_crop(face, M, crop, mask_type) crops.append(np.array(crop[0])) masks.append(np.array(mask[0])) diff --git a/utils.py b/utils.py index 363d177..cf9bdd4 100644 --- a/utils.py +++ b/utils.py @@ -109,7 +109,7 @@ def __init__(self, img, a, b, c, d) -> None: rot = cv2.getRotationMatrix2D((128*s,128*s), 90*i, 1) self.R = np.vstack((rot, np.array((0,0,1)))) - def crop(self, size, crop_factor): + def crop(self, size, crop_factor, image_override=None): S = np.array([[1/crop_factor, 0, 0], [0, 1/crop_factor, 0], [0, 0, 1]]) M = estimate_norm(self.kps, size) N = M @ self.R @ self.T2 @@ -117,7 +117,8 @@ def crop(self, size, crop_factor): T3 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]]) T4 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) N = N @ T4 @ S @ T3 - crop = cv2.warpAffine(self.img.numpy(), N, (size, size)) + img = self.img if image_override is None else image_override.squeeze() + crop = cv2.warpAffine(img.numpy(), N, (size, size)) crop = torch.from_numpy(crop)[None] return N, crop