# ECE3001 Project Part2: Speaker Identification

Updated on Mar. 26th, 2024

Assignment Maker: Yuejiao Xie, Lai Wei.

This notebook is recommended to run on Google Colab.
If you want to run it on your local machine, please make sure the required packages are installed. The datasets are downloaded and placed in the correct directory.

Follow the PDF description to complete the project.

## 1. Prepare the dataset & Install packages

## 1.1 Mount the Google Drive to this notebook.

**Note: You need to run this subsection (1.1) EVERY time you started this notebook.**

In [None]:
# Mount the google drive to this notebook
from google.colab import drive
drive.mount('/content/drive')

# Make a new directory in the drive, called "ECE3001_Project"
!mkdir -p /content/drive/MyDrive/ECE3001_Project

!pwd

## 1.2 Download and unzip the datasets

**Make sure your Google Drive has around 5GB free space.**

Run the following cell to:

Download the dataset to the Google Drive. It may take 1-2 minutes.

Unzip the dataset. It may take 3-4 minutes.


In [None]:
import os

# Download if necessary
if os.path.exists('/content/drive/MyDrive/ECE3001_Project/stu_dataset.zip'):
  print('The datasets are already at the Google Drive, skip this step...')
else:
  print('The datasets do NOT exist at the Google Drive, Downloading...')
  # Dataset link: https://drive.google.com/file/d/1-0YkEZ3-PXPR5Vfvnpei96kfAb0f3w3Y/view?usp=sharing
  !gdown --id 1-0YkEZ3-PXPR5Vfvnpei96kfAb0f3w3Y -O /content/drive/MyDrive/ECE3001_Project/

print('Unzipping the datasets...')
!unzip -o -q "/content/drive/MyDrive/ECE3001_Project/stu_dataset.zip" -d "/content"

## 2. Import packages

In [None]:
import torch
import argparse
from tqdm import tqdm
import datetime
import time
# from timm.utils import accuracy

import librosa
import numpy as np
import librosa.display
import math
import os
from torch.utils.data import Dataset, DataLoader
import IPython.display as ipd

import torch.nn as nn
from torchvision.models import vgg11, vgg11_bn, vgg13
from torchvision.models import resnet18

## 3. Define datasets & preprocessing helper functions

Hint: In most cases, you don't need to modify this part. Just run it.

In [None]:
# Loading data
class AudioDataset(Dataset):
    def __init__(self, data_dir, max_len, window_length, window_shift, use_stft):
        self.data_dir = data_dir
        self.file_list = os.listdir(data_dir)
        self.max_len = max_len
        self.window_shift = window_shift
        self.window_length = window_length
        self.use_stft = use_stft

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = os.path.join(self.data_dir, self.file_list[idx])
        wav_data = extract_hpss_features_sg(filename, max_length=self.max_len, window_length=self.window_length, window_shift=self.window_shift, use_stft=self.use_stft)
        wav_data = torch.tensor(wav_data)
        wav_data = wav_data.unsqueeze(0)

        # Parse label from filename (filename format: id1_filename.wav)
        label = self.file_list[idx].split('_')[0][2:]  # Extract label from filename
        label = torch.tensor([int(label)])

        return wav_data, label

def extract_hpss_features_sg(wav_path, max_length, window_length=320, window_shift=160, use_stft=True):
    """Extract Harmonic-Percussive Source Separation features.

    Args:
      wav_dir: string, directory of wavs.
      out_dir: string, directory to write out features.
      recompute: bool, if True recompute all features, if False skip existed
                 extracted features.
    """
    cnt = 0
    t1 = time.time()
    (audio, sr) = read_audio(wav_path)

    if audio.shape[0] == 0:
        print("File %s is corrupted!" % wav_path)
        raise ValueError
    else:
        # librosa.display.waveshow(audio, sr=sr)
        # plt.show()

        if use_stft: # compute stft
            spec = np.log(get_spectrogram(audio, window_length, window_shift) + 1e-8)
        else: # not use stft
            frame = 256
            split_num = math.floor(audio.shape[0] / frame)
            new_audio = np.split(audio[:split_num*frame], split_num)
            spec = np.stack(new_audio, axis=0).T

        spec = norm(spec)
        spec = spec.T
        spec = pad_trunc_seq(spec, max_length)

        # cnt += 1
    # print("Thread %d Extracting feature time: %s" % (i, (time.time() - t1)))
    return spec

def read_audio(path, target_fs=None):
    try :
        audio, fs = librosa.load(path, sr=None) # fs:sample rate
    except:
        print(path)

    if audio.ndim > 1:  # 维度>1，这里考虑双声道的情况，维度为2，在2个维度上取均值，变成单声道
        audio = np.mean(audio, axis=1)
    if target_fs is not None and fs != target_fs:
        audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs)  # 重采样输入信号，到目标采样频率
        fs = target_fs
    return audio, fs

def pad_trunc_seq(x, max_len):
    """Pad or truncate a sequence data to a fixed length.

    Args:
      x: ndarray, input sequence data.
      max_len: integer, length of sequence to be padded or truncated.

    Returns:
      ndarray, Padded or truncated input sequence data.
    """
    L = len(x)
    shape = x.shape
    if L < max_len:
        pad_shape = (max_len - L,) + shape[1:]
        pad = np.zeros(pad_shape)
        x_new = np.concatenate((x, pad), axis=0)
    else:
        x_new = x[0:max_len]

    return x_new

def get_spectrogram(wav, win_length, win_shift):
    D = librosa.stft(wav, n_fft=win_length, hop_length=win_shift, win_length=win_length, window='hamming')
    spect, phase = librosa.magphase(D)
    return spect


def norm(spec):
    mean = np.reshape(np.mean(spec, axis=1), (spec.shape[0],1))
    std = np.reshape(np.std(spec, axis=1), (spec.shape[0],1))
    spec = np.divide(np.subtract(spec,np.repeat(mean, spec.shape[1], axis=1)), np.repeat(std, spec.shape[1], axis=1))
    return spec

def get_one_wave(filename, args):
    """ Get one wave file, return the feature for model input.

    Args:
        filename: string, the path of the wave file.
        args: the args object, containing the parameters for feature extraction.
    """

    wav_data = extract_hpss_features_sg(filename, max_length=args.max_len, window_length=args.window_length, window_shift=args.window_shift, use_stft=args.use_stft)
    wav_data = torch.tensor(wav_data)
    wav_data = wav_data.unsqueeze(0)
    # to feed into the model, the shape should be (batch, channel, time, freq)
    wav_data = wav_data.unsqueeze(0)
    wav_data = wav_data.to(dtype=torch.float32)
    return wav_data

# Copied from timm/utils/metrics.py
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]


## 4. Define models

Hint: In most cases, you don't need to modify this part. Just run it.

But you can try to modify the model structure to improve the performance if you want.

In [None]:
# Loading Model
class vgg_base(nn.Module):
    def __init__(self, input_dim):
        super(vgg_base,self).__init__()
        self.vggmodel=vgg11(pretrained=False).features
        self.vggmodel[0]=nn.Conv2d(input_dim,64,kernel_size = 3, padding= 1)

    def forward(self, x):
        x = self.vggmodel(x)
        return x

class vggbn_base(nn.Module):
    def __init__(self, input_dim):
        super(vggbn_base,self).__init__()
        self.vggmodel=vgg11_bn(pretrained=False).features
        self.vggmodel[0]=nn.Conv2d(input_dim,64,kernel_size = 3, padding= 1)

    def forward(self, x):
        x = self.vggmodel(x)
        return x


class resnet_base(nn.Module):
    def __init__(self, input_dim):
        super(resnet_base,self).__init__()
        self.resnetmodel=resnet18(pretrained=False)
        self.resnetmodel.conv1=nn.Conv2d(input_dim,64,kernel_size = 7, stride=2,padding= 3,bias=False)

    def forward(self, x):
        x = self.resnetmodel(x)
        return x


class My_model(nn.Module):
    def __init__(self, input_dim=1, num_classes=93, model_base="vgg"):
        super(My_model,self).__init__()
        if model_base == "vgg":
            self.backbone=vgg_base(input_dim)
            self.linear = nn.Linear(in_features=512, out_features=num_classes)
        elif model_base == "vggbn":
            self.backbone=vggbn_base(input_dim)
            self.linear = nn.Linear(in_features=512, out_features=num_classes)
        elif model_base =="resnet":
            self.backbone=resnet_base(input_dim)
            self.linear = nn.Linear(in_features=1000, out_features=num_classes)
        else:
            raise ValueError("model_base should be vgg, vggbn, resnet or transformer")

        self.model_base=model_base
        self.avgpool = nn.AvgPool1d(kernel_size=200, stride=1)
        self.activate = nn.Softmax(dim=1)
        self.criteria = nn.CrossEntropyLoss()

    def forward(self, input, label=None):
        result = self.backbone(input)

        if self.model_base in ["vgg","vggbn"]:
            result = result.view(result.size(0), result.size(1), -1)
            result = self.avgpool(result)
            result = result.reshape(result.size(0), -1)

        result = self.linear(result)
        result = self.activate(result)

        _, pred_label = result.max(-1)

        if label is not None: # train
            loss = self.criteria(result, label.view(-1))
            return loss, result, pred_label
        else: # test
            return result, pred_label


## 5. Define procedures for training and testing

Hint: In most cases, you don't need to modify this part. Just run it.

But if you want to trace the training process (e.g., save validation accuracy at each epoch), you can modify the `train` function.

In [None]:
def valid(args, model):
    audio_testset = AudioDataset(args.test_path, args.max_len, args.window_length, args.window_shift, args.use_stft)
    test_data = DataLoader(audio_testset, batch_size=args.batchsize, shuffle=False, num_workers=2)

    model.eval()
    acc1_total = 0.
    acc5_total = 0.
    step = 0

    print("[Valid] : Start validation...")
    with torch.no_grad():
        for step, (x, label) in enumerate(tqdm(test_data)):
            x = x.to(dtype=torch.float32, device=args.device)
            label = label.to(args.device)
            result, pred = model(x)
            acc1, acc5 = accuracy(result, label.view(-1), topk=(1, 5))
            acc1, acc5 = acc1.item()/100, acc5.item()/100
            # loss_total += float(loss.item())
            acc1_total += acc1
            acc5_total += acc5

    print("[Valid] : Valid_acc1:{}, Valid_acc5: {}".format( acc1_total / (step+1), acc5_total / (step+1) ))
    return acc1_total / (step+1), acc5_total / (step+1)


def train(args, model):

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    best_acc1 = 0
    best_acc5 = 0

    audio_trainset = AudioDataset(args.train_path, args.max_len, args.window_length, args.window_shift, args.use_stft)
    print(f"[Train] : Length of training set: {len(audio_trainset)}")
    train_data = DataLoader(audio_trainset, batch_size=args.batchsize, shuffle=True, drop_last=True, num_workers=2)

    print("[Train] : Start training...")
    for epoch in range(args.epochs):
        print(f"[Train] : Training epoch {epoch}...")
        model.train()
        acc1_total = 0.
        acc5_total = 0.
        loss_total = 0.

        tbar = tqdm(train_data)
        for step, (x, label) in enumerate(tbar):
            x = x.to(dtype=torch.float32, device=args.device)
            label = label.to(args.device)
            optimizer.zero_grad()
            loss, result, pred = model(x, label)
            acc1, acc5 = accuracy(result, label.view(-1), topk=(1, 5))

            acc1, acc5 = acc1.item() / 100, acc5.item() / 100
            loss.backward()
            optimizer.step()
            acc1_total += acc1
            acc5_total += acc5
            loss_total += float(loss.item())

            if step % args.print_every == 0:
                tbar.set_postfix_str('epoch %d, step %d, step_loss %.4f, step_acc1 %.4f, step_acc5 %.4f' %
                                    (epoch, step, loss_total/(step+1), acc1_total/(step+1), acc5_total/(step+1)))

        acc1, acc5 = valid(args, model)
        # Hint: you may want to save the acc1 and acc5 for each epoch and plot them later
        # so you need to create a list before training loop, and append the acc1 and acc5 to the list
        # finally, you can return the list for plotting

        if acc1 > best_acc1:
            best_acc1 = acc1
            best_acc5 = acc5
            # pt_filename = args.model_base + "_best.pt"
            pt_filename = f"{args.model_base}_{'w' if args.use_stft else 'wo'}_stft_best.pt"
            pt_filename = args.save_path + '/' + pt_filename
            print('Achieve best acc1: %.4f, acc5: %.4f, epoch: %d. Saving to `%s`...' % (acc1, acc5, epoch, pt_filename))
            torch.save(model.state_dict(), pt_filename)




## 6. Start model training!

For this project, you need to at least run two experiments: one with `use_stft=True` and the other one with `use_stft=False`.

We recommend you to change the `model_base` to other models (e.g., `vggbn`), to see if the performance can be improved.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cpu':
    raise Warning("[From ECE3001 TA:] Your device is CPU, training will be very very slow. Please change the runtime to GPU.")

args = argparse.Namespace(
    device=device,                  # device, do not change

    ## path parameters
    train_path='./stu_dataset/train', # training dataset path, do not change usually
    test_path='./stu_dataset/test', # testing dataset path, do not change usually
    save_path='/content/drive/MyDrive/ECE3001_Project/',

    ## training parameters, you may change some of them
    epochs=20,                      # number of epochs to train
    print_every=10,                 # print training information every print_every steps
    batchsize=16,                   # batch size
    lr=1e-4,                        # learning rate
    model_base="resnet",            # model base: vgg, resnet, vggbn

    ## data processing parameters, you may change some of them
    max_len=800,                    # max length of audio
    window_shift=256,               # hop shift
    window_length=510,              # window length
    use_stft=True,                  # whether to use stft
)


model = My_model(num_classes=92, model_base=args.model_base)
model = model.to(dtype=torch.float32, device=args.device)

train(args, model)

## 7. Model evaluation

At here, you can dirrectly get the prediction results from the trained model.

Suppose you have trained the model and saved it (e.g., `resnet_w_stft_best.pt`).

You can set the `model_base` and the trained `model_path` below!

In [None]:
test_args = argparse.Namespace(
    device=device,                  # device, do not change
    batchsize=16,

    ## path parameters
    train_path='./stu_dataset/train', # training dataset path, do not change usually
    test_path='./stu_dataset/test', # testing dataset path, do not change usually
    save_path='/content/drive/MyDrive/ECE3001_Project/',

    ## data processing parameters, you may change some of them
    max_len=800,                    # max length of audio
    window_shift=256,               # hop shift
    window_length=510,              # window length
    use_stft=True,                  # whether to use stft


    model_base="resnet",            # model base: vgg, resnet, vggbn
    ## The model to be evaluated's path.
    model_path='resnet_w_stft_best.pt'
)

model = My_model(num_classes=92, model_base=test_args.model_base)
model.load_state_dict(torch.load(test_args.save_path + test_args.model_path))
model.eval()

print('Model (type "%s") loaded from `%s` successfully!' % (test_args.model_base, test_args.model_path))

# Simply run validation on the testing dataset
model = model.to(device)
valid(test_args, model)