forked from breadbread1984/wavenet-tf2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Sampler.py
48 lines (38 loc) · 1.82 KB
/
Sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/python3
import sys;
from os.path import exists, join;
import librosa;
import pandas as pd;
import tensorflow as tf;
from WaveNet import GCConv1D, calculate_receptive_field;
from create_dateset import mu_law_decode;
class Sampler(object):
def __init__(self, ):
if exists(join('model', 'wavenet.h5')):
self.wavenet = tf.keras.models.load_model(join('model', 'wavenet.h5'), compile = False, custom_objects = {'GCConv1D': GCConv1D});
else:
raise 'train the wavenet first before sampling audios';
self.receptive_field = calculate_receptive_field();
self.category = pd.read_pickle('category.pkl');
def sample(self, person_id, length = 10000):
class_id = self.category[self.category['person_id'] == person_id]['class_id'].iloc[0];
glob_cond = tf.reshape(class_id, (1,1)); # class_id.shape = (1,1)
inputs = tf.random.uniform((1, self.receptive_field, 1), minval = 0, maxval = 256, dtype = tf.int32); # inputs.shape = (1, receptive_field, 1)
samples = list();
for i in range(length):
outputs = self.wavenet([inputs, glob_cond]); # outputs.shape = (1, 1, 256)
index = tf.math.argmax(outputs, axis = -1); # index.shape = (1, 1, 1)
samples.append(index);
inputs = tf.concat([inputs[:, 1:, :], index], axis = 1); # inputs.shape = (1, receptive_field, 1);
samples = tf.squeeze(tf.concat(samples, axis = 1), axis = 0); # samples.shape = (length, 1)
audio = tf.constant([mu_law_decode(sample) for sample in samples], dtype = tf.float32);
return audio.numpy();
def list_person(self, ):
return self.category['person_id'];
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: " + sys.argv[0] + " <person_id>");
exit(1);
sampler = Sampler();
audio = sampler.sample(sys.argv[1]);
librosa.output.write_wave('sample.wav', audio, 16000);