1 Star 2 Fork 1

phoneProject / pytorch-caffe-darknet-convert

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
prototxt.py 5.88 KB
一键复制 编辑 原始数据 按行查看 历史
xiaohang 提交于 2017-08-04 10:40 . add README.md
from collections import OrderedDict
try:
import caffe.proto.caffe_pb2 as caffe_pb2
except:
try:
import caffe_pb2
except:
print 'caffe_pb2.py not found. Try:'
print ' protoc caffe.proto --python_out=.'
exit()
def parse_caffemodel(caffemodel):
model = caffe_pb2.NetParameter()
print 'Loading caffemodel: ', caffemodel
with open(caffemodel, 'rb') as fp:
model.ParseFromString(fp.read())
return model
def parse_prototxt(protofile):
def line_type(line):
if line.find(':') >= 0:
return 0
elif line.find('{') >= 0:
return 1
return -1
def parse_block(fp):
block = OrderedDict()
line = fp.readline().strip()
while line != '}':
ltype = line_type(line)
if ltype == 0: # key: value
#print line
line = line.split('#')[0]
key, value = line.split(':')
key = key.strip()
value = value.strip().strip('"')
if block.has_key(key):
if type(block[key]) == list:
block[key].append(value)
else:
block[key] = [block[key], value]
else:
block[key] = value
elif ltype == 1: # blockname {
key = line.split('{')[0].strip()
sub_block = parse_block(fp)
block[key] = sub_block
line = fp.readline().strip()
line = line.split('#')[0]
return block
fp = open(protofile, 'r')
props = OrderedDict()
layers = []
line = fp.readline()
while line != '':
line = line.strip().split('#')[0]
if line == '':
line = fp.readline()
continue
ltype = line_type(line)
if ltype == 0: # key: value
key, value = line.split(':')
key = key.strip()
value = value.strip().strip('"')
if props.has_key(key):
if type(props[key]) == list:
props[key].append(value)
else:
props[key] = [props[key], value]
else:
props[key] = value
elif ltype == 1: # blockname {
key = line.split('{')[0].strip()
if key == 'layer':
layer = parse_block(fp)
layers.append(layer)
else:
props[key] = parse_block(fp)
line = fp.readline()
if len(layers) > 0:
net_info = OrderedDict()
net_info['props'] = props
net_info['layers'] = layers
return net_info
else:
return props
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def print_prototxt(net_info):
# whether add double quote
def format_value(value):
#str = u'%s' % value
#if str.isnumeric():
if is_number(value):
return value
elif value == 'true' or value == 'false' or value == 'MAX' or value == 'SUM' or value == 'AVE':
return value
else:
return '\"%s\"' % value
def print_block(block_info, prefix, indent):
blanks = ''.join([' ']*indent)
print('%s%s {' % (blanks, prefix))
for key,value in block_info.items():
if type(value) == OrderedDict:
print_block(value, key, indent+4)
elif type(value) == list:
for v in value:
print('%s %s: %s' % (blanks, key, format_value(v)))
else:
print('%s %s: %s' % (blanks, key, format_value(value)))
print('%s}' % blanks)
props = net_info['props']
layers = net_info['layers']
print('name: \"%s\"' % props['name'])
print('input: \"%s\"' % props['input'])
print('input_dim: %s' % props['input_dim'][0])
print('input_dim: %s' % props['input_dim'][1])
print('input_dim: %s' % props['input_dim'][2])
print('input_dim: %s' % props['input_dim'][3])
print('')
for layer in layers:
print_block(layer, 'layer', 0)
def save_prototxt(net_info, protofile, region=True):
fp = open(protofile, 'w')
# whether add double quote
def format_value(value):
#str = u'%s' % value
#if str.isnumeric():
if is_number(value):
return value
elif value == 'true' or value == 'false' or value == 'MAX' or value == 'SUM' or value == 'AVE':
return value
else:
return '\"%s\"' % value
def print_block(block_info, prefix, indent):
blanks = ''.join([' ']*indent)
print >>fp, '%s%s {' % (blanks, prefix)
for key,value in block_info.items():
if type(value) == OrderedDict:
print_block(value, key, indent+4)
elif type(value) == list:
for v in value:
print >> fp, '%s %s: %s' % (blanks, key, format_value(v))
else:
print >> fp, '%s %s: %s' % (blanks, key, format_value(value))
print >> fp, '%s}' % blanks
props = net_info['props']
layers = net_info['layers']
print >> fp, 'name: \"%s\"' % props['name']
print >> fp, 'input: \"%s\"' % props['input']
print >> fp, 'input_dim: %s' % props['input_dim'][0]
print >> fp, 'input_dim: %s' % props['input_dim'][1]
print >> fp, 'input_dim: %s' % props['input_dim'][2]
print >> fp, 'input_dim: %s' % props['input_dim'][3]
print >> fp, ''
for layer in layers:
if layer['type'] != 'Region' or region == True:
print_block(layer, 'layer', 0)
fp.close()
if __name__ == '__main__':
import sys
if len(sys.argv) != 2:
print('Usage: python prototxt.py model.prototxt')
exit()
net_info = parse_prototxt(sys.argv[1])
print_prototxt(net_info)
save_prototxt(net_info, 'tmp.prototxt')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/giteebytsl/pytorch-caffe-darknet-convert.git
git@gitee.com:giteebytsl/pytorch-caffe-darknet-convert.git
giteebytsl
pytorch-caffe-darknet-convert
pytorch-caffe-darknet-convert
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891