Skip to content

Commit a065061

Browse files
authored
Merge pull request #14 from clwheel/starlink
Load model with CPU if GPU is unavailable
2 parents f81c269 + 93a9d00 commit a065061

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

starlink/starlink-inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ def load_model():
8080
global model, scaler, label_encoders
8181
try:
8282
model = StarlinkTransformer(input_dim=8)
83-
model.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'starlink_transformer.pth')))
83+
84+
if torch.cuda.is_available():
85+
model.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'starlink_transformer.pth')))
86+
else:
87+
model.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'starlink_transformer.pth'), map_location=torch.device('cpu')))
88+
8489
model.eval()
8590

8691
scaler = joblib.load(os.path.join(MODELS_DIR, 'scaler.joblib'))
@@ -211,4 +216,4 @@ def after_request(response):
211216
print(" - POST /predict")
212217
print("3. Debug: /debug/cors\n")
213218

214-
app.run(host='0.0.0.0', port=35001, debug=False)
219+
app.run(host='0.0.0.0', port=35001, debug=False)

0 commit comments

Comments
 (0)