-
Notifications
You must be signed in to change notification settings - Fork 124
Loading model #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: Develop_copy
Are you sure you want to change the base?
Loading model #168
Changes from 12 commits
01f97f0
6cb1e7b
788b517
d021f05
f12c73d
73c70c8
44891c9
9766329
cf68cef
7ba12ea
f0aa6a8
09641d0
a17a284
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# options cifar10, cifar100, ImageNet16-120 reports on their test acc is avaliable | ||
dataset: ImageNet16-120 | ||
# in the code base the deafult value for the seed is 2. | ||
# using random seeds that are logged but log files are not provided | ||
# not mentioned in the paper what are the random seeds are | ||
seed: 99 | ||
# darts (or nb301) | ||
# nb201 | ||
search_space: nasbench301 | ||
out_dir: run | ||
optimizer: drnas | ||
|
||
search: | ||
checkpoint_freq: 5 | ||
# default value batch size in code is 64 | ||
batch_size: 64 | ||
# lr_rate for progressive and original: 0.025 | ||
learning_rate: 0.025 | ||
# lr_rate for progressive and original: 0.025 | ||
learning_rate_min: 0.001 | ||
momentum: 0.9 | ||
# weight_decay for progressive and original: 0.0003 | ||
weight_decay: 0.0003 | ||
# for cifar10 the learning process is 2 stages of 25 epochs each | ||
# in code it states that the number of training epochs has the default value of 100 in nb201 | ||
epochs: 100 | ||
warm_start_epochs: 0 | ||
grad_clip: 5 | ||
# for cifar10 the train and optimization data (50k) is equally partitioned | ||
train_portion: 0.5 | ||
# for cifar10 the train and optimization data (50k) is equally partitioned | ||
data_size: 25000 | ||
|
||
# for the four args the values are same for oridinary and progressive mode for nb201 | ||
cutout: False | ||
cutout_length: 16 | ||
cutout_prob: 1.0 | ||
drop_path_prob: 0.0 | ||
|
||
# for nb201 this value is false | ||
unrolled: False | ||
arch_learning_rate: 0.0003 | ||
# not mentiond for progressive mode but for ordinary it is 1e-3 in nb201 | ||
arch_weight_decay: 0.001 | ||
output_weights: True | ||
|
||
fidelity: 200 | ||
|
||
# GDAS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to make yaml files generally more readable, should focus only on specific optimizer settings @Neonkraft ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The darts_defualts.yaml was reverted to the format of the Develop_copy branch. |
||
tau_max: 10 | ||
tau_min: 0.1 | ||
|
||
# RE | ||
sample_size: 10 | ||
population_size: 100 | ||
|
||
#LS | ||
num_init: 10 | ||
|
||
#GSparsity-> Uncomment the lines below for GSparsity | ||
#seed: 50 | ||
#grad_clip: 0 | ||
#threshold: 0.000001 | ||
#weight_decay: 120 | ||
#learning_rate: 0.01 | ||
#momentum: 0.8 | ||
#normalization: div | ||
#normalization_exponent: 0.5 | ||
#batch_size: 256 | ||
#learning_rate_min: 0.0001 | ||
#epochs: 100 | ||
#warm_start_epochs: 0 | ||
#train_portion: 0.9 | ||
#data_size: 25000 | ||
|
||
|
||
# BANANAS | ||
k: 10 | ||
num_ensemble: 3 | ||
acq_fn_type: its | ||
acq_fn_optimization: mutation | ||
encoding_type: path | ||
num_arches_to_mutate: 2 | ||
max_mutations: 1 | ||
num_candidates: 100 | ||
|
||
# BasePredictor | ||
predictor_type: var_sparse_gp | ||
debug_predictor: False | ||
|
||
evaluation: | ||
checkpoint_freq: 30 | ||
# Neither the paper nor the code base indicates the batch size but the default value is 64 | ||
batch_size: 64 | ||
|
||
learning_rate: 0.025 | ||
learning_rate_min: 0.00 | ||
# momentum is 0.9 | ||
momentum: 0.9 | ||
# for cifar weight_decay is 3e-4 | ||
weight_decay: 0.0003 | ||
# cifar's eval is 600 epochs, for imagenet it is 250 | ||
epochs: 250 | ||
# for image net it has 5 epochs of warm starting | ||
warm_start_epochs: 5 | ||
grad_clip: 5 | ||
# uses the whole training data of cifar10 (50K) to train from scratch for 600 epochs | ||
train_portion: 1. | ||
data_size: 50000 | ||
|
||
# cifar10 the cutout is done to have fair comparisons with previous work | ||
cutout: True | ||
# cifar10 cutout length is 16 | ||
cutout_length: 16 | ||
# cifar10 the cutout is done to have fair comparisons with previous work | ||
cutout_prob: 1.0 | ||
# cifar drop out is 0.3 | ||
drop_path_prob: 0.2 | ||
# cifar auxiliary is 0.4 | ||
auxiliary_weight: 0.4 | ||
|
||
|
||
|
||
# has a partial channel variable that for oridinary is 1 and in progressive mode has 4 as the default value. | ||
# mentions some things about regularization scale of l2 and kl (used for dirichlet) in code of nb201 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,7 +88,6 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
np.random.seed(self.config.search.seed) | ||
torch.manual_seed(self.config.search.seed) | ||
|
||
self.optimizer.before_training() | ||
checkpoint_freq = self.config.search.checkpoint_freq | ||
if self.optimizer.using_step_function: | ||
self.scheduler = self.build_search_scheduler( | ||
|
@@ -101,6 +100,8 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
else: | ||
start_epoch = self._setup_checkpointers(resume_from, period=checkpoint_freq) | ||
|
||
self.optimizer.before_training() | ||
|
||
if self.optimizer.using_step_function: | ||
self.train_queue, self.valid_queue, _ = self.build_search_dataloaders( | ||
self.config | ||
|
@@ -146,7 +147,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
|
||
self.train_loss.update(float(train_loss.detach().cpu())) | ||
self.val_loss.update(float(val_loss.detach().cpu())) | ||
|
||
# break | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. Please remove. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'break' was used for debugging and got removed in the new commit. |
||
self.scheduler.step() | ||
|
||
end_time = time.time() | ||
|
@@ -179,7 +180,10 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int | |
self.train_top1.avg = train_acc | ||
self.val_top1.avg = valid_acc | ||
|
||
self.periodic_checkpointer.step(e) | ||
# arch_weights = self.optimizer.get_checkpointables()["arch_weights"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. commented code has been removed. |
||
add_checkpointables = self.optimizer.get_checkpointables() | ||
del add_checkpointables["model"] | ||
self.periodic_checkpointer.step(e, **add_checkpointables) | ||
|
||
anytime_results = self.optimizer.test_statistics() | ||
# if anytime_results: | ||
|
@@ -216,8 +220,8 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): | |
evaluate with the current one-shot weights. | ||
""" | ||
logger.info("Start one-shot evaluation") | ||
self.optimizer.before_training() | ||
self._setup_checkpointers(resume_from) | ||
self.optimizer.before_training() | ||
|
||
loss = torch.nn.CrossEntropyLoss() | ||
|
||
|
@@ -286,7 +290,7 @@ def evaluate( | |
best_arch = self.optimizer.get_final_architecture() | ||
logger.info(f"Final architecture hash: {best_arch.get_hash()}") | ||
|
||
if best_arch.QUERYABLE: | ||
if best_arch.QUERYABLE and (not retrain): | ||
if metric is None: | ||
metric = Metric.TEST_ACCURACY | ||
result = best_arch.query( | ||
|
@@ -408,8 +412,10 @@ def evaluate( | |
logits_valid, target_valid, "val" | ||
) | ||
|
||
arch_weights = self.optimizer.get_checkpointables()["arch_weights"] | ||
|
||
scheduler.step() | ||
self.periodic_checkpointer.step(e) | ||
self.periodic_checkpointer.step(iteration=e, arch_weights=arch_weights) | ||
self._log_and_reset_accuracies(e) | ||
|
||
# Disable drop path | ||
|
@@ -585,8 +591,11 @@ def _setup_checkpointers( | |
|
||
if resume_from: | ||
logger.info("loading model from file {}".format(resume_from)) | ||
checkpoint = checkpointer.resume_or_load(resume_from, resume=True) | ||
# if resume=True starts from the last_checkpoint | ||
# if resume=False starts from the path mentioned as resume_from | ||
checkpoint = checkpointer.resume_or_load(resume_from, resume=False) | ||
if checkpointer.has_checkpoint(): | ||
self.optimizer.set_checkpointables(checkpoint) | ||
return checkpoint.get("iteration", -1) + 1 | ||
return 0 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,7 @@ def single_evaluate(self, test_data, zc_api): | |
logger.info("Querying the predictor") | ||
query_time_start = time.time() | ||
|
||
# TODO: shouldn't mode="val" be passed? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense for me. |
||
_, _, test_loader, _, _ = utils.get_train_val_loaders(self.config) | ||
|
||
# Iterate over the architectures, instantiate a graph with each architecture | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,3 +121,9 @@ def get_checkpointables(self): | |
(dict): with name as key and object as value. e.g. graph, arch weights, optimizers, ... | ||
""" | ||
pass | ||
|
||
def set_checkpointables(self, architectural_weights): | ||
""" | ||
would set the objects saved in the checkpoint during last phase of training | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since other functions also include this, maybe add a description of parameters and return types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, agreed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type of the Args has been specified. This function has no return value. |
||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -318,3 +318,4 @@ def get_arch_as_string(self, arch): | |
else: | ||
str_arch = str(arch) | ||
return str_arch | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,12 @@ def new_epoch(self, epoch): | |
""" | ||
Just log the architecture weights. | ||
""" | ||
# print("=====================================") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this code added? Can it be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The extra code which was used for debugging is removed. |
||
# if self.architectural_weights.is_cuda: | ||
# print("The tensor is on a GPU with index:", self.architectural_weights.get_device()) | ||
# else: | ||
# print("The tensor is not on a GPU.") | ||
# print("=====================================") | ||
alpha_str = [ | ||
", ".join(["{:+.06f}".format(x) for x in a]) | ||
+ ", {}".format(np.argmax(a.detach().cpu().numpy())) | ||
|
@@ -200,6 +206,11 @@ def get_op_optimizer(self): | |
def get_model_size(self): | ||
return count_parameters_in_MB(self.graph) | ||
|
||
def set_checkpointables(self, checkpointables): | ||
self.op_optimizer = checkpointables.get("op_optimizer") | ||
self.arch_optimizer = checkpointables.get("arch_optimizer") | ||
self.architectural_weights = checkpointables.get("arch_weights") | ||
|
||
def test_statistics(self): | ||
# nb301 is not there but we use it anyways to generate the arch strings. | ||
# if self.graph.QUERYABLE: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this change only used for testing? 50 epochs is also stated in the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please revert to 50.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted the epochs to its original value.