1
1
import base64
2
2
from pickle import dumps , loads
3
- from random import randrange
4
3
from typing import Dict , List
5
4
6
5
from .player import Player
@@ -22,15 +21,26 @@ class EvolvablePlayer(Player):
22
21
parent_class = Player
23
22
parent_kwargs = [] # type: List[str]
24
23
24
+ def __init__ (self , seed = None ):
25
+ # Parameter seed is required for reproducibility. Player will throw
26
+ # a warning to the user otherwise.
27
+ super ().__init__ ()
28
+ self .set_seed (seed = seed )
29
+
25
30
def overwrite_init_kwargs (self , ** kwargs ):
26
31
"""Use to overwrite parameters for proper cloning and testing."""
27
32
for k , v in kwargs .items ():
28
33
self .init_kwargs [k ] = v
29
34
30
35
def create_new (self , ** kwargs ):
31
- """Creates a new variant with parameters overwritten by kwargs."""
36
+ """Creates a new variant with parameters overwritten by kwargs. This differs from
37
+ cloning the Player because it propagates a seed forward, and is intended to be
38
+ used by the mutation and crossover methods."""
32
39
init_kwargs = self .init_kwargs .copy ()
33
40
init_kwargs .update (kwargs )
41
+ # Propagate seed forward for reproducibility.
42
+ if "seed" not in kwargs :
43
+ init_kwargs ["seed" ] = self ._random .random_seed_int ()
34
44
return self .__class__ (** init_kwargs )
35
45
36
46
# Serialization and deserialization. You may overwrite to obtain more human readable serializations
@@ -74,15 +84,15 @@ def copy_lists(lists: List[List]) -> List[List]:
74
84
return list (map (list , lists ))
75
85
76
86
77
- def crossover_lists (list1 : List , list2 : List ) -> List :
78
- cross_point = randrange ( len (list1 ))
87
+ def crossover_lists (list1 : List , list2 : List , rng ) -> List :
88
+ cross_point = rng . randint ( 0 , len (list1 ))
79
89
new_list = list (list1 [:cross_point ]) + list (list2 [cross_point :])
80
90
return new_list
81
91
82
92
83
- def crossover_dictionaries (table1 : Dict , table2 : Dict ) -> Dict :
93
+ def crossover_dictionaries (table1 : Dict , table2 : Dict , rng ) -> Dict :
84
94
keys = list (table1 .keys ())
85
- cross_point = randrange ( len (keys ))
95
+ cross_point = rng . randint ( 0 , len (keys ))
86
96
new_items = [(k , table1 [k ]) for k in keys [:cross_point ]]
87
97
new_items += [(k , table2 [k ]) for k in keys [cross_point :]]
88
98
new_table = dict (new_items )
0 commit comments