Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/saev/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train(
if cfg.slurm_acct:
executor = submitit.SlurmExecutor(folder=cfg.log_to)
executor.update_parameters(
time=int(cfg.n_hours * 60),
time=int(4500),
partition=cfg.slurm_partition,
gpus_per_node=1,
ntasks_per_node=1,
Expand Down
8 changes: 4 additions & 4 deletions src/saev/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ class DataLoad:
"""Maximum value for activations; activations will be clamped to within [-clamp, clamp]`."""
n_random_samples: int = 2**19
"""Number of random samples used to calculate approximate dataset means at startup."""
scale_mean: bool | str = True
"""Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath."""
scale_norm: bool | str = True
"""Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath."""
scale_mean: bool | str = False
"""Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath. Changed from True to False"""
scale_norm: bool | str = False
"""Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath. Changed from True to False"""


@beartype.beartype
Expand Down
83 changes: 43 additions & 40 deletions src/saev/interactive/features.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import marimo

__generated_with = "0.9.32"
__generated_with = "0.13.15"
app = marimo.App(width="full")


@app.cell
def __():
def _():
import json
import os
import random
Expand All @@ -17,15 +17,15 @@ def __():
import torch
import tqdm

return json, mo, np, os, pl, plt, random, torch, tqdm
return json, mo, np, os, plt, torch, tqdm


@app.cell
def __(mo, os):
def _(mo, os):
def make_ckpt_dropdown():
try:
choices = sorted(
os.listdir("/research/nfs_su_809/workspace/stevens.994/saev/features")
os.listdir("/users/PZS1151/igwilson99/local/src/saev/checkpoints/")
)

except FileNotFoundError:
Expand All @@ -34,32 +34,32 @@ def make_ckpt_dropdown():
return mo.ui.dropdown(choices, label="Checkpoint:")

ckpt_dropdown = make_ckpt_dropdown()
return ckpt_dropdown, make_ckpt_dropdown
return (ckpt_dropdown,)


@app.cell
def __(ckpt_dropdown, mo):
def _(ckpt_dropdown, mo):
mo.hstack([ckpt_dropdown], justify="start")
return


@app.cell
def __(ckpt_dropdown, mo):
def _(ckpt_dropdown, mo):
mo.stop(
ckpt_dropdown.value is None,
mo.md(
"Run `uv run main.py webapp --help` to fill out at least one checkpoint."
),
)

webapp_dir = f"/research/nfs_su_809/workspace/stevens.994/saev/features/{ckpt_dropdown.value}/sort_by_patch"
webapp_dir = f"/users/PZS1151/igwilson99/local/src/saev/checkpoints/"

get_i, set_i = mo.state(0)
return get_i, set_i, webapp_dir


@app.cell
def __(mo):
def _(mo):
sort_by_freq_btn = mo.ui.run_button(label="Sort by frequency")

sort_by_value_btn = mo.ui.run_button(label="Sort by value")
Expand All @@ -69,15 +69,15 @@ def __(mo):


@app.cell
def __(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn):
def _(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn):
mo.hstack(
[sort_by_freq_btn, sort_by_value_btn, sort_by_latent_btn], justify="start"
)
return


@app.cell
def __(
def _(
json,
mo,
os,
Expand Down Expand Up @@ -111,11 +111,11 @@ def get_neurons() -> list[dict]:
neurons = sorted(neurons, key=lambda dct: dct["log10_value"], reverse=True)

mo.md(f"Found {len(neurons)} saved neurons.")
return get_neurons, neurons
return (neurons,)


@app.cell
def __(mo, neurons, set_i):
def _(mo, neurons, set_i):
next_button = mo.ui.button(
label="Next",
on_change=lambda _: set_i(lambda v: (v + 1) % len(neurons)),
Expand All @@ -129,7 +129,7 @@ def __(mo, neurons, set_i):


@app.cell
def __(get_i, mo, neurons, set_i):
def _(get_i, mo, neurons, set_i):
neuron_slider = mo.ui.slider(
0,
len(neurons),
Expand All @@ -141,12 +141,12 @@ def __(get_i, mo, neurons, set_i):


@app.cell
def __():
def _():
return


@app.cell
def __(
def _(
display_info,
get_i,
mo,
Expand All @@ -165,12 +165,12 @@ def __(


@app.cell
def __():
def _():
return


@app.cell
def __(get_i, mo, neurons):
def _(get_i, mo, neurons):
def display_info(log10_freq: float, log10_value: float, neuron: int):
return mo.md(
f"Neuron {neuron} ({get_i()}/{len(neurons)}; {get_i() / len(neurons) * 100:.1f}%) | Frequency: {10**log10_freq * 100:.3f}% of inputs | Mean Value: {10**log10_value:.3f}"
Expand All @@ -180,7 +180,7 @@ def display_info(log10_freq: float, log10_value: float, neuron: int):


@app.cell
def __(mo, webapp_dir):
def _(mo, webapp_dir):
def show_img(n: int, i: int):
label = "No label found."
try:
Expand All @@ -195,7 +195,7 @@ def show_img(n: int, i: int):


@app.cell
def __(get_i, mo, neurons, show_img):
def _(get_i, mo, neurons, show_img):
n = neurons[get_i()]["neuron"]

mo.vstack([
Expand Down Expand Up @@ -250,21 +250,21 @@ def __(get_i, mo, neurons, show_img):
widths="equal",
),
])
return (n,)
return


@app.cell
def __(os, torch, webapp_dir):
def _(os, torch, webapp_dir):
sparsity_fpath = os.path.join(webapp_dir, "sparsity.pt")
sparsity = torch.load(sparsity_fpath, weights_only=True, map_location="cpu")

values_fpath = os.path.join(webapp_dir, "mean_values.pt")
values = torch.load(values_fpath, weights_only=True, map_location="cpu")
return sparsity, sparsity_fpath, values, values_fpath
return sparsity, values


@app.cell
def __(mo, np, plt, sparsity):
def _(mo, np, plt, sparsity):
def plot_hist(counts):
fig, ax = plt.subplots()
ax.hist(np.log10(counts.numpy() + 1e-9), bins=100)
Expand All @@ -279,17 +279,19 @@ def plot_hist(counts):


@app.cell
def __(mo, plot_hist, values):
mo.md(f"""
def _(mo, plot_hist, values):
mo.md(
f"""
Mean Value Log10

{mo.as_html(plot_hist(values))}
""")
"""
)
return


@app.cell
def __(np, plt, sparsity, values):
def _(np, plt, sparsity, values):
def plot_dist(
min_log_sparsity: float,
max_log_sparsity: float,
Expand Down Expand Up @@ -341,53 +343,54 @@ def plot_dist(


@app.cell
def __(mo, plot_dist, sparsity_slider, value_slider):
mo.md(f"""
def _(mo, plot_dist, sparsity_slider, value_slider):
mo.md(
f"""
Log Sparsity Range: {sparsity_slider}
{sparsity_slider.value}

Log Value Range: {value_slider}
{value_slider.value}

{mo.as_html(plot_dist(sparsity_slider.value[0], sparsity_slider.value[1], value_slider.value[0], value_slider.value[1]))}
""")
"""
)
return


@app.cell
def __(mo):
def _(mo):
sparsity_slider = mo.ui.range_slider(start=-8, stop=0, step=0.1, value=[-6, -1])
return (sparsity_slider,)


@app.cell
def __(mo):
def _(mo):
value_slider = mo.ui.range_slider(start=-3, stop=1, step=0.1, value=[-0.75, 1.0])
return (value_slider,)


@app.cell
def __():
def _():
return


@app.cell
def __():
def _():
return


@app.cell
def __():
def _():
return


@app.cell
def __():
def _():
return


@app.cell
def __():
def _():
return


Expand Down
Loading