long-context

Extend context windows of transformer models using RoPE, YaRN, ALiBi, and position interpolation techniques. Use when processing long documents (32k-128k+…

INSTALLATION
npx skills add https://github.com/davila7/claude-code-templates --skill long-context
Run in your project or agent environment. Adjust flags if your CLI version differs.

SKILL.md

Long Context: Extending Transformer Context Windows

When to Use This Skill

Use Long Context techniques when you need to:

  • Process long documents (32k, 64k, 128k+ tokens) with transformer models
  • Extend context windows of pre-trained models (LLaMA, Mistral, etc.)
  • Implement efficient positional encodings (RoPE, ALiBi)
  • Train models with length extrapolation capabilities
  • Deploy models that handle variable-length inputs efficiently
  • Fine-tune existing models for longer contexts with minimal compute

Key Techniques: RoPE (Rotary Position Embeddings), YaRN, ALiBi (Attention with Linear Biases), Position Interpolation

Papers: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)

Installation

# HuggingFace Transformers (includes RoPE, YaRN support)

pip install transformers torch

# For custom implementations

pip install einops  # Tensor operations

pip install rotary-embedding-torch  # Standalone RoPE

# Optional: FlashAttention for efficiency

pip install flash-attn --no-build-isolation

Quick Start

RoPE (Rotary Position Embeddings)

import torch

import torch.nn as nn

class RotaryEmbedding(nn.Module):

    """Rotary Position Embeddings (RoPE)."""

    def __init__(self, dim, max_seq_len=8192, base=10000):

        super().__init__()

        # Compute inverse frequencies

        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

        self.register_buffer("inv_freq", inv_freq)

        self.max_seq_len = max_seq_len

    def forward(self, seq_len, device):

        # Position indices

        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)

        # Compute frequencies

        freqs = torch.outer(t, self.inv_freq)  # (seq_len, dim/2)

        # Compute sin and cos

        emb = torch.cat((freqs, freqs), dim=-1)  # (seq_len, dim)

        return emb.cos(), emb.sin()

def rotate_half(x):

    """Rotate half the hidden dimensions."""

    x1, x2 = x.chunk(2, dim=-1)

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):

    """Apply rotary embeddings to queries and keys."""

    # q, k shape: (batch, heads, seq_len, dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)

    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

# Usage

rope = RotaryEmbedding(dim=64, max_seq_len=8192)

cos, sin = rope(seq_len=2048, device='cuda')

# In attention layer

q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)

ALiBi (Attention with Linear Biases)

def get_alibi_slopes(num_heads):

    """Get ALiBi slope values for each attention head."""

    def get_slopes_power_of_2(n):

        start = 2 ** (-(2 ** -(math.log2(n) - 3)))

        ratio = start

        return [start * (ratio ** i) for i in range(n)]

    if math.log2(num_heads).is_integer():

        return get_slopes_power_of_2(num_heads)

    else:

        # Closest power of 2

        closest_power = 2 ** math.floor(math.log2(num_heads))

        slopes = get_slopes_power_of_2(closest_power)

        # Add extra slopes

        extra = get_slopes_power_of_2(2 * closest_power)

        slopes.extend(extra[0::2][:num_heads - closest_power])

        return slopes

def create_alibi_bias(seq_len, num_heads):

    """Create ALiBi attention bias."""

    # Distance matrix

    context_position = torch.arange(seq_len)

    memory_position = torch.arange(seq_len)

    relative_position = memory_position[None, :] - context_position[:, None]

    # Get slopes

    slopes = torch.tensor(get_alibi_slopes(num_heads))

    # Apply slopes to distances

    alibi = slopes[:, None, None] * relative_position[None, :, :]

    return alibi  # (num_heads, seq_len, seq_len)

# Usage in attention

num_heads = 8

seq_len = 2048

alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')

# Add bias to attention scores

# attn_scores shape: (batch, num_heads, seq_len, seq_len)

attn_scores = attn_scores + alibi_bias

attn_weights = torch.softmax(attn_scores, dim=-1)

Position Interpolation for LLaMA

from transformers import LlamaForCausalLM, LlamaTokenizer

# Original context: 2048 tokens

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Extend to 32k with position interpolation

# Modify RoPE base frequency

model.config.rope_scaling = {

    "type": "linear",

    "factor": 16.0  # 2048 * 16 = 32768

}

# Or use dynamic scaling

model.config.rope_scaling = {

    "type": "dynamic",

    "factor": 16.0

}

# Fine-tune with long documents (minimal steps needed)

# Position interpolation works out-of-the-box after this config change

Core Concepts

1. RoPE (Rotary Position Embeddings)

How it works:

  • Encodes absolute position via rotation matrix
  • Provides relative position dependency in attention
  • Enables length extrapolation

Mathematical formulation:

q_m = (W_q * x_m) * e^(imθ)

k_n = (W_k * x_n) * e^(inθ)

where θ_j = base^(-2j/d) for j ∈ [0, d/2)

Advantages:

  • Decaying inter-token dependency with distance
  • Compatible with linear attention
  • Better extrapolation than absolute position encodings

2. YaRN (Yet another RoPE extensioN)

Key innovation:

  • NTK-aware interpolation (Neural Tangent Kernel)
  • Attention temperature scaling
  • Efficient context extension (10× less tokens vs baselines)

Parameters:

# YaRN configuration

yarn_config = {

    "scale": 16,                    # Extension factor

    "original_max_position": 2048,  # Base context

    "extrapolation_factor": 1.0,    # NTK parameter

    "attn_factor": 1.0,             # Attention scaling

    "beta_fast": 32,                # High-frequency scale

    "beta_slow": 1,                 # Low-frequency scale

}

Performance:

  • Extends LLaMA to 128k tokens
  • 2.5× less training steps than baselines
  • State-of-the-art context window extension

3. ALiBi (Attention with Linear Biases)

Core idea:

  • No positional embeddings added to tokens
  • Apply distance penalty directly to attention scores
  • Bias proportional to key-query distance

Formula:

attention_bias[i, j] = -m * |i - j|

where m = slope for each attention head

Advantages:

  • 11% faster training vs sinusoidal embeddings
  • 11% less memory usage
  • Strong length extrapolation (train 1k, test 2k+)
  • Inductive bias towards recency

4. Position Interpolation

Technique:

  • Linearly down-scale position indices
  • Interpolate within trained range (vs extrapolate beyond)
  • Minimal fine-tuning required

Formula:

# Original: position indices [0, 1, 2, ..., L]

# Extended: position indices [0, 0.5, 1.0, ..., L/2]

# (for 2× extension)

scaled_position[i] = i / extension_factor

Results:

  • LLaMA 7B-65B extended to 32k tokens
  • 1000 fine-tuning steps sufficient
  • 600× better stability than extrapolation

Method Comparison

Method

Max Context

Training Needed

Memory

Extrapolation

Best For

RoPE

8k-32k

Full pre-training

Moderate

Good

New models

YaRN

32k-128k

Minimal (10× efficient)

Moderate

Excellent

Extending existing models

ALiBi

Unlimited

Full pre-training

Low (-11%)

Excellent

Training from scratch

Position Interpolation

32k+

Minimal (1k steps)

Moderate

Poor (by design)

Quick extension

Implementation Patterns

HuggingFace Transformers Integration

from transformers import AutoModelForCausalLM, AutoConfig

# RoPE with YaRN scaling

config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")

config.rope_scaling = {

    "type": "yarn",

    "factor": 8.0,

    "original_max_position_embeddings": 8192,

    "attention_factor": 1.0

}

model = AutoModelForCausalLM.from_config(config)

# Position interpolation (simpler)

config.rope_scaling = {

    "type": "linear",

    "factor": 4.0

}

# Dynamic scaling (adjusts based on input length)

config.rope_scaling = {

    "type": "dynamic",

    "factor": 8.0

}

Custom RoPE Implementation

class LongContextAttention(nn.Module):

    """Multi-head attention with RoPE."""

    def __init__(self, hidden_size, num_heads, max_seq_len=32768):

        super().__init__()

        self.num_heads = num_heads

        self.head_dim = hidden_size // num_heads

        # Q, K, V projections

        self.q_proj = nn.Linear(hidden_size, hidden_size)

        self.k_proj = nn.Linear(hidden_size, hidden_size)

        self.v_proj = nn.Linear(hidden_size, hidden_size)

        self.o_proj = nn.Linear(hidden_size, hidden_size)

        # RoPE

        self.rotary_emb = RotaryEmbedding(

            dim=self.head_dim,

            max_seq_len=max_seq_len

        )

    def forward(self, hidden_states):

        batch_size, seq_len, _ = hidden_states.shape

        # Project to Q, K, V

        q = self.q_proj(hidden_states)

        k = self.k_proj(hidden_states)

        v = self.v_proj(hidden_states)

        # Reshape for multi-head

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE

        cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)

        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Standard attention

        attn_output = F.scaled_dot_product_attention(q, k, v)

        # Reshape and project

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.view(batch_size, seq_len, -1)

        output = self.o_proj(attn_output)

        return output

Fine-tuning for Long Context

Minimal Fine-tuning (Position Interpolation)

from transformers import Trainer, TrainingArguments

# Extend model config

model.config.max_position_embeddings = 32768

model.config.rope_scaling = {"type": "linear", "factor": 16.0}

# Training args (minimal steps needed)

training_args = TrainingArguments(

    output_dir="./llama-32k",

    num_train_epochs=1,

    max_steps=1000,           # Only 1000 steps!

    per_device_train_batch_size=1,

    gradient_accumulation_steps=16,

    learning_rate=2e-5,

    warmup_steps=100,

    logging_steps=10,

    save_steps=500,

)

# Train on long documents

trainer = Trainer(

    model=model,

    args=training_args,

    train_dataset=long_document_dataset,  # 32k token sequences

)

trainer.train()

YaRN Fine-tuning

# Clone YaRN implementation

git clone https://github.com/jquesnelle/yarn

cd yarn

# Fine-tune LLaMA with YaRN

python scripts/train.py \

    --model meta-llama/Llama-2-7b-hf \

    --scale 16 \

    --rope_theta 10000 \

    --max_length 32768 \

    --batch_size 1 \

    --gradient_accumulation 16 \

    --steps 400 \

    --learning_rate 2e-5

Best Practices

1. Choose the Right Method

# For NEW models (training from scratch)

use_method = "ALiBi"  # Best extrapolation, lowest memory

# For EXTENDING existing RoPE models

use_method = "YaRN"  # Most efficient extension (10× less data)

# For QUICK extension with minimal compute

use_method = "Position Interpolation"  # 1000 steps

# For MODERATE extension with good efficiency

use_method = "Linear RoPE Scaling"  # Built-in, simple

2. Scaling Factor Selection

# Conservative (safer, better quality)

scaling_factor = 2.0  # 8k → 16k

# Moderate (good balance)

scaling_factor = 4.0  # 8k → 32k

# Aggressive (requires more fine-tuning)

scaling_factor = 8.0  # 8k → 64k

scaling_factor = 16.0  # 8k → 128k

# Rule: Larger factors need more fine-tuning steps

steps_needed = 100 * scaling_factor  # Rough estimate

3. Fine-tuning Data

# ✅ Good: Long documents matching target length

train_data = [

    {"text": long_doc_32k_tokens},  # Full 32k

    {"text": long_doc_24k_tokens},  # Varied lengths

    {"text": long_doc_16k_tokens},

]

# ❌ Bad: Short documents (won't learn long context)

train_data = [

    {"text": short_doc_2k_tokens},

]

# Use datasets like:

# - PG-19 (books, long texts)

# - arXiv papers

# - Long-form conversations

# - GitHub repositories (concatenated files)

4. Avoid Common Pitfalls

# ❌ Bad: Applying position interpolation without fine-tuning

model.config.rope_scaling = {"type": "linear", "factor": 16.0}

# Model will perform poorly without fine-tuning!

# ✅ Good: Fine-tune after scaling

model.config.rope_scaling = {"type": "linear", "factor": 16.0}

fine_tune(model, long_documents, steps=1000)

# ❌ Bad: Too aggressive scaling without data

scale_to_1M_tokens()  # Won't work without massive fine-tuning

# ✅ Good: Incremental scaling

# 8k → 16k → 32k → 64k (fine-tune at each step)

Production Deployment

Inference with Long Context

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load long-context model

model = AutoModelForCausalLM.from_pretrained(

    "togethercomputer/LLaMA-2-7B-32K",  # 32k context

    torch_dtype=torch.float16,

    device_map="auto"

)

tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")

# Process long document

long_text = "..." * 30000  # 30k tokens

inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to('cuda')

# Generate

outputs = model.generate(

    **inputs,

    max_new_tokens=512,

    temperature=0.7,

)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

Memory Optimization

# Use gradient checkpointing for fine-tuning

model.gradient_checkpointing_enable()

# Use Flash Attention 2

model = AutoModelForCausalLM.from_pretrained(

    "meta-llama/Llama-2-7b-hf",

    attn_implementation="flash_attention_2",  # 2-3× faster

    torch_dtype=torch.float16

)

# Use paged attention (vLLM)

from vllm import LLM

llm = LLM(

    model="togethercomputer/LLaMA-2-7B-32K",

    max_model_len=32768,  # 32k context

    gpu_memory_utilization=0.9

)

Resources

See Also

  • references/rope.md - Detailed RoPE implementation and theory
  • references/extension_methods.md - YaRN, ALiBi, Position Interpolation comparisons
  • references/fine_tuning.md - Complete fine-tuning guide for context extension
BrowserAct

Let your agent run on any real-world website

Bypass CAPTCHA & anti-bot for free. Start local, scale to cloud.

Explore BrowserAct Skills →

Stop writing automation&scrapers

Install the CLI. Run your first Skill in 30 seconds. Scale when you're ready.

Start free
free · no credit card