Merge pull request #24 from GeorgeCaoJ/feat-mps-support
Enhance device compatibility to support MPS for MacOS
This commit is contained in:
10
demo1.py
10
demo1.py
@ -1,13 +1,19 @@
|
|||||||
|
import torch
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
|
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
|
||||||
|
device = (
|
||||||
|
"cuda:0"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
)
|
||||||
model = AutoModel(
|
model = AutoModel(
|
||||||
model=model_dir,
|
model=model_dir,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
remote_code="./model.py",
|
remote_code="./model.py",
|
||||||
device="cuda:0",
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
wav_path = f"{model.model_path}/example/zh.mp3"
|
wav_path = f"{model.model_path}/example/zh.mp3"
|
||||||
@ -28,7 +34,7 @@ def main():
|
|||||||
vad_model="fsmn-vad",
|
vad_model="fsmn-vad",
|
||||||
vad_kwargs={"max_single_segment_time": 30000},
|
vad_kwargs={"max_single_segment_time": 30000},
|
||||||
remote_code="./model.py",
|
remote_code="./model.py",
|
||||||
device="cuda:0",
|
device=device,
|
||||||
)
|
)
|
||||||
res = model.generate(input=[wav_path], cache={}, batch_size=1)
|
res = model.generate(input=[wav_path], cache={}, batch_size=1)
|
||||||
text = res[0]["text"]
|
text = res[0]["text"]
|
||||||
|
|||||||
8
demo2.py
8
demo2.py
@ -1,9 +1,15 @@
|
|||||||
|
import torch
|
||||||
from model import FunASRNano
|
from model import FunASRNano
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
|
model_dir = "FunAudioLLM/Fun-ASR-Nano-2512"
|
||||||
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device="cuda:0")
|
device = (
|
||||||
|
"cuda:0"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
m, kwargs = FunASRNano.from_pretrained(model=model_dir, device=device)
|
||||||
m.eval()
|
m.eval()
|
||||||
|
|
||||||
wav_path = f"{kwargs['model_path']}/example/zh.mp3"
|
wav_path = f"{kwargs['model_path']}/example/zh.mp3"
|
||||||
|
|||||||
11
model.py
11
model.py
@ -204,7 +204,9 @@ class FunASRNano(nn.Module):
|
|||||||
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(
|
device_type = next(self.parameters()).device.type
|
||||||
|
with torch.autocast(
|
||||||
|
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
|
||||||
enabled=True if self.llm_dtype != "fp32" else False,
|
enabled=True if self.llm_dtype != "fp32" else False,
|
||||||
dtype=dtype_map[self.llm_dtype],
|
dtype=dtype_map[self.llm_dtype],
|
||||||
):
|
):
|
||||||
@ -624,8 +626,11 @@ class FunASRNano(nn.Module):
|
|||||||
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
||||||
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
|
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(
|
device_type = torch.device(kwargs.get("device", "cuda")).type
|
||||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
with torch.autocast(
|
||||||
|
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
|
||||||
|
enabled=True if llm_dtype != "fp32" else False,
|
||||||
|
dtype=dtype_map[llm_dtype]
|
||||||
):
|
):
|
||||||
label = contents["assistant"][-1]
|
label = contents["assistant"][-1]
|
||||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||||
|
|||||||
Reference in New Issue
Block a user