Skip to content

Commit 560e820

Browse files
committed
Fix image classification module error
1 parent 71defaf commit 560e820

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddlehub/module/cv_module.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def validation_step(self, batch: int, batch_idx: int) -> dict:
7070
'''
7171
images = batch[0]
7272
labels = paddle.unsqueeze(batch[1], axis=-1)
73+
labels = labels.astype('int64')
7374

7475
preds, feature = self(images)
7576

@@ -104,7 +105,7 @@ def predict(self, images: List[np.ndarray], batch_size: int = 1, top_k: int = 1)
104105
batch_data.append(image)
105106
except:
106107
pass
107-
batch_image = np.array(batch_data)
108+
batch_image = np.array(batch_data, dtype='float32')
108109
preds, feature = self(paddle.to_tensor(batch_image))
109110
preds = F.softmax(preds, axis=1).numpy()
110111
pred_idxs = np.argsort(preds)[:, ::-1][:, :top_k]

0 commit comments

Comments
 (0)