diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index ebafd6a521..a760c5ee2a 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -1,3 +1,4 @@ +import ast # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. @@ -51,7 +52,7 @@ def __init__( self.logger = get_module_logger("TRA") self.logger.info("TRA Model...") - self.model = eval(model_type)(**model_config).to(device) + self.model = ast.literal_eval(model_type)(**model_config).to(device) if model_init_state: self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"]) if freeze_model: