diff --git a/tensorboardX/summary.py b/tensorboardX/summary.py index 27d99ea..f5bf234 100644 --- a/tensorboardX/summary.py +++ b/tensorboardX/summary.py @@ -373,36 +373,24 @@ def make_video(tensor, fps): def audio(tag, tensor, sample_rate=44100): tensor = make_np(tensor) - tensor = tensor.squeeze() if abs(tensor).max() > 1: print('warning: audio amplitude out of range, auto clipped.') tensor = tensor.clip(-1, 1) - assert(tensor.ndim == 1), 'input tensor should be 1 dimensional.' - - tensor_list = [int(32767.0 * x) for x in tensor] + assert(tensor.ndim == 2), 'input tensor should be 2 dimensional.' + length_frames, num_channels = tensor.shape + assert num_channels == 1 or num_channels == 2, f'Expected 1/2 channels, got {num_channels}' + import soundfile import io - import wave - import struct - fio = io.BytesIO() - Wave_write = wave.open(fio, 'wb') - Wave_write.setnchannels(1) - Wave_write.setsampwidth(2) - Wave_write.setframerate(sample_rate) - tensor_enc = b'' - tensor_enc += struct.pack("<" + "h" * len(tensor_list), *tensor_list) - - Wave_write.writeframes(tensor_enc) - Wave_write.close() - audio_string = fio.getvalue() - fio.close() + with io.BytesIO() as fio: + soundfile.write(fio, tensor, samplerate=sample_rate, format='wav') + audio_string = fio.getvalue() audio = Summary.Audio(sample_rate=sample_rate, - num_channels=1, - length_frames=len(tensor_list), + num_channels=num_channels, + length_frames=length_frames, encoded_audio_string=audio_string, content_type='audio/wav') return Summary(value=[Summary.Value(tag=tag, audio=audio)]) - def custom_scalars(layout): categoriesnames = layout.keys() categories = [] diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py index 06337a7..58d57a1 100644 --- a/tensorboardX/writer.py +++ b/tensorboardX/writer.py @@ -716,7 +716,7 @@ class SummaryWriter(object): sample_rate (int): sample rate in Hz walltime (float): Optional override default walltime (time.time()) of event Shape: - snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. + snd_tensor: :math:`(L, c)`. The values should lie between [-1, 1]. """ if self._check_caffe2_blob(snd_tensor): snd_tensor = workspace.FetchBlob(snd_tensor)