Skip to content

Added support for Moore Threads GPUs #8011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else:
scale_size = (size, size)

image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
if image.device.type == 'musa':
image = image.cpu()
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
image = image.to('musa')
else:
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
Expand Down
11 changes: 8 additions & 3 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
device = torch.device("cpu")
else:
device = pos.device

scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
Expand Down
37 changes: 37 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def get_supported_float8_types():
except:
mlu_available = False

try:
import torch_musa
_ = torch_musa.device_count()
musa_available = torch_musa.is_available()
if musa_available:
logging.info("MUSA device detected: {}".format(torch_musa.get_device_name(0)))
except:
musa_available = False

if args.cpu:
cpu_state = CPUState.CPU

Expand All @@ -151,6 +160,12 @@ def is_mlu():
return True
return False

def is_musa():
global musa_available
if musa_available:
return True
return False

def get_torch_device():
global directml_enabled
global cpu_state
Expand All @@ -168,6 +183,8 @@ def get_torch_device():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
elif is_musa():
return torch.device('musa', torch.musa.current_device())
else:
return torch.device(torch.cuda.current_device())

Expand Down Expand Up @@ -200,6 +217,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
elif is_musa():
stats = torch.musa.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total = torch.musa.mem_get_info(dev)
mem_total_torch = mem_reserved

else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand Down Expand Up @@ -1099,6 +1122,14 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
elif is_musa():
stats = torch.musa.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_musa, _ = torch.musa.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_musa + mem_free_torch

else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
Expand Down Expand Up @@ -1171,6 +1202,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True

if is_musa():
return True

if torch.version.hip:
return True

Expand Down Expand Up @@ -1231,6 +1265,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True

if is_musa():
return True

if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
Expand Down