niteshade.defence.KNN_Defender
- class niteshade.defence.KNN_Defender(init_x, init_y, nearest_neighbours: int, confidence_threshold: float, one_hot=False)
Bases:
niteshade.defence.PointModifierDefender
A KNN class, inheriting from the PointModifierDefender, that flips the labels of input points if the proportion of the most frequent label of nearest neighbours exceeds a threshold. A SKlearn KNeighborsClassifier is used to find nearest neighbours. The KNN_Defender is an implementation of a defence strategy discussed by Paudice, Andrea, et al. “Label Sanitization against Label Flipping Poisoning Attacks.” 2018.
- Parameters
init_x (np.ndarray, torch.Tensor) – point data (shape (batch_size, data dimensionality))
init_y (np.ndarray, torch.Tensor) – label data (shape (batch_size,))
nearest_neighbours (int) – number of nearest neighbours to use for decisionmaking
confidence_threshold (float) – threshold to use for decisionmaking
one_hot (boolean) – boolean to indicate if labels are one-hot or not
- __init__(init_x, init_y, nearest_neighbours: int, confidence_threshold: float, one_hot=False) None
Constructor method of KNN_Defender class. If the inputs are one-hot encoded, artificial integer labels are constructed to use the SKlearn classifier.
Methods
__init__
(init_x, init_y, nearest_neighbours, ...)Constructor method of KNN_Defender class.
defend
(datapoints, input_labels, **kwargs)The defend method for the KNN_defender.
- defend(datapoints, input_labels, **kwargs)
- The defend method for the KNN_defender.
For each incoming point, closest neighbours and their labels are found. If the proportion of the most frequent label in closest neighbours is higher than a threshold, then the label of the point is flipped to be the most frequent label of closest neighbours.
- Parameters
datapoints (np.ndarray, torch.Tensor) – point data (shape (batch_size, data dimensionality)).
input_labels (np.ndarray, torch.Tensor) – label data (shape (batch_size,)).
- Returns
datapoints (np.ndarray, torch.Tensor): point data (shape (batch_size, data dimensionality)), flipped_labels (np.ndarray, torch.Tensor): modified label data (shape (batch_size,)).
- Return type
tuple (datapoints, flipped_labels)