niteshade.defence.KNN_Defender

class niteshade.defence.KNN_Defender(init_x, init_y, nearest_neighbours: int, confidence_threshold: float, one_hot=False)

Bases: niteshade.defence.PointModifierDefender

A KNN class, inheriting from the PointModifierDefender, that flips the labels of input points if the proportion of the most frequent label of nearest neighbours exceeds a threshold. A SKlearn KNeighborsClassifier is used to find nearest neighbours. The KNN_Defender is an implementation of a defence strategy discussed by Paudice, Andrea, et al. “Label Sanitization against Label Flipping Poisoning Attacks.” 2018.

Parameters
  • init_x (np.ndarray, torch.Tensor) – point data (shape (batch_size, data dimensionality))

  • init_y (np.ndarray, torch.Tensor) – label data (shape (batch_size,))

  • nearest_neighbours (int) – number of nearest neighbours to use for decisionmaking

  • confidence_threshold (float) – threshold to use for decisionmaking

  • one_hot (boolean) – boolean to indicate if labels are one-hot or not

__init__(init_x, init_y, nearest_neighbours: int, confidence_threshold: float, one_hot=False) None

Constructor method of KNN_Defender class. If the inputs are one-hot encoded, artificial integer labels are constructed to use the SKlearn classifier.

Methods

__init__(init_x, init_y, nearest_neighbours, ...)

Constructor method of KNN_Defender class.

defend(datapoints, input_labels, **kwargs)

The defend method for the KNN_defender.

defend(datapoints, input_labels, **kwargs)
The defend method for the KNN_defender.

For each incoming point, closest neighbours and their labels are found. If the proportion of the most frequent label in closest neighbours is higher than a threshold, then the label of the point is flipped to be the most frequent label of closest neighbours.

Parameters
  • datapoints (np.ndarray, torch.Tensor) – point data (shape (batch_size, data dimensionality)).

  • input_labels (np.ndarray, torch.Tensor) – label data (shape (batch_size,)).

Returns

datapoints (np.ndarray, torch.Tensor): point data (shape (batch_size, data dimensionality)), flipped_labels (np.ndarray, torch.Tensor): modified label data (shape (batch_size,)).

Return type

tuple (datapoints, flipped_labels)