niteshade.utils.train_test_cifar

niteshade.utils.train_test_cifar(dir='datasets/', transform=None, val_size=None)

Function to load torchvisions’ CIFAR10 dataset, splitted into train, test, and validation sets (the latter only if val_size != None).

Parameters
  • transform (torchvision.transforms) –

    Sequence of transformations to apply to the train and test sets. Default: transforms.Compose([transforms.RandomHorizontalFlip(),

    transforms.ToTensor()]))

  • val_size (float) – Value between 0 and 1 indicating the percentage of the training set that should be allocated to the validation set. (Default = 0.2)

Returns

Train inputs y_train (torch.Tensor) : Train labels X_test (torch.Tensor) : Test inputs y_test (torch.Tensor) : Test labels X_val (torch.Tensor) : Validation inputs (only if val_size != None) y_val (torch.Tensor) : Validation labels (only if val_size != None)

Return type

X_train (torch.Tensor)