diff --git a/mmdnn/conversion/pytorch/pytorch_graph.py b/mmdnn/conversion/pytorch/pytorch_graph.py index 91136209..fe4de165 100644 --- a/mmdnn/conversion/pytorch/pytorch_graph.py +++ b/mmdnn/conversion/pytorch/pytorch_graph.py @@ -132,7 +132,10 @@ def build(self, shape): output_str = node.__str__().split('=')[0] output_shape_str = re.findall(r'[^()!]+', output_str) if len(output_shape_str) > 1: - output_shape = [int(x.replace('!', '')) for x in output_shape_str[1].split(',')] + try: + output_shape = [int(x.replace('!', '').split(':')[0]) for x in output_shape_str[1].split(',')] + except: + output_shape = None else: output_shape = None self.shape_dict[node_name] = output_shape