58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
from itertools import groupby
|
|
|
|
import soundfile as sf
|
|
import torch
|
|
import torchaudio
|
|
import torchaudio.functional as F
|
|
|
|
|
|
def load_audio(wav_path, rate: int = None, offset: float = 0, duration: float = None):
|
|
with sf.SoundFile(wav_path) as f:
|
|
start_frame = int(offset * f.samplerate)
|
|
if duration is None:
|
|
frames_to_read = f.frames - start_frame
|
|
else:
|
|
frames_to_read = int(duration * f.samplerate)
|
|
f.seek(start_frame)
|
|
audio_data = f.read(frames_to_read, dtype="float32")
|
|
audio_tensor = torch.from_numpy(audio_data)
|
|
if rate is not None and f.samplerate != rate:
|
|
if audio_tensor.ndim == 1:
|
|
audio_tensor = audio_tensor.unsqueeze(0)
|
|
else:
|
|
audio_tensor = audio_tensor.T
|
|
resampler = torchaudio.transforms.Resample(orig_freq=f.samplerate, new_freq=rate)
|
|
audio_tensor = resampler(audio_tensor)
|
|
if audio_tensor.shape[0] == 1:
|
|
audio_tensor = audio_tensor.squeeze(0)
|
|
return audio_tensor, rate if rate is not None else f.samplerate
|
|
|
|
|
|
def forced_align(log_probs: torch.Tensor, targets: torch.Tensor, blank: int = 0):
|
|
items = []
|
|
try:
|
|
# The current version only supports batch_size==1.
|
|
log_probs, targets = log_probs.unsqueeze(0).cpu(), targets.unsqueeze(0).cpu()
|
|
assert log_probs.shape[1] >= targets.shape[1]
|
|
alignments, scores = F.forced_align(log_probs, targets, blank=blank)
|
|
alignments, scores = alignments[0], torch.exp(scores[0]).tolist()
|
|
# use enumerate to keep track of the original indices, then group by token value
|
|
for token, group in groupby(enumerate(alignments), key=lambda item: item[1]):
|
|
if token == blank:
|
|
continue
|
|
group = list(group)
|
|
start = group[0][0]
|
|
end = start + len(group)
|
|
score = max(scores[start:end])
|
|
items.append(
|
|
{
|
|
"token": token.item(),
|
|
"start_time": start,
|
|
"end_time": end,
|
|
"score": round(score, 3),
|
|
}
|
|
)
|
|
except:
|
|
pass
|
|
return items
|