Module lib.gsm_data.GSMDataset
Expand source code
import json
import os
from typing import Literal, Union
from torch.utils.data import Dataset
def load_data(*, split=Union[Literal['train'], Literal['test']]):
path = os.path.join(f'../grade-school-math/grade_school_math/data/{split}.jsonl')
res = []
with open(path) as f:
for line in f:
data = json.loads(line)
question = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n' + data['question'] + '\n\n### Response:\n'
answer = data['answer']
answer = answer.replace('#### ', 'Final answer:\n')
res.append((question, answer))
return res
class GSMDataset(Dataset):
def __init__(self, *, split=Union[Literal['train'], Literal['test']]) -> None:
self.data = load_data(split=split)
super().__init__()
def __getitem__(self, idx: int):
return self.data[idx]
def __len__(self) -> int:
return len(self.data)
Functions
def load_data(*, split=typing.Union[typing.Literal['train'], typing.Literal['test']])
-
Expand source code
def load_data(*, split=Union[Literal['train'], Literal['test']]): path = os.path.join(f'../grade-school-math/grade_school_math/data/{split}.jsonl') res = [] with open(path) as f: for line in f: data = json.loads(line) question = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n' + data['question'] + '\n\n### Response:\n' answer = data['answer'] answer = answer.replace('#### ', 'Final answer:\n') res.append((question, answer)) return res
Classes
class GSMDataset (*, split=typing.Union[typing.Literal['train'], typing.Literal['test']])
-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
class GSMDataset(Dataset): def __init__(self, *, split=Union[Literal['train'], Literal['test']]) -> None: self.data = load_data(split=split) super().__init__() def __getitem__(self, idx: int): return self.data[idx] def __len__(self) -> int: return len(self.data)
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic