niteshade.utils.train_test_cifar
- niteshade.utils.train_test_cifar(dir='datasets/', transform=None, val_size=None)
Function to load torchvisions’ CIFAR10 dataset, splitted into train, test, and validation sets (the latter only if val_size != None).
- Parameters
transform (torchvision.transforms) –
Sequence of transformations to apply to the train and test sets. Default: transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.ToTensor()]))
val_size (float) – Value between 0 and 1 indicating the percentage of the training set that should be allocated to the validation set. (Default = 0.2)
- Returns
Train inputs y_train (torch.Tensor) : Train labels X_test (torch.Tensor) : Test inputs y_test (torch.Tensor) : Test labels X_val (torch.Tensor) : Validation inputs (only if val_size != None) y_val (torch.Tensor) : Validation labels (only if val_size != None)
- Return type
X_train (torch.Tensor)