Module lib.gsm_data.GSMDataset

Expand source code
import json
import os
from typing import Literal, Union

from torch.utils.data import Dataset

def load_data(*, split=Union[Literal['train'], Literal['test']]):
    path = os.path.join(f'../grade-school-math/grade_school_math/data/{split}.jsonl')
    res = []
    with open(path) as f:
        for line in f:
            data = json.loads(line)
            question = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n' + data['question'] + '\n\n### Response:\n'
            answer = data['answer']
            answer = answer.replace('#### ', 'Final answer:\n')
            res.append((question, answer))
    return res

class GSMDataset(Dataset):
    def __init__(self, *, split=Union[Literal['train'], Literal['test']]) -> None:
        self.data = load_data(split=split)
        super().__init__()

    def __getitem__(self, idx: int):
        return self.data[idx]

    def __len__(self) -> int:
        return len(self.data)

Functions

def load_data(*, split=typing.Union[typing.Literal['train'], typing.Literal['test']])
Expand source code
def load_data(*, split=Union[Literal['train'], Literal['test']]):
    path = os.path.join(f'../grade-school-math/grade_school_math/data/{split}.jsonl')
    res = []
    with open(path) as f:
        for line in f:
            data = json.loads(line)
            question = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n' + data['question'] + '\n\n### Response:\n'
            answer = data['answer']
            answer = answer.replace('#### ', 'Final answer:\n')
            res.append((question, answer))
    return res

Classes

class GSMDataset (*, split=typing.Union[typing.Literal['train'], typing.Literal['test']])

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

:class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

Expand source code
class GSMDataset(Dataset):
    def __init__(self, *, split=Union[Literal['train'], Literal['test']]) -> None:
        self.data = load_data(split=split)
        super().__init__()

    def __getitem__(self, idx: int):
        return self.data[idx]

    def __len__(self) -> int:
        return len(self.data)

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic