学习目标包括:

主要内容:


动机问题与“纸上估算”摘要:


张量

张量基础

张量是存储一切内容的基本构件,包括参数、梯度、优化器状态、数据与激活。

在 PyTorch 中可通过多种方式创建张量:

未初始化张量常用于后续用特定逻辑赋值,例如:

nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2) 进行截断正态初始化。

张量内存

张量在GPU上使用

content-image-1

张量操作

总体:大多数张量来源于对已有张量的操作,每个操作都有内存与计算代价。

1. 存储(tensor_storage)[PyTorch docs]

content-image-1

python
import torch
 
# 创建一个 4x4 张量
x = torch.tensor([
    [0., 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11],
    [12, 13, 14, 15],
])
 
# 张量的 stride 表示在某一维度上移动一个单位所需跳过的元素数
# dim=0 表示行方向,每下一行需要跳过 4 个元素
assert x.stride(0) == 4
 
# dim=1 表示列方向,每下一列需要跳过 1 个元素
assert x.stride(1) == 1
 
# 查找某个元素在内存中的索引
r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1)  # 行偏移 + 列偏移
assert index == 6
 
print("张量:\n", x)
print("stride:", x.stride())
print(f"元素 x[{r}, {c}] 在存储中的索引: {index}, 值为: {x[r, c].item()}")

2. 切片与视图(tensor_slicing)

python
import torch
 
def same_storage(a: torch.Tensor, b: torch.Tensor) -> bool:
    """判断两个张量是否共享底层存储"""
    return a.storage().data_ptr() == b.storage().data_ptr()
 
# =========================================
# 张量视图 (View) 基础
# =========================================
x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x
 
# 1. 获取第 0 行
y = x[0]  # @inspect y
assert torch.equal(y, torch.tensor([1., 2, 3]))
assert same_storage(x, y)   # 共享存储
 
# 2. 获取第 1 列
y = x[:, 1]  # @inspect y
assert torch.equal(y, torch.tensor([2, 5]))
assert same_storage(x, y)   # 共享存储
 
# 3. 将 2x3 矩阵重新视为 3x2 矩阵
y = x.view(3, 2)  # @inspect y
assert torch.equal(y, torch.tensor([[1, 2], [3, 4], [5, 6]]))
assert same_storage(x, y)
 
# 4. 转置矩阵
y = x.transpose(1, 0)  # @inspect y
assert torch.equal(y, torch.tensor([[1, 4], [2, 5], [3, 6]]))
assert same_storage(x, y)
 
# 5. 修改原张量,视图也会同步改变
x[0][0] = 100  # @inspect x, @inspect y
assert y[0][0] == 100
 
# =========================================
# 非连续存储 (Non-contiguous) 的限制
# =========================================
x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x
y = x.transpose(1, 0)  # @inspect y
assert not y.is_contiguous()
 
# 尝试对非连续张量直接 view 会报错
try:
    y.view(2, 3)
    assert False
except RuntimeError as e:
    assert "view size is not compatible with input tensor's size and stride" in str(e)
 
# 解决方法:先 contiguous() 再 view
y = x.transpose(1, 0).contiguous().view(2, 3)  # @inspect y
assert not same_storage(x, y)  # 复制了存储
 

3. 元素级操作(tensor_elementwise)

python
import torch
 
# =========================================
# 张量的逐元素操作 (Element-wise operations)
# =========================================
x = torch.tensor([1, 4, 9])
 
# 幂运算
assert import torch
 
# =========================================
# 张量的逐元素操作 (Element-wise operations)
# =========================================
x = torch.tensor([1, 4, 9])   # 每个元素平方
 
# 开方
assert torch.equal(x.sqrt(), torch.tensor([1, 2, 3]))     # 每个元素开方
 
# 取倒数开方 (1/sqrt)
assert torch.equal(x.rsqrt(), torch.tensor([1, 1/2, 1/3]))
 
# 加法
assert torch.equal(x + x, torch.tensor([2, 8, 18]))
 
# 乘法
assert torch.equal(x * 2, torch.tensor([2, 8, 18]))
 
# 除法
assert torch.equal(x / 0.5, torch.tensor([2, 8, 18]))
 
# =========================================
# 上三角矩阵 (Upper triangular matrix)
# =========================================
x = torch.ones(3, 3).triu()  # @inspect x
 
expected = torch.tensor([
    [1, 1, 1],
    [0, 1, 1],
    [0, 0, 1],
])
assert torch.equal(x, expected)
 
# 这种上三角掩码常用于 **因果注意力 (causal attention mask)**,
# 其中 M[i, j] 表示位置 i 对位置 j 的贡献是否允许。

4. 矩阵乘法(tensor_matmul)

content-image-1

python
import torch
 
# =========================================
# 深度学习核心操作:矩阵乘法 (Matrix Multiplication)
# =========================================
 
# 单个矩阵乘法示例
x = torch.ones(16, 32)  # 输入矩阵:16 行 32 列
w = torch.ones(32, 2)   # 权重矩阵:32 行 2 列
 
y = x @ w  # 矩阵乘法
assert y.size() == torch.Size([16, 2])
print("y.shape:", y.shape)  # 输出: torch.Size([16, 2])
 
# =========================================
# 批处理和多维矩阵乘法
# =========================================
# 假设有一个 4 维张量,表示 batch x seq_len x dim1 x dim2
x = torch.ones(4, 8, 16, 32)
w = torch.ones(32, 2)
 
# 对最后两个维度执行矩阵乘法,前两个维度自动广播
y = x @ w
assert y.size() == torch.Size([4, 8, 16, 2])
print("y.shape (batch & sequence):", y.shape)
 
# 说明:
# 对于多维张量,矩阵乘法会在前面多余的维度上迭代,
# 类似于对每个 batch 和序列位置分别进行乘法。

Einops张量操作

Einops 是一个用于操作张量的库,其中的维数都是命名的。它的灵感来自爱因斯坦求和符号(爱因斯坦,1916 年)。[Einops tutorial]

Einops 动机:

提供以命名维度操作张量的方法,避免传统 PyTorch 操作中维度易混乱的问题(如 2, -1)。

python
import torch
 
# =========================================
# 批次与序列维度的矩阵乘法 (Batch & Sequence Matrix Multiplication)
# =========================================
 
# 输入张量:batch x sequence x hidden
x = torch.ones(2, 2, 3)  # @inspect x
y = torch.ones(2, 2, 3)  # @inspect y
 
# 对最后两个维度做矩阵乘法
# 注意 y 需要转置最后两个维度 (-2, -1) 才能匹配 x 的最后一维
z = x @ y.transpose(-2, -1)  # 结果 shape: batch x sequence x sequence  @inspect z
 
print("x.shape:", x.shape)  # torch.Size([2, 2, 3])
print("y.shape:", y.shape)  # torch.Size([2, 2, 3])
print("z.shape:", z.shape)  # torch.Size([2, 2, 2])

jaxtyping

为张量维度加注释,便于文档化维度信息,例如:

x: Float[torch.Tensor, "batch seq heads hidden"]

python
import torch
from jaxtyping import Float
 
# =========================================
# 张量维度管理示例 (Tracking Tensor Dimensions)
# =========================================
 
# 传统方式 (Old way)
x = torch.ones(2, 2, 1, 3)  # batch, seq, heads, hidden  @inspect x
print("x.shape (old way):", x.shape)  # torch.Size([2, 2, 1, 3])
 
# 新方式 (jaxtyping 风格,主要用于文档注释)
x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3)  # @inspect x
print("x.shape (jaxtyping):", x.shape)  # torch.Size([2, 2, 1, 3])
 
# 注意:
# - jaxtyping 的注释只是文档说明,并不会强制类型或 shape。
# - 对大型模型或复杂张量计算,使用这种注释可以更清楚地记录维度语义。

einsum(推广矩阵乘法):

python
import torch
from jaxtyping import Float
from torch import einsum
 
# =========================================
# 通用矩阵乘法:Einsum 示例
# =========================================
 
# 定义张量,使用 jaxtyping 注释维度
x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4)  # @inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4)  # @inspect y
 
# -----------------------------------------
# 传统方式:矩阵乘法 + 转置
# -----------------------------------------
# 计算序列间相似度矩阵
z = x @ y.transpose(-2, -1)  # shape: batch x seq1 x seq2  @inspect z
print("z.shape (traditional):", z.shape)  # torch.Size([2, 3, 3])
 
# -----------------------------------------
# 新方式:einsum(使用命名维度)
# -----------------------------------------
# 明确维度对应关系
z_einsum = einsum("b i h, b j h -> b i j", x, y)
print("z_einsum.shape:", z_einsum.shape)  # torch.Size([2, 3, 3])
 
# 使用 ... 表示任意数量的前置维度(广播)
z = einsum("b i h, b j h -> b i j", x, y)
print("z.shape (torch.einsum):", z.shape)  # torch.Size([2, 3, 3])
 
# 说明:
# - einsum 可以自动对未在输出中出现的维度求和。
# - ... 可用于批次或额外维度的广播,简化多维计算。

reduce(降维操作):

python
import torch
from jaxtyping import Float
from einops import reduce
 
# =========================================
# 张量归约操作 (Reduction)
# =========================================
 
# 定义张量并注释维度
x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2, 3, 4)  # @inspect x
print("x.shape:", x.shape)  # torch.Size([2, 3, 4])
 
# -----------------------------------------
# 传统方式:沿最后一维求均值
# -----------------------------------------
y = x.mean(dim=-1)  # @inspect y
print("y.shape (mean along hidden):", y.shape)  # torch.Size([2, 3])
 
# -----------------------------------------
# 新方式:einops reduce
# -----------------------------------------
# 对最后一维 hidden 进行求和
y = reduce(x, "... hidden -> ...", "sum")  # @inspect y
print("y.shape (einops sum):", y.shape)  # torch.Size([2, 3])
 
# 说明:
# - reduce 可以指定任意维度进行归约操作,如 "sum", "mean", "max", "min"。
# - 使用 "..." 可方便表示任意数量的前置维度,简化多维张量操作。

rearrange(重排维度):

python
import torch
from jaxtyping import Float
from einops import rearrange, einsum
 
# =========================================
# 拆分和重组维度 + einsum 操作
# =========================================
 
# 定义张量,total_hidden 表示 heads * hidden1 的展开维度
x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8)  # @inspect x
print("x.shape (original):", x.shape)  # torch.Size([2, 3, 8])
 
# 权重矩阵
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 4)
 
# -----------------------------------------
# 1. 拆分 total_hidden 为 heads 和 hidden1
# -----------------------------------------
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)  # @inspect x
print("x.shape (after split):", x.shape)  # torch.Size([2, 3, 2, 4])
 
# -----------------------------------------
# 2. 对 hidden1 维度进行线性变换
# -----------------------------------------
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")  # @inspect x
print("x.shape (after einsum):", x.shape)  # torch.Size([2, 3, 2, 4])
 
# -----------------------------------------
# 3. 合并 heads 和 hidden2 回 total_hidden
# -----------------------------------------
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")  # @inspect x
print("x.shape (after combine):", x.shape)  # torch.Size([2, 3, 8])
 
# 说明:
# - 这种操作常用于多头注意力或类似结构,将 flattened hidden 维度拆开处理后再合并。
# - einsum + rearrange 可以清晰地处理复杂的多维变换。

tensor_operations_flops

经过对各种张量操作的分析,需要关注它们的计算成本。浮点运算(FLOP)是基本操作,如加法或乘法。存在两个容易混淆的缩写:FLOPs 表示浮点运算的总量,用于衡量完成的计算量;FLOP/s(或 FLOPS)表示每秒浮点运算次数,用于衡量硬件的计算速度。

直觉

线性模型

计算示例

python
import torch
 
# 根据是否有 GPU 调整矩阵大小
if torch.cuda.is_available():
    B = 16384  # 点的数量 (batch size)
    D = 32768  # 特征维度
    K = 8192   # 输出维度
else:
    B = 1024
    D = 256
    K = 64
 
# 获取设备 (CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# 输入张量和权重矩阵
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)
 
# 矩阵乘法
y = x @ w
print("y.shape:", y.shape)  # torch.Size([B, K])
 
# =========================================
# 计算 FLOPs (浮点运算量)
# =========================================
# 对每个 (i, j, k) 三元组有一次乘法和一次加法
flops = B * D * K * 2
print(f"FLOPs for this matmul: {flops:.2e}")
 

其他操作的FLOPs

浮点运算与实际时间

模型浮点运算利用率(MFU)

BFloat16数据类型示例

python
import torch
 
# =========================================
# 使用 bfloat16 进行矩阵乘法并计算 MFU
# =========================================
 
# 假设 x, w 已经定义好并在合适的设备上
# x: torch.Tensor of shape [B, D]
# w: torch.Tensor of shape [D, K]
 
# 转为 bfloat16 以加速计算
x = x.to(torch.bfloat16)
w = w.to(torch.bfloat16)
 
# 测量实际矩阵乘法时间
bf16_actual_time = time_matmul(x, w)  # @inspect bf16_actual_time
print("Actual time (bf16):", bf16_actual_time)
 
# 计算实际 FLOPs/s
bf16_actual_flop_per_sec = actual_num_flops / bf16_actual_time  # @inspect bf16_actual_flop_per_sec
print("Actual FLOPs/s (bf16):", bf16_actual_flop_per_sec)
 
# 获取设备理论峰值 FLOPs/s
bf16_promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype)  # @inspect bf16_promised_flop_per_sec
print("Promised FLOPs/s (bf16):", bf16_promised_flop_per_sec)
 
# 计算最大填充利用率 (MFU)
bf16_mfu = bf16_actual_flop_per_sec / bf16_promised_flop_per_sec
print("MFU (bfloat16):", bf16_mfu)
 
# 说明:
# - bfloat16 可以在保持精度的同时提高吞吐量
# - MFU(Maximum Filling Utilization) 衡量硬件利用效率
 

总结

gradients

梯度计算基础

示例:简单线性模型

python
import torch
 
# =========================================
# 简单线性模型正向和反向传播
# =========================================
 
x = torch.tensor([1., 2., 3])
w = torch.tensor([1., 1, 1], requires_grad=True)
 
# 前向传播
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)
 
# 反向传播
loss.backward()
 
# 检查梯度
assert loss.grad is None
assert pred_y.grad is None
assert x.grad is None
assert torch.equal(w.grad, torch.tensor([1., 2., 3]))
 

gradient flops

前向传播 FLOPs

对于给定的线性模型:x --w1--> h1 --w2--> h2,其前向传播的总FLOPs等于两次矩阵乘法的FLOPs之和。

因此,前向传播总FLOPs = (2 * B * D * D) + (2 * B * D * K)

反向传播 FLOPs

反向传播(梯度计算)的FLOPs是前向传播的两倍

因此,反向传播总FLOPs = (2 * B * D * K) + (2 * B * D * K) + (2 * B * D * D) + (2 * B * D * D),约等于前向传播总FLOPs的两倍

A nice graphical visualization:  [article]

content-image-1

python
import torch
 
# =========================================
# 计算线性模型的前向和反向 FLOPs
# =========================================
 
# 设置矩阵大小
if torch.cuda.is_available():
    B, D, K = 16384, 32768, 8192
else:
    B, D, K = 1024, 256, 64
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# 定义输入和参数
x = torch.ones(B, D, device=device)
w1 = torch.randn(D, D, device=device, requires_grad=True)
w2 = torch.randn(D, K, device=device, requires_grad=True)
 
# 前向传播
h1 = x @ w1
h2 = h1 @ w2
loss = h2.pow(2).mean()
 
# 计算前向 FLOPs
num_forward_flops = (2 * B * D * D) + (2 * B * D * K)  # @inspect num_forward_flops
 
# 保留中间梯度
h1.retain_grad()
h2.retain_grad()
 
# 反向传播
loss.backward()
 
# 计算 w2 相关的反向 FLOPs
num_backward_flops = 0
num_backward_flops += 2 * B * D * K  # w2.grad
num_backward_flops += 2 * B * D * K  # h1.grad
num_backward_flops += (2 + 2) * B * D * D  # w1.grad  # @inspect num_backward_flops
 
# 检查维度
assert w2.grad.size() == torch.Size([D, K])
assert h1.size() == torch.Size([B, D])
assert h2.grad.size() == torch.Size([B, K])
assert h1.grad.size() == torch.Size([B, D])
assert w2.size() == torch.Size([D, K])
 

模型参数

参数初始化

python
import torch
import torch.nn as nn
import numpy as np
 
input_dim, output_dim = 16384, 32
 
# 模型参数是 nn.Parameter
w = nn.Parameter(torch.randn(input_dim, output_dim))
assert isinstance(w, torch.Tensor)
assert isinstance(w.data, torch.Tensor)
 
# 普通初始化 -> 输出随 √input_dim 变大
x = nn.Parameter(torch.randn(input_dim))
out1 = x @ w  # ~ O(√input_dim)
 
# Xavier 初始化 (缩放 1/√input_dim)
w2 = nn.Parameter(torch.randn(input_dim, output_dim) / np.sqrt(input_dim))
out2 = x @ w2  # ~ O(1)
 
# 截断正态分布 (更安全)
w3 = nn.Parameter(
    nn.init.trunc_normal_(
        torch.empty(input_dim, output_dim),
        std=1 / np.sqrt(input_dim),
        a=-3, b=3
    )
)
out3 = x @ w3  # ~ O(1)

Custom model

python
import torch
import torch.nn as nn
import numpy as np
 
# ---- 定义简单线性层 ----
class Linear(nn.Module):
    """Simple linear layer."""
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(input_dim, output_dim) / np.sqrt(input_dim)
        )
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.weight
 
# ---- 定义深度线性模型 ----
class Cruncher(nn.Module):
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            Linear(dim, dim) for _ in range(num_layers)
        ])
        self.final = Linear(dim, 1)
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, D = x.size()
        for layer in self.layers:
            x = layer(x)
        # Final projection
        x = self.final(x)
        assert x.size() == torch.Size([B, 1])
        return x.squeeze(-1)
 
# ---- 工具函数 ----
def get_num_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())
 
def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# ---- custom_model 示例 ----
def custom_model():
    D = 64   # Dimension
    num_layers = 2
    model = Cruncher(dim=D, num_layers=num_layers)
 
    # 检查参数大小
    param_sizes = [(name, param.numel()) for name, param in model.state_dict().items()]
    assert param_sizes == [
        ("layers.0.weight", D * D),
        ("layers.1.weight", D * D),
        ("final.weight", D),
    ]
 
    # 参数总数
    num_parameters = get_num_parameters(model)
    assert num_parameters == (D * D) + (D * D) + D
 
    # 移动到 GPU
    device = get_device()
    model = model.to(device)
 
    # 运行模型
    B = 8  # Batch size
    x = torch.randn(B, D, dev
 

get batch

内存管理与异步传输

[article]

[article]

python
import torch
import numpy as np
 
def get_batch(data: np.ndarray, batch_size: int, sequence_length: int, device: str) -> torch.Tensor:
    """
    Sample a random batch of sequences from data.
    
    Args:
        data: numpy array of data
        batch_size: number of sequences per batch
        sequence_length: length of each sequence
        device: target device ("cpu" or "cuda")
 
    Returns:
        x: torch.Tensor [batch_size, sequence_length]
    """
    # 随机选择起始位置
    start_indices = torch.randint(len(data) - sequence_length, (batch_size,))
    assert start_indices.size() == torch.Size([batch_size])
 
    # 构造 batch
    x = torch.tensor(
        [data[start:start + sequence_length] for start in start_indices],
        dtype=torch.float32
    )
    assert x.size() == torch.Size([batch_size, sequence_length])
 
    # 固定内存 (提高 CPU→GPU 数据拷贝效率)
    if torch.cuda.is_available():
        x = x.pin_memory()
 
    # 传输到设备
    x = x.to(device, non_blocking=True)
 
    return x
 

随机性与可复现性

设置随机种子

为了确保完全的可复现性,应在三个主要库中同时设置随机种子:

数据加载

高效加载大型数据集

SGD(随机梯度下降)

python
class SGD(torch.optim.Optimizer):
    def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
        super(SGD, self).__init__(params, dict(lr=lr))
 
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                grad = p.grad.data
                p.data -= lr * grad
 

特点:

AdaGrad

python
class AdaGrad(torch.optim.Optimizer):
    def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
        super(AdaGrad, self).__init__(params, dict(lr=lr))
 
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                state = self.state[p]
                grad = p.grad.data
                g2 = state.get("g2", torch.zeros_like(grad))
                g2 += torch.square(grad)
                state["g2"] = g2
                p.data -= lr * grad / torch.sqrt(g2 + 1e-5)
 

特点:

优化器家族关系

参考论文:AdaGrad

基本训练流程

python
def train(name: str, get_batch, D: int, num_layers: int, B: int, num_train_steps: int, lr: float):
    model = Cruncher(dim=D, num_layers=num_layers).to(get_device())
    optimizer = SGD(model.parameters(), lr=lr)
    for t in range(num_train_steps):
        x, y = get_batch(B=B)
        pred_y = model(x)
        loss = F.mse_loss(pred_y, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
 

数据生成示例

python
def get_batch(B: int) -> tuple[torch.Tensor, torch.Tensor]:
    D = 16
    true_w = torch.arange(D, dtype=torch.float32, device=get_device())
    x = torch.randn(B, D).to(get_device())
    true_y = x @ true_w
    return x, true_y
 

检查点(Checkpointing)

保存训练状态,避免训练中断导致的数据丢失:

python
model = Cruncher(dim=64, num_layers=3).to(get_device())
optimizer = AdaGrad(model.parameters(), lr=0.01)
 
# 保存
checkpoint = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "model_checkpoint.pt")
 
# 加载
loaded_checkpoint = torch.load("model_checkpoint.pt")
 

混合精度训练(Mixed Precision Training)

PyTorch AMP:官方文档

NVIDIA Transformer Engine 支持 FP8:参考论文


内存与 FLOPs 估算

python
def get_memory_usage(x: torch.Tensor):
    return x.numel() * x.element_size()
 
def get_num_parameters(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters())
 
def get_promised_flop_per_sec(device: str, dtype: torch.dtype) -> float:
    ...
 

示例:


工具函数

python
def same_storage(x: torch.Tensor, y: torch.Tensor):
    return x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr()
 
def time_matmul(a: torch.Tensor, b: torch.Tensor) -> float:
    ...
 
def get_device(index: int = 0) -> torch.device:
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")
 

7. 总结