@@ -35,6 +35,7 @@ def print_mean_std(name, v):
3535 # Always use np.float32, to avoid floating errors in descaling + stats.
3636 v = jsa .asarray (data , dtype = np .float32 )
3737 m , s = np .mean (v ), np .std (v )
38+ # print(data)
3839 print (f"{ name } : MEAN({ m :.4f} ) / STD({ s :.4f} ) / SCALE({ scale :.4f} )" )
3940
4041
@@ -119,10 +120,10 @@ def data_stream():
119120 batches = data_stream ()
120121 params = init_random_params (param_scale , layer_sizes )
121122 # Transform parameters to `ScaledArray` and proper dtype.
122- params = jsa .as_scaled_array (params , scale = scale_dtype (1 ))
123+ params = jsa .as_scaled_array (params , scale = scale_dtype (param_scale ))
123124 params = jax .tree_map (lambda v : v .astype (training_dtype ), params , is_leaf = jsa .core .is_scaled_leaf )
124125
125- @jit
126+ # @jit
126127 @jsa .autoscale
127128 def update (params , batch ):
128129 grads = grad (loss )(params , batch )
@@ -150,9 +151,9 @@ def update(params, batch):
150151 epoch_time = time .time () - start_time
151152
152153 # Evaluation in float32, for consistency.
153- raw_params = jsa .asarray (params , dtype = np .float32 )
154- train_acc = accuracy (raw_params , (train_images , train_labels ))
155- test_acc = accuracy (raw_params , (test_images , test_labels ))
156- print (f"Epoch { epoch } in { epoch_time :0.2f} sec" )
157- print (f"Training set accuracy { train_acc :0.5f} " )
158- print (f"Test set accuracy { test_acc :0.5f} " )
154+ # raw_params = jsa.asarray(params, dtype=np.float32)
155+ # train_acc = accuracy(raw_params, (train_images, train_labels))
156+ # test_acc = accuracy(raw_params, (test_images, test_labels))
157+ # print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
158+ # print(f"Training set accuracy {train_acc:0.5f}")
159+ # print(f"Test set accuracy {test_acc:0.5f}")
0 commit comments