mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-02 16:22:37 +02:00
Support for many new model output types (#93)
* General model wrappers and methods supporting multi-label classifiers * Support for pytorch multi-label binary classifier * New model output types + single implementation of score method that supports multiple output types. * Anonymization with pytorch multi-output binary model * Support for multi-label binary models in minimizer. * Support for multi-label logits/probabilities --------- Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
e00535d120
commit
57e38ea4fa
13 changed files with 913 additions and 172 deletions
|
|
@ -233,7 +233,7 @@ class ArrayDataset(Dataset):
|
|||
raise ValueError("The supplied features are not the same as in the data features")
|
||||
self.features_names = x.columns.to_list()
|
||||
|
||||
if self._y is not None and len(self._x) != len(self._y):
|
||||
if self._y is not None and self._x.shape[0] != self._y.shape[0]:
|
||||
raise ValueError("Non equivalent lengths of x and y")
|
||||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
|
|
@ -266,6 +266,8 @@ class DatasetWithPredictions(Dataset):
|
|||
Dataset that is based on arrays (e.g., numpy/pandas/list...). Includes predictions from a model, and possibly also
|
||||
features and true labels.
|
||||
|
||||
:param pred: collection of model predictions
|
||||
:type pred: numpy array or pandas DataFrame or list or pytorch Tensor
|
||||
:param x: collection of data samples
|
||||
:type x: numpy array or pandas DataFrame or list or pytorch Tensor
|
||||
:param y: collection of labels
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue