feat: Enhance device compatibility by dynamically selecting CUDA, MPS, or CPU and updating autocast usage.

This commit is contained in:
GeorgeCaoJ
2025-12-19 09:31:07 +08:00
parent e604451a5c
commit 52879d2968
3 changed files with 23 additions and 6 deletions

View File

@ -204,7 +204,9 @@ class FunASRNano(nn.Module):
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
)
with torch.cuda.amp.autocast(
device_type = next(self.parameters()).device.type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
enabled=True if self.llm_dtype != "fp32" else False,
dtype=dtype_map[self.llm_dtype],
):
@ -624,8 +626,11 @@ class FunASRNano(nn.Module):
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
device_type = torch.device(kwargs.get("device", "cuda")).type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype]
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])