diff --git a/op/conv2d_gradfix.py b/op/conv2d_gradfix.py index 64229c5a..150ac3f5 100755 --- a/op/conv2d_gradfix.py +++ b/op/conv2d_gradfix.py @@ -1,5 +1,6 @@ import contextlib import warnings +from packaging import version import torch from torch import autograd @@ -82,7 +83,7 @@ def could_use_op(input): if input.device.type != "cuda": return False - if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): + if version.parse(torch.__version__) >= version.parse('1.7'): return True warnings.warn(