Merge pull request #24 from GeorgeCaoJ/feat-mps-support

Enhance device compatibility to support MPS for MacOS
This commit is contained in:
彭震东
2025-12-19 13:44:56 +08:00
committed by GitHub
3 changed files with 23 additions and 6 deletions

View File

@ -1,13 +1,19 @@
import torch
from funasr import AutoModel from funasr import AutoModel
def main(): def main():
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512" model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
device = (
"cuda:0"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
model = AutoModel( model = AutoModel(
model=model_dir, model=model_dir,
trust_remote_code=True, trust_remote_code=True,
remote_code="./model.py", remote_code="./model.py",
device="cuda:0", device=device,
) )
wav_path = f"{model.model_path}/example/zh.mp3" wav_path = f"{model.model_path}/example/zh.mp3"
@ -28,7 +34,7 @@ def main():
vad_model="fsmn-vad", vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000}, vad_kwargs={"max_single_segment_time": 30000},
remote_code="./model.py", remote_code="./model.py",
device="cuda:0", device=device,
) )
res = model.generate(input=[wav_path], cache={}, batch_size=1) res = model.generate(input=[wav_path], cache={}, batch_size=1)
text = res[0]["text"] text = res[0]["text"]

View File

@ -1,9 +1,15 @@
import torch
from model import FunASRNano from model import FunASRNano
def main(): def main():
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512" model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device="cuda:0") device = (
"cuda:0"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
m.eval() m.eval()
wav_path = f"{kwargs['model_path']}/example/zh.mp3" wav_path = f"{kwargs['model_path']}/example/zh.mp3"

View File

@ -204,7 +204,9 @@ class FunASRNano(nn.Module):
stats["batch_size_x_frames"] - stats["batch_size_real_frames"] stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
) )
with torch.cuda.amp.autocast( device_type = next(self.parameters()).device.type
with torch.autocast(
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
enabled=True if self.llm_dtype != "fp32" else False, enabled=True if self.llm_dtype != "fp32" else False,
dtype=dtype_map[self.llm_dtype], dtype=dtype_map[self.llm_dtype],
): ):
@ -624,8 +626,11 @@ class FunASRNano(nn.Module):
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
with torch.cuda.amp.autocast( device_type = torch.device(kwargs.get("device", "cuda")).type
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] with torch.autocast(
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
enabled=True if llm_dtype != "fp32" else False,
dtype=dtype_map[llm_dtype]
): ):
label = contents["assistant"][-1] label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype]) self.llm = self.llm.to(dtype_map[llm_dtype])