niteshade.utils.train_test_MNIST

niteshade.utils.train_test_MNIST(dir='datasets/', transform=None, val_size=None)

Function to load torchivisions’ MNIST dataset, splitted into train, test, and validation sets (the latter only if val_size != None).

Parameters
  • transform (torchvision.transforms) –

    Sequence of transformations to apply to the train and test sets. Default: transforms.Compose([torchvision.transforms.Normalize(

    (0.1307,), (0.3081,)), transforms.ToTensor()]))

  • val_size (float) – Value between 0 and 1 indicating the percentage of the training set that should be allocated to the validation set. (Default = 0.2)

Returns

Train inputs y_train (np.ndarray) : Train labels X_test (np.ndarray) : Test inputs y_test (np.ndarray) : Test labels X_val (np.ndarray) : Validation inputs (only if val_size != None) y_val (np.ndarray) : Validation labels (only if val_size != None)

Return type

X_train (np.ndarray)