4
model.py
4
model.py
@ -206,7 +206,7 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
device_type = next(self.parameters()).device.type
|
device_type = next(self.parameters()).device.type
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
|
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
|
||||||
enabled=True if self.llm_dtype != "fp32" else False,
|
enabled=True if self.llm_dtype != "fp32" else False,
|
||||||
dtype=dtype_map[self.llm_dtype],
|
dtype=dtype_map[self.llm_dtype],
|
||||||
):
|
):
|
||||||
@ -620,7 +620,7 @@ class FunASRNano(nn.Module):
|
|||||||
|
|
||||||
device_type = torch.device(kwargs.get("device", "cuda")).type
|
device_type = torch.device(kwargs.get("device", "cuda")).type
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type=device_type if device_type in ["cuda", "mps"] else "cpu",
|
device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu",
|
||||||
enabled=True if llm_dtype != "fp32" else False,
|
enabled=True if llm_dtype != "fp32" else False,
|
||||||
dtype=dtype_map[llm_dtype]
|
dtype=dtype_map[llm_dtype]
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user