-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdetect_image.py
More file actions
57 lines (40 loc) · 1.36 KB
/
detect_image.py
File metadata and controls
57 lines (40 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 21 18:51:51 2019
@author: myidispg
"""
import argparse
import torch
import numpy as np
import cv2
import os
from utilities.constants import threshold
from pose_detect import PoseDetect
from models.paf_model_v2 import StanceNet
parser = argparse.ArgumentParser()
parser.add_argument('image_path', type=str, help='The path to the image file.')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
image_path = args.image_path
image_path = image_path.replace('\\', '/')
if os.path.exists(image_path):
pass
else:
print('No such path exists. Please check')
exit()
detect = PoseDetect('trained_models/trained_model.pth')
# now, break the path into components
path_components = image_path.split('/')
image_name = path_components[-1].split('.')[0]
extension = path_components[-1].split('.')[1]
try:
os.mkdir('processed_images')
except FileExistsError:
pass
output_path = os.path.join(os.getcwd(), 'processed_images', f'{image_name}_keypoints.{extension}')
print(f'The processed image file will be saved in: {output_path}')
# Read the original image
orig_img = cv2.imread(image_path)
# Perform pose detection on the given image
orig_img = detect.detect_poses(orig_img, use_gpu=True)
cv2.imwrite(output_path, orig_img)