diff --git a/README.md b/README.md index 84697baf..644eb0d7 100644 --- a/README.md +++ b/README.md @@ -235,7 +235,7 @@ The Simulated Bifurcation algorithm stops after a certain number of iterations, At regular intervals, the energy of the agents is sampled and compared with its previous value to calculate their stability period. If an agent's stability period exceeds a convergence threshold, it is considered to have converged and its value is frozen. If all agents converge before the maximum number of iterations has been reached, the algorithm stops. - The sampling period and the convergence threshold are respectively set using the `sampling_period` and `convergence_threshold` parameters of the `minimize` and `maximize` functions. -- To use early stopping in the SB algorithm, set the `use_window` parameter to `True`. +- To use early stopping in the SB algorithm, set the `early_stopping` parameter to `True`. - If only some agents have converged when the maximum number of iterations is reached, the algorithm stops and only these agents are considered in the results. ```python @@ -247,7 +247,7 @@ sb.minimize( matrix, sampling_period=30, convergence_threshold=50, - use_window=True, + early_stopping=True, ) ``` diff --git a/docs/conf.py b/docs/conf.py index c96595b5..03e85c07 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ project = "Simulated Bifurcation" copyright = "2023, Romain Ageron, Thomas Bouquet and Lorenzo Pugliese" author = "Romain Ageron, Thomas Bouquet and Lorenzo Pugliese" -release = "2.0.1" +release = "1.3.0.dev0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/pages/general/background.rst b/docs/pages/general/background.rst index b6123589..a08ef99e 100644 --- a/docs/pages/general/background.rst +++ b/docs/pages/general/background.rst @@ -109,10 +109,15 @@ The Simulated Bifurcation algorithm stops after a certain number of iterations o computation timeout is reached. However, this implementation comes with the possibility to perform early stopping and save computation time by defining convergence conditions. -At regular intervals, the state of the spins is sampled and compared with its previous value to calculate -their stability period. If an agent's stability period exceeds a convergence threshold, it is considered -to have converged and its value is frozen. If all agents converge before the maximum number of iterations -has been reached, the algorithm stops. +At regular intervals (this interval being called a sampling period), the agents (spin vectors) are +sampled and compared with their previous state by comparing their Ising energy. If the energy is the +same, the stability period of the agent is increased. If an agent's stability period exceeds a +convergence threshold, it is considered to have converged and its state is frozen. If all agents converge +before the maximum number of iterations has been reached, the algorithm then stops earlier. + +The purpose of sampling the spins at regular intervals is to decorrelate them and make their stability more +informative about their convergence (because the evolution of the spins is *slow* it is expected that +most of the spins will not change from a time step to the following). Notes ~~~~~ diff --git a/docs/pages/modules/optimizer.rst b/docs/pages/modules/optimizer.rst index 038f44dc..a0e21f8a 100644 --- a/docs/pages/modules/optimizer.rst +++ b/docs/pages/modules/optimizer.rst @@ -12,5 +12,5 @@ Optimizer .. autoclass:: SymplecticIntegrator :members: -.. autoclass:: StopWindow +.. autoclass:: ConvergenceChecker :members: diff --git a/src/simulated_bifurcation/core/ising.py b/src/simulated_bifurcation/core/ising.py index e5c3dd7d..72f0d429 100644 --- a/src/simulated_bifurcation/core/ising.py +++ b/src/simulated_bifurcation/core/ising.py @@ -245,7 +245,7 @@ def minimize( heated: bool = False, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -274,17 +274,17 @@ def minimize( verbose : bool, default=True Whether to display a progress bar to monitor the progress of the algorithm. - use_window : bool, default=True - Whether to use the window as a stopping criterion: an agent is - said to have converged if its energy has not changed over the - last `convergence_threshold` energy samplings (done every - `sampling_period` steps). - sampling_period : int, default=50 - Number of iterations between two consecutive energy samplings - by the window. - convergence_threshold : int, default=50 - Number of consecutive identical energy samplings considered as - a proof of convergence by the window. + early_stopping : bool, default=True, keyword-only + Whether to use early stopping or not, making agents' convergence a + stopping criterion. An agent is said to have converged if its energy + has not changed over the last `convergence_threshold` energy samplings + (done every `sampling_period` steps). + sampling_period : int, default=50, keyword-only + Number of iterations between two consecutive spins samplings used for + early stopping. + convergence_threshold : int, default=50, keyword-only + Number of consecutive identical energy samplings considered as a + proof of convergence of an agent. timeout : float | None, default=None Time in seconds after which the simulation is stopped. None means no timeout. @@ -306,7 +306,7 @@ def minimize( Warns ----- - If `use_window` is True and no agent has reached the convergence + If `early_stopping` is True and no agent has reached the convergence criterion defined by `sampling_period` and `convergence_threshold` within `max_steps` iterations, a warning is logged in the console. This is just an indication however; the returned vectors may still @@ -401,7 +401,7 @@ def minimize( convergence_threshold, ) tensor = self.as_simulated_bifurcation_tensor() - spins = optimizer.run_integrator(tensor, use_window) + spins = optimizer.run_integrator(tensor, early_stopping) if self.linear_term: self.computed_spins = spins[-1] * spins[:-1] else: diff --git a/src/simulated_bifurcation/core/quadratic_polynomial.py b/src/simulated_bifurcation/core/quadratic_polynomial.py index cbe34865..41e92eb0 100644 --- a/src/simulated_bifurcation/core/quadratic_polynomial.py +++ b/src/simulated_bifurcation/core/quadratic_polynomial.py @@ -318,7 +318,7 @@ def optimize( minimize: bool = True, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -344,13 +344,15 @@ def optimize( algorithm with a supplementary non-symplectic term to refine the model To stop the iterations of the symplectic integrator, a number of maximum - steps needs to be specified. However, a refined way to stop is also possible - using a window that checks that the spins have not changed among a set - number of previous steps. In practice, a every fixed number of steps - (called a sampling period) the current spins will be compared to the - previous ones. If they remain constant throughout a certain number of - consecutive samplings (called the convergence threshold), the spins are - considered to have bifurcated and the algorithm stops. + steps and/or a timeout need(s) to be specified. However, a refined way to stop + is also possible using a convergence checker that asserts that the energy + of the agents has not changed during a fixed number of steps. If so, the computation + stops earlier than expected. In practice, every fixed number of steps (called a + sampling period) the current spins will be compared to the previous + ones (energy-wise). If the energy remains constant throughout a certain number of + consecutive samplings (called the convergence threshold), the spins are considered + to have bifurcated andthe algorithm stops. These spaced samplings make it possible + to decorrelate the spins and make their stability more informative. Finally, it is possible to make several particle vectors at the same time (each one being called an agent). As the vectors are randomly @@ -364,7 +366,7 @@ def optimize( ---------- * convergence_threshold : int, optional - number of consecutive identical spin sampling considered as a proof + number of consecutive identical spins samplings considered as a proof of convergence (default is 50) sampling_period : int, optional number of time steps between two spin sampling (default is 50) @@ -373,9 +375,9 @@ def optimize( (default is 10000) agents : int, optional number of vectors to make evolve at the same time (default is 128) - use_window : bool, optional - indicates whether to use the window as a stopping criterion or not - (default is True) + early_stopping : bool, optional + indicates whether to use the early stopping or not, thus making agents' + convergence a stopping criterion (default is True) timeout : float | None, default=None Time in seconds after which the simulation is stopped. None means no timeout. @@ -410,7 +412,7 @@ def optimize( ballistic, heated, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -434,7 +436,7 @@ def minimize( heated: bool = False, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -460,13 +462,15 @@ def minimize( algorithm with a supplementary non-symplectic term to refine the model To stop the iterations of the symplectic integrator, a number of maximum - steps needs to be specified. However, a refined way to stop is also possible - using a window that checks that the spins have not changed among a set - number of previous steps. In practice, a every fixed number of steps - (called a sampling period) the current spins will be compared to the - previous ones. If they remain constant throughout a certain number of - consecutive samplings (called the convergence threshold), the spins are - considered to have bifurcated and the algorithm stops. + steps and/or a timeout need(s) to be specified. However, a refined way to stop + is also possible using a convergence checker that asserts that the energy + of the agents has not changed during a fixed number of steps. If so, the computation + stops earlier than expected. In practice, every fixed number of steps (called a + sampling period) the current spins will be compared to the previous + ones (energy-wise). If the energy remains constant throughout a certain number of + consecutive samplings (called the convergence threshold), the spins are considered + to have bifurcated andthe algorithm stops. These spaced samplings make it possible + to decorrelate the spins and make their stability more informative. Finally, it is possible to make several particle vectors at the same time (each one being called an agent). As the vectors are randomly @@ -480,7 +484,7 @@ def minimize( ---------- * convergence_threshold : int, optional - number of consecutive identical spin sampling considered as a proof + number of consecutive identical spins samplings considered as a proof of convergence (default is 50) sampling_period : int, optional number of time steps between two spin sampling (default is 50) @@ -489,9 +493,9 @@ def minimize( (default is 10000) agents : int, optional number of vectors to make evolve at the same time (default is 128) - use_window : bool, optional - indicates whether to use the window as a stopping criterion or not - (default is True) + early_stopping : bool, optional + indicates whether to use the early stopping or not, thus making agents' + convergence a stopping criterion (default is True) timeout : float | None, default=None Time in seconds after which the simulation is stopped. None means no timeout. @@ -522,7 +526,7 @@ def minimize( heated, True, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -538,7 +542,7 @@ def maximize( heated: bool = False, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -564,13 +568,15 @@ def maximize( algorithm with a supplementary non-symplectic term to refine the model To stop the iterations of the symplectic integrator, a number of maximum - steps needs to be specified. However, a refined way to stop is also possible - using a window that checks that the spins have not changed among a set - number of previous steps. In practice, a every fixed number of steps - (called a sampling period) the current spins will be compared to the - previous ones. If they remain constant throughout a certain number of - consecutive samplings (called the convergence threshold), the spins are - considered to have bifurcated and the algorithm stops. + steps and/or a timeout need(s) to be specified. However, a refined way to stop + is also possible using a convergence checker that asserts that the energy + of the agents has not changed during a fixed number of steps. If so, the computation + stops earlier than expected. In practice, every fixed number of steps (called a + sampling period) the current spins will be compared to the previous + ones (energy-wise). If the energy remains constant throughout a certain number of + consecutive samplings (called the convergence threshold), the spins are considered + to have bifurcated andthe algorithm stops. These spaced samplings make it possible + to decorrelate the spins and make their stability more informative. Finally, it is possible to make several particle vectors at the same time (each one being called an agent). As the vectors are randomly @@ -583,7 +589,7 @@ def maximize( Parameters ---------- convergence_threshold : int, optional - number of consecutive identical spin sampling considered as a proof + number of consecutive identical spins samplings considered as a proof of convergence (default is 50) sampling_period : int, optional number of time steps between two spin sampling (default is 50) @@ -592,9 +598,9 @@ def maximize( (default is 10000) agents : int, optional number of vectors to make evolve at the same time (default is 128) - use_window : bool, optional - indicates whether to use the window as a stopping criterion or not - (default is True) + early_stopping : bool, optional + indicates whether to use the early stopping or not, thus making agents' + convergence a stopping criterion (default is True) timeout : float | None, default=None Time in seconds after which the simulation is stopped. None means no timeout. @@ -625,7 +631,7 @@ def maximize( heated, False, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, diff --git a/src/simulated_bifurcation/models/abc_model.py b/src/simulated_bifurcation/models/abc_model.py index 690a7407..1f6e807c 100644 --- a/src/simulated_bifurcation/models/abc_model.py +++ b/src/simulated_bifurcation/models/abc_model.py @@ -31,7 +31,7 @@ def optimize( minimize: bool = True, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None @@ -45,7 +45,7 @@ def optimize( heated, minimize, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -60,7 +60,7 @@ def minimize( heated: bool = False, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None @@ -73,7 +73,7 @@ def minimize( heated, True, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -88,7 +88,7 @@ def maximize( heated: bool = False, verbose: bool = True, *, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None @@ -101,7 +101,7 @@ def maximize( heated, False, verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, diff --git a/src/simulated_bifurcation/optimizer/__init__.py b/src/simulated_bifurcation/optimizer/__init__.py index d3e4be92..80f0753a 100644 --- a/src/simulated_bifurcation/optimizer/__init__.py +++ b/src/simulated_bifurcation/optimizer/__init__.py @@ -14,11 +14,11 @@ """ +from .convergence_checker import ConvergenceChecker from .environment import get_env, reset_env, set_env from .simulated_bifurcation_engine import SimulatedBifurcationEngine from .simulated_bifurcation_optimizer import ( ConvergenceWarning, SimulatedBifurcationOptimizer, ) -from .stop_window import StopWindow from .symplectic_integrator import SymplecticIntegrator diff --git a/src/simulated_bifurcation/optimizer/convergence_checker.py b/src/simulated_bifurcation/optimizer/convergence_checker.py new file mode 100644 index 00000000..d787076a --- /dev/null +++ b/src/simulated_bifurcation/optimizer/convergence_checker.py @@ -0,0 +1,185 @@ +import torch +from tqdm import tqdm + + +class ConvergenceChecker: + """ + Optimization tool to monitor agents bifurcation and convergence for the Simulated + Bifurcation (SB) algorithm. Allows an early stopping of the iterations and saves + computation time. + """ + + def __init__( + self, + convergence_threshold: int, + ising_tensor: torch.Tensor, + n_agents: int, + verbose: bool, + ) -> None: + self.__check_convergence_threshold(convergence_threshold) + self.convergence_threshold = convergence_threshold + self.ising_tensor = ising_tensor + self.stability = torch.zeros( + n_agents, dtype=torch.int16, device=ising_tensor.device + ) + self.energies = torch.tensor( + [float("inf") for _ in range(n_agents)], + dtype=ising_tensor.dtype, + device=ising_tensor.device, + ) + self.progress = tqdm( + total=n_agents, + desc="🏁 Bifurcated agents", + disable=not verbose, + smoothing=0, + unit=" agents", + ) + self.stored_spins = torch.zeros( + ising_tensor.shape[0], + n_agents, + dtype=ising_tensor.dtype, + device=ising_tensor.device, + ) + self.shifted_agents_indices = torch.tensor( + list(range(n_agents)), device=ising_tensor.device + ) + + def __compute_energies(self, sampled_spins: torch.Tensor) -> torch.Tensor: + """ + Compute the Ising energy (modulo a -2 factor) of the sampled spins. + + Parameters + ---------- + sampled_spins : torch.Tensor + Sampled spins provided by the optimizer. + + Returns + ------- + torch.Tensor + The energy of each agent. + """ + return torch.nn.functional.bilinear( + sampled_spins.t(), sampled_spins.t(), torch.unsqueeze(self.ising_tensor, 0) + ).reshape(sampled_spins.shape[1]) + + def __check_convergence_threshold(self, convergence_threshold: int) -> None: + """ + Check that the provided convergence threshold is a positive integer. + + Parameters + ---------- + convergence_threshold : int + Convergence threshold that defines a convergence criterion for the agents. + + Raises + ------ + TypeError + If the convergence threshold is not an integer. + ValueError + If the convergence threshold is negative or bigger than 2**15 - 1 (32767). + """ + if not isinstance(convergence_threshold, int): + raise TypeError( + "convergence_threshold should be an integer, " + f"received {convergence_threshold}." + ) + if convergence_threshold <= 0: + raise ValueError( + "convergence_threshold should be a positive integer, " + f"received {convergence_threshold}." + ) + if convergence_threshold > torch.iinfo(torch.int16).max: + raise ValueError( + "convergence_threshold should be less than or equal to " + f"{torch.iinfo(torch.int16).max}, received {convergence_threshold}." + ) + + def update(self, sampled_spins: torch.Tensor) -> torch.Tensor: + """ + Update the stability streaks and the spins stored in the memory with sampled + spins from the Simulated Bifurcation optimizer. When an agent converges, it is + stored in the memory and removed from the optimization process. + + Return a boolean tensor that indicates which agents still have not converged. + + Parameters + ---------- + sampled_spins : torch.Tensor + Sampled spins provided by the optimizer. + + Returns + ------- + torch.Tensor + The agents that still have not converged (as a boolean tensor). + """ + self.__update_stability_streaks(sampled_spins) + self.__update_progressbar(sampled_spins.shape[1]) + return self.__store_converged_spins(sampled_spins) + + def __update_stability_streaks(self, sampled_spins: torch.Tensor): + """ + Update the stability streaks from the sampled spins provided by the optimizer. + + Parameters + ---------- + sampled_spins : torch.Tensor + Sampled spins provided by the optimizer. + """ + current_agents = self.energies.shape[0] + energies = self.__compute_energies(sampled_spins) + stable_agents = torch.eq(energies, self.energies) + self.energies = energies + self.stability = torch.where( + stable_agents, + self.stability + 1, + torch.zeros(current_agents, device=self.ising_tensor.device), + ) + + def __store_converged_spins(self, sampled_spins: torch.Tensor) -> torch.Tensor: + """ + Store the newly converged agents in the memory and updates the utility tensors + by removing data relative to converged agents. + + Return a boolean tensor that indicates which agents still have not converged. + + Parameters + ---------- + sampled_spins : torch.Tensor + Sampled spins provided by the optimizer. + + Returns + ------- + torch.Tensor + The agents that still have not converged (as a boolean tensor). + """ + converged_agents = torch.eq(self.stability, self.convergence_threshold - 1) + not_converged_agents = torch.logical_not(converged_agents) + self.stored_spins[:, self.shifted_agents_indices[converged_agents]] = ( + sampled_spins[:, converged_agents] + ) + self.shifted_agents_indices = self.shifted_agents_indices[not_converged_agents] + self.energies = self.energies[not_converged_agents] + self.stability = self.stability[not_converged_agents] + return not_converged_agents + + def __update_progressbar(self, previous_agents: int): + """ + Update the progressbar with the number of newly converged agents. + + Parameters + ---------- + previous_agents : int + Previous number of agents. + """ + new_agents = self.energies.shape[0] + self.progress.update(previous_agents - new_agents) + + def get_stored_spins(self) -> torch.Tensor: + """ + Return the converged spins stored in the memory. + + Returns + ------- + torch.Tensor + """ + return self.stored_spins.clone() diff --git a/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py b/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py index 3df26b8b..5dac332e 100644 --- a/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py +++ b/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py @@ -7,9 +7,9 @@ from numpy import minimum from tqdm import tqdm +from .convergence_checker import ConvergenceChecker from .environment import ENVIRONMENT from .simulated_bifurcation_engine import SimulatedBifurcationEngine -from .stop_window import StopWindow from .symplectic_integrator import SymplecticIntegrator LOGGER = logging.getLogger("simulated_bifurcation_optimizer") @@ -45,13 +45,15 @@ class SimulatedBifurcationOptimizer: algorithm with a supplementary non-symplectic term to refine the model To stop the iterations of the symplectic integrator, a number of maximum - steps needs to be specified. However, a refined way to stop is also possible - using a window that checks that the spins have not changed among a set - number of previous steps. In practice, a every fixed number of steps - (called a sampling period) the current spins will be compared to the - previous ones. If they remain constant throughout a certain number of - consecutive samplings (called the convergence threshold), the spins are - considered to have bifurcated and the algorithm stops. + steps and/or a timeout need(s) to be specified. However, a refined way to stop + is also possible using a convergence checker that asserts that the energy + of the agents has not changed during a fixed number of steps. If so, the computation + stops earlier than expected. In practice, every fixed number of steps (called a + sampling period) the current spins will be compared to the previous + ones (energy-wise). If the energy remains constant throughout a certain number of + consecutive samplings (called the convergence threshold), the spins are considered + to have bifurcated andthe algorithm stops. These spaced samplings make it possible + to decorrelate the spins and make their stability more informative. Finally, it is possible to make several particle vectors at the same time (each one being called an agent). As the vectors are randomly @@ -75,7 +77,7 @@ def __init__( ) -> None: # Optimizer setting self.engine = engine - self.window = None + self.convergence_checker = None self.symplectic_integrator = None self.heat_coefficient = ENVIRONMENT.heat_coefficient self.heated = engine.heated @@ -92,10 +94,10 @@ def __init__( self.max_steps = max_steps if max_steps is not None else float("inf") self.timeout = timeout if timeout is not None else float("inf") - def __reset(self, matrix: torch.Tensor, use_window: bool) -> None: + def __reset(self, matrix: torch.Tensor, early_stopping: bool) -> None: self.__init_progress_bars() self.__init_symplectic_integrator(matrix) - self.__init_window(matrix, use_window) + self.__init_convergence_checker(matrix, early_stopping) self.__init_quadratic_scale_parameter(matrix) self.run = True self.step = 0 @@ -125,14 +127,14 @@ def __init_quadratic_scale_parameter(self, matrix: torch.Tensor): 0.5 * (matrix.shape[0] - 1) ** 0.5 / (torch.sqrt(torch.sum(matrix**2))) ) - def __init_window(self, matrix: torch.Tensor, use_window: bool) -> None: - self.window = StopWindow( + def __init_convergence_checker( + self, matrix: torch.Tensor, early_stopping: bool + ) -> None: + self.convergence_checker = ConvergenceChecker( + self.convergence_threshold, matrix, self.agents, - self.convergence_threshold, - matrix.dtype, - matrix.device, - (self.verbose and use_window), + (self.verbose and early_stopping), ) def __init_symplectic_integrator(self, matrix: torch.Tensor) -> None: @@ -147,9 +149,11 @@ def __step_update(self) -> None: self.step += 1 self.iterations_progress.update() - def __check_stop(self, use_window: bool) -> None: - if use_window and self.__do_sampling: - self.run = self.window.must_continue() + def __check_stop(self, early_stopping: bool) -> None: + if early_stopping and self.__do_sampling: + stored_spins = self.convergence_checker.get_stored_spins() + all_agents_converged = torch.any(torch.eq(stored_spins, 0)).item() + self.run = all_agents_converged if not self.run: LOGGER.info("Optimizer stopped. Reason: all agents converged.") return @@ -177,12 +181,12 @@ def __do_sampling(self) -> bool: def __close_progress_bars(self): self.iterations_progress.close() self.time_progress.close() - self.window.progress.close() + self.convergence_checker.progress.close() def __symplectic_update( self, matrix: torch.Tensor, - use_window: bool, + early_stopping: bool, ) -> torch.Tensor: self.start_time = time() while self.run: @@ -205,15 +209,26 @@ def __symplectic_update( self.__heat(momentum_copy) self.__step_update() - if use_window and self.__do_sampling: + if early_stopping and self.__do_sampling: sampled_spins = self.symplectic_integrator.sample_spins() - self.window.update(sampled_spins) + not_converged_agents = self.convergence_checker.update(sampled_spins) + # Only reshape the oscillators if some agents converged + if not torch.all(not_converged_agents).item(): + self.__remove_converged_agents(not_converged_agents) - self.__check_stop(use_window) + self.__check_stop(early_stopping) sampled_spins = self.symplectic_integrator.sample_spins() return sampled_spins + def __remove_converged_agents(self, not_converged_agents: torch.Tensor): + self.symplectic_integrator.momentum = self.symplectic_integrator.momentum[ + :, not_converged_agents + ] + self.symplectic_integrator.position = self.symplectic_integrator.position[ + :, not_converged_agents + ] + def __heat(self, momentum_copy: torch.Tensor) -> None: torch.add( self.symplectic_integrator.momentum, @@ -233,7 +248,9 @@ def __compute_symplectic_coefficients(self) -> Tuple[float, float, float]: def __pressure(self): return minimum(self.time_step * self.step * self.pressure_slope, 1.0) - def run_integrator(self, matrix: torch.Tensor, use_window: bool) -> torch.Tensor: + def run_integrator( + self, matrix: torch.Tensor, early_stopping: bool + ) -> torch.Tensor: """ Runs the Simulated Bifurcation (SB) algorithm. Given an input matrix, the SB algorithm aims at finding the groud state of the Ising model @@ -245,8 +262,8 @@ def run_integrator(self, matrix: torch.Tensor, use_window: bool) -> torch.Tensor ---------- matrix : torch.Tensor The matrix that defines the Ising model to optimize. - use_window : bool - Whether to use a stop window or not to perform early-stopping. + early_stopping : bool + Whether to perform early-stopping or not. Returns ------- @@ -261,38 +278,43 @@ def run_integrator(self, matrix: torch.Tensor, use_window: bool) -> torch.Tensor if ( self.max_steps == float("inf") and self.timeout == float("inf") - and not use_window + and not early_stopping ): raise ValueError("No stopping criterion provided.") - self.__reset(matrix, use_window) - spins = self.__symplectic_update(matrix, use_window) + self.__reset(matrix, early_stopping) + spins = self.__symplectic_update(matrix, early_stopping) self.__close_progress_bars() - return self.get_final_spins(spins, use_window) + return self.get_final_spins(spins, early_stopping) - def get_final_spins(self, spins: torch.Tensor, use_window: bool) -> torch.Tensor: + def get_final_spins( + self, spins: torch.Tensor, early_stopping: bool + ) -> torch.Tensor: """ Returns the final spins retrieved at the end of the Simulated Bifurcation (SB) algorithm. - If the stop window was used, it returns the bifurcated agents if any, + If the early stopping was used, it returns the converged agents if any, otherwise the actual final spins are returned. - If the stop window was not used, the final spins are returned. + If the early stopping was not used, the final spins are returned. Parameters ---------- spins : torch.Tensor The spins returned by the Simulated Bifurcation algorithm. - use_window : bool - Whether the stop window was used or not. + early_stopping : bool + Whether the early stopping was used or not. Returns ------- torch.Tensor """ - if use_window: - if not self.window.has_bifurcated_spins(): + if early_stopping: + final_spins = self.convergence_checker.get_stored_spins() + any_converged_agents = torch.any(torch.not_equal(final_spins, 0)).item() + if not any_converged_agents: warnings.warn(ConvergenceWarning(), stacklevel=2) - return self.window.get_bifurcated_spins(spins) + final_spins[:, torch.all(torch.eq(final_spins, 0), dim=0)] = spins + return final_spins else: return spins diff --git a/src/simulated_bifurcation/optimizer/stop_window.py b/src/simulated_bifurcation/optimizer/stop_window.py deleted file mode 100644 index 892b313a..00000000 --- a/src/simulated_bifurcation/optimizer/stop_window.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Tuple, Union - -import torch -from tqdm import tqdm - - -class StopWindow: - """ - Optimization tool to monitor spins bifurcation and convergence - for the Simulated Bifurcation (SB) algorithm. - Allows an early stopping of the iterations and saves computation time. - """ - - def __init__( - self, - ising_tensor: torch.Tensor, - n_agents: int, - convergence_threshold: int, - dtype: torch.dtype, - device: Union[str, torch.device], - verbose: bool, - ) -> None: - self.ising_tensor = ising_tensor - self.n_spins = self.ising_tensor.shape[0] - self.n_agents = n_agents - self.__init_convergence_threshold(convergence_threshold) - self.dtype = dtype - self.device = device - self.__init_tensors() - self.__init_energies() - self.final_spins = self.__init_spins() - self.progress = self.__init_progress_bar(verbose) - - @property - def shape(self) -> Tuple[int, int]: - return self.n_spins, self.n_agents - - def __init_progress_bar(self, verbose: bool) -> tqdm: - return tqdm( - total=self.n_agents, - desc="🏁 Bifurcated agents", - disable=not verbose, - smoothing=0, - unit=" agents", - ) - - def __init_convergence_threshold(self, convergence_threshold: int) -> None: - if not isinstance(convergence_threshold, int): - raise TypeError( - "convergence_threshold should be an integer, " - f"received {convergence_threshold}." - ) - if convergence_threshold <= 0: - raise ValueError( - "convergence_threshold should be a positive integer, " - f"received {convergence_threshold}." - ) - if convergence_threshold > torch.iinfo(torch.int16).max: - raise ValueError( - "convergence_threshold should be less than or equal to " - f"{torch.iinfo(torch.int16).max}, received {convergence_threshold}." - ) - self.convergence_threshold = convergence_threshold - - def __init_tensor(self, dtype: torch.dtype) -> torch.Tensor: - return torch.zeros(self.n_agents, device=self.device, dtype=dtype) - - def __init_energies(self) -> None: - self.energies = torch.tensor( - [float("inf") for _ in range(self.n_agents)], device=self.device - ) - - def __init_tensors(self) -> None: - self.stability = self.__init_tensor(torch.int16) - self.newly_bifurcated = self.__init_tensor(torch.bool) - self.previously_bifurcated = self.__init_tensor(torch.bool) - self.bifurcated = self.__init_tensor(torch.bool) - self.stable_agents = self.__init_tensor(torch.bool) - - def __init_spins(self) -> torch.Tensor: - return torch.zeros(size=self.shape, dtype=self.dtype, device=self.device) - - def __update_final_spins(self, sampled_spins) -> None: - self.final_spins[:, self.newly_bifurcated] = sampled_spins[ - :, self.newly_bifurcated - ] - - def __set_previously_bifurcated_spins(self) -> None: - self.previously_bifurcated = torch.clone(self.bifurcated) - - def __set_newly_bifurcated_spins(self) -> None: - torch.logical_xor( - self.bifurcated, self.previously_bifurcated, out=self.newly_bifurcated - ) - - def __update_bifurcated_spins(self) -> None: - torch.eq(self.stability, self.convergence_threshold - 1, out=self.bifurcated) - - def __update_stability_streak(self) -> None: - self.stability[torch.logical_and(self.stable_agents, self.not_bifurcated)] += 1 - self.stability[torch.logical_and(self.changed_agents, self.not_bifurcated)] = 0 - - @property - def changed_agents(self) -> torch.Tensor: - return torch.logical_not(self.stable_agents) - - @property - def not_bifurcated(self) -> torch.Tensor: - return torch.logical_not(self.bifurcated) - - def __compare_energies(self, sampled_spins: torch.Tensor) -> None: - energies = torch.nn.functional.bilinear( - sampled_spins.t(), sampled_spins.t(), torch.unsqueeze(self.ising_tensor, 0) - ).reshape(self.n_agents) - torch.eq( - energies, - self.energies, - out=self.stable_agents, - ) - self.energies = energies - - def __get_number_newly_bifurcated_agents(self) -> int: - return torch.count_nonzero(self.newly_bifurcated).item() - - def update(self, sampled_spins: torch.Tensor): - self.__compare_energies(sampled_spins) - self.__update_stability_streak() - self.__update_bifurcated_spins() - self.__set_newly_bifurcated_spins() - self.__set_previously_bifurcated_spins() - self.__update_final_spins(sampled_spins) - self.progress.update(self.__get_number_newly_bifurcated_agents()) - - def must_continue(self) -> bool: - return torch.any( - torch.lt(self.stability, self.convergence_threshold - 1) - ).item() - - def has_bifurcated_spins(self) -> bool: - return torch.any(self.bifurcated).item() - - def get_bifurcated_spins(self, spins: torch.Tensor) -> torch.Tensor: - return torch.where(self.bifurcated, self.final_spins, spins) diff --git a/src/simulated_bifurcation/simulated_bifurcation.py b/src/simulated_bifurcation/simulated_bifurcation.py index 0e55e8e7..c8d90f64 100644 --- a/src/simulated_bifurcation/simulated_bifurcation.py +++ b/src/simulated_bifurcation/simulated_bifurcation.py @@ -168,7 +168,7 @@ def optimize( heated: bool = False, minimize: bool = True, verbose: bool = True, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -242,17 +242,17 @@ def optimize( verbose : bool, default=True, keyword-only Whether to display a progress bar to monitor the progress of the algorithm. - use_window : bool, default=True, keyword-only - Whether to use the window as a stopping criterion. An agent is said - to have converged if its energy has not changed over the - last `convergence_threshold` energy samplings + early_stopping : bool, default=True, keyword-only + Whether to use early stopping or not, making agents' convergence a + stopping criterion. An agent is said to have converged if its energy + has not changed over the last `convergence_threshold` energy samplings (done every `sampling_period` steps). sampling_period : int, default=50, keyword-only - Number of iterations between two consecutive energy samplings by - the window. + Number of iterations between two consecutive spins samplings used for + early stopping. convergence_threshold : int, default=50, keyword-only Number of consecutive identical energy samplings considered as a - proof of convergence by the window. + proof of convergence of an agent. timeout : float | None, default=None, keyword-only Time, in seconds, after which the simulation will be stopped. None means no timeout. @@ -284,8 +284,8 @@ def optimize( Warns ----- - Use of Stop Window - If `use_window` is True and no agent has reached the convergence + Use of early stopping + If `early_stopping` is True and no agent has reached the convergence criterion defined by `sampling_period` and `convergence_threshold` within `max_steps` iterations, a warning is logged in the console. This is just an indication however; the returned vectors may still be @@ -426,7 +426,7 @@ def optimize( heated=heated, minimize=minimize, verbose=verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -445,7 +445,7 @@ def minimize( ballistic: bool = False, heated: bool = False, verbose: bool = True, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -516,17 +516,17 @@ def minimize( verbose : bool, default=True, keyword-only Whether to display a progress bar to monitor the progress of the algorithm. - use_window : bool, default=True, keyword-only - Whether to use the window as a stopping criterion. An agent is said - to have converged if its energy has not changed over the - last `convergence_threshold` energy samplings + early_stopping : bool, default=True, keyword-only + Whether to use early stopping or not, making agents' convergence a + stopping criterion. An agent is said to have converged if its energy + has not changed over the last `convergence_threshold` energy samplings (done every `sampling_period` steps). sampling_period : int, default=50, keyword-only - Number of iterations between two consecutive energy samplings by - the window. + Number of iterations between two consecutive spins samplings used for + early stopping. convergence_threshold : int, default=50, keyword-only Number of consecutive identical energy samplings considered as a - proof of convergence by the window. + proof of convergence of an agent. timeout : float | None, default=None, keyword-only Time, in seconds, after which the simulation will be stopped. None means no timeout. @@ -558,8 +558,8 @@ def minimize( Warns ----- - Use of Stop Window - If `use_window` is True and no agent has reached the convergence + Use of early stopping + If `early_stopping` is True and no agent has reached the convergence criterion defined by `sampling_period` and `convergence_threshold` within `max_steps` iterations, a warning is logged in the console. This is just an indication however; the returned vectors may still be @@ -690,7 +690,7 @@ def minimize( heated=heated, minimize=True, verbose=verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, @@ -708,7 +708,7 @@ def maximize( ballistic: bool = False, heated: bool = False, verbose: bool = True, - use_window: bool = True, + early_stopping: bool = True, sampling_period: int = 50, convergence_threshold: int = 50, timeout: Optional[float] = None, @@ -779,17 +779,17 @@ def maximize( verbose : bool, default=True, keyword-only Whether to display a progress bar to monitor the progress of the algorithm. - use_window : bool, default=True, keyword-only - Whether to use the window as a stopping criterion. An agent is said - to have converged if its energy has not changed over the - last `convergence_threshold` energy samplings + early_stopping : bool, default=True, keyword-only + Whether to use early stopping or not, making agents' convergence a + stopping criterion. An agent is said to have converged if its energy + has not changed over the last `convergence_threshold` energy samplings (done every `sampling_period` steps). sampling_period : int, default=50, keyword-only - Number of iterations between two consecutive energy samplings by - the window. + Number of iterations between two consecutive spins samplings used for + early stopping. convergence_threshold : int, default=50, keyword-only Number of consecutive identical energy samplings considered as a - proof of convergence by the window. + proof of convergence of an agent. timeout : float | None, default=None, keyword-only Time, in seconds, after which the simulation will be stopped. None means no timeout. @@ -821,8 +821,8 @@ def maximize( Warns ----- - Use of Stop Window - If `use_window` is True and no agent has reached the convergence + Use of early stopping + If `early_stopping` is True and no agent has reached the convergence criterion defined by `sampling_period` and `convergence_threshold` within `max_steps` iterations, a warning is logged in the console. This is just an indication however; the returned vectors may still be @@ -953,7 +953,7 @@ def maximize( heated=heated, minimize=False, verbose=verbose, - use_window=use_window, + early_stopping=early_stopping, sampling_period=sampling_period, convergence_threshold=convergence_threshold, timeout=timeout, diff --git a/tests/models/test_sequential_markowitz.py b/tests/models/test_sequential_markowitz.py index 136df1a4..6f325ac9 100644 --- a/tests/models/test_sequential_markowitz.py +++ b/tests/models/test_sequential_markowitz.py @@ -84,7 +84,7 @@ def test_sequential_markowitz(): ) assert torch.equal(torch.tensor(-0.2), model[0]) - model.maximize(agents=128, use_window=False, verbose=False) + model.maximize(agents=128, early_stopping=False, verbose=False) assert (4, 2) == model.portfolio.shape assert torch.equal( torch.tensor([[0.0, 1.0], [0.0, 0.0], [1.0, 0.0], [1.0, 0.0]]), model.portfolio diff --git a/tests/optimizer/test_convergence_checker.py b/tests/optimizer/test_convergence_checker.py new file mode 100644 index 00000000..79023ea9 --- /dev/null +++ b/tests/optimizer/test_convergence_checker.py @@ -0,0 +1,152 @@ +import pytest +import torch + +from src.simulated_bifurcation.optimizer import ConvergenceChecker + +TENSOR = torch.tensor([[1.0, 0.5, -1.0], [0.5, 0.0, 1.0], [-1.0, 1.0, -2.0]]) +CONVERGENCE_THRESHOLD = 3 +SPINS = 3 +AGENTS = 2 +SCENARIO = [ + torch.tensor( + [ + [-1, -1], + [1, -1], + [1, -1], + ], + dtype=torch.float32, + ), + torch.tensor( + [ + [-1, -1], + [-1, 1], + [1, -1], + ], + dtype=torch.float32, + ), + torch.tensor( + [ + [-1, 1], + [1, -1], + [-1, 1], + ], + dtype=torch.float32, + ), + torch.tensor( + [ + [-1, -1], + [1, 1], + [-1, -1], + ], + dtype=torch.float32, + ), + # 1 agents has converged and was removed from the oscillators + torch.tensor( + [ + [-1], + [1], + [-1], + ], + dtype=torch.float32, + ), +] + + +def test_wrong_convergence_threshold_value(): + with pytest.raises( + TypeError, match="convergence_threshold should be an integer, received 30.0." + ): + # noinspection PyTypeChecker + ConvergenceChecker(30.0, TENSOR, AGENTS, verbose=False) + with pytest.raises( + ValueError, + match="convergence_threshold should be a positive integer, received 0.", + ): + ConvergenceChecker(0, TENSOR, AGENTS, verbose=False) + with pytest.raises( + ValueError, + match="convergence_threshold should be a positive integer, received -42.", + ): + ConvergenceChecker(-42, TENSOR, AGENTS, verbose=False) + with pytest.raises( + ValueError, + match="convergence_threshold should be less than or equal to 32767, received 32768.", + ): + ConvergenceChecker(2**15, TENSOR, AGENTS, verbose=False) + + +def test_use_scenario(): + """ + Ground state is degenerate: [-1, 1, -1] and [1, -1, 1] + both reach the minimal energy value -6. + + Test of the convergence checker's behavior on 2 agents: + - agent 1 converges to an optimal vector from step 3; + - agent 2 oscillates in the optimal space from step 2. + """ + convergence_checker = ConvergenceChecker( + CONVERGENCE_THRESHOLD, TENSOR, AGENTS, verbose=False + ) + + # Initial state + assert torch.equal(convergence_checker.get_stored_spins(), torch.zeros(3, 2)) + assert torch.all(torch.isinf(convergence_checker.energies)) + assert torch.equal(convergence_checker.stability, torch.zeros(2)) + + # First update + convergence_checker.update(SCENARIO[0]) + assert torch.equal(convergence_checker.energies, torch.tensor([2.0, 0.0])) + assert torch.equal(convergence_checker.get_stored_spins(), torch.zeros((3, 2))) + assert torch.equal( + convergence_checker.stability, torch.tensor([0, 0], dtype=torch.int16) + ) + + # Second update + convergence_checker.update(SCENARIO[1]) + assert torch.equal(convergence_checker.energies, torch.tensor([0.0, -6.0])) + assert torch.equal(convergence_checker.get_stored_spins(), torch.zeros((3, 2))) + assert torch.equal( + convergence_checker.stability, torch.tensor([0, 0], dtype=torch.int16) + ) + + # Third update + convergence_checker.update(SCENARIO[2]) + assert torch.equal(convergence_checker.energies, torch.tensor([-6.0, -6.0])) + assert torch.equal(convergence_checker.get_stored_spins(), torch.zeros((3, 2))) + assert torch.equal( + convergence_checker.stability, torch.tensor([0, 1], dtype=torch.int16) + ) + + # Fourth update + convergence_checker.update(SCENARIO[3]) + assert torch.equal(convergence_checker.energies, torch.tensor([-6.0])) + assert torch.equal( + convergence_checker.get_stored_spins(), + torch.tensor( + [ + [0, -1], + [0, 1], + [0, -1], + ], + dtype=torch.float32, + ), + ) + assert torch.equal( + convergence_checker.stability, torch.tensor([1], dtype=torch.float32) + ) + + # Fourth update + convergence_checker.update(SCENARIO[4]) + assert torch.equal(convergence_checker.energies, torch.tensor([])) + assert torch.equal( + convergence_checker.get_stored_spins(), + torch.tensor( + [ + [-1, -1], + [1, 1], + [-1, -1], + ], + dtype=torch.float32, + ), + ) + assert torch.equal(convergence_checker.stability, torch.tensor([])) diff --git a/tests/optimizer/test_optimizer.py b/tests/optimizer/test_optimizer.py index 64013731..e912f921 100644 --- a/tests/optimizer/test_optimizer.py +++ b/tests/optimizer/test_optimizer.py @@ -27,7 +27,7 @@ def test_optimizer(): False, False, False, - use_window=False, + early_stopping=False, sampling_period=50, convergence_threshold=50, ) @@ -53,7 +53,7 @@ def test_optimizer_without_bifurcation(): False, False, False, - use_window=True, + early_stopping=True, sampling_period=50, convergence_threshold=50, ) @@ -70,7 +70,7 @@ def test_optimizer_without_bifurcation(): ) -def test_optimizer_with_window(): +def test_optimizer_with_convergence_checker(): torch.manual_seed(42) J = torch.tensor( [ @@ -88,7 +88,7 @@ def test_optimizer_with_window(): False, False, False, - use_window=True, + early_stopping=True, sampling_period=20, convergence_threshold=20, ) @@ -113,7 +113,7 @@ def test_optimizer_with_heating(): False, True, False, - use_window=False, + early_stopping=False, sampling_period=50, convergence_threshold=50, ) @@ -176,7 +176,7 @@ def test_timeout(): assert optimizer.simulation_time > 3.0 -def test_window(): +def test_convergence_checker(): torch.manual_seed(42) J = torch.tensor( [ diff --git a/tests/optimizer/test_stop_window.py b/tests/optimizer/test_stop_window.py deleted file mode 100644 index 25362cc2..00000000 --- a/tests/optimizer/test_stop_window.py +++ /dev/null @@ -1,238 +0,0 @@ -import pytest -import torch - -from src.simulated_bifurcation.optimizer import StopWindow - -TENSOR = torch.tensor([[1.0, 0.5, -1.0], [0.5, 0.0, 1.0], [-1.0, 1.0, -2.0]]) -CONVERGENCE_THRESHOLD = 3 -SPINS = 3 -AGENTS = 2 -SCENARIO = [ - torch.tensor( - [ - [-1, -1], - [1, -1], - [1, -1], - ], - dtype=torch.float32, - ), - torch.tensor( - [ - [-1, -1], - [-1, 1], - [1, -1], - ], - dtype=torch.float32, - ), - torch.tensor( - [ - [-1, 1], - [1, -1], - [-1, 1], - ], - dtype=torch.float32, - ), - torch.tensor( - [ - [-1, -1], - [1, 1], - [-1, -1], - ], - dtype=torch.float32, - ), - torch.tensor( - [ - [-1, 1], - [1, -1], - [-1, 1], - ], - dtype=torch.float32, - ), -] - - -def test_init_window(): - window = StopWindow( - TENSOR, - AGENTS, - CONVERGENCE_THRESHOLD, - dtype=torch.float32, - device="cpu", - verbose=False, - ) - assert window.n_spins == 3 - assert window.n_agents == 2 - assert window.convergence_threshold == 3 - assert window.shape == (3, 2) - assert torch.equal(window.final_spins, torch.zeros((3, 2))) - - -def test_wrong_convergence_threshold_value(): - with pytest.raises(TypeError): - # noinspection PyTypeChecker - StopWindow( - TENSOR, AGENTS, 30.0, dtype=torch.float32, device="cpu", verbose=False - ) - with pytest.raises(ValueError): - StopWindow(TENSOR, AGENTS, 0, dtype=torch.float32, device="cpu", verbose=False) - with pytest.raises(ValueError): - StopWindow( - TENSOR, AGENTS, -42, dtype=torch.float32, device="cpu", verbose=False - ) - with pytest.raises(ValueError): - StopWindow( - TENSOR, AGENTS, 2**15, dtype=torch.float32, device="cpu", verbose=False - ) - - -def test_use_scenario(): - """ - Ground state is degenerate: [-1, 1, -1] and [1, -1, 1] - both reach the minimal energy value -6. - - Test of the stop window's behavior on 2 agents: - - agent 1 converges to an optimal vector from step 3; - - agent 2 oscillates in the optimal space from step 2. - """ - window = StopWindow( - TENSOR, - AGENTS, - CONVERGENCE_THRESHOLD, - dtype=torch.float32, - device="cpu", - verbose=False, - ) - - # Initial state - assert torch.equal( - window.previously_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.all(torch.isinf(window.energies)) - - # First update - window.update(SCENARIO[0]) - assert window.must_continue() - assert not window.has_bifurcated_spins() - assert torch.equal(window.get_bifurcated_spins(SCENARIO[0]), SCENARIO[0]) - assert torch.equal(window.energies, torch.tensor([2.0, 0.0])) - assert torch.equal(window.final_spins, torch.zeros((3, 2))) - assert torch.equal(window.stability, torch.tensor([0, 0], dtype=torch.int16)) - assert torch.equal( - window.newly_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.stable_agents, torch.tensor([False, False], dtype=torch.bool) - ) - - # Second update - assert torch.equal( - window.previously_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - window.update(SCENARIO[1]) - assert window.must_continue() - assert not window.has_bifurcated_spins() - assert torch.equal(window.get_bifurcated_spins(SCENARIO[1]), SCENARIO[1]) - assert torch.equal(window.energies, torch.tensor([0.0, -6.0])) - assert torch.equal(window.final_spins, torch.zeros((3, 2))) - assert torch.equal(window.stability, torch.tensor([0, 0], dtype=torch.int16)) - assert torch.equal( - window.newly_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.stable_agents, torch.tensor([False, False], dtype=torch.bool) - ) - - # Third update - assert torch.equal( - window.previously_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - window.update(SCENARIO[2]) - assert window.must_continue() - assert not window.has_bifurcated_spins() - assert torch.equal(window.get_bifurcated_spins(SCENARIO[2]), SCENARIO[2]) - assert torch.equal(window.energies, torch.tensor([-6.0, -6.0])) - assert torch.equal(window.final_spins, torch.zeros((3, 2))) - assert torch.equal(window.stability, torch.tensor([0, 1], dtype=torch.int16)) - assert torch.equal( - window.newly_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - assert torch.equal( - window.stable_agents, torch.tensor([False, True], dtype=torch.bool) - ) - - # Fourth update - assert torch.equal( - window.previously_bifurcated, torch.tensor([False, False], dtype=torch.bool) - ) - window.update(SCENARIO[3]) - assert window.must_continue() - assert window.has_bifurcated_spins() - assert torch.equal(window.get_bifurcated_spins(SCENARIO[3]), SCENARIO[3]) - assert torch.equal(window.energies, torch.tensor([-6.0, -6.0])) - assert torch.equal( - window.final_spins, - torch.tensor( - [ - [0, -1], - [0, 1], - [0, -1], - ], - dtype=torch.float32, - ), - ) - assert torch.equal(window.stability, torch.tensor([1, 2], dtype=torch.float32)) - assert torch.equal( - window.newly_bifurcated, torch.tensor([False, True], dtype=torch.bool) - ) - assert torch.equal(window.bifurcated, torch.tensor([False, True], dtype=torch.bool)) - assert torch.equal( - window.stable_agents, torch.tensor([True, True], dtype=torch.bool) - ) - - # Fourth update - assert torch.equal( - window.previously_bifurcated, torch.tensor([False, True], dtype=torch.bool) - ) - window.update(SCENARIO[4]) - assert not window.must_continue() - assert window.has_bifurcated_spins() - assert torch.equal( - window.get_bifurcated_spins(SCENARIO[4]), - torch.tensor( - [ - [-1, -1], - [1, 1], - [-1, -1], - ], - dtype=torch.float32, - ), - ) - assert torch.equal(window.energies, torch.tensor([-6.0, -6.0])) - assert torch.equal( - window.final_spins, - torch.tensor( - [ - [-1, -1], - [1, 1], - [-1, -1], - ], - dtype=torch.float32, - ), - ) - assert torch.equal(window.stability, torch.tensor([2, 2], dtype=torch.int16)) - assert torch.equal( - window.newly_bifurcated, torch.tensor([True, False], dtype=torch.bool) - ) - assert torch.equal(window.bifurcated, torch.tensor([True, True], dtype=torch.bool)) - assert torch.equal( - window.stable_agents, torch.tensor([True, True], dtype=torch.bool) - )