Skip to content

Commit 1f1dc8a

Browse files
committed
updating so that convergence must be measured twice in a row
1 parent bc599ba commit 1f1dc8a

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

gradoptorch/gradoptorch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
default_opt_settings = {
1212
"ep_g": 1e-8,
1313
"ep_a": 1e-6,
14-
"ep_r": 1e-2,
14+
"ep_r": 1e-4,
1515
"iter_lim": 1000,
1616
"restart_iter": 50, # for conjugate gradient methods gradient stability
1717
"Hessian": None,
@@ -195,6 +195,8 @@ def grad_exact(f, g, x_guess, opt_params, ls_method, ls_params):
195195
# check relative and absolute convergence criteria
196196
if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
197197
conv_count += 1
198+
else:
199+
conv_count = 0
198200

199201
x_k = x_k1
200202
f_k = f_k1
@@ -281,6 +283,8 @@ def conj_grad_fr(f, g, x_guess, opt_params, ls_method, ls_params):
281283
# check relative and absolute convergence criteria
282284
if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
283285
conv_count += 1
286+
else:
287+
conv_count = 0
284288

285289
x_k = x_k1
286290
f_k = f_k1
@@ -386,6 +390,8 @@ def conj_grad_pr(f, g, x_guess, opt_params, ls_method, ls_params):
386390
# check relative and absolute convergence criteria
387391
if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
388392
conv_count += 1
393+
else:
394+
conv_count = 0
389395

390396
x_k = x_k1
391397
f_k = f_k1
@@ -500,6 +506,8 @@ def newton_exact(f, g, x_guess, opt_params, ls_method, ls_params):
500506
# check relative and absolute convergence criteria
501507
if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
502508
conv_count += 1
509+
else:
510+
conv_count = 0
503511

504512
x_k = x_k1
505513
f_k = f_k1
@@ -585,6 +593,8 @@ def bfgs(f, g, x_guess, opt_params, ls_method, ls_params):
585593
# check relative and absolute convergence criteria
586594
if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
587595
conv_count += 1
596+
else:
597+
conv_count = 0
588598

589599
x_k = x_k1
590600
f_k = f_k1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gradoptorch"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "Classical gradient based optimization in PyTorch"
55
authors = ["Kevin Course <kevin.course@mail.utoronto.ca>"]
66
license = "GNUv3"

0 commit comments

Comments
 (0)