From 6af796815ab87c4a3d039b31efbc7c8e4849a412 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Fri, 29 Aug 2025 12:25:05 +0800 Subject: [PATCH 1/8] add xpu --- basicsr/data/prefetch_dataloader.py | 40 +++++++++++++++++++++++++ basicsr/models/base_model.py | 23 ++++++++++++-- basicsr/models/hifacegan_model.py | 2 +- basicsr/models/realesrgan_model.py | 8 ++--- basicsr/models/realesrnet_model.py | 8 ++--- basicsr/models/sr_model.py | 3 +- basicsr/models/video_base_model.py | 4 +-- basicsr/models/video_recurrent_model.py | 2 +- basicsr/train.py | 19 +++++++++--- 9 files changed, 88 insertions(+), 21 deletions(-) diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py index 332abd32f..93e3255c4 100644 --- a/basicsr/data/prefetch_dataloader.py +++ b/basicsr/data/prefetch_dataloader.py @@ -120,3 +120,43 @@ def next(self): def reset(self): self.loader = iter(self.ori_loader) self.preload() + +class XPUPrefetcher(): + """XPU prefetcher. + + It may consume more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.xpu.Stream() + self.device = torch.device('xpu' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.xpu.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index fbf8229f5..710fbb920 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -15,11 +15,23 @@ class BaseModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.device = 'cpu' + self.dtype = torch.float32 + if opt['num_gpu'] != 0: + if torch.cuda.is_available(): + self.device = 'cuda' + if torch.xpu.is_available(): + self.device = 'xpu' self.is_train = opt['is_train'] self.schedulers = [] self.optimizers = [] + def empty_cache(self): + if self.device == 'cuda': + torch.cuda.empty_cache() + elif self.device == 'xpu': + torch.xpu.empty_cache() + def feed_data(self, data): pass @@ -91,11 +103,16 @@ def model_to_device(self, net): Args: net (nn.Module) """ - net = net.to(self.device) + net = net.to(self.device, dtype = self.dtype) if self.opt['dist']: find_unused_parameters = self.opt.get('find_unused_parameters', False) + ids = [0] + if self.device == 'cuda': + ids = [torch.cuda.current_device()] + if self.device == 'xpu': + ids = [torch.xpu.current_device()] net = DistributedDataParallel( - net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + net, device_ids=ids, find_unused_parameters=find_unused_parameters) elif self.opt['num_gpu'] > 1: net = DataParallel(net) return net diff --git a/basicsr/models/hifacegan_model.py b/basicsr/models/hifacegan_model.py index 435a2b179..591508ab4 100644 --- a/basicsr/models/hifacegan_model.py +++ b/basicsr/models/hifacegan_model.py @@ -249,7 +249,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # tentative for out of GPU memory del self.lq del self.output - torch.cuda.empty_cache() + self.empty_cache() if save_img: if self.opt['is_train']: diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py index c74b28fb1..0eb072332 100644 --- a/basicsr/models/realesrgan_model.py +++ b/basicsr/models/realesrgan_model.py @@ -24,8 +24,8 @@ class RealESRGANModel(SRGANModel): def __init__(self, opt): super(RealESRGANModel, self).__init__(opt) - self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts - self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.jpeger = DiffJPEG(differentiable=False).to(self.device) # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().to(self.device) # do usm sharpening self.queue_size = opt.get('queue_size', 180) @torch.no_grad() @@ -40,9 +40,9 @@ def _dequeue_and_enqueue(self): b, c, h, w = self.lq.size() if not hasattr(self, 'queue_lr'): assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' - self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.device) _, c, h, w = self.gt.size() - self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.device) self.queue_ptr = 0 if self.queue_ptr == self.queue_size: # the pool is full # do dequeue and enqueue diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py index f5790918b..1d51f6da3 100644 --- a/basicsr/models/realesrnet_model.py +++ b/basicsr/models/realesrnet_model.py @@ -23,8 +23,8 @@ class RealESRNetModel(SRModel): def __init__(self, opt): super(RealESRNetModel, self).__init__(opt) - self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts - self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.jpeger = DiffJPEG(differentiable=False).to(self.device) # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().to(self.device) # do usm sharpening self.queue_size = opt.get('queue_size', 180) @torch.no_grad() @@ -39,9 +39,9 @@ def _dequeue_and_enqueue(self): b, c, h, w = self.lq.size() if not hasattr(self, 'queue_lr'): assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' - self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.device) _, c, h, w = self.gt.size() - self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.device) self.queue_ptr = 0 if self.queue_ptr == self.queue_size: # the pool is full # do dequeue and enqueue diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 787f1fd2e..6701b41e3 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -215,8 +215,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # tentative for out of GPU memory del self.lq del self.output - torch.cuda.empty_cache() - + self.empty_cache() if save_img: if self.opt['is_train']: save_img_path = osp.join(self.opt['path']['visualization'], img_name, diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index 9f7993a15..5a8146765 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -30,7 +30,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): num_frame_each_folder = Counter(dataset.data_info['folder']) for folder, num_frame in num_frame_each_folder.items(): self.metric_results[folder] = torch.zeros( - num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device=self.device) # initialize the best metric results self._initialize_best_metric_results(dataset_name) # zero self.metric_results @@ -64,7 +64,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): # tentative for out of GPU memory del self.lq del self.output - torch.cuda.empty_cache() + self.empty_cache() if save_img: if self.opt['is_train']: diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py index 796ee57d5..d9aa0a241 100644 --- a/basicsr/models/video_recurrent_model.py +++ b/basicsr/models/video_recurrent_model.py @@ -78,7 +78,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): num_frame_each_folder = Counter(dataset.data_info['folder']) for folder, num_frame in num_frame_each_folder.items(): self.metric_results[folder] = torch.zeros( - num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device=self.device) # initialize the best metric results self._initialize_best_metric_results(dataset_name) # zero self.metric_results diff --git a/basicsr/train.py b/basicsr/train.py index e02d98fe0..7b1ad78e5 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -7,7 +7,7 @@ from basicsr.data import build_dataloader, build_dataset from basicsr.data.data_sampler import EnlargedSampler -from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher, XPUPrefetcher from basicsr.models import build_model from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) @@ -82,8 +82,14 @@ def load_resume_state(opt): if resume_state_path is None: resume_state = None else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + map_location = None + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + map_location = lambda storage, loc: storage.cuda(device_id) + if torch.xpu.is_available(): + device_id = torch.xpu.current_device() + map_location = lambda storage, loc: storage.xpu(device_id) + resume_state = torch.load(resume_state_path, map_location=map_location) check_resume(opt, resume_state['iter']) return resume_state @@ -138,13 +144,18 @@ def train_pipeline(root_path): prefetch_mode = opt['datasets']['train'].get('prefetch_mode') if prefetch_mode is None or prefetch_mode == 'cpu': prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'xpu': + prefetcher = XPUPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for XPUPrefetcher.') elif prefetch_mode == 'cuda': prefetcher = CUDAPrefetcher(train_loader, opt) logger.info(f'Use {prefetch_mode} prefetch dataloader') if opt['datasets']['train'].get('pin_memory') is not True: raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') else: - raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'xpu', 'cuda', 'cpu'.") # training logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') From dec31c70cd5b4559eae39ca19d1396e4f3103ce1 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Fri, 29 Aug 2025 12:50:24 +0800 Subject: [PATCH 2/8] fix num_gpu --- basicsr/utils/img_process_util.py | 4 ++-- basicsr/utils/options.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/basicsr/utils/img_process_util.py b/basicsr/utils/img_process_util.py index 52e02f099..fbc26fc7c 100644 --- a/basicsr/utils/img_process_util.py +++ b/basicsr/utils/img_process_util.py @@ -22,11 +22,11 @@ def filter2D(img, kernel): if kernel.size(0) == 1: # apply the same kernel to all batch images - img = img.view(b * c, 1, ph, pw) + img = img.reshape(b * c, 1, ph, pw) kernel = kernel.view(1, 1, k, k) return F.conv2d(img, kernel, padding=0).view(b, c, h, w) else: - img = img.view(1, b * c, ph, pw) + img = img.reshape(1, b * c, ph, pw) kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 5c7155ecc..45ec94276 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -151,7 +151,10 @@ def parse_options(root_path, is_train=True): opt['name'] = 'debug_' + opt['name'] if opt['num_gpu'] == 'auto': - opt['num_gpu'] = torch.cuda.device_count() + if torch.cuda.is_available(): + opt['num_gpu'] = torch.cuda.device_count() + if torch.xpu.is_available(): + opt['num_gpu'] = torch.xpu.device_count() # datasets for phase, dataset in opt['datasets'].items(): From 67f1507ffc5c140d77ec04800f0bd2d70edfc07e Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Fri, 29 Aug 2025 13:07:10 +0800 Subject: [PATCH 3/8] fix fp64 not support for poisson --- basicsr/data/degradations.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py index 14319605d..0b4c327b6 100644 --- a/basicsr/data/degradations.py +++ b/basicsr/data/degradations.py @@ -634,7 +634,11 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) - out = torch.poisson(img_gray * vals) / vals + device = img_gray.device + if device.type == 'xpu': + out = torch.poisson((img_gray * vals).to('cpu')).to(device) / vals + else: + out = torch.poisson(img * vals) / vals noise_gray = out - img_gray noise_gray = noise_gray.expand(b, 3, h, w) @@ -645,7 +649,11 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] vals = img.new_tensor(vals_list).view(b, 1, 1, 1) - out = torch.poisson(img * vals) / vals + device = img_gray.device + if device.type == 'xpu': + out = torch.poisson((img_gray * vals).to('cpu')).to(device) / vals + else: + out = torch.poisson(img * vals) / vals noise = out - img if cal_gray_noise: noise = noise * (1 - gray_noise) + noise_gray * gray_noise From 0a4fe37fd1d96d65ca4e836f09037eb3e46f4fa6 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Sat, 30 Aug 2025 18:02:51 +0800 Subject: [PATCH 4/8] add dtype option --- basicsr/data/degradations.py | 4 ++-- basicsr/models/base_model.py | 2 +- basicsr/models/realesrgan_model.py | 2 ++ basicsr/models/realesrnet_model.py | 2 ++ basicsr/models/sr_model.py | 12 ++++++------ basicsr/utils/options.py | 6 ++++++ 6 files changed, 19 insertions(+), 9 deletions(-) diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py index 0b4c327b6..6ca358ae1 100644 --- a/basicsr/data/degradations.py +++ b/basicsr/data/degradations.py @@ -649,9 +649,9 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] vals = img.new_tensor(vals_list).view(b, 1, 1, 1) - device = img_gray.device + device = img.device if device.type == 'xpu': - out = torch.poisson((img_gray * vals).to('cpu')).to(device) / vals + out = torch.poisson((img * vals).to('cpu')).to(device) / vals else: out = torch.poisson(img * vals) / vals noise = out - img diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index 710fbb920..873dbc845 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -16,7 +16,7 @@ class BaseModel(): def __init__(self, opt): self.opt = opt self.device = 'cpu' - self.dtype = torch.float32 + self.dtype = opt['dtype'] if opt['num_gpu'] != 0: if torch.cuda.is_available(): self.device = 'cuda' diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py index 0eb072332..7b80251f1 100644 --- a/basicsr/models/realesrgan_model.py +++ b/basicsr/models/realesrgan_model.py @@ -183,6 +183,8 @@ def feed_data(self, data): if 'gt' in data: self.gt = data['gt'].to(self.device) self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.to(self.dtype) + self.gt = self.gt.to(self.dtype) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py index 1d51f6da3..e1ce888d4 100644 --- a/basicsr/models/realesrnet_model.py +++ b/basicsr/models/realesrnet_model.py @@ -181,6 +181,8 @@ def feed_data(self, data): if 'gt' in data: self.gt = data['gt'].to(self.device) self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.to(self.dtype) + self.gt = self.gt.to(self.dtype) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 6701b41e3..94dc6292d 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -43,7 +43,7 @@ def init_training_settings(self): # define network net_g with Exponential Moving Average (EMA) # net_g_ema is used only for testing on one GPU and saving # There is no need to wrap with DistributedDataParallel - self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + self.net_g_ema = build_network(self.opt['network_g']).to(self.device, dtype=self.dtype) # load pretrained model load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: @@ -54,12 +54,12 @@ def init_training_settings(self): # define losses if train_opt.get('pixel_opt'): - self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device, dtype=self.dtype) else: self.cri_pix = None if train_opt.get('perceptual_opt'): - self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device, dtype=self.dtype) else: self.cri_perceptual = None @@ -85,9 +85,9 @@ def setup_optimizers(self): self.optimizers.append(self.optimizer_g) def feed_data(self, data): - self.lq = data['lq'].to(self.device) + self.lq = data['lq'].to(self.device, dtype=self.dtype) if 'gt' in data: - self.gt = data['gt'].to(self.device) + self.gt = data['gt'].to(self.device, dtype=self.dtype) def optimize_parameters(self, current_iter): self.optimizer_g.zero_grad() @@ -144,7 +144,7 @@ def _transform(v, op): elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() - ret = torch.Tensor(tfnp).to(self.device) + ret = torch.Tensor(tfnp).to(self.device, dtype=self.dtype) # if self.precision == 'half': ret = ret.half() return ret diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 45ec94276..013192d6c 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -146,6 +146,12 @@ def parse_options(root_path, is_train=True): opt['auto_resume'] = args.auto_resume opt['is_train'] = is_train + opt['dtype'] = opt.get('dtype', 'float32') + if opt['dtype'] == 'float32': + opt['dtype'] = torch.float32 + elif opt['dtype'] == 'bfloat16': + opt['dtype'] = torch.bfloat16 + # debug setting if args.debug and not opt['name'].startswith('debug'): opt['name'] = 'debug_' + opt['name'] From 95313da5894cf0072e32e1cd3344d7b6af503f31 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Sat, 30 Aug 2025 18:15:04 +0800 Subject: [PATCH 5/8] fix xpu map_location --- basicsr/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basicsr/train.py b/basicsr/train.py index 7b1ad78e5..76376063e 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -88,7 +88,7 @@ def load_resume_state(opt): map_location = lambda storage, loc: storage.cuda(device_id) if torch.xpu.is_available(): device_id = torch.xpu.current_device() - map_location = lambda storage, loc: storage.xpu(device_id) + map_location = f'xpu:{device_id}' resume_state = torch.load(resume_state_path, map_location=map_location) check_resume(opt, resume_state['iter']) return resume_state From c20d9be52b0c61627795862890ccd0a2c88dfb7f Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Sat, 30 Aug 2025 20:27:12 +0800 Subject: [PATCH 6/8] save float tensor on CPU --- basicsr/models/sr_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 94dc6292d..f6333d487 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -264,10 +264,10 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): def get_current_visuals(self): out_dict = OrderedDict() - out_dict['lq'] = self.lq.detach().cpu() - out_dict['result'] = self.output.detach().cpu() + out_dict['lq'] = self.lq.detach().cpu().float() + out_dict['result'] = self.output.detach().cpu().float() if hasattr(self, 'gt'): - out_dict['gt'] = self.gt.detach().cpu() + out_dict['gt'] = self.gt.detach().cpu().float() return out_dict def save(self, epoch, current_iter): From a089547bc00436500fa9a63f53b8b3a87c42eb53 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Sat, 30 Aug 2025 20:53:15 +0800 Subject: [PATCH 7/8] support bfloat16 for srgan model --- basicsr/models/realesrgan_model.py | 1 + basicsr/models/srgan_model.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py index 7b80251f1..54f7528ab 100644 --- a/basicsr/models/realesrgan_model.py +++ b/basicsr/models/realesrgan_model.py @@ -185,6 +185,7 @@ def feed_data(self, data): self.gt_usm = self.usm_sharpener(self.gt) self.lq = self.lq.to(self.dtype) self.gt = self.gt.to(self.dtype) + self.gt_usm = self.gt_usm.to(self.dtype) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py index 45387ca79..31aff667c 100644 --- a/basicsr/models/srgan_model.py +++ b/basicsr/models/srgan_model.py @@ -22,7 +22,7 @@ def init_training_settings(self): # define network net_g with Exponential Moving Average (EMA) # net_g_ema is used only for testing on one GPU and saving # There is no need to wrap with DistributedDataParallel - self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + self.net_g_ema = build_network(self.opt['network_g']).to(self.device, dtype=self.dtype) # load pretrained model load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: @@ -47,22 +47,22 @@ def init_training_settings(self): # define losses if train_opt.get('pixel_opt'): - self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device, dtype=self.dtype) else: self.cri_pix = None if train_opt.get('ldl_opt'): - self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) + self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device, dtype=self.dtype) else: self.cri_ldl = None if train_opt.get('perceptual_opt'): - self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device, dtype=self.dtype) else: self.cri_perceptual = None if train_opt.get('gan_opt'): - self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device, dtype=self.dtype) self.net_d_iters = train_opt.get('net_d_iters', 1) self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) From d00d17665c15516790346a4075adae4d6a9a60cb Mon Sep 17 00:00:00 2001 From: "Luo, Yu" Date: Wed, 29 Oct 2025 12:08:17 +0000 Subject: [PATCH 8/8] support distributed training --- basicsr/utils/dist_util.py | 11 +++++++---- basicsr/utils/options.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py index 0fab887b2..c2f232dd8 100644 --- a/basicsr/utils/dist_util.py +++ b/basicsr/utils/dist_util.py @@ -6,7 +6,6 @@ import torch.distributed as dist import torch.multiprocessing as mp - def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') @@ -20,9 +19,13 @@ def init_dist(launcher, backend='nccl', **kwargs): def _init_dist_pytorch(backend, **kwargs): rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) + if backend == 'xccl': + num_gpus = torch.xpu.device_count() + torch.xpu.set_device(rank % num_gpus) + if backend == 'nccl': + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend) def _init_dist_slurm(backend, port=None): diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 013192d6c..7cef1a8e1 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -102,7 +102,7 @@ def parse_options(root_path, is_train=True): parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--debug', action='store_true') - parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--local-rank', type=int, default=0) parser.add_argument( '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') args = parser.parse_args() @@ -119,7 +119,7 @@ def parse_options(root_path, is_train=True): if args.launcher == 'slurm' and 'dist_params' in opt: init_dist(args.launcher, **opt['dist_params']) else: - init_dist(args.launcher) + init_dist(args.launcher, **opt['dist_params']) opt['rank'], opt['world_size'] = get_dist_info() # random seed