diff --git a/calflops/pytorch_ops.py b/calflops/pytorch_ops.py index b52b9fa..f95bede 100644 --- a/calflops/pytorch_ops.py +++ b/calflops/pytorch_ops.py @@ -292,7 +292,7 @@ def _einsum_flops_compute(equation, *operands): Count flops for the einsum operation. """ equation = equation.replace(" ", "") - input_shapes = [o.shape for o in operands] + input_shapes = [o_element.shape for o in operands for o_element in (o if isinstance(o, list) else [o])] # Re-map equation so that same equation with different alphabet # representations will look the same.