From 50e080179f0fe3a119cd655d5a848d27c8cd9371 Mon Sep 17 00:00:00 2001 From: Xing Han Lu <21180505+xhluca@users.noreply.github.com> Date: Thu, 31 Aug 2023 12:50:48 -0400 Subject: [PATCH] Fix another error in pytorch-ddp-accelerate-transformers.md --- pytorch-ddp-accelerate-transformers.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch-ddp-accelerate-transformers.md b/pytorch-ddp-accelerate-transformers.md index 07ac39f725..a998e2b244 100644 --- a/pytorch-ddp-accelerate-transformers.md +++ b/pytorch-ddp-accelerate-transformers.md @@ -41,7 +41,7 @@ class BasicNet(nn.Module): self.fc2 = nn.Linear(128, 10) self.act = F.relu - def forward(self, x): + def forward(self, x, labels=None): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) x = F.max_pool2d(x, 2) @@ -54,7 +54,9 @@ class BasicNet(nn.Module): return output ``` -We define the training device (`cuda`): +Note we specified a `labels=None`; this avoids an error when we try to pass a `labels` keyword argument later. However, it is not being used here. + +Now, we define the training device (`cuda`): ```python device = "cuda"