Subscribe Now

Edit Template

Subscribe Now

Edit Template

Creating a Llama or GPT Model for Next-Token Prediction

import dataclasses

 

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch import Tensor

 

 

@dataclasses.dataclass

class LlamaConfig:

    “”“Define Llama model hyperparameters.”“”

    vocab_size: int = 50000  # Size of the tokenizer vocabulary

    max_position_embeddings: int = 2048  # Maximum sequence length

    hidden_size: int = 768  # Dimension of hidden layers

    intermediate_size: int = 4*768  # Dimension of MLP’s hidden layer

    num_hidden_layers: int = 12  # Number of transformer layers

    num_attention_heads: int = 12  # Number of attention heads

    num_key_value_heads: int = 3  # Number of key-value heads for GQA

 

 

def rotate_half(x: Tensor) -> Tensor:

    “”“Rotates half the hidden dims of the input.

 

    This is a helper function for rotary position embeddings (RoPE).

    For a tensor of shape (…, d), it returns a tensor where the last

    d/2 dimensions are rotated by swapping and negating.

 

    Args:

        x: Input tensor of shape (…, d)

 

    Returns:

        Tensor of same shape with rotated last dimension

    ““”

    x1, x2 = x.chunk(2, dim=1)

    return torch.cat((x2, x1), dim=1)  # Concatenate with rotation

 

 

class RotaryPositionEncoding(nn.Module):

    “”“Rotary position encoding.”“”

 

    def __init__(self, dim: int, max_position_embeddings: int) -> None:

        “”“Initialize the RotaryPositionEncoding module

 

        Args:

            dim: The hidden dimension of the input tensor to which RoPE is applied

            max_position_embeddings: The maximum sequence length of the input tensor

        ““”

        super().__init__()

        self.dim = dim

        self.max_position_embeddings = max_position_embeddings

        # compute a matrix of n\theta_i

        N = 10_000.0

        inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2).float() / dim))

        inv_freq = torch.cat((inv_freq, inv_freq), dim=1)

        position = torch.arange(max_position_embeddings).float()

        sinusoid_inp = torch.outer(position, inv_freq)

        # save cosine and sine matrices as buffers, not parameters

        self.register_buffer(“cos”, sinusoid_inp.cos())

        self.register_buffer(“sin”, sinusoid_inp.sin())

 

    def forward(self, x: Tensor) -> Tensor:

        “”“Apply RoPE to tensor x

 

        Args:

            x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)

 

        Returns:

            Output tensor of shape (batch_size, seq_length, num_heads, head_dim)

        ““”

        batch_size, seq_len, num_heads, head_dim = x.shape

        dtype = x.dtype

        # transform the cosine and sine matrices to 4D tensor and the same dtype as x

        cos = self.cos.to(dtype)[:seq_len].view(1, seq_len, 1, 1)

        sin = self.sin.to(dtype)[:seq_len].view(1, seq_len, 1, 1)

        # apply RoPE to x

        output = (x * cos) + (rotate_half(x) * sin)

        return output

 

 

class LlamaAttention(nn.Module):

    “”“Grouped-query attention with rotary embeddings.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.hidden_size = config.hidden_size

        self.num_heads = config.num_attention_heads

        self.head_dim = self.hidden_size // self.num_heads

        self.num_kv_heads = config.num_key_value_heads  # GQA: H_kv < H_q

 

        # hidden_size must be divisible by num_heads

        assert (self.head_dim * self.num_heads) == self.hidden_size

 

        # Linear layers for Q, K, V projections

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        bs, seq_len, dim = hidden_states.size()

 

        # Project inputs to Q, K, V

        query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

        key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

        value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

 

        # Apply rotary position embeddings

        query_states = rope(query_states)

        key_states = rope(key_states)

 

        # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

        query_states = query_states.transpose(1, 2)

        key_states = key_states.transpose(1, 2)

        value_states = value_states.transpose(1, 2)

 

        # Use PyTorch’s optimized attention implementation

        # setting is_causal=True is incompatible with setting explicit attention mask

        attn_output = F.scaled_dot_product_attention(

            query_states,

            key_states,

            value_states,

            attn_mask=attn_mask,

            dropout_p=0.0,

            enable_gqa=True,

        )

 

        # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output

        attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output

 

 

class LlamaMLP(nn.Module):

    “”“Feed-forward network with SwiGLU activation.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        # Two parallel projections for SwiGLU

        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.act_fn = F.silu  # SwiGLU activation function

        # Project back to hidden size

        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

 

    def forward(self, x: Tensor) -> Tensor:

        # SwiGLU activation: multiply gate and up-projected inputs

        gate = self.act_fn(self.gate_proj(x))

        up = self.up_proj(x)

        return self.down_proj(gate * up)

 

 

class LlamaDecoderLayer(nn.Module):

    “”“Single transformer layer for a Llama model.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e5)

        self.self_attn = LlamaAttention(config)

        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e5)

        self.mlp = LlamaMLP(config)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        # First residual block: Self-attention

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)

        hidden_states = attn_outputs + residual

 

        # Second residual block: MLP

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states) + residual

        return hidden_states

 

 

class LlamaModel(nn.Module):

    “”“The full Llama model without any pretraining heads.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.rotary_emb = RotaryPositionEncoding(

            config.hidden_size // config.num_attention_heads,

            config.max_position_embeddings,

        )

 

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])

        self.norm = nn.RMSNorm(config.hidden_size, eps=1e5)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        # Convert input token IDs to embeddings

        hidden_states = self.embed_tokens(input_ids)

        # Process through all transformer layers, then the final norm layer

        for layer in self.layers:

            hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)

        hidden_states = self.norm(hidden_states)

        # Return the final hidden states

        return hidden_states

 

 

class LlamaForPretraining(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.base_model = LlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        hidden_states = self.base_model(input_ids, attn_mask)

        return self.lm_head(hidden_states)

 

 

def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor:

    “”“Create a causal mask for self-attention.

 

    Args:

        seq_len: Length of the sequence

        device: Device to create the mask on

        dtype: Data type of the mask

 

    Returns:

        Causal mask of shape (seq_len, seq_len)

    ““”

    mask = torch.full((seq_len, seq_len), float(‘-inf’), device=device, dtype=dtype) \

                .triu(diagonal=1)

    return mask

 

def create_padding_mask(batch, padding_token_id, device: torch.device, dtype: torch.dtype = torch.float32):

    “”“Create a padding mask for a batch of sequences for self-attention.

 

    Args:

        batch: Batch of sequences, shape (batch_size, seq_len)

        padding_token_id: ID of the padding token

 

    Returns:

        Padding mask of shape (batch_size, 1, seq_len, seq_len)

    ““”

    padded = torch.zeros_like(batch, device=device, dtype=dtype) \

                  .masked_fill(batch == padding_token_id, float(‘-inf’))

    mask = padded[:,:,None] + padded[:,None,:]

    return mask[:, None, :, :]

 

 

# Create model with default config

test_config = LlamaConfig()

device = torch.device(“cuda”) if torch.cuda.is_available() else torch.device(“cpu”)

model = LlamaModel(test_config).to(device)

# print the model size

print(f“Model parameters size: {sum(p.numel() for p in model.parameters()) / 1024**2:.2f} MB”)

print(f“Model buffers size: {sum(p.numel() for p in model.buffers()) / 1024**2:.2f} MB”)

 

# Create a random tensor

PAD_TOKEN_ID = 0

bs, seq_len = 5, 13

x = torch.randint(1, test_config.vocab_size, (bs, seq_len), dtype=torch.int32, device=device)

# set random length of padding tokens at the end of each sequence

for i, pad_length in enumerate([4, 1, 0, 3, 8]):

    if pad_length > 0:

        x[i, pad_length:] = PAD_TOKEN_ID

# Create causal and padding masks

causal_mask = create_causal_mask(seq_len, device)

padding_mask = create_padding_mask(x, PAD_TOKEN_ID, device)

attn_mask = causal_mask + padding_mask

print(f“Input ids: {x}”)

print(f“Attention mask: {attn_mask}”)

 

# Run the model

output = model(x, attn_mask)

print(“OK”)

thecrossroadtimes.com

Writer & Blogger

Considered an invitation do introduced sufficient understood instrument it. Of decisively friendship in as collecting at. No affixed be husband ye females brother garrets proceed. Least child who seven happy yet balls young. Discovery sweetness principle discourse shameless bed one excellent. Sentiments of surrounded friendship dispatched connection is he.

Leave a Reply

Your email address will not be published. Required fields are marked *

About Me

Kapil Kumar

Founder & Editor

As a passionate explorer of the intersection between technology, art, and the natural world, I’ve embarked on a journey to unravel the fascinating connections that weave our world together. In my digital haven, you’ll find a blend of insights into cutting-edge technology, the mesmerizing realms of artificial intelligence, the expressive beauty of art.

Edit Template
As a passionate explorer of the intersection between technology, art, and the natural world, I’ve embarked on a journey to unravel the fascinating connections.
You have been successfully Subscribed! Ops! Something went wrong, please try again.

Quick Links

Home

Features

Terms & Conditions

Privacy Policy

Contact

Contact Us

© 2024 Created by Shadowbiz

As a passionate explorer of the intersection between technology, art, and the natural world, I’ve embarked on a journey to unravel the fascinating connections.
You have been successfully Subscribed! Ops! Something went wrong, please try again.

Quick Links

Home

Features

Terms & Conditions

Privacy Policy

Contact

Contact Us

© 2024 Created by Shadowbiz

Fill Your Contact Details

Fill out this form, and we’ll reach out to you through WhatsApp for further communication.

Popup Form