forked from vadimkantorov/caffemodel2json
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcaffemodel2json.py
More file actions
executable file
·82 lines (73 loc) · 2.65 KB
/
caffemodel2json.py
File metadata and controls
executable file
·82 lines (73 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import sys
import json
import argparse
import tempfile
import subprocess
def log(x, prefix = None):
if prefix != None:
x = '%8s: %s' % (prefix, x)
print >> sys.stderr, x
def pb2json(pb):
from google.protobuf.descriptor import FieldDescriptor as FD
_ftype2js = {
FD.TYPE_DOUBLE: float,
FD.TYPE_FLOAT: float,
FD.TYPE_INT64: long,
FD.TYPE_UINT64: long,
FD.TYPE_INT32: int,
FD.TYPE_FIXED64: float,
FD.TYPE_FIXED32: float,
FD.TYPE_BOOL: bool,
FD.TYPE_STRING: unicode,
FD.TYPE_BYTES: lambda x: x.encode('string_escape'),
FD.TYPE_UINT32: int,
FD.TYPE_ENUM: int,
FD.TYPE_SFIXED32: float,
FD.TYPE_SFIXED64: float,
FD.TYPE_SINT32: int,
FD.TYPE_SINT64: long,
}
js = {}
fields = pb.ListFields() #only filled (including extensions)
for field,value in fields:
if field.type == FD.TYPE_MESSAGE:
ftype = pb2json
elif field.type in _ftype2js:
ftype = _ftype2js[field.type]
else:
log("WARNING: Field %s.%s of type '%d' is not supported" % (pb.__class__.__name__, field.name, field.type, ))
if field.label == FD.LABEL_REPEATED:
js_value = []
for v in value:
js_value.append(ftype(v))
# Add the 3 commented lines if you want just a "preview" (short) mode of the json generated (for human inspection)
# note that the preview json is useless for actual deep learning processing
#
# if len(js_value) > 64 or (field.name == 'data' and len(js_value) > 8):
# head_n = 5
# js_value = js_value[:head_n] + ['(%d elements more)' % (len(js_value) - head_n)]
else:
js_value = ftype(value)
js[field.name] = js_value
return js
parser = argparse.ArgumentParser('Dump model_name.caffemodel to a file JSON format for debugging')
parser.add_argument('caffe_proto', help = 'Path to caffe.proto (typically located at CAFFE_ROOT/src/caffe/proto/caffe.proto)')
parser.add_argument('model_caffemodel', help = 'Path to model.caffemodel')
parser.add_argument('--codegenDir', help = 'Path to an existing temporary directory to save generated protobuf Python classes', default = tempfile.mkdtemp())
args = parser.parse_args()
log('calling protoc', 'protobuf')
subprocess.check_call(['protoc', '--proto_path', os.path.dirname(args.caffe_proto), '--python_out', args.codegenDir, args.caffe_proto])
log('generated', 'protobuf')
sys.path.insert(0, args.codegenDir)
import caffe_pb2
log('imported', 'protobuf')
netParam = caffe_pb2.NetParameter()
msg = open(args.model_caffemodel, 'rb').read()
log('caffemodel read in memory. Deserialization will take a few minutes. Take a coffee!', 'model')
netParam.ParseFromString(msg)
log('deserialized', 'model')
json.dump(pb2json(netParam), sys.stdout, indent = 2)
log('json saved', 'model')
log('')
log('ALLOK. Quitting')