@@ -94,9 +94,14 @@ def determine_available_memory(self) -> int:
94
94
xpu_get_used_global_memory ,
95
95
)
96
96
97
- total_memory = xpu_get_total_global_memory (self .local_rank )
98
- used_memory = xpu_get_used_global_memory (self .local_rank )
99
- free_memory = xpu_get_free_global_memory (self .local_rank )
97
+ assert self .device_ids [self .local_rank ] is not None , f"device_id is none for rank { self .local_rank } "
98
+ assert (
99
+ len (self .device_ids ) > self .local_rank
100
+ ), f"device number must be greater than local rank, but get device number is { len (self .device_ids )} , rank is { self .local_rank } "
101
+
102
+ total_memory = xpu_get_total_global_memory (int (self .device_ids [self .local_rank ]))
103
+ used_memory = xpu_get_used_global_memory (int (self .device_ids [self .local_rank ]))
104
+ free_memory = xpu_get_free_global_memory (int (self .device_ids [self .local_rank ]))
100
105
101
106
logger .info (
102
107
f"Before warm up, total_memory: { total_memory } , \
@@ -107,7 +112,7 @@ def determine_available_memory(self) -> int:
107
112
self .model_runner .profile_run ()
108
113
109
114
total_available_memory = int (total_memory * self .cache_config .gpu_memory_utilization )
110
- used_memory = xpu_get_used_global_memory (self .local_rank )
115
+ used_memory = xpu_get_used_global_memory (int ( self .device_ids [ self . local_rank ]) )
111
116
available_kv_cache_memory = total_available_memory - used_memory
112
117
model_block_memory_used = self .cal_theortical_kvcache ()
113
118
available_kv_cache_memory += model_block_memory_used * self .parallel_config .total_block_num
0 commit comments