import dataclasses
import tokenizers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# Model architecture same as training script
@dataclasses.dataclass
class LlamaConfig:
“”“Define Llama model hyperparameters.”“”
vocab_size: int = 50000
max_position_embeddings: int = 2048
hidden_size: int = 768
intermediate_size: int = 4*768
num_hidden_layers: int = 12
num_attention_heads: int = 12
num_key_value_heads: int = 3
class RotaryPositionEncoding(nn.Module):
“”“Rotary position encoding.”“”
def __init__(self, dim: int, max_position_embeddings: int) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
N = 10_000.0
inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))
inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)
position = torch.arange(max_position_embeddings)
sinusoid_inp = torch.outer(position, inv_freq)
self.register_buffer(“cos”, sinusoid_inp.cos())
self.register_buffer(“sin”, sinusoid_inp.sin())
def forward(self, x: Tensor) -> Tensor:
batch_size, seq_len, num_heads, head_dim = x.shape
device = x.device
dtype = x.dtype
cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)
x1, x2 = x.chunk(2, dim=–1)
rotated = torch.cat((–x2, x1), dim=–1)
return (x * cos) + (rotated * sin)
class LlamaAttention(nn.Module):
“”“Grouped-query attention with rotary embeddings.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_heads = config.num_key_value_heads
assert (self.head_dim * self.num_heads) == self.hidden_size
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:
bs, seq_len, dim = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
attn_output = F.scaled_dot_product_attention(
rope(query_states).transpose(1, 2),
rope(key_states).transpose(1, 2),
value_states.transpose(1, 2),
is_causal=True,
dropout_p=0.0,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)
return self.o_proj(attn_output)
class LlamaMLP(nn.Module):
“”“Feed-forward network with SwiGLU activation.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.act_fn = F.silu
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
gate = self.act_fn(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class LlamaDecoderLayer(nn.Module):
“”“Single transformer layer for a Llama model.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
self.self_attn = LlamaAttention(config)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)
self.mlp = LlamaMLP(config)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(hidden_states, rope=rope)
hidden_states = attn_outputs + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return self.mlp(hidden_states) + residual
class LlamaModel(nn.Module):
“”“The full Llama model without any pretraining heads.”“”
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.rotary_emb = RotaryPositionEncoding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings,
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)
def forward(self, input_ids: Tensor) -> Tensor:
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, rope=self.rotary_emb)
return self.norm(hidden_states)
class LlamaForPretraining(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.base_model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids: Tensor) -> Tensor:
hidden_states = self.base_model(input_ids)
return self.lm_head(hidden_states)
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:
“”“Apply repetition penalty to the logits.”“”
for tok in tokens:
if logits[tok] > 0:
logits[tok] /= penalty
else:
logits[tok] *= penalty
return logits
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,
repetition_penalty_range=10, top_k=50, device=None) -> str:
“”“Generate text autoregressively from a prompt.
Args:
model: The trained LlamaForPretraining model
tokenizer: The tokenizer
prompt: Input text prompt
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
repetition_penalty: Penalty for repeating tokens
repetition_penalty_range: Number of previous tokens to consider for repetition penalty
top_k: Only sample from top k most likely tokens
device: Device the model is loaded on
Returns:
Generated text
““”
# Turn model to evaluation mode: Norm layer will work differently
model.eval()
# Get special token IDs
bot_id = tokenizer.token_to_id(“[BOT]”)
eot_id = tokenizer.token_to_id(“[EOT]”)
# Tokenize the prompt into integer tensor
prompt_tokens = [bot_id] + tokenizer.encode(” “ + prompt).ids
input_ids = torch.tensor([prompt_tokens], dtype=torch.int64, device=device)
# Recursively generate tokens
generated_tokens = []
for _step in range(max_tokens):
# Forward pass through model
logits = model(input_ids)
# Get logits for the last token
next_token_logits = logits[0, –1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
next_token_logits = apply_repetition_penalty(
next_token_logits,
generated_tokens[–repetition_penalty_range:],
repetition_penalty,
)
# Apply top-k filtering
if top_k > 0:
top_k_logits = torch.topk(next_token_logits, top_k)[0]
indices_to_remove = next_token_logits < top_k_logits[–1]
next_token_logits[indices_to_remove] = float(“-inf”)
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=–1)
next_token = torch.multinomial(probs, num_samples=1)
# Early stop if EOT token is generated
if next_token.item() == eot_id:
break
# Append the new token to input_ids for next iteration
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
generated_tokens.append(next_token.item())
# Decode all generated tokens
return tokenizer.decode(generated_tokens)
checkpoint = “llama_model_final.pth” # saved model checkpoint
tokenizer = “bpe_50K.json” # saved tokenizer
max_tokens = 100
temperature = 0.9
top_k = 50
penalty = 1.1
penalty_range = 10
# Load tokenizer and model
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
tokenizer = tokenizers.Tokenizer.from_file(tokenizer)
config = LlamaConfig()
model = LlamaForPretraining(config).to(device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
prompt = “Once upon a time, there was”
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=penalty,
repetition_penalty_range=penalty_range,
device=device,
)
print(prompt)
print(“-“ * 20)
print(response)