第10章 10.2 单元测试与 pytest
「上章回顾」 上一章我们学会了如何把一个项目从「单文件乱炖」整理成「分门别类的文件夹+文件」,小明的代码终于看起来像个正经项目了。但是,新的问题来了——每次改动代码后,怎么知道自己的修改有没有把原来能跑的功能搞坏了?
「本章目标」 读完这篇文章,你就能:
- 明白「单元测试」到底是什么、为什么科技公司都要求「测试驱动开发」
- 用 pytest 写出能自动跑、帮你发现 bug 的测试代码
- 用 fixture 管理测试数据,像搭积木一样复用测试环境
- 课程结束时,你会有一个「带完整测试用例的个人工具箱」
🎯 开场 3 分钟:为什么要学这个?
一个真实场景
想象你是个厨师,今天要往「红烧肉」食谱里加一味新调料。改完之后,你得尝一口才知道味道对不对——但如果这道菜已经被 1000 个人点过了,你怎么确保新调料没把老味道毁了?
代码也是这样。
小明上周给「购物车」模块加了个「打折功能」,结果把原来的「计算总价」功能搞坏了。因为他只手动测了打折后的价格,没测不打折的情况。上线后,用户发现不打折的订单全部显示 NaN(Not a Number,意思是「不是数字」)。
两个痛点
- 手动测试太累:每次改代码,要手动跑一遍所有功能。代码量超过 1000 行时,根本测不过来。
- 回归 bug 太可怕:你修复了一个 bug,结果引入了三个新 bug。上线后用户骂声一片。
学完能解决什么
学完这一章,你写的每一个「函数」都会有对应的「测试函数」。改代码后一键跑测试,谁的 bug 引入的清清楚楚。
🧱 基础 25 分钟:核心概念
10.2.1 什么是单元测试?为什么要用?
生活类比:产品质检员
想象工厂流水线生产手机。每个零件(电池、屏幕、芯片)在组装前,都要单独检测——这就是「单元测试」。只有每个零件都合格,整机才可能合格。
代码里的「单元」通常指的是最小的可测试单元:一个函数、一个类方法。
为什么要单元测试?
| 不用测试 | 用测试 |
|---|---|
| 改代码靠「运气」 | 改代码靠「证据」 |
| bug 偷偷上线 | bug 在本地就被抓住 |
| 加班到半夜修 bug | 早发现早治疗 |
| 改一处坏三处 | 改一处只有一处可能坏 |
10.2.2 assert 断言:最简单的测试
Python 自带的 assert 语句,就是最原始的「测试」。
def add(a, b):
return a + b
# 用 assert 来「验证」结果
result = add(2, 3)
assert result == 5, "2 + 3 应该等于 5 啊,怎么会等于 {}".format(result)
print("测试通过!2 + 3 = {}".format(result))
运行结果:
测试通过!2 + 3 = 5
如果 add(2, 3) 返回了 6,assert 会抛出 AssertionError。
说白了:assert 条件, "失败时的提示" 就是「如果这个条件不成立,就报错」。
10.2.3 unittest:Python 内置测试框架
Python 自带 unittest 模块,是第一个标准的测试工具。
import unittest
def add(a, b):
return a + b
class TestAddFunction(unittest.TestCase):
"""测试 add 函数的各种场景"""
def test_positive_numbers(self):
"""测试正数相加"""
self.assertEqual(add(2, 3), 5)
def test_negative_numbers(self):
"""测试负数相加"""
self.assertEqual(add(-1, -1), -2)
def test_mixed_numbers(self):
"""测试一正一负"""
self.assertEqual(add(-1, 1), 0)
if __name__ == "__main__":
unittest.main()
运行结果:
...
----------------------------------------------------------------------
Ran 3 tests in 0.001s
OK
三个要点:
- unittest.TestCase 是「测试用例」的基类,继承它才有测试能力
- self.assertEqual(a, b) 是「断言 a 等于 b」,不等就报错
- 方法名必须以 test_ 开头,框架会自动找到这些方法

10.2.4 pytest:更优雅的测试框架
unittest 虽好,但写起来有点啰嗦。pytest 是社区最喜欢的测试框架,语法更简洁,插件更丰富。
先安装:
pip install pytest
用 pytest 重写上面的测试:
def add(a, b):
return a + b
def test_positive_numbers():
"""测试正数相加"""
assert add(2, 3) == 5
def test_negative_numbers():
"""测试负数相加"""
assert add(-1, -1) == -2
def test_mixed_numbers():
"""测试一正一负"""
assert add(-1, 1) == 0
对比一下:
- 不需要 import unittest
- 不需要继承 unittest.TestCase
- 不需要 self.assertEqual,直接用 assert
- 方法名还是需要 test_ 开头
运行方式:
pytest test_file.py -v
-v 表示 verbose(详细模式),会显示每个测试的名字和结果。
10.2.5 fixture:测试的「准备工作」和「清理工作」
生活类比:考试前的布置考场
每次考试(测试)前,需要有人摆桌子、发试卷、贴考号。考完后需要有人收试卷、清扫教室。fixture 就是干这个的。
场景:测试「用户注册」功能
每个测试用户需要:名字、邮箱、密码。如果每个测试函数都手动创建一遍,代码会又臭又长。
用 pytest fixture:
import pytest
@pytest.fixture
def test_user():
"""创建一个测试用户,供所有测试函数使用"""
user = {"name": "小明", "email": "xiaoming@example.com", "password": "123456"}
return user
def test_user_name(test_user):
"""测试用户名是否正确"""
assert test_user["name"] == "小明"
def test_user_email(test_user):
"""测试用户邮箱是否正确"""
assert test_user["email"] == "xiaoming@example.com"
def test_user_password_length(test_user):
"""测试密码长度"""
assert len(test_user["password"]) == 6
运行结果:
pytest test_fixture.py -v
========================== test session starts ===========================
collected 3 items
test_fixture.py::test_user_name PASSED [ 33%]
test_fixture.py::test_user_email PASSED [ 66%]
test_fixture.py::test_user_password_length PASSED [100%]
========================== 3 passed in 0.02s ===========================
说白了:@pytest.fixture 装饰的函数返回值,可以直接作为测试函数的参数传入。pytest 会自动帮你管理创建和销毁。
10.2.6 参数化测试:同一个测试,多组数据
场景:测试「计算器」的「除法」函数
除法有这些边界情况:
- 正数除以正数:6 / 2 = 3
- 负数除以正数:-6 / 2 = -3
- 零作为被除数:0 / 5 = 0
- 零作为除数:5 / 0 = 报错
如果写四个测试函数,代码重复率很高。用 @pytest.mark.parametrize:
import pytest
def divide(a, b):
if b == 0:
raise ValueError("除数不能为 0")
return a / b
@pytest.mark.parametrize("a,b,expected", [
(6, 2, 3), # 正数除以正数
(-6, 2, -3), # 负数除以正数
(0, 5, 0), # 零作为被除数
(10, 4, 2.5), # 整数除法结果是小数
])
def test_divide(a, b, expected):
"""测试除法运算"""
assert divide(a, b) == expected
def test_divide_by_zero():
"""测试除数为零的情况"""
with pytest.raises(ValueError):
divide(5, 0)
运行结果:
========================== test session starts ===========================
test_param.py::test_divide[6-2-3] PASSED [ 20%]
test_param.py::test_divide[-6-2--3] PASSED [ 40%]
test_param.py::test_divide[0-5-0] PASSED [ 60%]
test_param.py::test_divide[10-4-2.5] PASSED [ 80%]
test_param.py::test_divide_by_zero PASSED [100%]
========================== 5 passed in 0.03s ===========================

🔥 实战 35 分钟:3 个递进的小项目
项目 1:给「计算器」函数写测试(5 分钟)
目标:理解 pytest 基本用法,会写能跑的测试
# calculator.py
"""一个简单的计算器模块"""
def add(a, b):
"""加法"""
return a + b
def subtract(a, b):
"""减法"""
return a - b
def multiply(a, b):
"""乘法"""
return a * b
def divide(a, b):
"""除法"""
if b == 0:
raise ValueError("除数不能为 0")
return a / b
# test_calculator.py
"""测试计算器模块"""
import pytest
from calculator import add, subtract, multiply, divide
class TestCalculator:
"""计算器测试套件"""
def test_add(self):
assert add(2, 3) == 5
assert add(-1, 1) == 0
def test_subtract(self):
assert subtract(10, 3) == 7
assert subtract(3, 10) == -7
def test_multiply(self):
assert multiply(3, 4) == 12
assert multiply(-2, 3) == -6
def test_divide(self):
assert divide(10, 2) == 5
assert divide(-6, 2) == -3
def test_divide_by_zero(self):
with pytest.raises(ValueError):
divide(1, 0)
运行:
pytest test_calculator.py -v
预期输出:
========================== test session starts ===========================
collected 5 items
test_calculator.py::TestCalculator::test_add PASSED [ 20%]
test_calculator.py::TestCalculator::test_subtract PASSED [ 40%]
test_calculator.py::TestCalculator::test_multiply PASSED [ 60%]
test_calculator.py::TestCalculator::test_divide PASSED [ 80%]
test_calculator.py::TestCalculator::test_divide_by_zero PASSED [100%]
========================== 5 passed in 0.03s ===========================
一句话解释:每个 test_ 函数测试一个场景,assert 后面跟「期望值」。
项目 2:测试「待办清单」管理类(15 分钟)
目标:从 JSON 文件读写数据,用 fixture 管理测试数据
首先,待办清单的数据存储在 todos.json:
[
{"id": 1, "title": "买牛奶", "done": false},
{"id": 2, "title": "写周报", "done": false},
{"id": 3, "title": "健身", "done": true}
]
待办清单管理类:
# todo_manager.py
"""待办清单管理"""
import json
from pathlib import Path
class TodoManager:
"""待办清单管理器"""
def __init__(self, file_path="todos.json"):
self.file_path = file_path
self.todos = []
self.load()
def load(self):
"""从文件加载待办"""
if Path(self.file_path).exists():
with open(self.file_path, "r", encoding="utf-8") as f:
self.todos = json.load(f)
def save(self):
"""保存待办到文件"""
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(self.todos, f, ensure_ascii=False, indent=2)
def add(self, title):
"""添加待办"""
new_id = max([t["id"] for t in self.todos], default=0) + 1
self.todos.append({"id": new_id, "title": title, "done": False})
self.save()
return new_id
def complete(self, todo_id):
"""标记完成"""
for todo in self.todos:
if todo["id"] == todo_id:
todo["done"] = True
self.save()
return True
return False
def get_pending(self):
"""获取未完成的待办"""
return [t for t in self.todos if not t["done"]]
def get_completed(self):
"""获取已完成的待办"""
return [t for t in self.todos if t["done"]]
测试代码(用 fixture 管理测试数据文件):
# test_todo_manager.py
"""测试待办清单管理器"""
import pytest
import json
import tempfile
from todo_manager import TodoManager
@pytest.fixture
def temp_todo_file():
"""创建一个临时待办文件"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f:
json.dump([
{"id": 1, "title": "买牛奶", "done": False},
{"id": 2, "title": "写周报", "done": False},
{"id": 3, "title": "健身", "done": True}
], f, ensure_ascii=False)
temp_path = f.name
yield temp_path # 返回临时文件路径给测试用
# 测试结束后清理
Path(temp_path).unlink(missing_ok=True)
def test_load_todos(temp_todo_file):
"""测试加载待办"""
manager = TodoManager(temp_todo_file)
assert len(manager.todos) == 3
def test_add_todo(temp_todo_file):
"""测试添加待办"""
manager = TodoManager(temp_todo_file)
new_id = manager.add("学 pytest")
assert new_id == 4
assert len(manager.todos) == 4
# 验证文件已更新
manager2 = TodoManager(temp_todo_file)
assert len(manager2.todos) == 4
def test_complete_todo(temp_todo_file):
"""测试标记完成"""
manager = TodoManager(temp_todo_file)
result = manager.complete(1)
assert result is True
assert manager.todos[0]["done"] is True
def test_get_pending(temp_todo_file):
"""测试获取未完成待办"""
manager = TodoManager(temp_todo_file)
pending = manager.get_pending()
assert len(pending) == 2
assert all(not t["done"] for t in pending)
def test_get_completed(temp_todo_file):
"""测试获取已完成待办"""
manager = TodoManager(temp_todo_file)
completed = manager.get_completed()
assert len(completed) == 1
assert completed[0]["title"] == "健身"
运行:
pytest test_todo_manager.py -v
预期输出:
========================== test session starts ===========================
collected 5 items
test_todo_manager.py::test_load_todos PASSED [ 20%]
test_todo_manager.py::test_add_todo PASSED [ 40%]
test_todo_manager.py::test_complete_todo PASSED [ 60%]
test_todo_manager.py::test_get_pending PASSED [ 80%]
test_todo_manager.py::test_get_completed PASSED [100%]
========================== 5 passed in 0.05s ===========================
一句话解释:tempfile 创建临时文件测试 IO 功能,yield 返回临时路径后自动清理。
项目 3:做一个「工资计算器」并测试(15 分钟)
目标:组合计算逻辑和数据读写,做一个有点真实用处的工具
需求:
1. 从 CSV 读取员工工资数据
2. 计算每个人扣除社保后的实发工资
3. 将结果写入新 CSV 文件
salary.csv(原始数据):
name,base_salary,bonus
张三,8000,1000
李四,12000,2000
王五,6000,500
salary_calculator.py:
"""工资计算器"""
import csv
from pathlib import Path
def calculate_net_salary(base_salary, bonus, social_rate=0.105):
"""
计算实发工资
参数:
base_salary: 基本工资
bonus: 奖金
social_rate: 社保缴费比例(默认 10.5%)
返回:
实发工资(基本工资 + 奖金 - 社保扣款)
"""
social_deduction = (base_salary + bonus) * social_rate
net_salary = base_salary + bonus - social_deduction
return round(net_salary, 2)
def process_salary_file(input_file, output_file, social_rate=0.105):
"""
处理工资表:读取、计算、写入
参数:
input_file: 输入的 CSV 文件路径
output_file: 输出的 CSV 文件路径
social_rate: 社保缴费比例
"""
results = []
with open(input_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
name = row["name"]
base_salary = float(row["base_salary"])
bonus = float(row["bonus"])
net_salary = calculate_net_salary(base_salary, bonus, social_rate)
social_deduction = round((base_salary + bonus) * social_rate, 2)
results.append({
"name": name,
"base_salary": base_salary,
"bonus": bonus,
"social_deduction": social_deduction,
"net_salary": net_salary
})
with open(output_file, "w", encoding="utf-8", newline="") as f:
fieldnames = ["name", "base_salary", "bonus", "social_deduction", "net_salary"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
return results
test_salary_calculator.py:
"""测试工资计算器"""
import pytest
import csv
import tempfile
from salary_calculator import calculate_net_salary, process_salary_file
class TestCalculateNetSalary:
"""测试工资计算函数"""
def test_normal_case(self):
"""普通情况:基本工资 8000 + 奖金 1000,社保 10.5%"""
net = calculate_net_salary(8000, 1000)
expected = 8000 + 1000 - (9000 * 0.105) # 955.5 社保扣款
assert net == 8954.5
def test_zero_bonus(self):
"""零奖金情况"""
net = calculate_net_salary(10000, 0)
expected = 10000 - (10000 * 0.105)
assert net == 8950
def test_custom_social_rate(self):
"""自定义社保比例"""
net = calculate_net_salary(10000, 0, social_rate=0.08)
expected = 10000 - (10000 * 0.08)
assert net == 9200
@pytest.fixture
def sample_salary_csv():
"""创建测试用的工资 CSV"""
content = """name,base_salary,bonus
张三,8000,1000
李四,12000,2000
王五,6000,500"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as f:
f.write(content)
temp_path = f.name
yield temp_path
Path(temp_path).unlink(missing_ok=True)
def test_process_salary_file(sample_salary_csv):
"""测试完整的工资处理流程"""
output_file = sample_salary_csv.replace(".csv", "_result.csv")
try:
results = process_salary_file(sample_salary_csv, output_file)
# 验证返回结果
assert len(results) == 3
assert results[0]["name"] == "张三"
assert results[0]["net_salary"] == 8954.5
# 验证输出文件
with open(output_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == 3
assert rows[0]["name"] == "张三"
assert rows[0]["net_salary"] == "8954.5"
finally:
Path(output_file).unlink(missing_ok=True)
运行:
pytest test_salary_calculator.py -v
预期输出:
========================== test session starts ===========================
collected 5 items
test_salary_calculator.py::TestCalculateNetSalary::test_normal_case PASSED [ 20%]
test_salary_calculator.py::TestCalculateNetSalary::test_zero_bonus PASSED [ 40%]
test_salary_calculator.py::TestCalculateNetSalary::test_custom_social_rate PASSED [ 60%]
test_salary_calculator.py::test_process_salary_file PASSED [100%]
========================== 4 passed in 0.04s ===========================
一句话解释:calculate_net_salary 测试纯计算逻辑,process_salary_file 测试文件 IO,两者独立测试互不干扰。
💪 进阶 20 分钟:常见坑 + 性能小贴士
坑 1:测试之间的「状态污染」
❌ 错误示例:全局变量被多个测试修改
# ❌ 错误代码
data = [] # 全局状态
def test_add():
data.append(1)
assert len(data) == 1
def test_add_again():
data.append(2)
assert len(data) == 1 # 失败!因为 data 里有 [1] 了
✅ 正确做法:用 fixture 或在每个测试前重置
# ✅ 正确代码
@pytest.fixture
def clean_data():
data = []
yield data
# 测试后清理
def test_add(clean_data):
clean_data.append(1)
assert len(clean_data) == 1
def test_add_again(clean_data):
clean_data.append(2)
assert len(clean_data) == 1
坑 2:浮点数比较的精度问题
❌ 错误示例:直接用 == 比较浮点数
# ❌ 错误代码
result = 0.1 + 0.2
assert result == 0.3 # 失败!因为 0.1 + 0.2 = 0.30000000000000004
✅ 正确做法:用 pytest.approx 或 math.isclose
# ✅ 正确代码
result = 0.1 + 0.2
assert result == pytest.approx(0.3) # 通过!
坑 3:测试「只跑一次」的初始化代码
❌ 错误示例:把「只该跑一次」的代码放在测试函数里
# ❌ 错误代码
def test_database_query():
conn = connect_to_database() # 每次测试都新建连接,拖慢速度
result = conn.query("SELECT * FROM users")
assert len(result) > 0
✅ 正确做法:用 scope="session" 的 fixture
# ✅ 正确代码
@pytest.fixture(scope="session")
def db_connection():
"""整个测试会话只创建一次连接"""
conn = connect_to_database()
yield conn
conn.close()
def test_database_query(db_connection):
result = db_connection.query("SELECT * FROM users")
assert len(result) > 0
坑 4:忘记 __init__.py 导致模块导入失败
❌ 错误示例:
# ❌ 命令
$ pytest tests/test_calculator.py
# 报错:ModuleNotFoundError: No module named 'calculator'
✅ 正确做法:确保项目结构正确,或用相对导入
my_project/
├── calculator.py
└── tests/
├── __init__.py
└── test_calculator.py
# ✅ 在 tests/__init__.py 中添加
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
坑 5:测试「太全面」反而脆弱
❌ 错误示例:测试内部实现细节
# ❌ 错误代码:测试「怎么做的」而不是「做什么的」
def test_sort_algorithm():
data = [3, 1, 2]
result = quick_sort(data) # 如果改成 merge_sort,这个测试就挂了
assert result._internal_tree._root.value == 1 # 测内部状态,噩梦
✅ 正确做法:只测「公开行为」
# ✅ 正确代码:只测「做什么的」
def test_sort_result():
data = [3, 1, 2]
result = sort(data) # 无论是快排还是归并,只要结果对就行
assert result == [1, 2, 3]
性能小贴士:测试顺序优化
pytest 默认按文件/函数名的字母顺序跑测试。把「快」测试放前面,能更快发现问题。
# 文件名按字母顺序:test_a_fast.py, test_b_medium.py, test_c_slow.py
# 函数名也按字母顺序
def test_config_loading(): # 非常快,毫秒级
...
def test_database_query(): # 中等,可能要 100ms
...
def test_full_integration(): # 最慢,可能要几秒
...
调试技巧:用 -s 和 -pdb
# -s:显示 print 输出(默认测试时 print 被捕获不显示)
pytest test_debug.py -s
# -pdb:测试失败时进入交互式调试器
pytest test_debug.py --pdb
# 或者只想看局部变量的值:
pytest test_debug.py -v --tb=short
✏️ 练习题
练习 1(2 分钟):抄改练习
题目:修改项目 1 的 test_multiply 测试,增加一组测试数据 5 * 5 = 25。
# 输入:在 test_multiply 中添加一行
# 预期输出:测试仍然全部通过
# 提示:直接在 assert 语句后加一行新的 assert
练习 2(2 分钟):加个判断
题目:在 calculator.py 中新增 power(a, b) 函数(a 的 b 次方),并写测试验证 2^10 = 1024。
# 输入:在 calculator.py 添加 power 函数
# 预期输出:pytest 显示 power 相关测试 PASSED
# 提示:可以用 ** 运算符,也可以自己写循环
练习 3(5 分钟):新数据处理
题目:将项目 2 的 TodoManager 扩展一个 delete(todo_id) 方法,删除指定待办,并写测试验证。
# 输入:在 TodoManager 类添加 delete 方法
# 预期输出:删除后 get_pending() 长度减少 1
# 提示:遍历 self.todos,用 pop() 或 remove() 删除
练习 4(5 分钟):串起来
题目:将项目 2 的待办清单和项目 3 的工资计算器组合:读取 employees.csv,计算每个人的实发工资,然后把「工资最低的员工」的名字写入待办清单。
# 输入:employees.csv 文件
# 预期输出:待办清单中增加一条「关注:XXX(工资最低)」
# 提示:先用项目 3 的逻辑找最低工资的人,再用项目 2 的 add 方法
练习 5(5 分钟):debug 挑战
题目:以下测试失败了,分析原因并修复:
# test_buggy.py
def test_string_reverse():
assert reverse_string("hello") == "olleh"
assert reverse_string("world") == "dlrow"
def reverse_string(s):
return s[::-1]
# 运行结果:
# FAILED test_buggy.py::test_string_reverse
# 预期:两个都 PASSED
# 实际:第二个才 PASSED
提示:reverse_string 的实现是对的,问题是「测试代码的结构」。
作业:做一个「带测试的个人记账本」
需求描述:开发一个命令行记账工具,支持:
- 记录收入/支出(日期、金额、分类、备注)
- 按月份统计收支
- 从 JSON 文件读写数据
功能点:
1. add_record(type, amount, category, note) - 添加记录
2. get_monthly_summary(year, month) - 获取月度摘要
3. save_to_file(path) / load_from_file(path) - 文件持久化
加分项:
1. 用 @pytest.mark.parametrize 测试多组收支场景
2. 用 fixture 管理测试用的临时文件
验收标准:
- 能跑起来(python account_book.py 无报错)
- 添加记录后,保存再加载,数据一致
- 月度摘要计算正确(收入合计 - 支出合计 = 结余)
📚 总结 + 资源
本文学到的 3 个核心点:
- 单元测试是「自动化质检员」:每个函数/类方法都应有对应测试,改代码后一键跑全部
- pytest 三剑客:
assert断言结果、@pytest.fixture管理测试环境、@pytest.mark.parametrize一测多组数据 - 测试要测「行为」不要测「实现」:公开 API 的输入输出正确性,而不是内部变量怎么变
延伸学习资源:
- pytest 官方文档 - 最权威的参考资料
- 《Python 测试驱动开发》- 经典书籍,测试驱动开发(TDD)入门必读
- Python Uncle Bob 的 Clean Code 系列 - 教你写出「可测试、可维护」代码
「下章预告」:测试写好了,代码质量有保证了——下一步怎么让全世界的人都能用上你写的工具?下一章我们学习「打包发布到 PyPI」,把你的项目变成一个 pip install 就能安装的包!
「互动钩子」:你在工作中遇到过「改代码改出 bug」的场景吗?后来是怎么解决的?评论区聊聊你的故事,老粉优先回复!

评论(0)