Skip to content

Commit ff1062d

Browse files
authored
Add --reset_learning_rate option to lstmtraining (#3470)
When the --reset_learning_rate option is specified, it resets the learning rate stored in each layer of the network loaded with --continue_from to the value specified by the --learning_rate option. If checkpoint is available, it does nothing.
1 parent d8bd78f commit ff1062d

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

src/lstm/lstmrecognizer.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,26 @@ class TESS_API LSTMRecognizer {
157157
series->ScaleLayerLearningRate(&id[1], factor);
158158
}
159159

160+
// Set the all the learning rate(s) to the given value.
161+
void SetLearningRate(float learning_rate)
162+
{
163+
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
164+
learning_rate_ = learning_rate;
165+
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
166+
for (auto &id : EnumerateLayers()) {
167+
SetLayerLearningRate(id, learning_rate);
168+
}
169+
}
170+
}
171+
// Set the learning rate of the layer with id, by the given value.
172+
void SetLayerLearningRate(const std::string &id, float learning_rate)
173+
{
174+
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175+
ASSERT_HOST(id.length() > 1 && id[0] == ':');
176+
auto *series = static_cast<Series *>(network_);
177+
series->SetLayerLearningRate(&id[1], learning_rate);
178+
}
179+
160180
// Converts the network to int if not already.
161181
void ConvertToInt() {
162182
if ((training_flags_ & TF_INT_MODE) == 0) {

src/lstm/plumbing.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ class Plumbing : public Network {
120120
ASSERT_HOST(lr_ptr != nullptr);
121121
*lr_ptr *= factor;
122122
}
123+
124+
// Set the learning rate for a specific layer of the stack to the given value.
125+
void SetLayerLearningRate(const char *id, float learning_rate) {
126+
float *lr_ptr = LayerLearningRatePtr(id);
127+
ASSERT_HOST(lr_ptr != nullptr);
128+
*lr_ptr = learning_rate;
129+
}
130+
123131
// Returns a pointer to the learning rate for the given layer id.
124132
TESS_API
125133
float *LayerLearningRatePtr(const char *id);

src/training/lstmtraining.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples betwe
3636
static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
3737
static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
3838
static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
39+
static BOOL_PARAM_FLAG(reset_learning_rate, false,
40+
"Resets all stored learning rates to the value specified by --learning_rate.");
3941
static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
4042
static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
4143
static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
@@ -157,6 +159,10 @@ int main(int argc, char **argv) {
157159
return EXIT_FAILURE;
158160
}
159161
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
162+
if (FLAGS_reset_learning_rate) {
163+
trainer.SetLearningRate(FLAGS_learning_rate);
164+
tprintf("Set learning rate to %f\n", static_cast<float>(FLAGS_learning_rate));
165+
}
160166
trainer.InitIterations();
161167
}
162168
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {

0 commit comments

Comments
 (0)