Skip to content

Commit fc93245

Browse files
committed
updated loading in ACtivation Patching in TL Demo to use transformer bridge
1 parent 680d4e7 commit fc93245

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ jobs:
144144
strategy:
145145
matrix:
146146
notebook:
147-
# - "Activation_Patching_in_TL_Demo"
147+
- "Activation_Patching_in_TL_Demo"
148148
# - "Attribution_Patching_Demo"
149149
- "ARENA_Content"
150150
- "Colab_Compatibility"

demos/Activation_Patching_in_TL_Demo.ipynb

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,13 @@
121121
},
122122
{
123123
"cell_type": "code",
124-
"execution_count": 4,
124+
"execution_count": null,
125125
"metadata": {},
126126
"outputs": [],
127127
"source": [
128128
"import transformer_lens\n",
129129
"import transformer_lens.utils as utils\n",
130-
"from transformer_lens.hook_points import (\n",
131-
" HookedRootModule,\n",
132-
" HookPoint,\n",
133-
") # Hooking utilities\n",
134-
"from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"
130+
"from transformer_lens.model_bridge import TransformerBridge"
135131
]
136132
},
137133
{
@@ -196,7 +192,7 @@
196192
},
197193
{
198194
"cell_type": "code",
199-
"execution_count": 8,
195+
"execution_count": null,
200196
"metadata": {},
201197
"outputs": [
202198
{
@@ -215,7 +211,7 @@
215211
}
216212
],
217213
"source": [
218-
"model = HookedTransformer.from_pretrained(\"gpt2-small\")"
214+
"model = TransformerBridge.boot_transformers(\"gpt2\")"
219215
]
220216
},
221217
{
@@ -943,7 +939,7 @@
943939
},
944940
{
945941
"cell_type": "code",
946-
"execution_count": 19,
942+
"execution_count": null,
947943
"metadata": {},
948944
"outputs": [
949945
{
@@ -955,7 +951,7 @@
955951
}
956952
],
957953
"source": [
958-
"attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n",
954+
"attn_only = TransformerBridge.boot_transformers(\"attn-only-2l\")\n",
959955
"batch = 4\n",
960956
"seq_len = 20\n",
961957
"rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n",

0 commit comments

Comments
 (0)