Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c711fa0
new constructor for stop window
bqth29 Jan 2, 2024
37e0de8
new definition of final spins
bqth29 Jan 2, 2024
3e223ab
remove converged agents from oscillators
bqth29 Jan 2, 2024
5f9f325
docstring
bqth29 Jan 2, 2024
3859311
replaced stop window by convergence checker
bqth29 Jan 2, 2024
a21dcfe
Rename variable
bqth29 Jan 4, 2024
e3b70b3
Merge branch 'main' into feature/more_efficient_window
bqth29 Jan 12, 2024
0d08769
Merge branch 'main' into rename_stop_window
bqth29 Jan 12, 2024
edee05b
Merge branch 'main' into feature/more_efficient_window
bqth29 Jan 12, 2024
0756fa4
Merge branch 'main' into rename_stop_window
bqth29 Jan 12, 2024
fd15233
package version in doc conf
bqth29 Jan 12, 2024
384a65a
Merge branch 'rename_stop_window' of https://github.com/bqth29/simula…
bqth29 Jan 12, 2024
60c6880
non systematic reshape
bqth29 Jan 12, 2024
dddf7f1
sinc with stop window branch
bqth29 Jan 12, 2024
85b26a2
Merge branch 'main' into feature/more_efficient_window
bqth29 Feb 8, 2024
9cde4f5
Merge branch 'main' into rename_stop_window
bqth29 Feb 8, 2024
99bbddb
lint black 24.1.1
bqth29 Feb 8, 2024
64bface
lint black 24.1.1
bqth29 Feb 8, 2024
9095838
Merge branch 'main' into feature/more_efficient_window
bqth29 Mar 25, 2024
e15d943
Merge branch 'main' into rename_stop_window
bqth29 Mar 25, 2024
a9112ec
Merge branch 'main' into rename_stop_window
bqth29 Jun 17, 2024
67429a6
Merge branch 'main' into feature/more_efficient_window
bqth29 Jun 17, 2024
debf45f
Merge branch 'main' into rename_stop_window
bqth29 Jun 17, 2024
5f22d94
Merge branch 'main' into feature/more_efficient_window
bqth29 Jun 17, 2024
006ea92
Merge branch 'main' into feature/more_efficient_window
bqth29 Sep 29, 2024
3a8a25c
Merge branch 'main' into rename_stop_window
bqth29 Sep 29, 2024
9c0dde9
class definition
bqth29 Sep 29, 2024
2af52d9
revert class definition
bqth29 Sep 29, 2024
c7984b6
Merge pull request #56 from bqth29/rename_stop_window
bqth29 Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,9 @@ def __init_quadratic_scale_parameter(self, matrix: torch.Tensor):

def __init_window(self, matrix: torch.Tensor, use_window: bool) -> None:
self.window = StopWindow(
self.convergence_threshold,
matrix,
self.agents,
self.convergence_threshold,
matrix.dtype,
matrix.device,
(self.verbose and use_window),
)

Expand All @@ -150,7 +148,9 @@ def __step_update(self) -> None:

def __check_stop(self, use_window: bool) -> None:
if use_window and self.__do_sampling:
self.run = self.window.must_continue()
stored_spins = self.window.get_stored_spins()
some_agents_not_converged = torch.any(torch.eq(stored_spins, 0)).item()
self.run = some_agents_not_converged
if not self.run:
LOGGER.info("Optimizer stopped. Reason: all agents converged.")
return
Expand Down Expand Up @@ -208,13 +208,22 @@ def __symplectic_update(
self.__step_update()
if use_window and self.__do_sampling:
sampled_spins = self.symplectic_integrator.sample_spins()
self.window.update(sampled_spins)
not_converged_agents = self.window.update(sampled_spins)
self.__remove_converged_agents(not_converged_agents)

self.__check_stop(use_window)

sampled_spins = self.symplectic_integrator.sample_spins()
return sampled_spins

def __remove_converged_agents(self, not_converged_agents: torch.Tensor):
self.symplectic_integrator.momentum = self.symplectic_integrator.momentum[
:, not_converged_agents
]
self.symplectic_integrator.position = self.symplectic_integrator.position[
:, not_converged_agents
]

def __heat(self, momentum_copy: torch.Tensor) -> None:
torch.add(
self.symplectic_integrator.momentum,
Expand Down Expand Up @@ -292,8 +301,11 @@ def get_final_spins(self, spins: torch.Tensor, use_window: bool) -> torch.Tensor
torch.Tensor
"""
if use_window:
if not self.window.has_bifurcated_spins():
final_spins = self.window.get_stored_spins()
any_converged_agents = torch.any(torch.not_equal(final_spins, 0)).item()
if not any_converged_agents:
warnings.warn(ConvergenceWarning(), stacklevel=2)
return self.window.get_bifurcated_spins(spins)
final_spins[:, torch.all(torch.eq(final_spins, 0), dim=0)] = spins
return final_spins
else:
return spins
250 changes: 148 additions & 102 deletions src/simulated_bifurcation/optimizer/stop_window.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,84 @@
from typing import Tuple, Union

import torch
from tqdm import tqdm


class StopWindow:

"""
Optimization tool to monitor spins bifurcation and convergence
for the Simulated Bifurcation (SB) algorithm.
Allows an early stopping of the iterations and saves computation time.
Optimization tool to monitor agents bifurcation and convergence for the Simulated
Bifurcation (SB) algorithm. Allows an early stopping of the iterations and saves
computation time.
"""

def __init__(
self,
convergence_threshold: int,
ising_tensor: torch.Tensor,
n_agents: int,
convergence_threshold: int,
dtype: torch.dtype,
device: Union[str, torch.device],
verbose: bool,
) -> None:
self.__check_convergence_threshold(convergence_threshold)
self.convergence_threshold = convergence_threshold
self.ising_tensor = ising_tensor
self.n_spins = self.ising_tensor.shape[0]
self.n_agents = n_agents
self.__init_convergence_threshold(convergence_threshold)
self.dtype = dtype
self.device = device
self.__init_tensors()
self.__init_energies()
self.final_spins = self.__init_spins()
self.progress = self.__init_progress_bar(verbose)

@property
def shape(self) -> Tuple[int, int]:
return self.n_spins, self.n_agents

def __init_progress_bar(self, verbose: bool) -> tqdm:
return tqdm(
total=self.n_agents,
self.stability = torch.zeros(
n_agents, dtype=torch.int16, device=ising_tensor.device
)
self.energies = torch.tensor(
[float("inf") for _ in range(n_agents)],
dtype=ising_tensor.dtype,
device=ising_tensor.device,
)
self.progress = tqdm(
total=n_agents,
desc="🏁 Bifurcated agents",
disable=not verbose,
smoothing=0,
unit=" agents",
)
self.stored_spins = torch.zeros(
ising_tensor.shape[0],
n_agents,
dtype=ising_tensor.dtype,
device=ising_tensor.device,
)
self.shifted_agents_indices = torch.tensor(
list(range(n_agents)), device=ising_tensor.device
)

def __init_convergence_threshold(self, convergence_threshold: int) -> None:
def __compute_energies(self, sampled_spins: torch.Tensor) -> torch.Tensor:
"""
Compute the Ising energy (modulo a -2 factor) of the sampled spins.

Parameters
----------
sampled_spins : torch.Tensor
Sampled spins provided by the optimizer.

Returns
-------
torch.Tensor
The energy of each agent.
"""
return torch.nn.functional.bilinear(
sampled_spins.t(), sampled_spins.t(), torch.unsqueeze(self.ising_tensor, 0)
).reshape(sampled_spins.shape[1])

def __check_convergence_threshold(self, convergence_threshold: int) -> None:
"""
Check that the provided convergence threshold is a positive integer.

Parameters
----------
convergence_threshold : int
Convergence threshold that defines a convergence criterion for the agents.

Raises
------
TypeError
If the convergence threshold is not an integer.
ValueError
If the convergence threshold is negative or bigger than 2**15 - 1 (32767).
"""
if not isinstance(convergence_threshold, int):
raise TypeError(
"convergence_threshold should be an integer, "
Expand All @@ -61,82 +94,95 @@ def __init_convergence_threshold(self, convergence_threshold: int) -> None:
"convergence_threshold should be less than or equal to "
f"{torch.iinfo(torch.int16).max}, received {convergence_threshold}."
)
self.convergence_threshold = convergence_threshold

def __init_tensor(self, dtype: torch.dtype) -> torch.Tensor:
return torch.zeros(self.n_agents, device=self.device, dtype=dtype)

def __init_energies(self) -> None:
self.energies = torch.tensor([float("inf") for _ in range(self.n_agents)])

def __init_tensors(self) -> None:
self.stability = self.__init_tensor(torch.int16)
self.newly_bifurcated = self.__init_tensor(torch.bool)
self.previously_bifurcated = self.__init_tensor(torch.bool)
self.bifurcated = self.__init_tensor(torch.bool)
self.stable_agents = self.__init_tensor(torch.bool)

def __init_spins(self) -> torch.Tensor:
return torch.zeros(size=self.shape, dtype=self.dtype, device=self.device)

def __update_final_spins(self, sampled_spins) -> None:
self.final_spins[:, self.newly_bifurcated] = sampled_spins[
:, self.newly_bifurcated
]

def __set_previously_bifurcated_spins(self) -> None:
self.previously_bifurcated = torch.clone(self.bifurcated)

def __set_newly_bifurcated_spins(self) -> None:
torch.logical_xor(
self.bifurcated, self.previously_bifurcated, out=self.newly_bifurcated
)

def __update_bifurcated_spins(self) -> None:
torch.eq(self.stability, self.convergence_threshold - 1, out=self.bifurcated)

def __update_stability_streak(self) -> None:
self.stability[torch.logical_and(self.stable_agents, self.not_bifurcated)] += 1
self.stability[torch.logical_and(self.changed_agents, self.not_bifurcated)] = 0

@property
def changed_agents(self) -> torch.Tensor:
return torch.logical_not(self.stable_agents)

@property
def not_bifurcated(self) -> torch.Tensor:
return torch.logical_not(self.bifurcated)

def __compare_energies(self, sampled_spins: torch.Tensor) -> None:
energies = torch.nn.functional.bilinear(
sampled_spins.t(), sampled_spins.t(), torch.unsqueeze(self.ising_tensor, 0)
).reshape(self.n_agents)
torch.eq(
energies,
self.energies,
out=self.stable_agents,
)
def update(self, sampled_spins: torch.Tensor) -> torch.Tensor:
"""
Update the stability streaks and the stored spins of the
window with sampled spins from the Simulated Bifurcation
optimizer. When an agent converges, it is stored in the
window's memory and removed from the optimization process.

Return a boolean tensor that indicates which agents still have not converged.

Parameters
----------
sampled_spins : torch.Tensor
Sampled spins provided by the optimizer.

Returns
-------
torch.Tensor
The agents that still have not converged (as a boolean tensor).
"""
self.__update_stability_streaks(sampled_spins)
self.__update_progressbar(sampled_spins.shape[1])
return self.__store_converged_spins(sampled_spins)

def __update_stability_streaks(self, sampled_spins: torch.Tensor):
"""
Update the stability streaks of the window from the
sampled spins provided by the optimizer.

Parameters
----------
sampled_spins : torch.Tensor
Sampled spins provided by the optimizer.
"""
current_agents = self.energies.shape[0]
energies = self.__compute_energies(sampled_spins)
stable_agents = torch.eq(energies, self.energies)
self.energies = energies
self.stability = torch.where(
stable_agents,
self.stability + 1,
torch.zeros(current_agents, device=self.ising_tensor.device),
)

def __get_number_newly_bifurcated_agents(self) -> int:
return torch.count_nonzero(self.newly_bifurcated).item()

def update(self, sampled_spins: torch.Tensor):
self.__compare_energies(sampled_spins)
self.__update_stability_streak()
self.__update_bifurcated_spins()
self.__set_newly_bifurcated_spins()
self.__set_previously_bifurcated_spins()
self.__update_final_spins(sampled_spins)
self.progress.update(self.__get_number_newly_bifurcated_agents())

def must_continue(self) -> bool:
return torch.any(
torch.lt(self.stability, self.convergence_threshold - 1)
).item()

def has_bifurcated_spins(self) -> bool:
return torch.any(self.bifurcated).item()

def get_bifurcated_spins(self, spins: torch.Tensor) -> torch.Tensor:
return torch.where(self.bifurcated, self.final_spins, spins)
def __store_converged_spins(self, sampled_spins: torch.Tensor) -> torch.Tensor:
"""
Store the newly converged agents in the window's memory and updates the
utility tensors by removing data relative to converged agents.

Return a boolean tensor that indicates which agents still have not converged.

Parameters
----------
sampled_spins : torch.Tensor
Sampled spins provided by the optimizer.

Returns
-------
torch.Tensor
The agents that still have not converged (as a boolean tensor).
"""
converged_agents = torch.eq(self.stability, self.convergence_threshold - 1)
not_converged_agents = torch.logical_not(converged_agents)
self.stored_spins[
:, self.shifted_agents_indices[converged_agents]
] = sampled_spins[:, converged_agents]
self.shifted_agents_indices = self.shifted_agents_indices[not_converged_agents]
self.energies = self.energies[not_converged_agents]
self.stability = self.stability[not_converged_agents]
return not_converged_agents

def __update_progressbar(self, previous_agents: int):
"""
Update the progressbar with the number of newly converged agents.

Parameters
----------
previous_agents : int
Previous number of agents.
"""
new_agents = self.energies.shape[0]
self.progress.update(previous_agents - new_agents)

def get_stored_spins(self) -> torch.Tensor:
"""
Return the converged spins stored in the window.

Returns
-------
torch.Tensor
"""
return self.stored_spins.clone()
Loading