Skip to content

Commit b42d63a

Browse files
committed
wip
1 parent b1cdf1d commit b42d63a

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def print_mean_std(name, v):
3535
# Always use np.float32, to avoid floating errors in descaling + stats.
3636
v = jsa.asarray(data, dtype=np.float32)
3737
m, s = np.mean(v), np.std(v)
38+
# print(data)
3839
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")
3940

4041

@@ -119,10 +120,10 @@ def data_stream():
119120
batches = data_stream()
120121
params = init_random_params(param_scale, layer_sizes)
121122
# Transform parameters to `ScaledArray` and proper dtype.
122-
params = jsa.as_scaled_array(params, scale=scale_dtype(1))
123+
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
123124
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
124125

125-
@jit
126+
# @jit
126127
@jsa.autoscale
127128
def update(params, batch):
128129
grads = grad(loss)(params, batch)
@@ -150,9 +151,9 @@ def update(params, batch):
150151
epoch_time = time.time() - start_time
151152

152153
# Evaluation in float32, for consistency.
153-
raw_params = jsa.asarray(params, dtype=np.float32)
154-
train_acc = accuracy(raw_params, (train_images, train_labels))
155-
test_acc = accuracy(raw_params, (test_images, test_labels))
156-
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
157-
print(f"Training set accuracy {train_acc:0.5f}")
158-
print(f"Test set accuracy {test_acc:0.5f}")
154+
# raw_params = jsa.asarray(params, dtype=np.float32)
155+
# train_acc = accuracy(raw_params, (train_images, train_labels))
156+
# test_acc = accuracy(raw_params, (test_images, test_labels))
157+
# print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
158+
# print(f"Training set accuracy {train_acc:0.5f}")
159+
# print(f"Test set accuracy {test_acc:0.5f}")

0 commit comments

Comments
 (0)