Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions flopco_keras/compute_layer_flops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

def numel(w : list):
out = 1
for k in w:
Expand All @@ -20,27 +19,26 @@ def compute_conv2d_flops(layer, macs = False):

# _, cin, h, w = input_shape
if layer.data_format == "channels_first":
_, input_channels, _, _ = layer.input_shape
_, output_channels, h, w, = layer.output_shape
_, input_channels, _, _ = layer.input.shape
_, output_channels, h, w, = layer.output.shape
elif layer.data_format == "channels_last":
_, _, _, input_channels = layer.input_shape
_, h, w, output_channels = layer.output_shape
_, _, _, input_channels = layer.input.shape
_, h, w, output_channels = layer.output.shape

w_h, w_w = layer.kernel_size

# flops = h * w * output_channels * input_channels * w_h * w_w / (stride**2)
flops = h * w * output_channels * input_channels * w_h * w_w


if not macs:
flops_bias = numel(layer.output_shape[1:]) if layer.use_bias is not None else 0
flops_bias = numel(layer.output.shape[1:]) if layer.use_bias is not None else 0
flops = 2 * flops + flops_bias

return int(flops)


def compute_fc_flops(layer, macs = False):
ft_in, ft_out = layer.input_shape[-1], layer.output_shape[-1]
ft_in, ft_out = layer.input.shape[-1], layer.output.shape[-1]
flops = ft_in * ft_out

if not macs:
Expand All @@ -51,19 +49,18 @@ def compute_fc_flops(layer, macs = False):

def compute_bn2d_flops(layer, macs = False):
# subtract, divide, gamma, beta
flops = 2 * numel(layer.input_shape[1:])
flops = 2 * numel(layer.input.shape[1:])

if not macs:
flops *= 2

return int(flops)


def compute_relu_flops(layer, macs = False):

flops = 0
if not macs:
flops = numel(layer.input_shape[1:])
flops = numel(layer.input.shape[1:])

return int(flops)

Expand All @@ -72,7 +69,7 @@ def compute_maxpool2d_flops(layer, macs = False):

flops = 0
if not macs:
flops = layer.pool_size[0]**2 * numel(layer.output_shape[1:])
flops = layer.pool_size[0]**2 * numel(layer.output.shape[1:])

return flops

Expand All @@ -81,24 +78,24 @@ def compute_pool2d_flops(layer, macs = False):

flops = 0
if not macs:
flops = layer.pool_size[0]**2 * numel(layer.output_shape[1:])
flops = layer.pool_size[0]**2 * numel(layer.output.shape[1:])

return flops

def compute_globalavgpool2d_flops(layer, macs = False):

if layer.data_format == "channels_first":
_, input_channels, h, w = layer.input_shape
_, output_channels = layer.output_shape
_, input_channels, h, w = layer.input.shape
_, output_channels = layer.output.shape
elif layer.data_format == "channels_last":
_, h, w, input_channels = layer.input_shape
_, output_channels = layer.output_shape
_, h, w, input_channels = layer.input.shape
_, output_channels = layer.output.shape

return h*w

def compute_softmax_flops(layer, macs = False):

nfeatures = numel(layer.input_shape[1:])
nfeatures = numel(layer.input.shape[1:])

total_exp = nfeatures # https://stackoverflow.com/questions/3979942/what-is-the-complexity-real-cost-of-exp-in-cmath-compared-to-a-flop
total_add = nfeatures - 1
Expand Down