This commit is contained in:
olasaadi 2022-06-02 15:25:07 +03:00
parent a3fb68fb56
commit 302d0c4b8c
3 changed files with 46 additions and 42 deletions

View file

@ -221,21 +221,21 @@ class PytorchData(Dataset):
def get_samples(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return data samples as numpy array"""
return array2numpy(self._x)
return array2numpy(self, self._x)
def get_labels(self) -> OUTPUT_DATA_ARRAY_TYPE:
"""Return labels as numpy array"""
return array2numpy(self._y) if self._y is not None else None
return array2numpy(self, self._y) if self._y is not None else None
def get_sample_item(self, idx) -> Tensor:
return self.x[idx]
return self._x[idx]
def get_item(self, idx) -> Tensor:
sample, label = self.x[idx], self.y[idx]
sample, label = self._x[idx], self._y[idx]
return sample, label
def __len__(self):
return len(self.x)
return len(self._x)
class DatasetFactory: