Python基于自相关的基音提取



  • 通过对音频波形文件(PCM格式)进行自相关的计算,检测出音频基音所在的频率,最后通过查表将频率转换成MIDI格式的音符。


    # -*- coding: utf-8 -*-
    import wave
    import pyaudio
    import numpy as np
    import math
    import matplotlib.pyplot as pyplot
    from numpy.fft import rfft, rfftfreq
    import time
    import scipy.signal as signal
    import midi
    from midiutil import MIDIFile
    import pygame
    
    def play_music(music_file):
      clock = pygame.time.Clock()
      try:
        pygame.mixer.music.load(music_file)
        print("Music file {} loaded!".format(music_file))
      except pygame.error:
        print("File {} not found! {}".format(music_file, pygame.get_error()))
        return
      pygame.mixer.music.play()
      # check if playback has finished
      while pygame.mixer.music.get_busy():
        clock.tick(30)
    
    freq = 44100  # audio CD quality
    bitsize = -16  # unsigned 16 bit
    channels = 2  # 1 is mono, 2 is stereo
    buffer = 2048  # number of samples (experiment to get right sound)
    pygame.mixer.init(freq, bitsize, channels, buffer)
    # optional volume 0 to 1.0
    pygame.mixer.music.set_volume(0.8)
    
    
    midi_hz = [27.500, 29.135, 30.868, 32.703, 34.648, 36.708, 38.891,
                    41.203, 43.654, 46.249, 48.999, 51.913, 55.0, 58.270,
                    61.735, 65.406, 69.269, 73.416, 77.782, 82.407, 87.307,
                    92.499, 97.999, 103.83 ,110.00, 116.34, 123.47, 130.81, 138.59,
                    146.83, 155.36, 164.81, 174.61, 185.0, 196.0, 207.65,
                    220.0, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13,
                    329.63, 349.23, 369.99, 392.0, 415.3, 440.0, 466.16,
                    493.88, 523.25, 554.37, 587.33, 622.23, 659.26, 698.46,
                    739.99, 783.99, 830.61, 880.0, 932.33, 987.77, 1046.5,
                    1108.7, 1174.7, 1244.5, 1318.5, 1396.9, 1480.0, 1568.0,
                    1661.2, 1760.0,1864.7, 1975.5, 2093.0, 2217.5, 2349.3,
                    2489.0, 2637.0, 2793.0, 2960.0, 3136.0, 3322.4, 3520.0,
                    3729.3, 3951.1, 4186.0]
    midi_num = [x for x in range(21, 21 + 88)]
    
    
    class Wave:
        def __init__(self, filename):
            self.f = wave.open(filename,"rb")
            self.nchannels, self.sampwidth, self.framerate, self.nframes = self.f.getparams()[0:4]
            print self.nchannels, self.sampwidth, self.framerate, self.nframes
            self.data = self.f.readframes(self.nframes)
            self.f.close()
            self.frame_size = self.sampwidth * self.nchannels
            self.frame_time = 1.0 / self.framerate
            self.len = self.nframes * self.frame_size
            self.player = pyaudio.PyAudio()
            self.stream = self.player.open(format=self.player.get_format_from_width(self.sampwidth),
                    channels=self.nchannels,
                    rate=self.framerate,
                    output=True)
    
        def play(self):
            pos = 0
            while pos + self.frame_size * self.framerate < self.len:
                print "current pos -> {}\t/{}".format(pos + self.frame_size * self.framerate, self.len)
                self.stream.write(self.data[pos:pos + self.frame_size * self.framerate])
                pos += self.frame_size * self.framerate
            print "current pos -> {}\t/{} over!".format(pos + self.frame_size * self.framerate, self.len)
    
    
        def make_spectrum(self, start, duration):
            dtype_map = {1:np.int8, 2:np.int16, 3:'special', 4:np.int32}
            ys = None
            if self.sampwidth not in dtype_map:
                raise ValueError('sampwidth %d unknown' % self.sampwidth)
            if self.sampwidth == 3:
                xs = np.fromstring(self.data, dtype=np.int8).astype(np.int32)
                ys = (xs[2::3] * 256 + xs[1::3]) * 256 + xs[0::3]
            else:
                ys = np.fromstring(self.data, dtype=dtype_map[self.sampwidth])
            print ys
            ys.shape = -1,2
            print ys
            ys = ys.T[0]
            print len(ys)
            start_frame = int(start * self.framerate)
            end_frame = int(start_frame + duration * self.framerate)
            ys = ys[start_frame:end_frame]
            n = len(ys)
            d = 1.0 / self.framerate
            fs = np.fft.rfftfreq(n, d)
            hs = np.fft.rfft(ys)
            hs = hs / (len(hs)) # 归一化处理
            return Spectrum(hs, fs, self.framerate)
    
        def ploat(self):
            wave_data = None
            if self.sampwidth == 1:
                wave_data = np.fromstring(self.data, dtype=np.int8)
            elif self.sampwidth == 2:
                wave_data = np.fromstring(self.data, dtype=np.int16)
            elif self.sampwidth == 3:
                data = []
                for i in range(len(self.data)):
                    if i % 3 == 0: data.append("\0")
                    data.append(self.data[i])
                wave_data = np.fromstring("".join(data), dtype=np.int32)
    
            t = np.arange(0, self.nframes) * self.frame_time
            if self.nchannels == 1:
                wave_data.shape = -1,1
                wave_data = wave_data.T
                pyplot.subplot(1,1,1)
                d = wave_data[0] * 1.0 / max(abs(wave_data[0]))
                pyplot.plot(t, wave_data[0], c="darkorange")
            elif self.nchannels == 2:
                wave_data.shape = -1,2
                wave_data = wave_data.T
                pyplot.subplot(2,1,1)
                d = wave_data[0] * 1.0 / max(abs(wave_data[0]))
                pyplot.plot(t, d, c="darkorange")
                pyplot.subplot(2,1,2)
                d = wave_data[1] * 1.0 / max(abs(wave_data[1]))
                pyplot.plot(t, d, c="g")
            pyplot.xlabel("time (seconds)")
    
    
    def corrcoef(xs, ys):
        return np.corrcoef(xs, ys, ddof=0)[0, 1]
    
    def serial_corr(wave, lag=1):
        n = len(wave)
        y1 = wave[lag:]
        y2 = wave[:n-lag]
        corr = corrcoef(y1, y2)
        return corr
    
    def autocorr(wave):
        lags = range(len(wave)//2)
        corrs = [serial_corr(wave, lag) for lag in lags]
        return lags, corrs
    
    # def autocorr(wave):
    #     lags = range(len(wave))
    #     corrs = np.correlate(wave, wave, mode='same')
    #     return lags, corrs
    
    def GetMidi(hz):
        for index,item in enumerate(midi_hz):
            if hz < midi_hz[0]: return 0
            if hz < item: return midi_num[index-1]
        return 0
    
    def ACFPitch(wave):
        wave_data = np.fromstring(wave.data, dtype=np.int16)
        # b, a = signal.butter(8, 0.02, 'lowpass')
        # wave_data = signal.filtfilt(b, a, wave_data)
        wave_data = signal.medfilt(wave_data, 401)
        wave_data = signal.detrend(wave_data,type='linear')
        frame_size = 4096
        frame_time = round(frame_size / 44100.0, 3)
        step_total = len(wave_data) / frame_size
    
        MyMIDI = MIDIFile(1)
        MyMIDI.addTempo(0, 0, 120)
        degrees = []
        for i in range(step_total-1):
            start = i * frame_size
            lags, corrs = autocorr(wave_data[start:start+frame_size])
            width = np.arange(10,30)
            peaks = signal.find_peaks_cwt(corrs, width)[1:]
            data = [corrs[x] for x in peaks]
            md = 0
            if data:
                index = np.argmax(data)
                hz = 1.0 / (float(lags[peaks[index]]) / 44100)
                md = GetMidi(hz)
            degrees.append(md)
    
        count = 1
        time = 0
        for i in range(len(degrees)-1):
            if degrees[i] == degrees[i+1]: count += 1
            else:
                if degrees[i] !=0:
                    MyMIDI.addNote(0, 0, degrees[i], time, count * frame_time, 100)
                count = 1
            time += frame_time
        if degrees[-1] != 0:
            MyMIDI.addNote(0, 0, degrees[-1], time, count * frame_time, 100)
    
        with open("~/test.mid", "wb") as output_file:
            MyMIDI.writeFile(output_file)
    
    if __name__ == "__main__":
        w = Wave("~/test.wav")
        w.play()
        ACFPitch(w)
        play_music("~/test.mid")