Module lib.gsm_data.gsm_collate_fn
Expand source code
from itertools import chain, repeat
import jax.numpy as jnp
from transformers import LlamaTokenizer
from ..data import TrainData, TestData
def gsm_collate_fn_train(tokenizer: LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]]):
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
seq_list = []
seq_mask_list = []
labels_list = []
labels_mask_list = []
for question, answer in data_batch:
question = tokenizer(question, add_special_tokens=False, return_attention_mask=False).input_ids
answer = tokenizer(answer, add_special_tokens=False, return_attention_mask=False).input_ids
len_question = len(question)
len_answer = len(answer)
len_seq = len_question + len_answer + 2
len_pad = max_len - len_seq
assert len(question) + 1 < max_len, '`max_len` too small'
seq = list(chain((bos_id,), question, answer, (eos_id,), repeat(eos_id, len_pad)))
seq_mask = list(chain(repeat(True, 1 + len_question + len_answer + 1), repeat(False, len_pad)))
labels = list(chain(question, answer, (eos_id,), repeat(eos_id, len_pad + 1)))
labels_mask = list(chain(repeat(False, len_question), repeat(True, len_answer + 1), repeat(False, len_pad + 1)))
seq = seq[:max_len]
seq_mask = seq_mask[:max_len]
labels = labels[:max_len]
labels_mask = labels_mask[:max_len]
seq_list.append(seq)
seq_mask_list.append(seq_mask)
labels_list.append(labels)
labels_mask_list.append(labels_mask)
seq_ = jnp.array(seq_list, dtype=jnp.uint16)
seq_mask_ = jnp.array(seq_mask_list, dtype=jnp.bool_)
labels_ = jnp.array(labels_list, dtype=jnp.uint16)
labels_mask_ = jnp.array(labels_mask_list, dtype=jnp.bool_)
return TrainData(seq_, seq_mask_, labels_, labels_mask_)
def gsm_collate_fn_test(tokenizer: LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]]):
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
seq_list = []
seq_mask_list = []
labels_list = []
for question, answer in data_batch:
question = tokenizer(question, add_special_tokens=False, return_attention_mask=False).input_ids
len_question = len(question)
len_seq = len_question + 1
len_pad = max_len - len_seq
assert len(question) + 1 < max_len, '`max_len` too small'
seq = list(chain((bos_id,), question, repeat(eos_id, len_pad)))
seq_mask = list(chain(repeat(True, 1 + len_question), repeat(False, len_pad)))
seq = seq[:max_len]
seq_mask = seq_mask[:max_len]
seq_list.append(seq)
seq_mask_list.append(seq_mask)
labels_list.append(answer)
seq_ = jnp.array(seq_list, dtype=jnp.uint16)
seq_mask_ = jnp.array(seq_mask_list, dtype=jnp.bool_)
return TestData(seq_, seq_mask_, labels_list)
Functions
def gsm_collate_fn_train(tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]])
-
Expand source code
def gsm_collate_fn_train(tokenizer: LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]]): bos_id = tokenizer.bos_token_id eos_id = tokenizer.eos_token_id seq_list = [] seq_mask_list = [] labels_list = [] labels_mask_list = [] for question, answer in data_batch: question = tokenizer(question, add_special_tokens=False, return_attention_mask=False).input_ids answer = tokenizer(answer, add_special_tokens=False, return_attention_mask=False).input_ids len_question = len(question) len_answer = len(answer) len_seq = len_question + len_answer + 2 len_pad = max_len - len_seq assert len(question) + 1 < max_len, '`max_len` too small' seq = list(chain((bos_id,), question, answer, (eos_id,), repeat(eos_id, len_pad))) seq_mask = list(chain(repeat(True, 1 + len_question + len_answer + 1), repeat(False, len_pad))) labels = list(chain(question, answer, (eos_id,), repeat(eos_id, len_pad + 1))) labels_mask = list(chain(repeat(False, len_question), repeat(True, len_answer + 1), repeat(False, len_pad + 1))) seq = seq[:max_len] seq_mask = seq_mask[:max_len] labels = labels[:max_len] labels_mask = labels_mask[:max_len] seq_list.append(seq) seq_mask_list.append(seq_mask) labels_list.append(labels) labels_mask_list.append(labels_mask) seq_ = jnp.array(seq_list, dtype=jnp.uint16) seq_mask_ = jnp.array(seq_mask_list, dtype=jnp.bool_) labels_ = jnp.array(labels_list, dtype=jnp.uint16) labels_mask_ = jnp.array(labels_mask_list, dtype=jnp.bool_) return TrainData(seq_, seq_mask_, labels_, labels_mask_)
def gsm_collate_fn_test(tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]])
-
Expand source code
def gsm_collate_fn_test(tokenizer: LlamaTokenizer, max_len: int, data_batch: list[tuple[str, str]]): bos_id = tokenizer.bos_token_id eos_id = tokenizer.eos_token_id seq_list = [] seq_mask_list = [] labels_list = [] for question, answer in data_batch: question = tokenizer(question, add_special_tokens=False, return_attention_mask=False).input_ids len_question = len(question) len_seq = len_question + 1 len_pad = max_len - len_seq assert len(question) + 1 < max_len, '`max_len` too small' seq = list(chain((bos_id,), question, repeat(eos_id, len_pad))) seq_mask = list(chain(repeat(True, 1 + len_question), repeat(False, len_pad))) seq = seq[:max_len] seq_mask = seq_mask[:max_len] seq_list.append(seq) seq_mask_list.append(seq_mask) labels_list.append(answer) seq_ = jnp.array(seq_list, dtype=jnp.uint16) seq_mask_ = jnp.array(seq_mask_list, dtype=jnp.bool_) return TestData(seq_, seq_mask_, labels_list)