• Home
  • Blog
  • Reading
  • Travel
  • Projects

SERIES · 斯坦福CS336: Language Modeling from Scratch

Stanford CS336: assignment 1

2025-11-22 · 60 min read · by GUMP

Stanford CS336: assignment 1

作业一官方仓库:https://github.com/stanford-cs336/assignment1-basics

Next in Series

Stanford CS336 Lang. Modeling from Scratch | Spring 2025 | Lec. 3: Architectures, Hyperparameters -><- Stanford CS336: lecture 2 Pytorch, Resource Accounting

This Series

  1. 01Stanford CS336: lecture 1 Overview, tokenization
  2. 02Stanford CS336: lecture 2 Pytorch, Resource Accounting
  3. 03Stanford CS336: assignment 1
  4. 04Stanford CS336 Lang. Modeling from Scratch | Spring 2025 | Lec. 3: Architectures, Hyperparameters

Tags

#large-language-model

On This Page

  1. assignment 1
  2. 字节对编码(BPE)分词器
  3. 2.1 Unicode standard
  4. Problem (unicode1): Understanding Unicode (1 point)
  5. 2.2 Unicode Encodings
  6. Problem (unicode2): Unicode Encodings
  7. 2.3 Subword Tokenization
  8. 2.4 BPE Tokenizer Training
  9. Example (bpe_example): BPE training example
  10. 2.5 Experimenting with BPE Tokenizer Training
  11. 并行化预分词处理
  12. 预分词前移除特殊标记
  13. 优化合并步骤
  14. 实施要点
  15. Problem (train_bpe): BPE Tokenizer Training
  16. Problem (train_bpe_tinystories): BPE Training on TinyStories
GUMP'S WORLD

A personal archive of writing, reading notes, travel journals, and software projects.

  • Home
  • Blog
  • Reading
  • Travel
  • Projects
GitHub ProfileEmail Me

© 2026 GUMP's World. All rights reserved.

assignment 1

低资源/降级技巧:性能分析 应使⽤如 cProfile 或 scalene 等性能分析⼯具来识别实现中的瓶颈,并重点优化这些部分。

字节对编码(BPE)分词器

2.1 Unicode standard

Problem (unicode1): Understanding Unicode (1 point)

  1. What Unicode character does chr(0) return?

    1. '\x00’ ,返回 Unicode 码点 U+0000,对应NULL空字符。这个字符在Unicode标准中被称为"NULL"或"NUL",是一个控制字符。
  2. How does this character’s string representation (repr()) differ from its printed representation?

    1. __repr__() / repr()表示:显示为 '\x00',这是一个可读的转义序列表示,明确显示这是一个十六进制值为00的字符。
    2. 打印表示(print):NULL字符是一个不可见的控制字符,所以当你打印它时,通常不会显示任何可见内容,或者可能显示为一个空白。

    这种区别的存在是因为:

    • repr() 的目的是提供一个"开发者友好"的字符串表示,能够明确显示字符的实际值
    • print() 的目的是显示字符的实际外观,但NULL字符本身就是不可见的
    python
    # 演示 chr(0) 的不同表示方式
     
    null_char = chr(0)
     
    print("字符本身:", null_char)
    print("字符串表示 (__repr__):", repr(null_char))
    print("字符的Unicode编码点:", ord(null_char))
    print("使用print()打印:", end="")
    print(null_char)
    print("(上面是打印的结果)")
     
    # 更清楚地看到区别
    print("\n比较:")
    print(f"repr(chr(0)) = {repr(null_char)}")
    print(f"print(chr(0))打印的是一个不可见字符")
    print(f"字符串长度: {len(null_char)}")
     
    # 在字符串中的表现
    test_string = "Hello" + chr(0) + "World"
    print(f"\n包含null字符的字符串:")
    print(f"repr: {repr(test_string)}")
    print(f"print: {test_string}")
    print(f"长度: {len(test_string)}")
    python
    字符本身:
    字符串表示 (__repr__): '\x00'
    字符的Unicode编码点: 0
    使用print()打印:
    (上面是打印的结果)
     
    比较:
    repr(chr(0)) = '\x00'
    print(chr(0))打印的是一个不可见字符
    字符串长度: 1
     
    包含null字符的字符串:
    repr: 'Hello\x00World'
    print: HelloWorld
    长度: 11
  3. 当该字符出现在文本中时会发生什么?

    1. chr(0)

      返回 NULL 字符,即 '\x00'。它在屏幕上不可见,但仍然是字符串中的一个有效字符。

    2. print(chr(0))

      打印 NULL 字符时,因为它没有可显示的符号,看起来是空白的。

    3. "this is a test" + chr(0) + "string"

      生成一个字符串,其中 "test" 和 "string" 之间包含一个 NULL 字符。

      它的 repr 表示为:

      python
      'this is a test\x00string'
    4. print("this is a test" + chr(0) + "string")

      打印出来时,NULL 字符依然不可见,所以屏幕上看起来是:

      this is a teststring
      

      但在内存中,'\x00' 仍然存在。chr(0) 会在文本中插入一个隐藏的 NULL 字符。Python 中字符串不会报错,但如果把这个字符串写入文件或传给某些 C 语言库,'\x00' 可能会被当作字符串结束符,导致后面的内容被截断。

2.2 Unicode Encodings

Unicode 与编码

  • Unicode 标准定义了 字符到码点(整数)的映射。
  • 直接用 Unicode 码点训练 tokenizer 不现实:
    • 词表太大(~150K)
    • 稀疏(很多字符很少出现)

解决方法:Unicode 编码

  • 将 Unicode 字符转换为 字节序列。
  • 标准三种编码:
    • UTF-8(互联网主流,占比 >98%)
    • UTF-16
    • UTF-32

Python 示例:UTF-8 编码

python
test_string = "hello! こんにちは!"
 
# 字符串编码为 UTF-8
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)
# b'hello! \\xe3\\x81\\x93\\xe3\\x82\\x93\\xe3\\x81\\xab\\xe3\\x81\\xa1\\xe3\\x81\\xaf!'
 
# 类型:bytes
print(type(utf8_encoded))  # <class 'bytes'>
 
# 获取底层字节值(范围 0~255)
list(utf8_encoded)
# [104, 101, 108, 108, 111, 33, 32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 33]
 
# 字符数 vs 字节数
len(test_string)      # 13
len(utf8_encoded)     # 23
 
# 解码回 Unicode 字符串
print(utf8_encoded.decode("utf-8"))  # "hello! こんにちは!"

关键现象

  • bytes 类型可通过 list() 查看每个字节的整数值(0–255)。
  • 一个 Unicode 字符 ≠ 一个字节(UTF-8 可变长编码)。
  • 示例:
    • len("hello! こんにちは!") = 13 (13 个字符)
    • len(utf8_encoded) = 23 (23 个字节)

意义

  • 将 Unicode 码点序列(0–154,997) 转换为 字节序列(0–255)。
  • 字节词表大小仅 256,更易处理。
  • 字节级分词优势:
    • 不会有 OOV(out-of-vocabulary)问题。
    • 任意文本都能转化为字节序列。

Problem (unicode2): Unicode Encodings

  1. 为什么我们更倾向于使用UTF-8编码的字节来训练我们的分词器,而不是UTF-16或UTF-32?比较这些编码对不同输入字符串的输出结果可能会有所帮助。

    1. 对常见文本更紧凑。UTF-8:ASCII 字符只需 1 个字节。UTF-16:至少 2 个字节,即使是英文字母、数字、空格。UTF-32:永远 4 个字节。因为大多数训练数据中包含大量 ASCII(标点、空格、英语),所以 UTF-8 更省空间。
    2. **互联网标准。**UTF-8 是互联网的主流编码(>98% 网页使用)。数据管道、API、文件几乎都默认 UTF-8。直接用 UTF-8 可以避免不必要的转换。
    3. **词表大小一样,但 UTF-8 序列更短。**不管 UTF-8、UTF-16 还是 UTF-32,底层都是 0–255 的字节。区别在于:UTF-8 需要的字节数更少(尤其对混合文本)。结果:文本 token 数更少。训练序列更短。训练推理速度更快。
    4. **没有字节序 (endianness) 问题。**UTF-16 / UTF-32 可能是 大端 (BE) 或 小端 (LE),还可能需要字节顺序标记 (BOM)。UTF-8 没有这个问题,格式唯一,处理更简单。
  2. 考虑以下(错误的)函数,其设计⽬的是将 UTF-8 字节字符串解码为 Unicode 字符串。为何该函数是错误的?请提供⼀个会产⽣错误结果的输⼊字节字符串⽰例。

    python
    def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    	return "".join([bytes([b]).decode("utf-8") for b in bytestring])
    >>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
    'hello'
    python
    >>> decode_utf8_bytes_to_str_wrong("你好".encode("utf-8"))
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong
      File "<stdin>", line 2, in <listcomp>
    UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

    原因: 该函数逐字节调用 .decode("utf-8"),而 UTF-8 中的多字节字符(如中文)需要按完整字节序列解码,逐字节解码会导致乱码。

    python
    # 错误的函数
    def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
        return "".join([bytes([b]).decode("utf-8") for b in bytestring])
     
    # 演示问题的示例
    test_string = "你好"  # 包含中文字符
    byte_string = test_string.encode("utf-8")
     
    print(f"原始字符串: {test_string}")
    print(f"UTF-8 字节: {byte_string}")
    print(f"字节列表: {list(byte_string)}")
     
    # 尝试错误的函数 - 这会引发错误
    try:
        result = decode_utf8_bytes_to_str_wrong(byte_string)
        print(f"错误函数结果: {result}")
    except UnicodeDecodeError as e:
        print(f"错误: {e}")
     
    # 正确的解码方式
    correct_result = byte_string.decode("utf-8")
    print(f"正确结果: {correct_result}")
     
    # 解释为什么会失败:中文字符在UTF-8中需要3个字节
    print(f"\n字符'你'的UTF-8编码: {list('你'.encode('utf-8'))}")
    print(f"字符'好'的UTF-8编码: {list('好'.encode('utf-8'))}")
    print("单个字节无法作为有效的UTF-8字符解码")
  3. 给出⼀个两字节序列,该序列⽆法解码为任何 Unicode 字符。

    1. 示例: b'\xc0\x80'。解释: 这个字节序列是一个无效的UTF-8序列,因为它是字符U+0000(null字符)的"过长编码"形式——UTF-8标准禁止使用多字节序列来编码本可以用更短序列表示的字符,以防止安全漏洞和编码歧义。
    python
    # 无效的UTF-8字节序列示例
    invalid_sequence = b'\xc0\x80'
     
    print(f"字节序列: {invalid_sequence}")
    print(f"十六进制表示: {invalid_sequence.hex()}")
     
    # 尝试解码 - 这会失败
    try:
        decoded = invalid_sequence.decode('utf-8')
        print(f"解码结果: {decoded}")
    except UnicodeDecodeError as e:
        print(f"解码错误: {e}")
     
    # 解释为什么这是无效的
    print("\n解释:")
    print("0xC0 0x80 是null字符(U+0000)的过长编码")
    print("正确的UTF-8编码中,U+0000应该编码为单字节0x00")
    print("UTF-8标准禁止过长编码以防止安全漏洞")
     
    # 其他无效序列的例子
    other_invalid = [
        b'\xff\xff',  # 0xFF在UTF-8中永远无效
        b'\xc0\xc0',  # 无效的起始字节组合
        b'\x80\x80',  # 孤立的继续字节
    ]
     
    print("\n其他无效UTF-8序列的例子:")
    for seq in other_invalid:
        try:
            seq.decode('utf-8')
            print(f"{seq.hex()}: 有效")
        except UnicodeDecodeError:
            print(f"{seq.hex()}: 无效")

2.3 Subword Tokenization

  • 背景
    • 词级分词器:词汇外(OOV)问题严重,但序列较短(例如 10 个词 = 10 个 token)。
    • 字节级分词器:没有 OOV 问题,但序列过长(10 个词可能变成 50+ 个 token),训练更慢,并引入更长的依赖关系。
  • 子词分词器(Subword Tokenizer)
    • 介于词级和字节级之间,平衡 OOV 问题和序列长度。
    • 字节级分词器词表只有 256 个 token(所有字节)。
    • 子词分词器通过增加词表规模,压缩输入字节序列。
    • 例如:字节序列 b'the' 出现频繁 → 在词表中加入 'the' → 原本 3 个 token 变成 1 个 token。
  • BPE(Byte-Pair Encoding)
    • 一种压缩算法(Gage, 1994;Sennrich et al., 2016 提出用于 NLP)。
    • 迭代合并训练语料中最频繁的字节对,生成新的 token。
    • 如果一个词出现足够多次,它就会被整体作为单个子词单元。
  • 结果
    • BPE 分词器(BPE Tokenizer) = 字节或合并后的字节序列。
    • 既能避免 OOV 问题,又能保持合理的输入长度。
    • 词表的构建过程称为 训练分词器。

2.4 BPE Tokenizer Training

BPE 分词器的训练过程主要分为 三步:

  1. Vocabulary Initialization(词表初始化)

    • 初始词表为 所有可能的字节值(0–255,共 256 个 token)。
    • 每个字节对应一个唯一的整数 ID。
  2. Pre-tokenization(预分词)

    • 目的:避免逐字节扫描语料库的高计算开销,并减少语义相近但带标点的词被割裂(如 dog! vs dog.)。
    • 方法:
      • 将语料进行粗粒度分词(pre-token)。
      • 每个 pre-token 转换为 UTF-8 字节序列。
      • 统计字节对的频率时,按 pre-token 出现次数加权(如 'text' 出现 10 次,则 't' 与 'e' 的相邻统计 +10)。
    • 实现方式:
      • Sennrich et al. (2016):用空格分词 s.split(" ")。

      • GPT-2 / tiktoken:使用 正则表达式预分词器:

        python
        PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
         
      • 推荐用 re.finditer 来遍历预分词结果(避免存储大规模 token 列表)。

  3. Compute BPE Merges(计算 BPE 合并)

    • 算法核心:
      1. 迭代统计所有相邻字节对的频率。
      2. 找到频率最高的字节对 (A, B)。
      3. 将所有 (A, B) 替换为新 token "AB"。
      4. 将 "AB" 加入词表。
    • 结果:最终词表大小 = 初始 256 + BPE 合并次数。
    • 注意事项:
      • 不跨 pre-token 边界统计合并对。

      • 若存在频率并列,按 字典序最大 的 pair 合并:

        python
        max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])
        # 结果: ('BA', 'A')
         
  4. Special Tokens(特殊 token)

    • 一些字符串(如 &lt;|endoftext|&gt;)需保留为单一 token,不应被拆分。
    • 在训练时,必须把这些特殊 token 加入词表,并为其分配固定 ID。
  5. 参考实现

    • Sennrich et al. (2016) 的 Algorithm 1:给出了一版低效的 BPE 实现,可作为练习帮助理解上述过程。

👉 总结一句话:

BPE 分词器训练 = 初始化 256 字节词表 → 预分词(统计频率)→ 迭代字节对合并 → 加入特殊 token,最终得到既紧凑又可扩展的词表。

Example (bpe_example): BPE training example

以下是 Sennrich 等⼈[2016]提出的⻛格化⽰例。假设语料库包含如下⽂本:

low low low low low

lower lower widest widest widest

newest newest newest newest newest newest

且词汇表中包含特殊标记**<|endoftext|>**。

Vocabulary 我们使⽤特殊标记**<|endoftext|>**和 256 个字节值初始化词汇表。

Pre-tokenization 为简化流程并专注于合并步骤,本例中我们假设预分词仅按空格分割。经预分词统计后,我们得到如下频次表。

{low: 5, lower: 2, widest: 3, newest: 6}

⽅便将其表⽰为字典类型dict[tuple[bytes], int],例如{(l,o,w): 5 …}。需要注意的是,在 Python 中即使单个字节也是 bytes 对象。Python 中没有单独的 byte 类型来表⽰单个字节,就像没有 char 类型来表⽰单个字符⼀样。

⾸先,我们查看每对相邻字节,并统计它们所在单词的出现频率{lo:7, ow:7, we:8, er:2, wi:3, id:3, de:3, es:9, st:9, ne:6, ew:6}。其中(es)和(st)出现频率相同,因此选择字典序较⼤的(st)进⾏合并。随后我们将预标记合并为{(l,o,w):5, (l,o,w,e,r):2, (w,i,d,e,st):3, (n,e,w,e,st):6}。

第⼆轮中,(e, st)成为最⾼频对(出现 9 次),合并后得到{(l,o,w):5, (l,o,w,e,r):2, (w,i,d,est):3, (n,e,w,est):6}。继续该过程,最终得到的合并序列为['s t', 'e st', 'o w', 'l ow', 'w est', 'n e', 'ne west', 'w i', 'wi d', 'wid est', 'lowe', 'lowe r']。

若进⾏ 6 次合并,得到['s t', 'e st', 'o w', 'l ow', 'w est', 'n e'],此时词汇表元素为[&lt;|endoftext|&gt;, [...256 字节字符], st, est, ow, low, west, ne]。

基于该词汇表和合并规则,单词"newest"将被分词为[ne, west]。

2.5 Experimenting with BPE Tokenizer Training

在 TinyStories 数据集上训练一个字节级 BPE(Byte Pair Encoding)分词器,并通过多进程并行化来优化性能。

并行化预分词处理

问题: 预分词是主要性能瓶颈 解决方案: 使用 multiprocessing 库进行并行处理

关键实现:

  • 将语料库按特殊标记 &lt;|endoftext|&gt; 分块
  • 确保分块边界始终在特殊标记开头,避免跨文档合并
  • 使用提供的 find_chunk_boundaries() 函数获取分块边界

分块原理:

python
# 每个进程处理一个文本块
for start, end in zip(boundaries[:-1], boundaries[1:]):
    f.seek(start)
    chunk = f.read(end - start).decode("utf-8", errors="ignore")
    # 对chunk进行预分词处理
 

优势:

  • 语义完整性:不会在文档中间切断
  • 独立处理:各块可以并行处理
  • 内存效率:每次只处理部分数据

预分词前移除特殊标记

目的: 防止跨文档边界的合并操作

处理流程:

  1. 按特殊标记分割文本块:[文档1] &lt;|endoftext|&gt; [文档2]
  2. 移除特殊标记
  3. 分别对每个文档进行预分词:[文档1] 和 [文档2]

实现方式:

python
# 使用 re.split 按特殊标记分割
pattern = "|".join(re.escape(token) for token in special_tokens)
documents = re.split(pattern, text_chunk)
 

注意事项:

  • 使用 re.escape() 处理特殊字符(如 |)
  • 避免跨文档的字节对合并

优化合并步骤

问题: 朴素实现每次合并都要遍历所有字节对 解决方案: 建立计数索引,增量更新统计

优化策略:

  • 建立索引: 创建所有字节对的计数缓存
  • 增量更新: 只更新与已合并字节对重叠的配对统计
  • 避免全遍历: 不再显式遍历每个字节对进行频率统计

性能提升:

  • 显著加速 BPE 训练过程
  • 减少重复计算开销

限制:

  • 合并环节在 Python 中无法并行化
  • 但缓存机制仍能带来可观加速

实施要点

数据处理流程

  1. 数据准备: 查看和理解 TinyStories 数据集结构
  2. 文件分块: 使用 find_chunk_boundaries() 按 &lt;|endoftext|&gt; 分块
  3. 并行预分词: 多进程处理各个文本块
  4. 特殊标记处理: 移除标记,避免跨文档合并
  5. BPE训练: 使用优化的合并算法

关键技术细节

  • 分块边界: 确保在 &lt;|endoftext|&gt; 处分割
  • 内存管理: 使用 4KB 小块逐步读取
  • 编码处理: 使用 decode("utf-8", errors="ignore") 处理编码问题
  • 去重排序: sorted(set(chunk_boundaries)) 确保边界唯一性

测试验证

  • 使用 test_train_bpe_special_tokens 测试用例验证特殊标记处理功能
  • 确保分块和合并逻辑的正确性

性能优化总结

  1. 并行化: 预分词阶段通过多进程提升速度
  2. 智能分块: 保证语义完整性的同时实现并行处理
  3. 缓存优化: 合并步骤使用增量更新减少计算开销
  4. 内存优化: 分块读取避免内存溢出

这种实现方式在保证训练质量的同时,显著提升了 BPE 分词器的训练效率。

低资源/降级技巧:性能分析 应使⽤如 cProfile 或 scalene 等性能分析⼯具来识别实现中的瓶颈,并重点优化这些部分。

低资源/降级技巧:"降级处理" 建议先在小规模数据子集(如 TinyStories 验证集,2.2 万份文档而非 212 万份)上训练分词器。这体现了一种通用的开发加速策略:通过降级处理(更小的数据集、更小的模型等)来快速迭代。调试集规模和超参数需权衡:既要足够大以保留主要瓶颈特征,保证优化措施具普适性,又不能过大以免运行过慢。

Problem (train_bpe): BPE Tokenizer Training

Deliverable: Write a function that, given a path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input arameters:

input_path: str Path to a text file with BPE tokenizer training data.

vocab_size: int A positive integer that defines the maximum final vocabulary size (including the initial byte vocabulary, vocabulary items produced from merging, and any special tokens).

special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not otherwise affect BPE training.

Your BPE training function should return the resulting vocabulary and merges:

vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).

merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with <token2>. The merges should be ordered by order of creation.

To test your BPE training function against our provided tests, you will first need to implement the test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py. Your implementation should be able to pass all tests. Optionally (this could be a large time-investment), you can implement the key parts of your training method using some systems language, for instance C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations require copying vs reading directly from Python memory, and make sure to leave build instructions, or make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported in most regex engines and will be too slow in most that do. We have verified that Oniguruma is reasonably fast and supports negative lookahead, but the regex package in Python is, if anything, even faster.

Problem (train_bpe_tinystories): BPE Training on TinyStories

  1. Train a byte-level BPE tokenizer on the TinyStories dataset, using a maximum vocabulary sizeof 10,000. Make sure to add the TinyStories <|endoftext|> special token to the vocabulary. Serialize the resulting vocabulary and merges to disk for further inspection. How many hoursand memory did training take? What is the longest token in the vocabulary? Does it make sense?

    Resource requirements: ≤ 30 minutes (no GPUs), ≤ 30GB RAM

    Hint You should be able to get under 2 minutes for BPE training using multiprocessing during pretokenization and the following two facts:

    1. The <|endoftext|> token delimits documents in the data files.
    2. The <|endoftext|> token is handled as a special case before the BPE merges are applied.

    Deliverable: A one-to-two sentence response.

  2. Profile your code. What part of the tokenizer training process takes the most time?Deliverable: A one-to-two sentence response.

Build and Train GPT-4 Tokenizer from scratch

Implementing A Byte Pair Encoding (BPE) Tokenizer From Scratch

How to Train BPE, WordPiece, and Unigram Tokenizers from Scratch using Hugging Face

NLP From Scratch— Part 1: BPE Tokenization

https://github.com/karpathy/minbpe

https://github.com/JohannesVod/QuickBPE

https://github.com/openai/tiktoken

https://github.com/glample/fastBPE

https://github.com/bheinzerling/bpemb