From f6e61e5bce40e950e391333f24a17f4e28e12083 Mon Sep 17 00:00:00 2001 From: Isaac Wilson Date: Wed, 18 Jun 2025 12:05:07 -0400 Subject: [PATCH 1/2] Add siglip functionality, syntax changes, and portability changes --- src/saev/interactive/metrics.py | 112 ++++++++++++++++---------------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/src/saev/interactive/metrics.py b/src/saev/interactive/metrics.py index 7205d89..5c73267 100644 --- a/src/saev/interactive/metrics.py +++ b/src/saev/interactive/metrics.py @@ -1,11 +1,11 @@ import marimo -__generated_with = "0.9.32" +__generated_with = "0.13.15" app = marimo.App(width="medium") @app.cell -def __(): +def _(): import json import os @@ -22,43 +22,43 @@ def __(): @app.cell -def __(mo): +def _(mo): mo.md( """ - # SAE Metrics Explorer + # SAE Metrics Explorer - This notebook helps you analyze and compare SAE training runs from WandB. + This notebook helps you analyze and compare SAE training runs from WandB. - ## Setup Instructions + ## Setup Instructions - 1. Edit the configuration cell at the top to set your WandB username and project - 2. Make sure you have access to the original ViT activation shards - 3. Use the filters to narrow down which models to compare + 1. Edit the configuration cell at the top to set your WandB username and project + 2. Make sure you have access to the original ViT activation shards + 3. Use the filters to narrow down which models to compare - ## Troubleshooting + ## Troubleshooting - - **Missing data error**: This notebook needs access to the original ViT activation shards - - **No runs found**: Check your WandB username, project name, and tag filter - """ + - **Missing data error**: This notebook needs access to the original ViT activation shards + - **No runs found**: Check your WandB username, project name, and tag filter + """ ) return @app.cell -def __(): - WANDB_USERNAME = "samuelstevens" +def _(): + WANDB_USERNAME = "igwilson99-the-ohio-state-university" WANDB_PROJECT = "saev" return WANDB_PROJECT, WANDB_USERNAME @app.cell -def __(mo): - tag_input = mo.ui.text(value="classification-v1.0", label="Sweep Tag:") +def _(mo): + tag_input = mo.ui.text(value="full_eval", label="Sweep Tag:") return (tag_input,) @app.cell -def __(WANDB_PROJECT, WANDB_USERNAME, mo, tag_input): +def _(WANDB_PROJECT, WANDB_USERNAME, mo, tag_input): mo.vstack([ mo.md( f"Look at [{WANDB_USERNAME}/{WANDB_PROJECT} on WandB](https://wandb.ai/{WANDB_USERNAME}/{WANDB_PROJECT}/table) to pick your tag." @@ -69,23 +69,32 @@ def __(WANDB_PROJECT, WANDB_USERNAME, mo, tag_input): @app.cell -def __(alt, df, mo): +def _(mo): + mo.md(r"""When using the section below, the default is the mse or mean squared error, but if you want the the mean absolute error, the mse needs commented out and summary loss needs commented back in as well as y=alt.Y("summary/mse") needs changed to y=alt.Y("summary/loss")""") + return + + +@app.cell +def _(alt, df, mo): chart = mo.ui.altair_chart( alt.Chart( df.select( "summary/eval/l0", - "summary/losses/mse", + "summary/mse", + # "summary/loss", "id", - "config/sae/sparsity_coeff", + # "config/sae/sparsity_coeff", "config/lr", "config/sae/d_sae", "model_key", ) + # Modified summary/losses/mse to summary/loss to match wandb could also be changed to summary/mse ) .mark_point() .encode( x=alt.X("summary/eval/l0"), - y=alt.Y("summary/losses/mse"), + y=alt.Y("summary/mse"), + # If changing over to mse summary/loss needs changed to summary/mse tooltip=["id", "config/lr"], color="config/lr:Q", # shape="config/sae/sparsity_coeff:N", @@ -98,7 +107,7 @@ def __(alt, df, mo): @app.cell -def __(chart, df, mo, np, plot_dist, plt): +def _(chart, df, mo, np, plot_dist, plt): mo.stop( len(chart.value) < 2, mo.md( @@ -160,36 +169,23 @@ def __(chart, df, mo, np, plot_dist, plt): scatter_fig.tight_layout() hist_fig.tight_layout() - return ( - bins, - freq_hist_ax, - freqs, - hist_axes, - hist_fig, - id, - scatter_ax, - scatter_axes, - scatter_fig, - sub_df, - values, - values_hist_ax, - ) + return hist_fig, scatter_fig @app.cell -def __(scatter_fig): +def _(scatter_fig): scatter_fig return @app.cell -def __(hist_fig): +def _(hist_fig): hist_fig return @app.cell -def __(chart, df, pl): +def _(chart, df, pl): df.join(chart.value.select("id"), on="id", how="inner").sort( by="summary/eval/l0" ).select("id", pl.selectors.starts_with("config/")) @@ -197,7 +193,7 @@ def __(chart, df, pl): @app.cell -def __(Float, beartype, jaxtyped, np): +def _(Float, beartype, jaxtyped, np): @jaxtyped(typechecker=beartype.beartype) def plot_dist( freqs: Float[np.ndarray, " d_sae"], @@ -248,7 +244,9 @@ def plot_dist( @app.cell -def __( +def _( + WANDB_PROJECT, + WANDB_USERNAME, beartype, get_data_key, get_model_key, @@ -291,9 +289,10 @@ def find_metadata(shard_root: str): def make_df(tag: str): filters = {} if tag: - filters["config.tag"] = tag - runs = wandb.Api().runs(path="samuelstevens/saev", filters=filters) - + filters["tags"] = tag + # Changed config.tag to tags + runs = wandb.Api().runs(path=f"{WANDB_USERNAME}/{WANDB_PROJECT}", filters=filters) + # Changed hardcoded username and project to variables for better portability rows = [] for run in mo.status.progress_bar( runs, @@ -303,7 +302,6 @@ def make_df(tag: str): ): row = {} row["id"] = run.id - row.update(**{ f"summary/{key}": value for key, value in run.summary.items() }) @@ -364,35 +362,35 @@ def make_df(tag: str): return df df = make_df(tag_input.value) - return MetadataAccessError, df, find_metadata, make_df + return (df,) @app.cell -def __(beartype): +def _(beartype): @beartype.beartype def get_model_key(metadata: dict[str, object]) -> str: family = next( metadata[key] for key in ("vit_family", "model_family") if key in metadata ) - ckpt = next( metadata[key] for key in ("vit_ckpt", "model_ckpt") if key in metadata ) - if family == "dinov2" and ckpt == "dinov2_vitb14_reg": return "DINOv2 ViT-B/14 (reg)" if family == "clip" and ckpt == "ViT-B-16/openai": return "CLIP ViT-B/16" if family == "clip" and ckpt == "hf-hub:imageomics/bioclip": return "BioCLIP ViT-B/16" - + if family == "siglip" and ckpt == "hf-hub:timm/ViT-B-16-SigLIP2-256": + return "SigLIP2-256 ViT-B/16" print(f"Unknown model: {(family, ckpt)}") return ckpt @beartype.beartype def get_data_key(metadata: dict[str, object]) -> str | None: if ( - "train_mini" in metadata["data"] + "train" in metadata["data"] + # Removed _mini from train_mini and "ImageFolderDataset" in metadata["data"] ): return "iNat21" @@ -407,7 +405,7 @@ def get_data_key(metadata: dict[str, object]) -> str | None: @app.cell -def __(Float, json, np, os): +def _(Float, json, np, os): def load_freqs(run) -> Float[np.ndarray, " d_sae"]: try: for artifact in run.logged_artifacts(): @@ -446,7 +444,7 @@ def load_mean_values(run) -> Float[np.ndarray, " d_sae"]: @app.cell -def __(df): +def _(df): df.drop( "config/log_every", "config/slurm_acct", @@ -454,11 +452,13 @@ def __(df): "config/n_workers", "config/wandb_project", "config/track", - "config/slurm", + "config/slurm_partition", "config/log_to", "config/ckpt_path", - "config/sae/ghost_grads", + "config/sae/remove_parallel_grads", ) + # Changed ghost_grads to remove_parallel_grads as there is no ghost_grad under sae + # Changed slurm to slurm_partition return From 5dce30db7719960f27dfe51adeeb2792572e0166 Mon Sep 17 00:00:00 2001 From: isaac_w_dev Date: Wed, 18 Jun 2025 16:02:05 -0400 Subject: [PATCH 2/2] Debug commands added --- src/saev/__main__.py | 2 +- src/saev/config.py | 8 +-- src/saev/interactive/features.py | 83 +++++++++++++++++--------------- src/saev/visuals.py | 42 ++++++++++++++-- 4 files changed, 85 insertions(+), 50 deletions(-) diff --git a/src/saev/__main__.py b/src/saev/__main__.py index f8cbae6..a9fb7ed 100644 --- a/src/saev/__main__.py +++ b/src/saev/__main__.py @@ -61,7 +61,7 @@ def train( if cfg.slurm_acct: executor = submitit.SlurmExecutor(folder=cfg.log_to) executor.update_parameters( - time=int(cfg.n_hours * 60), + time=int(4500), partition=cfg.slurm_partition, gpus_per_node=1, ntasks_per_node=1, diff --git a/src/saev/config.py b/src/saev/config.py index ea8774b..d61dd79 100644 --- a/src/saev/config.py +++ b/src/saev/config.py @@ -144,10 +144,10 @@ class DataLoad: """Maximum value for activations; activations will be clamped to within [-clamp, clamp]`.""" n_random_samples: int = 2**19 """Number of random samples used to calculate approximate dataset means at startup.""" - scale_mean: bool | str = True - """Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath.""" - scale_norm: bool | str = True - """Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath.""" + scale_mean: bool | str = False + """Whether to subtract approximate dataset means from examples. If a string, manually load from the filepath. Changed from True to False""" + scale_norm: bool | str = False + """Whether to scale average dataset norm to sqrt(d_vit). If a string, manually load from the filepath. Changed from True to False""" @beartype.beartype diff --git a/src/saev/interactive/features.py b/src/saev/interactive/features.py index d4bdb0e..632e4f7 100644 --- a/src/saev/interactive/features.py +++ b/src/saev/interactive/features.py @@ -1,11 +1,11 @@ import marimo -__generated_with = "0.9.32" +__generated_with = "0.13.15" app = marimo.App(width="full") @app.cell -def __(): +def _(): import json import os import random @@ -17,15 +17,15 @@ def __(): import torch import tqdm - return json, mo, np, os, pl, plt, random, torch, tqdm + return json, mo, np, os, plt, torch, tqdm @app.cell -def __(mo, os): +def _(mo, os): def make_ckpt_dropdown(): try: choices = sorted( - os.listdir("/research/nfs_su_809/workspace/stevens.994/saev/features") + os.listdir("/users/PZS1151/igwilson99/local/src/saev/checkpoints/") ) except FileNotFoundError: @@ -34,17 +34,17 @@ def make_ckpt_dropdown(): return mo.ui.dropdown(choices, label="Checkpoint:") ckpt_dropdown = make_ckpt_dropdown() - return ckpt_dropdown, make_ckpt_dropdown + return (ckpt_dropdown,) @app.cell -def __(ckpt_dropdown, mo): +def _(ckpt_dropdown, mo): mo.hstack([ckpt_dropdown], justify="start") return @app.cell -def __(ckpt_dropdown, mo): +def _(ckpt_dropdown, mo): mo.stop( ckpt_dropdown.value is None, mo.md( @@ -52,14 +52,14 @@ def __(ckpt_dropdown, mo): ), ) - webapp_dir = f"/research/nfs_su_809/workspace/stevens.994/saev/features/{ckpt_dropdown.value}/sort_by_patch" + webapp_dir = f"/users/PZS1151/igwilson99/local/src/saev/checkpoints/" get_i, set_i = mo.state(0) return get_i, set_i, webapp_dir @app.cell -def __(mo): +def _(mo): sort_by_freq_btn = mo.ui.run_button(label="Sort by frequency") sort_by_value_btn = mo.ui.run_button(label="Sort by value") @@ -69,7 +69,7 @@ def __(mo): @app.cell -def __(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn): +def _(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn): mo.hstack( [sort_by_freq_btn, sort_by_value_btn, sort_by_latent_btn], justify="start" ) @@ -77,7 +77,7 @@ def __(mo, sort_by_freq_btn, sort_by_latent_btn, sort_by_value_btn): @app.cell -def __( +def _( json, mo, os, @@ -111,11 +111,11 @@ def get_neurons() -> list[dict]: neurons = sorted(neurons, key=lambda dct: dct["log10_value"], reverse=True) mo.md(f"Found {len(neurons)} saved neurons.") - return get_neurons, neurons + return (neurons,) @app.cell -def __(mo, neurons, set_i): +def _(mo, neurons, set_i): next_button = mo.ui.button( label="Next", on_change=lambda _: set_i(lambda v: (v + 1) % len(neurons)), @@ -129,7 +129,7 @@ def __(mo, neurons, set_i): @app.cell -def __(get_i, mo, neurons, set_i): +def _(get_i, mo, neurons, set_i): neuron_slider = mo.ui.slider( 0, len(neurons), @@ -141,12 +141,12 @@ def __(get_i, mo, neurons, set_i): @app.cell -def __(): +def _(): return @app.cell -def __( +def _( display_info, get_i, mo, @@ -165,12 +165,12 @@ def __( @app.cell -def __(): +def _(): return @app.cell -def __(get_i, mo, neurons): +def _(get_i, mo, neurons): def display_info(log10_freq: float, log10_value: float, neuron: int): return mo.md( f"Neuron {neuron} ({get_i()}/{len(neurons)}; {get_i() / len(neurons) * 100:.1f}%) | Frequency: {10**log10_freq * 100:.3f}% of inputs | Mean Value: {10**log10_value:.3f}" @@ -180,7 +180,7 @@ def display_info(log10_freq: float, log10_value: float, neuron: int): @app.cell -def __(mo, webapp_dir): +def _(mo, webapp_dir): def show_img(n: int, i: int): label = "No label found." try: @@ -195,7 +195,7 @@ def show_img(n: int, i: int): @app.cell -def __(get_i, mo, neurons, show_img): +def _(get_i, mo, neurons, show_img): n = neurons[get_i()]["neuron"] mo.vstack([ @@ -250,21 +250,21 @@ def __(get_i, mo, neurons, show_img): widths="equal", ), ]) - return (n,) + return @app.cell -def __(os, torch, webapp_dir): +def _(os, torch, webapp_dir): sparsity_fpath = os.path.join(webapp_dir, "sparsity.pt") sparsity = torch.load(sparsity_fpath, weights_only=True, map_location="cpu") values_fpath = os.path.join(webapp_dir, "mean_values.pt") values = torch.load(values_fpath, weights_only=True, map_location="cpu") - return sparsity, sparsity_fpath, values, values_fpath + return sparsity, values @app.cell -def __(mo, np, plt, sparsity): +def _(mo, np, plt, sparsity): def plot_hist(counts): fig, ax = plt.subplots() ax.hist(np.log10(counts.numpy() + 1e-9), bins=100) @@ -279,17 +279,19 @@ def plot_hist(counts): @app.cell -def __(mo, plot_hist, values): - mo.md(f""" +def _(mo, plot_hist, values): + mo.md( + f""" Mean Value Log10 {mo.as_html(plot_hist(values))} - """) + """ + ) return @app.cell -def __(np, plt, sparsity, values): +def _(np, plt, sparsity, values): def plot_dist( min_log_sparsity: float, max_log_sparsity: float, @@ -341,53 +343,54 @@ def plot_dist( @app.cell -def __(mo, plot_dist, sparsity_slider, value_slider): - mo.md(f""" +def _(mo, plot_dist, sparsity_slider, value_slider): + mo.md( + f""" Log Sparsity Range: {sparsity_slider} {sparsity_slider.value} Log Value Range: {value_slider} {value_slider.value} - {mo.as_html(plot_dist(sparsity_slider.value[0], sparsity_slider.value[1], value_slider.value[0], value_slider.value[1]))} - """) + """ + ) return @app.cell -def __(mo): +def _(mo): sparsity_slider = mo.ui.range_slider(start=-8, stop=0, step=0.1, value=[-6, -1]) return (sparsity_slider,) @app.cell -def __(mo): +def _(mo): value_slider = mo.ui.range_slider(start=-3, stop=1, step=0.1, value=[-0.75, 1.0]) return (value_slider,) @app.cell -def __(): +def _(): return @app.cell -def __(): +def _(): return @app.cell -def __(): +def _(): return @app.cell -def __(): +def _(): return @app.cell -def __(): +def _(): return diff --git a/src/saev/visuals.py b/src/saev/visuals.py index c2a11dc..7ec4680 100644 --- a/src/saev/visuals.py +++ b/src/saev/visuals.py @@ -263,28 +263,49 @@ def get_topk_patch(cfg: config.Visuals) -> TopKPatch: Returns: A tuple of TopKPatch and m randomly sampled activation distributions. """ + print("DEBUG: Starting get_topk_patch") assert cfg.sort_by == "patch" assert cfg.data.patches == "patches" + print("DEBUG: Loading SAE and dataset") sae = nn.load(cfg.ckpt).to(cfg.device) - dataset = activations.Dataset(cfg.data) + print("DEBUG: SAE loaded successfully") + print("DEBUG: Loading dataset...") + print(f"DEBUG: cfg.data = {cfg.data}") + print(f"DEBUG: cfg.data.shard_root = {cfg.data.shard_root}") + print(f"DEBUG: cfg.data.layer = {cfg.data.layer}") + try: + dataset = activations.Dataset(cfg.data) + print(f"DEBUG: SAE loaded, dataset size: {len(dataset)}") + except Exception as e: + print(f"DEBUG: Dataset loading failed with error: {e}") + print(f"DEBUG: Error type: {type(e)}") + import traceback + traceback.print_exc() + raise + top_values_p = torch.full( (sae.cfg.d_sae, cfg.top_k, dataset.metadata.n_patches_per_img), -1.0, device=cfg.device, ) + print("Top values p:", top_values_p) top_i_im = torch.zeros( (sae.cfg.d_sae, cfg.top_k), dtype=torch.int, device=cfg.device ) - + print("Top i im: ", top_i_im) sparsity_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) + print("Sparsity S:", sparsity_S) mean_values_S = torch.zeros((sae.cfg.d_sae,), device=cfg.device) - + print("Mean Values S:", mean_values_S) distributions_MN = torch.zeros((cfg.n_distributions, len(dataset)), device="cpu") + # cfg.n_distributions + print("Distributions MN", distributions_MN) estimator = PercentileEstimator( cfg.percentile, len(dataset), shape=(sae.cfg.d_sae,) ) + print("Estimator: ", estimator) batch_size = ( cfg.topk_batch_size @@ -306,12 +327,16 @@ def get_topk_patch(cfg: config.Visuals) -> TopKPatch: for batch in helpers.progress(dataloader, desc="picking top-k"): vit_acts_BD = batch["act"] + print("Vit acts BD", vit_acts_BD) sae_acts_BS = get_sae_acts(vit_acts_BD, sae, cfg) - + print("SAE acts BS:", sae_acts_BS) for sae_act_S in sae_acts_BS: estimator.update(sae_act_S) + print("Estimator Update:", estimator) sae_acts_SB = einops.rearrange(sae_acts_BS, "batch d_sae -> d_sae batch") + print("SAE acts SB:", sae_acts_SB) + distributions_MN[:, batch["image_i"]] = sae_acts_SB[: cfg.n_distributions].to( "cpu" ) @@ -366,22 +391,29 @@ def dump_activations(cfg: config.Visuals): Returns: None. All data is saved to disk. """ + print(f"DEBUG: cfg.dump_to = '{cfg.dump_to}'") + print(f"DEBUG: cfg.sort_by = '{cfg.sort_by}'") + print(f"DEBUG: cfg.root = '{cfg.root}'") if cfg.sort_by == "img": topk = get_topk_img(cfg) elif cfg.sort_by == "patch": + print("DEBUG: About to call get_topk_patch") topk = get_topk_patch(cfg) + print("DEBUG: get_topk_patch completed") else: typing.assert_never(cfg.sort_by) + print(f"DEBUG: Creating directory: {cfg.root}") os.makedirs(cfg.root, exist_ok=True) + print("DEBUG: Saving files...") torch.save(topk.top_values, cfg.top_values_fpath) torch.save(topk.top_i, cfg.top_img_i_fpath) torch.save(topk.mean_values, cfg.mean_values_fpath) torch.save(topk.sparsity, cfg.sparsity_fpath) torch.save(topk.distributions, cfg.distributions_fpath) torch.save(topk.percentiles, cfg.percentiles_fpath) - + print("DEBUG: All files saved successfully!") @jaxtyped(typechecker=beartype.beartype) def plot_activation_distributions(