diff --git a/model.py b/model.py index 657edcc..7200e2b 100644 --- a/model.py +++ b/model.py @@ -206,7 +206,7 @@ class FunASRNano(nn.Module): device_type = next(self.parameters()).device.type with torch.autocast( - device_type=device_type if device_type in ["cuda", "mps"] else "cpu", + device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu", enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype], ): @@ -620,7 +620,7 @@ class FunASRNano(nn.Module): device_type = torch.device(kwargs.get("device", "cuda")).type with torch.autocast( - device_type=device_type if device_type in ["cuda", "mps"] else "cpu", + device_type=device_type if device_type in ["cuda", "xpu", "mps"] else "cpu", enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] ):