File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -80,7 +80,12 @@ def load_model():
80
80
global model , scaler , label_encoders
81
81
try :
82
82
model = StarlinkTransformer (input_dim = 8 )
83
- model .load_state_dict (torch .load (os .path .join (MODELS_DIR , 'starlink_transformer.pth' )))
83
+
84
+ if torch .cuda .is_available ():
85
+ model .load_state_dict (torch .load (os .path .join (MODELS_DIR , 'starlink_transformer.pth' )))
86
+ else :
87
+ model .load_state_dict (torch .load (os .path .join (MODELS_DIR , 'starlink_transformer.pth' ), map_location = torch .device ('cpu' )))
88
+
84
89
model .eval ()
85
90
86
91
scaler = joblib .load (os .path .join (MODELS_DIR , 'scaler.joblib' ))
@@ -211,4 +216,4 @@ def after_request(response):
211
216
print (" - POST /predict" )
212
217
print ("3. Debug: /debug/cors\n " )
213
218
214
- app .run (host = '0.0.0.0' , port = 35001 , debug = False )
219
+ app .run (host = '0.0.0.0' , port = 35001 , debug = False )
You can’t perform that action at this time.
0 commit comments