通过对音频波形文件(PCM格式)进行自相关的计算,检测出音频基音所在的频率,最后通过查表将频率转换成MIDI格式的音符。


# -*- coding: utf-8 -*-
import wave
import pyaudio
import numpy as np
import math
import matplotlib.pyplot as pyplot
from numpy.fft import rfft, rfftfreq
import time
import scipy.signal as signal
import midi
from midiutil import MIDIFile
import pygame

def play_music(music_file):
  clock = pygame.time.Clock()
  try:
    pygame.mixer.music.load(music_file)
    print("Music file {} loaded!".format(music_file))
  except pygame.error:
    print("File {} not found! {}".format(music_file, pygame.get_error()))
    return
  pygame.mixer.music.play()
  # check if playback has finished
  while pygame.mixer.music.get_busy():
    clock.tick(30)

freq = 44100  # audio CD quality
bitsize = -16  # unsigned 16 bit
channels = 2  # 1 is mono, 2 is stereo
buffer = 2048  # number of samples (experiment to get right sound)
pygame.mixer.init(freq, bitsize, channels, buffer)
# optional volume 0 to 1.0
pygame.mixer.music.set_volume(0.8)


midi_hz = [27.500, 29.135, 30.868, 32.703, 34.648, 36.708, 38.891,
                41.203, 43.654, 46.249, 48.999, 51.913, 55.0, 58.270,
                61.735, 65.406, 69.269, 73.416, 77.782, 82.407, 87.307,
                92.499, 97.999, 103.83 ,110.00, 116.34, 123.47, 130.81, 138.59,
                146.83, 155.36, 164.81, 174.61, 185.0, 196.0, 207.65,
                220.0, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13,
                329.63, 349.23, 369.99, 392.0, 415.3, 440.0, 466.16,
                493.88, 523.25, 554.37, 587.33, 622.23, 659.26, 698.46,
                739.99, 783.99, 830.61, 880.0, 932.33, 987.77, 1046.5,
                1108.7, 1174.7, 1244.5, 1318.5, 1396.9, 1480.0, 1568.0,
                1661.2, 1760.0,1864.7, 1975.5, 2093.0, 2217.5, 2349.3,
                2489.0, 2637.0, 2793.0, 2960.0, 3136.0, 3322.4, 3520.0,
                3729.3, 3951.1, 4186.0]
midi_num = [x for x in range(21, 21 + 88)]


class Wave:
    def __init__(self, filename):
        self.f = wave.open(filename,"rb")
        self.nchannels, self.sampwidth, self.framerate, self.nframes = self.f.getparams()[0:4]
        print self.nchannels, self.sampwidth, self.framerate, self.nframes
        self.data = self.f.readframes(self.nframes)
        self.f.close()
        self.frame_size = self.sampwidth * self.nchannels
        self.frame_time = 1.0 / self.framerate
        self.len = self.nframes * self.frame_size
        self.player = pyaudio.PyAudio()
        self.stream = self.player.open(format=self.player.get_format_from_width(self.sampwidth),
                channels=self.nchannels,
                rate=self.framerate,
                output=True)

    def play(self):
        pos = 0
        while pos + self.frame_size * self.framerate < self.len:
            print "current pos -> {}\t/{}".format(pos + self.frame_size * self.framerate, self.len)
            self.stream.write(self.data[pos:pos + self.frame_size * self.framerate])
            pos += self.frame_size * self.framerate
        print "current pos -> {}\t/{} over!".format(pos + self.frame_size * self.framerate, self.len)


    def make_spectrum(self, start, duration):
        dtype_map = {1:np.int8, 2:np.int16, 3:'special', 4:np.int32}
        ys = None
        if self.sampwidth not in dtype_map:
            raise ValueError('sampwidth %d unknown' % self.sampwidth)
        if self.sampwidth == 3:
            xs = np.fromstring(self.data, dtype=np.int8).astype(np.int32)
            ys = (xs[2::3] * 256 + xs[1::3]) * 256 + xs[0::3]
        else:
            ys = np.fromstring(self.data, dtype=dtype_map[self.sampwidth])
        print ys
        ys.shape = -1,2
        print ys
        ys = ys.T[0]
        print len(ys)
        start_frame = int(start * self.framerate)
        end_frame = int(start_frame + duration * self.framerate)
        ys = ys[start_frame:end_frame]
        n = len(ys)
        d = 1.0 / self.framerate
        fs = np.fft.rfftfreq(n, d)
        hs = np.fft.rfft(ys)
        hs = hs / (len(hs)) # 归一化处理
        return Spectrum(hs, fs, self.framerate)

    def ploat(self):
        wave_data = None
        if self.sampwidth == 1:
            wave_data = np.fromstring(self.data, dtype=np.int8)
        elif self.sampwidth == 2:
            wave_data = np.fromstring(self.data, dtype=np.int16)
        elif self.sampwidth == 3:
            data = []
            for i in range(len(self.data)):
                if i % 3 == 0: data.append("\0")
                data.append(self.data[i])
            wave_data = np.fromstring("".join(data), dtype=np.int32)

        t = np.arange(0, self.nframes) * self.frame_time
        if self.nchannels == 1:
            wave_data.shape = -1,1
            wave_data = wave_data.T
            pyplot.subplot(1,1,1)
            d = wave_data[0] * 1.0 / max(abs(wave_data[0]))
            pyplot.plot(t, wave_data[0], c="darkorange")
        elif self.nchannels == 2:
            wave_data.shape = -1,2
            wave_data = wave_data.T
            pyplot.subplot(2,1,1)
            d = wave_data[0] * 1.0 / max(abs(wave_data[0]))
            pyplot.plot(t, d, c="darkorange")
            pyplot.subplot(2,1,2)
            d = wave_data[1] * 1.0 / max(abs(wave_data[1]))
            pyplot.plot(t, d, c="g")
        pyplot.xlabel("time (seconds)")


def corrcoef(xs, ys):
    return np.corrcoef(xs, ys, ddof=0)[0, 1]

def serial_corr(wave, lag=1):
    n = len(wave)
    y1 = wave[lag:]
    y2 = wave[:n-lag]
    corr = corrcoef(y1, y2)
    return corr

def autocorr(wave):
    lags = range(len(wave)//2)
    corrs = [serial_corr(wave, lag) for lag in lags]
    return lags, corrs

# def autocorr(wave):
#     lags = range(len(wave))
#     corrs = np.correlate(wave, wave, mode='same')
#     return lags, corrs

def GetMidi(hz):
    for index,item in enumerate(midi_hz):
        if hz < midi_hz[0]: return 0
        if hz < item: return midi_num[index-1]
    return 0

def ACFPitch(wave):
    wave_data = np.fromstring(wave.data, dtype=np.int16)
    # b, a = signal.butter(8, 0.02, 'lowpass')
    # wave_data = signal.filtfilt(b, a, wave_data)
    wave_data = signal.medfilt(wave_data, 401)
    wave_data = signal.detrend(wave_data,type='linear')
    frame_size = 4096
    frame_time = round(frame_size / 44100.0, 3)
    step_total = len(wave_data) / frame_size

    MyMIDI = MIDIFile(1)
    MyMIDI.addTempo(0, 0, 120)
    degrees = []
    for i in range(step_total-1):
        start = i * frame_size
        lags, corrs = autocorr(wave_data[start:start+frame_size])
        width = np.arange(10,30)
        peaks = signal.find_peaks_cwt(corrs, width)[1:]
        data = [corrs[x] for x in peaks]
        md = 0
        if data:
            index = np.argmax(data)
            hz = 1.0 / (float(lags[peaks[index]]) / 44100)
            md = GetMidi(hz)
        degrees.append(md)

    count = 1
    time = 0
    for i in range(len(degrees)-1):
        if degrees[i] == degrees[i+1]: count += 1
        else:
            if degrees[i] !=0:
                MyMIDI.addNote(0, 0, degrees[i], time, count * frame_time, 100)
            count = 1
        time += frame_time
    if degrees[-1] != 0:
        MyMIDI.addNote(0, 0, degrees[-1], time, count * frame_time, 100)

    with open("~/test.mid", "wb") as output_file:
        MyMIDI.writeFile(output_file)

if __name__ == "__main__":
    w = Wave("~/test.wav")
    w.play()
    ACFPitch(w)
    play_music("~/test.mid")