Module lib.llama_params.convert_back_params
Expand source code
from jax import Array
import torch
import torch.nn as tnn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel as LlamaModelPt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm
from ..array_utils import jax2pt
from ..llama import Llama, LlamaModel
from ..llama.attention import Attention
from ..llama.decoder_block import DecoderBlock
from ..tree_utils import unstack_leaves
def convert_back_embedding(x: Array) -> tnn.Embedding:
with torch.no_grad():
embedding = tnn.Embedding(*x.shape) # type: ignore
embedding.weight = tnn.Parameter(jax2pt(x))
return embedding
def convert_back_norm(x: Array, *, config: LlamaConfig) -> LlamaRMSNorm:
d_model = config.hidden_size
rms_norm_eps = config.rms_norm_eps
with torch.no_grad():
llama_rms_norm = LlamaRMSNorm(d_model, eps=rms_norm_eps)
llama_rms_norm.weight = tnn.Parameter(jax2pt(x))
return llama_rms_norm
def convert_back_proj(x: Array) -> tnn.Linear:
with torch.no_grad():
linear = tnn.Linear(*x.shape, bias=False) # type: ignore
linear.weight = tnn.Parameter(jax2pt(x).T)
return linear
def convert_back_q_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_rep_kv = config.num_attention_heads // config.num_key_value_heads
n_heads_kv = config.num_key_value_heads
d_k = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_rep_kv * n_heads_kv * d_k
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_k_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_heads_kv = config.num_key_value_heads
d_k = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_heads_kv * d_k
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_v_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_heads_kv = config.num_key_value_heads
d_v = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_heads_kv * d_v
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_out_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_rep_kv = config.num_attention_heads // config.num_key_value_heads
n_heads_kv = config.num_key_value_heads
d_v = config.hidden_size // config.num_attention_heads
in_features = n_rep_kv * n_heads_kv * d_v
out_features = d_model
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False) # type: ignore
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_attention(x: Attention, *, config: LlamaConfig) -> LlamaAttention:
with torch.no_grad():
llama_attention = LlamaAttention(config=config)
llama_attention.q_proj = convert_back_q_proj(x.q_proj, config=config)
llama_attention.k_proj = convert_back_k_proj(x.k_proj, config=config)
llama_attention.v_proj = convert_back_v_proj(x.v_proj, config=config)
llama_attention.o_proj = convert_back_out_proj(x.out_proj, config=config)
return llama_attention
def convert_back_mlp(gate_proj: Array, up_proj: Array, down_proj: Array, *, config: LlamaConfig) -> LlamaMLP:
with torch.no_grad():
llama_mlp = LlamaMLP(config=config)
llama_mlp.gate_proj = convert_back_proj(gate_proj)
llama_mlp.up_proj = convert_back_proj(up_proj)
llama_mlp.down_proj = convert_back_proj(down_proj)
return llama_mlp
def convert_back_decoder_block(x: DecoderBlock, *, config: LlamaConfig) -> LlamaDecoderLayer:
with torch.no_grad():
llama_decoder_layer = LlamaDecoderLayer(config=config)
llama_decoder_layer.self_attn = convert_back_attention(x.attention, config=config)
llama_decoder_layer.mlp = convert_back_mlp(x.gate_proj, x.up_proj, x.down_proj, config=config)
llama_decoder_layer.input_layernorm = convert_back_norm(x.input_norm, config=config)
llama_decoder_layer.post_attention_layernorm = convert_back_norm(x.post_attn_norm, config=config)
return llama_decoder_layer
def convert_back_llama_model(x: LlamaModel, *, config: LlamaConfig) -> LlamaModelPt:
with torch.no_grad():
llama_model = LlamaModelPt(config=config)
llama_model.embed_tokens = convert_back_embedding(x.embedding)
llama_model.layers = tnn.ModuleList([convert_back_decoder_block(decoder_block, config=config) for decoder_block in unstack_leaves(x.decoder)])
llama_model.norm = convert_back_norm(x.norm, config=config)
return llama_model
def convert_back_llama(x: Llama, *, config: LlamaConfig) -> LlamaForCausalLM:
with torch.no_grad():
llama = LlamaForCausalLM(config=config)
llama.model = convert_back_llama_model(x.model, config=config)
llama.lm_head = convert_back_proj(x.lm_head)
return llama
# from pathlib import Path; import sys; sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
# from lib.proc_init_utils import initialise_cpu; initialise_cpu()
# model_pt = LlamaForCausalLM.from_pretrained('/dev/shm/llama-weights/llama2-7B')
# config = LlamaConfig.from_pretrained('/dev/shm/llama-weights/llama2-7B')
# from lib.param_utils.convert_params import convert_proj
# assert torch.equal(convert_back_proj(convert_proj(model_pt.lm_head)).weight, model_pt.lm_head.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.q_proj)).weight, model_pt.model.layers[0].self_attn.q_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.k_proj)).weight, model_pt.model.layers[0].self_attn.k_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.v_proj)).weight, model_pt.model.layers[0].self_attn.v_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.o_proj)).weight, model_pt.model.layers[0].self_attn.o_proj.weight)
# model_pt.model.norm.weight
# model_pt.model.embed_tokens.weight
Functions
def convert_back_embedding(x: jax.Array) ‑> torch.nn.modules.sparse.Embedding
-
Expand source code
def convert_back_embedding(x: Array) -> tnn.Embedding: with torch.no_grad(): embedding = tnn.Embedding(*x.shape) # type: ignore embedding.weight = tnn.Parameter(jax2pt(x)) return embedding
def convert_back_norm(x: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaRMSNorm
-
Expand source code
def convert_back_norm(x: Array, *, config: LlamaConfig) -> LlamaRMSNorm: d_model = config.hidden_size rms_norm_eps = config.rms_norm_eps with torch.no_grad(): llama_rms_norm = LlamaRMSNorm(d_model, eps=rms_norm_eps) llama_rms_norm.weight = tnn.Parameter(jax2pt(x)) return llama_rms_norm
def convert_back_proj(x: jax.Array) ‑> torch.nn.modules.linear.Linear
-
Expand source code
def convert_back_proj(x: Array) -> tnn.Linear: with torch.no_grad(): linear = tnn.Linear(*x.shape, bias=False) # type: ignore linear.weight = tnn.Parameter(jax2pt(x).T) return linear
def convert_back_q_proj(x: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> torch.nn.modules.linear.Linear
-
Expand source code
def convert_back_q_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear: d_model = config.hidden_size n_rep_kv = config.num_attention_heads // config.num_key_value_heads n_heads_kv = config.num_key_value_heads d_k = config.hidden_size // config.num_attention_heads in_features = d_model out_features = n_rep_kv * n_heads_kv * d_k with torch.no_grad(): linear = tnn.Linear(in_features, out_features, bias=False) linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T) return linear
def convert_back_k_proj(x: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> torch.nn.modules.linear.Linear
-
Expand source code
def convert_back_k_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear: d_model = config.hidden_size n_heads_kv = config.num_key_value_heads d_k = config.hidden_size // config.num_attention_heads in_features = d_model out_features = n_heads_kv * d_k with torch.no_grad(): linear = tnn.Linear(in_features, out_features, bias=False) linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T) return linear
def convert_back_v_proj(x: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> torch.nn.modules.linear.Linear
-
Expand source code
def convert_back_v_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear: d_model = config.hidden_size n_heads_kv = config.num_key_value_heads d_v = config.hidden_size // config.num_attention_heads in_features = d_model out_features = n_heads_kv * d_v with torch.no_grad(): linear = tnn.Linear(in_features, out_features, bias=False) linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T) return linear
def convert_back_out_proj(x: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> torch.nn.modules.linear.Linear
-
Expand source code
def convert_back_out_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear: d_model = config.hidden_size n_rep_kv = config.num_attention_heads // config.num_key_value_heads n_heads_kv = config.num_key_value_heads d_v = config.hidden_size // config.num_attention_heads in_features = n_rep_kv * n_heads_kv * d_v out_features = d_model with torch.no_grad(): linear = tnn.Linear(in_features, out_features, bias=False) # type: ignore linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T) return linear
def convert_back_attention(x: Attention, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaAttention
-
Expand source code
def convert_back_attention(x: Attention, *, config: LlamaConfig) -> LlamaAttention: with torch.no_grad(): llama_attention = LlamaAttention(config=config) llama_attention.q_proj = convert_back_q_proj(x.q_proj, config=config) llama_attention.k_proj = convert_back_k_proj(x.k_proj, config=config) llama_attention.v_proj = convert_back_v_proj(x.v_proj, config=config) llama_attention.o_proj = convert_back_out_proj(x.out_proj, config=config) return llama_attention
def convert_back_mlp(gate_proj: jax.Array, up_proj: jax.Array, down_proj: jax.Array, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaMLP
-
Expand source code
def convert_back_mlp(gate_proj: Array, up_proj: Array, down_proj: Array, *, config: LlamaConfig) -> LlamaMLP: with torch.no_grad(): llama_mlp = LlamaMLP(config=config) llama_mlp.gate_proj = convert_back_proj(gate_proj) llama_mlp.up_proj = convert_back_proj(up_proj) llama_mlp.down_proj = convert_back_proj(down_proj) return llama_mlp
def convert_back_decoder_block(x: DecoderBlock, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaDecoderLayer
-
Expand source code
def convert_back_decoder_block(x: DecoderBlock, *, config: LlamaConfig) -> LlamaDecoderLayer: with torch.no_grad(): llama_decoder_layer = LlamaDecoderLayer(config=config) llama_decoder_layer.self_attn = convert_back_attention(x.attention, config=config) llama_decoder_layer.mlp = convert_back_mlp(x.gate_proj, x.up_proj, x.down_proj, config=config) llama_decoder_layer.input_layernorm = convert_back_norm(x.input_norm, config=config) llama_decoder_layer.post_attention_layernorm = convert_back_norm(x.post_attn_norm, config=config) return llama_decoder_layer
def convert_back_llama_model(x: LlamaModel, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaModel
-
Expand source code
def convert_back_llama_model(x: LlamaModel, *, config: LlamaConfig) -> LlamaModelPt: with torch.no_grad(): llama_model = LlamaModelPt(config=config) llama_model.embed_tokens = convert_back_embedding(x.embedding) llama_model.layers = tnn.ModuleList([convert_back_decoder_block(decoder_block, config=config) for decoder_block in unstack_leaves(x.decoder)]) llama_model.norm = convert_back_norm(x.norm, config=config) return llama_model
def convert_back_llama(x: Llama, *, config: transformers.models.llama.configuration_llama.LlamaConfig) ‑> transformers.models.llama.modeling_llama.LlamaForCausalLM
-
Expand source code
def convert_back_llama(x: Llama, *, config: LlamaConfig) -> LlamaForCausalLM: with torch.no_grad(): llama = LlamaForCausalLM(config=config) llama.model = convert_back_llama_model(x.model, config=config) llama.lm_head = convert_back_proj(x.lm_head) return llama