Bug fix for PytorchData dataset

This commit is contained in:
Natalia Razinkov 2022-06-26 15:15:51 +03:00 committed by GitHub
parent 1c4b963add
commit bb224cd3dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -292,7 +292,7 @@ class PytorchData(Dataset):
:type idx: int
:return: the sample as a pytorch Tensor
"""
return self.x[idx]
return self._x[idx]
def get_item(self, idx: int) -> Tensor:
"""
@ -302,11 +302,11 @@ class PytorchData(Dataset):
:type idx: int
:return: the sample and label as pytorch Tensors. Returned as a tuple (sample, label)
"""
sample, label = self.x[idx], self.y[idx]
sample, label = self._x[idx], self._y[idx]
return sample, label
def __len__(self):
return len(self.x)
return len(self._x)
class DatasetFactory: