Module lib.llama_params.convert_params

Expand source code
from jax import Array
import torch
import torch.nn as tnn
from transformers import LlamaForCausalLM, LlamaModel as LlamaModelPt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer

from ..array_utils import pt2jax
from ..llama import Llama, LlamaModel, ModelConfig
from ..llama.attention import Attention
from ..llama.decoder_block import DecoderBlock
from ..tree_utils import stack_leaves

def convert_proj(x: tnn.Linear) -> Array:
    return pt2jax(x.weight.T)

def convert_q_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k))

def convert_k_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_k))

def convert_v_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_v))

def convert_out_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model))

def convert_attention(x: LlamaAttention, *, model_config: ModelConfig) -> Attention:
    q_proj = convert_q_proj(x.q_proj, model_config=model_config)
    k_proj = convert_k_proj(x.k_proj, model_config=model_config)
    v_proj = convert_v_proj(x.v_proj, model_config=model_config)
    out_proj = convert_out_proj(x.o_proj, model_config=model_config)
    return Attention(q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, out_proj=out_proj)

def convert_decoder_block(x: LlamaDecoderLayer, *, model_config: ModelConfig) -> DecoderBlock:
    input_norm = pt2jax(x.input_layernorm.weight)
    attention = convert_attention(x.self_attn, model_config=model_config)
    post_attn_norm = pt2jax(x.post_attention_layernorm.weight)
    gate_proj = convert_proj(x.mlp.gate_proj)
    up_proj = convert_proj(x.mlp.up_proj)
    down_proj = convert_proj(x.mlp.down_proj)
    return DecoderBlock(input_norm=input_norm, attention=attention, post_attn_norm=post_attn_norm, gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj)

def convert_llama_model(model: LlamaModelPt, *, model_config: ModelConfig) -> LlamaModel:
    embedding = pt2jax(model.embed_tokens.weight)
    decoder = stack_leaves([convert_decoder_block(model.layers[i], model_config=model_config) for i in range(model_config.n_layers)])
    norm = pt2jax(model.norm.weight)
    return LlamaModel(embedding=embedding, decoder=decoder, norm=norm)

def convert_llama(model_pt: LlamaForCausalLM, *, model_config: ModelConfig) -> Llama:
    with torch.no_grad():
        model = convert_llama_model(model_pt.model, model_config=model_config)
        lm_head = convert_proj(model_pt.lm_head)
        return Llama(model=model, lm_head=lm_head)

Functions

def convert_proj(x: torch.nn.modules.linear.Linear) ‑> jax.Array
Expand source code
def convert_proj(x: tnn.Linear) -> Array:
    return pt2jax(x.weight.T)
def convert_q_proj(x: torch.nn.modules.linear.Linear, *, model_config: ModelConfig) ‑> jax.Array
Expand source code
def convert_q_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k))
def convert_k_proj(x: torch.nn.modules.linear.Linear, *, model_config: ModelConfig) ‑> jax.Array
Expand source code
def convert_k_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_k))
def convert_v_proj(x: torch.nn.modules.linear.Linear, *, model_config: ModelConfig) ‑> jax.Array
Expand source code
def convert_v_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_v))
def convert_out_proj(x: torch.nn.modules.linear.Linear, *, model_config: ModelConfig) ‑> jax.Array
Expand source code
def convert_out_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
    return pt2jax(x.weight.T.reshape(model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model))
def convert_attention(x: transformers.models.llama.modeling_llama.LlamaAttention, *, model_config: ModelConfig) ‑> Attention
Expand source code
def convert_attention(x: LlamaAttention, *, model_config: ModelConfig) -> Attention:
    q_proj = convert_q_proj(x.q_proj, model_config=model_config)
    k_proj = convert_k_proj(x.k_proj, model_config=model_config)
    v_proj = convert_v_proj(x.v_proj, model_config=model_config)
    out_proj = convert_out_proj(x.o_proj, model_config=model_config)
    return Attention(q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, out_proj=out_proj)
def convert_decoder_block(x: transformers.models.llama.modeling_llama.LlamaDecoderLayer, *, model_config: ModelConfig) ‑> DecoderBlock
Expand source code
def convert_decoder_block(x: LlamaDecoderLayer, *, model_config: ModelConfig) -> DecoderBlock:
    input_norm = pt2jax(x.input_layernorm.weight)
    attention = convert_attention(x.self_attn, model_config=model_config)
    post_attn_norm = pt2jax(x.post_attention_layernorm.weight)
    gate_proj = convert_proj(x.mlp.gate_proj)
    up_proj = convert_proj(x.mlp.up_proj)
    down_proj = convert_proj(x.mlp.down_proj)
    return DecoderBlock(input_norm=input_norm, attention=attention, post_attn_norm=post_attn_norm, gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj)
def convert_llama_model(model: transformers.models.llama.modeling_llama.LlamaModel, *, model_config: ModelConfig) ‑> LlamaModel
Expand source code
def convert_llama_model(model: LlamaModelPt, *, model_config: ModelConfig) -> LlamaModel:
    embedding = pt2jax(model.embed_tokens.weight)
    decoder = stack_leaves([convert_decoder_block(model.layers[i], model_config=model_config) for i in range(model_config.n_layers)])
    norm = pt2jax(model.norm.weight)
    return LlamaModel(embedding=embedding, decoder=decoder, norm=norm)
def convert_llama(model_pt: transformers.models.llama.modeling_llama.LlamaForCausalLM, *, model_config: ModelConfig) ‑> Llama
Expand source code
def convert_llama(model_pt: LlamaForCausalLM, *, model_config: ModelConfig) -> Llama:
    with torch.no_grad():
        model = convert_llama_model(model_pt.model, model_config=model_config)
        lm_head = convert_proj(model_pt.lm_head)
        return Llama(model=model, lm_head=lm_head)