diff --git a/versions/2022/supervised/python/kws-infer.py b/versions/2022/supervised/python/kws-infer.py index 8eb7a55..a98d9b0 100644 --- a/versions/2022/supervised/python/kws-infer.py +++ b/versions/2022/supervised/python/kws-infer.py @@ -8,12 +8,19 @@ To use microphone input, run: python3 kws-infer.py --gui + On RPi 4: + python3 kws-infer.py --rpi --gui + Dependencies: pip3 install pysimplegui pip3 install sounddevice + pip3 install librosa + + sudo apt-get install libasound2-dev libportaudio2 Inference time: 0.03 sec Quad Core Intel i7 2.3GHz + 0.09 sec on RPi 4 ''' @@ -38,6 +45,7 @@ def get_args(): parser.add_argument("--wav-file", type=str, default=None) parser.add_argument("--model-path", type=str, default="resnet18-kws-best-acc.pt") parser.add_argument("--gui", default=False, action="store_true") + parser.add_argument("--rpi", default=False, action="store_true") args = parser.parse_args() return args @@ -82,18 +90,6 @@ def get_args(): if not args.gui: waveform, sample_rate = torchaudio.load(wav_file) - - #wav = np.expand_dims(waveform.numpy().squeeze(), axis=1) - - #print("Shape:", wav.shape) - #sd.default.samplerate = sample_rate - #sd.default.channels = 1 - #sd.play(wav, blocking=True) - - #print("Shape:", waveform.shape) - #sd.fs = sample_rate - #sd.channel_count = waveform.shape[1] - #sd.play(waveform.numpy(), sample_rate) transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_fft=args.n_fft, @@ -127,8 +123,17 @@ def get_args(): if waveform.max() > 1.0: continue start_time = time.time() - waveform = torch.from_numpy(waveform).unsqueeze(0) - mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max)) + if args.rpi: + waveform = torch.FloatTensor(waveform.tolist()) + mel = np.array(transform(waveform).squeeze().tolist()) + mel = librosa.power_to_db(mel, ref=np.max).tolist() + + mel = torch.FloatTensor(mel) + mel = mel.unsqueeze(0) + + else: + waveform = torch.from_numpy(waveform).unsqueeze(0) + mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max)) mel = mel.unsqueeze(0) pred = scripted_module(mel) max_prob = pred.max()