
主要介绍训练模型所需的基本要素,从张量到模型、再到优化器与训练循环,强调资源效率,尤其是内存(GB)与计算量(FLOPs)的核算。课程不涉及 Transformer,而是通过更简单的模型来讲解。
SERIES · 斯坦福CS336: Language Modeling from Scratch
2025-09-25 · 25 min read · by GUMP

主要介绍训练模型所需的基本要素,从张量到模型、再到优化器与训练循环,强调资源效率,尤其是内存(GB)与计算量(FLOPs)的核算。课程不涉及 Transformer,而是通过更简单的模型来讲解。
学习目标包括:
主要内容:
动机问题与“纸上估算”摘要:
h100_flop_per_sec = 1979e12 / 2,mfu = 0.5。total_flops = 6 * 70e9 * 15e12flops_per_day = h100_flop_per_sec * mfu * 1024 * 60 * 60 * 24days = total_flops / flops_per_dayh100_bytes = 80e9;每参数内存:4(参数)+4(梯度)+(4+4(优化器状态))=16 字节。num_parameters = (h100_bytes * 8) / 16。朴素地用 float32 表示参数与梯度;也可用 bf16(2+2)并保留一份 fp32 参数副本(4),速度更快但不省显存。
未计入激活开销(依赖 batch size 与序列长度)。
以上为粗略估算。
张量是存储一切内容的基本构件,包括参数、梯度、优化器状态、数据与激活。
在 PyTorch 中可通过多种方式创建张量:
torch.tensor([[1., 2, 3], [4, 5, 6]]):直接定义二维张量torch.zeros(4, 8):4×8 全零矩阵torch.ones(4, 8):4×8 全一矩阵torch.randn(4, 8):4×8 标准正态分布随机矩阵torch.empty(4, 8):4×8 未初始化矩阵(可自定义填充值)未初始化张量常用于后续用特定逻辑赋值,例如:
nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2) 进行截断正态初始化。
存储内容:几乎所有数据(参数、梯度、激活、优化器状态)都以浮点数形式存储。
float32(fp32,单精度)
torch.zeros(4,8) → 占用 4*8*4=128 字节。
float16(fp16,半精度)
1e-8 → 0。训练中可能导致不稳定。

bfloat16(bf16)

fp8

训练影响:
torch.zeros(32, 32) → 设备为 cpu。
torch.cuda.is_available()torch.cuda.device_count()torch.cuda.get_device_properties(i)torch.cuda.memory_allocated()y = x.to("cuda:0") → 张量移至 0 号 GPU。z = torch.zeros(32, 32, device="cuda:0")memory_used = 2 * (32 * 32 * 4) → 8192 字节。总体:大多数张量来源于对已有张量的操作,每个操作都有内存与计算代价。
x[1,2] 对应存储索引 1*4 + 2*1 = 6。
import torch
# 创建一个 4x4 张量
x = torch.tensor([
[0., 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
])
# 张量的 stride 表示在某一维度上移动一个单位所需跳过的元素数
# dim=0 表示行方向,每下一行需要跳过 4 个元素
assert x.stride(0) == 4
# dim=1 表示列方向,每下一列需要跳过 1 个元素
assert x.stride(1) == 1
# 查找某个元素在内存中的索引
r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1) # 行偏移 + 列偏移
assert index == 6
print("张量:\n", x)
print("stride:", x.stride())
print(f"元素 x[{r}, {c}] 在存储中的索引: {index}, 值为: {x[r, c].item()}").view())会报错,需要先 .contiguous()。import torch
def same_storage(a: torch.Tensor, b: torch.Tensor) -> bool:
"""判断两个张量是否共享底层存储"""
return a.storage().data_ptr() == b.storage().data_ptr()
# =========================================
# 张量视图 (View) 基础
# =========================================
x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # @inspect x
# 1. 获取第 0 行
y = x[0] # @inspect y
assert torch.equal(y, torch.tensor([1., 2, 3]))
assert same_storage(x, y) # 共享存储
# 2. 获取第 1 列
y = x[:, 1] # @inspect y
assert torch.equal(y, torch.tensor([2, 5]))
assert same_storage(x, y) # 共享存储
# 3. 将 2x3 矩阵重新视为 3x2 矩阵
y = x.view(3, 2) # @inspect y
assert torch.equal(y, torch.tensor([[1, 2], [3, 4], [5, 6]]))
assert same_storage(x, y)
# 4. 转置矩阵
y = x.transpose(1, 0) # @inspect y
assert torch.equal(y, torch.tensor([[1, 4], [2, 5], [3, 6]]))
assert same_storage(x, y)
# 5. 修改原张量,视图也会同步改变
x[0][0] = 100 # @inspect x, @inspect y
assert y[0][0] == 100
# =========================================
# 非连续存储 (Non-contiguous) 的限制
# =========================================
x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # @inspect x
y = x.transpose(1, 0) # @inspect y
assert not y.is_contiguous()
# 尝试对非连续张量直接 view 会报错
try:
y.view(2, 3)
assert False
except RuntimeError as e:
assert "view size is not compatible with input tensor's size and stride" in str(e)
# 解决方法:先 contiguous() 再 view
y = x.transpose(1, 0).contiguous().view(2, 3) # @inspect y
assert not same_storage(x, y) # 复制了存储
pow、sqrt、+、、/ 等。triu 可取上三角矩阵,常用于因果注意力掩码。import torch
# =========================================
# 张量的逐元素操作 (Element-wise operations)
# =========================================
x = torch.tensor([1, 4, 9])
# 幂运算
assert import torch
# =========================================
# 张量的逐元素操作 (Element-wise operations)
# =========================================
x = torch.tensor([1, 4, 9]) # 每个元素平方
# 开方
assert torch.equal(x.sqrt(), torch.tensor([1, 2, 3])) # 每个元素开方
# 取倒数开方 (1/sqrt)
assert torch.equal(x.rsqrt(), torch.tensor([1, 1/2, 1/3]))
# 加法
assert torch.equal(x + x, torch.tensor([2, 8, 18]))
# 乘法
assert torch.equal(x * 2, torch.tensor([2, 8, 18]))
# 除法
assert torch.equal(x / 0.5, torch.tensor([2, 8, 18]))
# =========================================
# 上三角矩阵 (Upper triangular matrix)
# =========================================
x = torch.ones(3, 3).triu() # @inspect x
expected = torch.tensor([
[1, 1, 1],
[0, 1, 1],
[0, 0, 1],
])
assert torch.equal(x, expected)
# 这种上三角掩码常用于 **因果注意力 (causal attention mask)**,
# 其中 M[i, j] 表示位置 i 对位置 j 的贡献是否允许。(16,32) @ (32,2) → (16,2)。(4,8,16,32) @ (32,2) → (4,8,16,2)。
import torch
# =========================================
# 深度学习核心操作:矩阵乘法 (Matrix Multiplication)
# =========================================
# 单个矩阵乘法示例
x = torch.ones(16, 32) # 输入矩阵:16 行 32 列
w = torch.ones(32, 2) # 权重矩阵:32 行 2 列
y = x @ w # 矩阵乘法
assert y.size() == torch.Size([16, 2])
print("y.shape:", y.shape) # 输出: torch.Size([16, 2])
# =========================================
# 批处理和多维矩阵乘法
# =========================================
# 假设有一个 4 维张量,表示 batch x seq_len x dim1 x dim2
x = torch.ones(4, 8, 16, 32)
w = torch.ones(32, 2)
# 对最后两个维度执行矩阵乘法,前两个维度自动广播
y = x @ w
assert y.size() == torch.Size([4, 8, 16, 2])
print("y.shape (batch & sequence):", y.shape)
# 说明:
# 对于多维张量,矩阵乘法会在前面多余的维度上迭代,
# 类似于对每个 batch 和序列位置分别进行乘法。Einops 是一个用于操作张量的库,其中的维数都是命名的。它的灵感来自爱因斯坦求和符号(爱因斯坦,1916 年)。[Einops tutorial]
提供以命名维度操作张量的方法,避免传统 PyTorch 操作中维度易混乱的问题(如 2, -1)。
import torch
# =========================================
# 批次与序列维度的矩阵乘法 (Batch & Sequence Matrix Multiplication)
# =========================================
# 输入张量:batch x sequence x hidden
x = torch.ones(2, 2, 3) # @inspect x
y = torch.ones(2, 2, 3) # @inspect y
# 对最后两个维度做矩阵乘法
# 注意 y 需要转置最后两个维度 (-2, -1) 才能匹配 x 的最后一维
z = x @ y.transpose(-2, -1) # 结果 shape: batch x sequence x sequence @inspect z
print("x.shape:", x.shape) # torch.Size([2, 2, 3])
print("y.shape:", y.shape) # torch.Size([2, 2, 3])
print("z.shape:", z.shape) # torch.Size([2, 2, 2])为张量维度加注释,便于文档化维度信息,例如:
x: Float[torch.Tensor, "batch seq heads hidden"]
import torch
from jaxtyping import Float
# =========================================
# 张量维度管理示例 (Tracking Tensor Dimensions)
# =========================================
# 传统方式 (Old way)
x = torch.ones(2, 2, 1, 3) # batch, seq, heads, hidden @inspect x
print("x.shape (old way):", x.shape) # torch.Size([2, 2, 1, 3])
# 新方式 (jaxtyping 风格,主要用于文档注释)
x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3) # @inspect x
print("x.shape (jaxtyping):", x.shape) # torch.Size([2, 2, 1, 3])
# 注意:
# - jaxtyping 的注释只是文档说明,并不会强制类型或 shape。
# - 对大型模型或复杂张量计算,使用这种注释可以更清楚地记录维度语义。... 表示任意数量维度的广播。import torch
from jaxtyping import Float
from torch import einsum
# =========================================
# 通用矩阵乘法:Einsum 示例
# =========================================
# 定义张量,使用 jaxtyping 注释维度
x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4) # @inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4) # @inspect y
# -----------------------------------------
# 传统方式:矩阵乘法 + 转置
# -----------------------------------------
# 计算序列间相似度矩阵
z = x @ y.transpose(-2, -1) # shape: batch x seq1 x seq2 @inspect z
print("z.shape (traditional):", z.shape) # torch.Size([2, 3, 3])
# -----------------------------------------
# 新方式:einsum(使用命名维度)
# -----------------------------------------
# 明确维度对应关系
z_einsum = einsum("b i h, b j h -> b i j", x, y)
print("z_einsum.shape:", z_einsum.shape) # torch.Size([2, 3, 3])
# 使用 ... 表示任意数量的前置维度(广播)
z = einsum("b i h, b j h -> b i j", x, y)
print("z.shape (torch.einsum):", z.shape) # torch.Size([2, 3, 3])
# 说明:
# - einsum 可以自动对未在输出中出现的维度求和。
# - ... 可用于批次或额外维度的广播,简化多维计算。reduce(x, "... hidden -> ...", "sum")import torch
from jaxtyping import Float
from einops import reduce
# =========================================
# 张量归约操作 (Reduction)
# =========================================
# 定义张量并注释维度
x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2, 3, 4) # @inspect x
print("x.shape:", x.shape) # torch.Size([2, 3, 4])
# -----------------------------------------
# 传统方式:沿最后一维求均值
# -----------------------------------------
y = x.mean(dim=-1) # @inspect y
print("y.shape (mean along hidden):", y.shape) # torch.Size([2, 3])
# -----------------------------------------
# 新方式:einops reduce
# -----------------------------------------
# 对最后一维 hidden 进行求和
y = reduce(x, "... hidden -> ...", "sum") # @inspect y
print("y.shape (einops sum):", y.shape) # torch.Size([2, 3])
# 说明:
# - reduce 可以指定任意维度进行归约操作,如 "sum", "mean", "max", "min"。
# - 使用 "..." 可方便表示任意数量的前置维度,简化多维张量操作。rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)rearrange(x, "... heads hidden2 -> ... (heads hidden2)")einsum 结合进行复杂变换。import torch
from jaxtyping import Float
from einops import rearrange, einsum
# =========================================
# 拆分和重组维度 + einsum 操作
# =========================================
# 定义张量,total_hidden 表示 heads * hidden1 的展开维度
x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8) # @inspect x
print("x.shape (original):", x.shape) # torch.Size([2, 3, 8])
# 权重矩阵
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 4)
# -----------------------------------------
# 1. 拆分 total_hidden 为 heads 和 hidden1
# -----------------------------------------
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2) # @inspect x
print("x.shape (after split):", x.shape) # torch.Size([2, 3, 2, 4])
# -----------------------------------------
# 2. 对 hidden1 维度进行线性变换
# -----------------------------------------
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2") # @inspect x
print("x.shape (after einsum):", x.shape) # torch.Size([2, 3, 2, 4])
# -----------------------------------------
# 3. 合并 heads 和 hidden2 回 total_hidden
# -----------------------------------------
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)") # @inspect x
print("x.shape (after combine):", x.shape) # torch.Size([2, 3, 8])
# 说明:
# - 这种操作常用于多头注意力或类似结构,将 flattened hidden 维度拆开处理后再合并。
# - einsum + rearrange 可以清晰地处理复杂的多维变换。经过对各种张量操作的分析,需要关注它们的计算成本。浮点运算(FLOP)是基本操作,如加法或乘法。存在两个容易混淆的缩写:FLOPs 表示浮点运算的总量,用于衡量完成的计算量;FLOP/s(或 FLOPS)表示每秒浮点运算次数,用于衡量硬件的计算速度。
训练规模:
GPU 性能:
示例估算:
total_flops = 8 * 60*60*24*7 * h100_flop_per_sec计算示例
y = x @ w。
x 的维度为 (B, D)。w 的维度为 (D, K)。y 的维度为 (B, K)。x @ w 涉及一次乘法和一次加法。import torch
# 根据是否有 GPU 调整矩阵大小
if torch.cuda.is_available():
B = 16384 # 点的数量 (batch size)
D = 32768 # 特征维度
K = 8192 # 输出维度
else:
B = 1024
D = 256
K = 64
# 获取设备 (CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 输入张量和权重矩阵
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)
# 矩阵乘法
y = x @ w
print("y.shape:", y.shape) # torch.Size([B, K])
# =========================================
# 计算 FLOPs (浮点运算量)
# =========================================
# 对每个 (i, j, k) 三元组有一次乘法和一次加法
flops = B * D * K * 2
print(f"FLOPs for this matmul: {flops:.2e}")
浮点运算与实际时间
理论与实际:线性模型前向传播的FLOPs可以概括为**2×(tokens)×(parameters),**这一规律在Transformer模型中也大致适用。
计算实际时间:
actual_time:通过对矩阵乘法进行计时获得实际耗时(以秒为单位)。actual_flop_per_sec:用总的FLOPs除以实际耗时,得到每秒的实际浮点运算次数。 actual_time = time_matmul(x, w) # @inspect actual_time
actual_flop_per_sec = actual_num_flops / actual_time # @inspect actual_flop_per_sec峰值性能:每个GPU都有一个规格表,会报告其峰值性能(如A100和H100)**[spec] [spec]**。**每秒浮点运算数(FLOP/s)**会根据使用的数据类型(如FP32、FP16)有很大差异。
actual_flop_per_sec / promised_flop_per_secBFloat16数据类型示例
float32切换到bfloat16后,实际浮点运算速度(bf16_actual_flop_per_sec)会变高。import torch
# =========================================
# 使用 bfloat16 进行矩阵乘法并计算 MFU
# =========================================
# 假设 x, w 已经定义好并在合适的设备上
# x: torch.Tensor of shape [B, D]
# w: torch.Tensor of shape [D, K]
# 转为 bfloat16 以加速计算
x = x.to(torch.bfloat16)
w = w.to(torch.bfloat16)
# 测量实际矩阵乘法时间
bf16_actual_time = time_matmul(x, w) # @inspect bf16_actual_time
print("Actual time (bf16):", bf16_actual_time)
# 计算实际 FLOPs/s
bf16_actual_flop_per_sec = actual_num_flops / bf16_actual_time # @inspect bf16_actual_flop_per_sec
print("Actual FLOPs/s (bf16):", bf16_actual_flop_per_sec)
# 获取设备理论峰值 FLOPs/s
bf16_promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype) # @inspect bf16_promised_flop_per_sec
print("Promised FLOPs/s (bf16):", bf16_promised_flop_per_sec)
# 计算最大填充利用率 (MFU)
bf16_mfu = bf16_actual_flop_per_sec / bf16_promised_flop_per_sec
print("MFU (bfloat16):", bf16_mfu)
# 说明:
# - bfloat16 可以在保持精度的同时提高吞吐量
# - MFU(Maximum Filling Utilization) 衡量硬件利用效率
x:输入张量 [1., 2, 3]。w:参数张量 [1., 1, 1],设置 requires_grad=True 以便计算梯度。pred_y:x @ w 得到 1*1 + 2*1 + 3*1 = 6。loss:0.5 * (6 - 5)^2 = 0.5。loss.backward() 执行反向传播,自动计算梯度。w 的梯度为**[1, 2, 3]*。loss、pred_y 和 x 等没有设置 requires_grad=True 的张量,其梯度为 None。import torch
# =========================================
# 简单线性模型正向和反向传播
# =========================================
x = torch.tensor([1., 2., 3])
w = torch.tensor([1., 1, 1], requires_grad=True)
# 前向传播
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)
# 反向传播
loss.backward()
# 检查梯度
assert loss.grad is None
assert pred_y.grad is None
assert x.grad is None
assert torch.equal(w.grad, torch.tensor([1., 2., 3]))
对于给定的线性模型:x --w1--> h1 --w2--> h2,其前向传播的总FLOPs等于两次矩阵乘法的FLOPs之和。
x @ w1:矩阵维度为 (B, D) 和 (D, D)。FLOPs为 2 * B * D * D。h1 @ w2:矩阵维度为 (B, D) 和 (D, K)。FLOPs为 2 * B * D * K。因此,前向传播总FLOPs = (2 * B * D * D) + (2 * B * D * K)。
反向传播(梯度计算)的FLOPs是前向传播的两倍。
w2 的梯度:w2.grad 的计算涉及 h1 和 h2.grad 的矩阵乘法,这与前向传播中 h1 @ w2 的计算量类似。FLOPs约为 2 * B * D * K。h1 的梯度:h1.grad 的计算涉及 h2.grad 和 w2 的矩阵乘法,FLOPs约为 2 * B * D * K。w1 和 x 的梯度:这一步的计算量类似。其中,w1.grad 涉及 x 和 h1.grad 的矩阵乘法,FLOPs约为 2 * B * D * D;x.grad 涉及 h1.grad 和 w1 的矩阵乘法,FLOPs约为 2 * B * D * D。因此,反向传播总FLOPs = (2 * B * D * K) + (2 * B * D * K) + (2 * B * D * D) + (2 * B * D * D),约等于前向传播总FLOPs的两倍。
2 * (数据点数) * (参数数量)*。4 * (数据点数) * (参数数量)*。6 * (数据点数) * (参数数量)*。A nice graphical visualization: [article]

import torch
# =========================================
# 计算线性模型的前向和反向 FLOPs
# =========================================
# 设置矩阵大小
if torch.cuda.is_available():
B, D, K = 16384, 32768, 8192
else:
B, D, K = 1024, 256, 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义输入和参数
x = torch.ones(B, D, device=device)
w1 = torch.randn(D, D, device=device, requires_grad=True)
w2 = torch.randn(D, K, device=device, requires_grad=True)
# 前向传播
h1 = x @ w1
h2 = h1 @ w2
loss = h2.pow(2).mean()
# 计算前向 FLOPs
num_forward_flops = (2 * B * D * D) + (2 * B * D * K) # @inspect num_forward_flops
# 保留中间梯度
h1.retain_grad()
h2.retain_grad()
# 反向传播
loss.backward()
# 计算 w2 相关的反向 FLOPs
num_backward_flops = 0
num_backward_flops += 2 * B * D * K # w2.grad
num_backward_flops += 2 * B * D * K # h1.grad
num_backward_flops += (2 + 2) * B * D * D # w1.grad # @inspect num_backward_flops
# 检查维度
assert w2.grad.size() == torch.Size([D, K])
assert h1.size() == torch.Size([B, D])
assert h2.grad.size() == torch.Size([B, K])
assert h1.grad.size() == torch.Size([B, D])
assert w2.size() == torch.Size([D, K])
nn.Parameter对象的形式存储,它本质上是一种特殊的张量(torch.Tensor),可以通过 .data 属性访问底层张量。input_dim)的增加而不成比例地增大,其增长速率为\sqrt{\text{input_dim}}。这可能导致梯度爆炸,使模型训练变得不稳定。[-3, 3]的范围内。import torch
import torch.nn as nn
import numpy as np
input_dim, output_dim = 16384, 32
# 模型参数是 nn.Parameter
w = nn.Parameter(torch.randn(input_dim, output_dim))
assert isinstance(w, torch.Tensor)
assert isinstance(w.data, torch.Tensor)
# 普通初始化 -> 输出随 √input_dim 变大
x = nn.Parameter(torch.randn(input_dim))
out1 = x @ w # ~ O(√input_dim)
# Xavier 初始化 (缩放 1/√input_dim)
w2 = nn.Parameter(torch.randn(input_dim, output_dim) / np.sqrt(input_dim))
out2 = x @ w2 # ~ O(1)
# 截断正态分布 (更安全)
w3 = nn.Parameter(
nn.init.trunc_normal_(
torch.empty(input_dim, output_dim),
std=1 / np.sqrt(input_dim),
a=-3, b=3
)
)
out3 = x @ w3 # ~ O(1)import torch
import torch.nn as nn
import numpy as np
# ---- 定义简单线性层 ----
class Linear(nn.Module):
"""Simple linear layer."""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.weight = nn.Parameter(
torch.randn(input_dim, output_dim) / np.sqrt(input_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.weight
# ---- 定义深度线性模型 ----
class Cruncher(nn.Module):
def __init__(self, dim: int, num_layers: int):
super().__init__()
self.layers = nn.ModuleList([
Linear(dim, dim) for _ in range(num_layers)
])
self.final = Linear(dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, D = x.size()
for layer in self.layers:
x = layer(x)
# Final projection
x = self.final(x)
assert x.size() == torch.Size([B, 1])
return x.squeeze(-1)
# ---- 工具函数 ----
def get_num_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def get_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- custom_model 示例 ----
def custom_model():
D = 64 # Dimension
num_layers = 2
model = Cruncher(dim=D, num_layers=num_layers)
# 检查参数大小
param_sizes = [(name, param.numel()) for name, param in model.state_dict().items()]
assert param_sizes == [
("layers.0.weight", D * D),
("layers.1.weight", D * D),
("final.weight", D),
]
# 参数总数
num_parameters = get_num_parameters(model)
assert num_parameters == (D * D) + (D * D) + D
# 移动到 GPU
device = get_device()
model = model.to(device)
# 运行模型
B = 8 # Batch size
x = torch.randn(B, D, dev
data 中,随机采样出batch_size个序列,每个序列的长度为sequence_length。torch.randint 随机生成 batch_size 个起始索引 start_indices,确保每个索引都可以在数据范围内截取一个完整的序列。start_indices 索引到 data 中,构建一个大小为 [batch_size, sequence_length] 的输入张量 x。.pin_memory(),可以将张量显式地放入固定内存中。non_blocking=True。import torch
import numpy as np
def get_batch(data: np.ndarray, batch_size: int, sequence_length: int, device: str) -> torch.Tensor:
"""
Sample a random batch of sequences from data.
Args:
data: numpy array of data
batch_size: number of sequences per batch
sequence_length: length of each sequence
device: target device ("cpu" or "cuda")
Returns:
x: torch.Tensor [batch_size, sequence_length]
"""
# 随机选择起始位置
start_indices = torch.randint(len(data) - sequence_length, (batch_size,))
assert start_indices.size() == torch.Size([batch_size])
# 构造 batch
x = torch.tensor(
[data[start:start + sequence_length] for start in start_indices],
dtype=torch.float32
)
assert x.size() == torch.Size([batch_size, sequence_length])
# 固定内存 (提高 CPU→GPU 数据拷贝效率)
if torch.cuda.is_available():
x = x.pin_memory()
# 传输到设备
x = x.to(device, non_blocking=True)
return x
为了确保完全的可复现性,应在三个主要库中同时设置随机种子:
torch.manual_seed(seed)。np.random.seed(seed)。random.seed(seed)。numpy数组,并以.npy文件格式存储,以便于加载。np.memmap:numpy的内存映射功能允许**延迟加载(lazily load)**数据。这意味着只有在访问数据文件的特定部分时,才会将其加载到内存中,从而节省了大量的RAM。B)和固定长度(L)的序列,形成一个大小为[B, L]的张量。class SGD(torch.optim.Optimizer):
def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
super(SGD, self).__init__(params, dict(lr=lr))
def step(self):
for group in self.param_groups:
lr = group["lr"]
for p in group["params"]:
grad = p.grad.data
p.data -= lr * grad
特点:
p = p - lr * gradclass AdaGrad(torch.optim.Optimizer):
def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
super(AdaGrad, self).__init__(params, dict(lr=lr))
def step(self):
for group in self.param_groups:
lr = group["lr"]
for p in group["params"]:
state = self.state[p]
grad = p.grad.data
g2 = state.get("g2", torch.zeros_like(grad))
g2 += torch.square(grad)
state["g2"] = g2
p.data -= lr * grad / torch.sqrt(g2 + 1e-5)
特点:
参考论文:AdaGrad
def train(name: str, get_batch, D: int, num_layers: int, B: int, num_train_steps: int, lr: float):
model = Cruncher(dim=D, num_layers=num_layers).to(get_device())
optimizer = SGD(model.parameters(), lr=lr)
for t in range(num_train_steps):
x, y = get_batch(B=B)
pred_y = model(x)
loss = F.mse_loss(pred_y, y)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
def get_batch(B: int) -> tuple[torch.Tensor, torch.Tensor]:
D = 16
true_w = torch.arange(D, dtype=torch.float32, device=get_device())
x = torch.randn(B, D).to(get_device())
true_y = x @ true_w
return x, true_y
保存训练状态,避免训练中断导致的数据丢失:
model = Cruncher(dim=64, num_layers=3).to(get_device())
optimizer = AdaGrad(model.parameters(), lr=0.01)
# 保存
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "model_checkpoint.pt")
# 加载
loaded_checkpoint = torch.load("model_checkpoint.pt")
PyTorch AMP:官方文档
NVIDIA Transformer Engine 支持 FP8:参考论文
def get_memory_usage(x: torch.Tensor):
return x.numel() * x.element_size()
def get_num_parameters(model: nn.Module) -> int:
return sum(param.numel() for param in model.parameters())
def get_promised_flop_per_sec(device: str, dtype: torch.dtype) -> float:
...
示例:
num_parameters = D*D*num_layers + Dnum_activations = B * D * num_layersnum_gradients = num_parametersnum_optimizer_states = num_parameterstotal_memory = 4 * (num_parameters + num_activations + num_gradients + num_optimizer_states)flops = 6 * B * num_parametersdef same_storage(x: torch.Tensor, y: torch.Tensor):
return x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr()
def time_matmul(a: torch.Tensor, b: torch.Tensor) -> float:
...
def get_device(index: int = 0) -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{index}")
else:
return torch.device("cpu")