diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index caa97203a..092520f47 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -38,6 +38,7 @@ def main(): cfg = Config.fromfile(args.config) init_default_scope(cfg.get('default_scope', 'mmocr')) model = MODELS.build(cfg.model) + model.eval() flops = FlopCountAnalysis(model, torch.ones(input_shape))