Skip to content
27 changes: 20 additions & 7 deletions notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -129,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -179,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -296,6 +296,17 @@
"unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n",
"unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n",
"\n",
"# Feature map height and width are dynamic\n",
"fm_height = torch.export.Dim(\"fm_height\", min=16, max=256)\n",
"fm_width = torch.export.Dim(\"fm_width\", min=16, max=256)\n",
"dim = torch.export.Dim(\"dim\", min=1, max=16)\n",
"fm_height = 16 * dim\n",
"fm_width = 16 * dim\n",
"\n",
"dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n",
"# iterate through the unet kwargs and set only hidden state kwarg to dynamic\n",
"dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n",
"\n",
"with torch.no_grad():\n",
" with disable_patching():\n",
" text_encoder = torch.export.export_for_training(\n",
Expand All @@ -308,10 +319,12 @@
" args=(text_encoder_input,),\n",
" kwargs=(text_encoder_kwargs),\n",
" ).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,)).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,)).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" vae = pipe.vae\n",
" transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs)).module()\n",
" transformer = torch.export.export_for_training(\n",
" pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer\n",
" ).module()\n",
"models_dict = {}\n",
"models_dict[\"transformer\"] = transformer\n",
"models_dict[\"vae\"] = vae\n",
Expand Down Expand Up @@ -766,7 +779,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
Loading