From 52879d29685aa30b6d07da692eef782a6f4fe86e Mon Sep 17 00:00:00 2001 From: GeorgeCaoJ <851383386@qq.com> Date: Fri, 19 Dec 2025 09:31:07 +0800 Subject: [PATCH] feat: Enhance device compatibility by dynamically selecting CUDA, MPS, or CPU and updating autocast usage. --- demo1.py | 10 ++++++++-- demo2.py | 8 +++++++- model.py | 11 ++++++++--- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/demo1.py b/demo1.py index 6210a44..bc0cc02 100644 --- a/demo1.py +++ b/demo1.py @@ -1,13 +1,19 @@ +import torch from funasr import AutoModel def main(): model_dir = "FunAudioLLM/Fun-ASR-Nano-2512" + device = ( + "cuda:0" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) model = AutoModel( model=model_dir, trust_remote_code=True, remote_code="./model.py", - device="cuda:0", + device=device, ) wav_path = f"{model.model_path}/example/zh.mp3" @@ -28,7 +34,7 @@ def main(): vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 30000}, remote_code="./model.py", - device="cuda:0", + device=device, ) res = model.generate(input=[wav_path], cache={}, batch_size=1) text = res[0]["text"] diff --git a/demo2.py b/demo2.py index f536429..2564875 100644 --- a/demo2.py +++ b/demo2.py @@ -1,9 +1,15 @@ +import torch from model import FunASRNano def main(): model_dir = "FunAudioLLM/Fun-ASR-Nano-2512" - m, kwargs = FunASRNano.from_pretrained(model=model_dir, device="cuda:0") + device = ( + "cuda:0" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device) m.eval() wav_path = f"{kwargs['model_path']}/example/zh.mp3" diff --git a/model.py b/model.py index 2c9d7d2..06421c3 100644 --- a/model.py +++ b/model.py @@ -204,7 +204,9 @@ class FunASRNano(nn.Module): stats["batch_size_x_frames"] - stats["batch_size_real_frames"] ) - with torch.cuda.amp.autocast( + device_type = next(self.parameters()).device.type + with torch.autocast( + device_type=device_type if device_type in ["cuda", "mps"] else "cpu", enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype], ): @@ -624,8 +626,11 @@ class FunASRNano(nn.Module): llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype - with torch.cuda.amp.autocast( - enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] + device_type = torch.device(kwargs.get("device", "cuda")).type + with torch.autocast( + device_type=device_type if device_type in ["cuda", "mps"] else "cpu", + enabled=True if llm_dtype != "fp32" else False, + dtype=dtype_map[llm_dtype] ): label = contents["assistant"][-1] self.llm = self.llm.to(dtype_map[llm_dtype])