mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-27 13:56:22 +02:00
update
This commit is contained in:
parent
a3fb68fb56
commit
302d0c4b8c
3 changed files with 46 additions and 42 deletions
|
|
@ -221,21 +221,21 @@ class PytorchData(Dataset):
|
|||
|
||||
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return data samples as numpy array"""
|
||||
return array2numpy(self._x)
|
||||
return array2numpy(self, self._x)
|
||||
|
||||
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""Return labels as numpy array"""
|
||||
return array2numpy(self._y) if self._y is not None else None
|
||||
return array2numpy(self, self._y) if self._y is not None else None
|
||||
|
||||
def get_sample_item(self, idx) -> Tensor:
|
||||
return self.x[idx]
|
||||
return self._x[idx]
|
||||
|
||||
def get_item(self, idx) -> Tensor:
|
||||
sample, label = self.x[idx], self.y[idx]
|
||||
sample, label = self._x[idx], self._y[idx]
|
||||
return sample, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
return len(self._x)
|
||||
|
||||
|
||||
class DatasetFactory:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue