Files
mpsp/test/test_multiprocess.py
2026-02-21 12:00:47 +08:00

447 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多进程数据共享测试 - 验证跨进程数据同步能力
"""
import multiprocessing
import time
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数(需要在模块级别定义以便多进程使用)====================
def worker_put_data(key, value):
"""子进程:往共享池写入数据"""
pool = MultiProcessingSharedPool()
pool.put(key, value)
return True
def worker_get_data(key, result_queue):
"""子进程:从共享池读取数据并放入结果队列"""
pool = MultiProcessingSharedPool()
value = pool.get(key)
result_queue.put(value)
def worker_check_exists(key, result_queue):
"""子进程:检查 key 是否存在"""
pool = MultiProcessingSharedPool()
result_queue.put(pool.exists(key))
def worker_modify_data(key, new_value, result_queue):
"""子进程:修改数据并返回旧值"""
pool = MultiProcessingSharedPool()
old_value = pool.get(key)
pool.put(key, new_value)
result_queue.put(old_value)
def worker_wait_and_get(key, wait_time, result_queue):
"""子进程:等待一段时间后读取数据"""
time.sleep(wait_time)
pool = MultiProcessingSharedPool()
result_queue.put(pool.get(key))
def worker_increment_counter(key, iterations):
"""子进程:对计数器进行递增"""
pool = MultiProcessingSharedPool()
for _ in range(iterations):
# 注意:这不是原子操作,仅用于测试并发访问
current = pool.get(key, 0)
pool.put(key, current + 1)
def worker_pop_data(key, result_queue):
"""子进程:弹出数据"""
pool = MultiProcessingSharedPool()
value = pool.pop(key)
result_queue.put(value)
# ==================== 测试类 ====================
class TestParentChildProcess:
"""测试父子进程间数据传递"""
def test_parent_write_child_read(self):
"""测试父进程写入,子进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 子进程读取数据
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "shared_value"
def test_child_write_parent_read(self):
"""测试子进程写入,父进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 子进程写入数据
p = multiprocessing.Process(
target=worker_put_data, args=("child_key", "child_value")
)
p.start()
p.join()
# 父进程读取数据(需要短暂等待以确保数据同步)
time.sleep(0.1)
result = pool.get("child_key")
assert result == "child_value"
def test_parent_child_data_isolation(self):
"""测试父子进程数据隔离 - 验证 manager.dict 的同步机制"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入初始数据
pool.put("isolation_test", "parent_value")
# 子进程读取并修改
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_modify_data,
args=("isolation_test", "child_value", result_queue),
)
p.start()
p.join()
# 验证子进程读取到了父进程的值
child_read = result_queue.get()
assert child_read == "parent_value"
# 验证父进程可以看到子进程的修改
time.sleep(0.1)
parent_read = pool.get("isolation_test")
assert parent_read == "child_value"
class TestMultipleChildProcesses:
"""测试多个子进程间的数据共享"""
def test_multiple_children_write_same_key(self):
"""测试多个子进程写入同一个 key后面的覆盖前面的"""
pool = MultiProcessingSharedPool()
pool.clear()
processes = []
for i in range(5):
p = multiprocessing.Process(
target=worker_put_data, args=("shared_key", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证有一个值被成功写入
result = pool.get("shared_key")
assert result.startswith("value_")
def test_multiple_children_write_different_keys(self):
"""测试多个子进程写入不同的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
num_processes = 5
processes = []
for i in range(num_processes):
p = multiprocessing.Process(
target=worker_put_data, args=(f"key_{i}", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证所有值都被成功写入
for i in range(num_processes):
assert pool.get(f"key_{i}") == f"value_{i}"
def test_multiple_children_read_same_key(self):
"""测试多个子进程读取同一个 key"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 多个子进程读取
num_processes = 5
result_queue = multiprocessing.Queue()
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有结果
results = []
for _ in range(num_processes):
results.append(result_queue.get())
# 所有子进程都应该读取到相同的值
assert all(r == "shared_value" for r in results)
def test_concurrent_exists_check(self):
"""测试并发检查 key 是否存在"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入一些数据
for i in range(5):
pool.put(f"key_{i}", f"value_{i}")
# 多个子进程并发检查
result_queue = multiprocessing.Queue()
processes = []
for i in range(10):
key = f"key_{i % 7}" # 有些 key 存在,有些不存在
p = multiprocessing.Process(
target=worker_check_exists, args=(key, result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集结果
results = []
for _ in range(10):
results.append(result_queue.get())
# 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查)
assert sum(results) >= 5 # 至少 5 个存在
# 进程池测试需要在模块级别定义 worker 函数
def _pool_worker_map(args):
"""进程池 map 操作的 worker"""
idx, key = args
shared_pool = MultiProcessingSharedPool()
value = shared_pool.get(key)
return idx, value
def _pool_worker_apply(key):
"""进程池 apply_async 的 worker"""
shared_pool = MultiProcessingSharedPool()
return shared_pool.get(key)
class TestProcessPool:
"""测试在进程池中使用"""
def test_pool_map_with_shared_data(self):
"""测试在进程池 map 操作中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入测试数据
for i in range(5):
pool.put(f"input_{i}", i * 10)
# 使用进程池
with multiprocessing.Pool(processes=3) as process_pool:
results = process_pool.map(
_pool_worker_map, [(i, f"input_{i}") for i in range(5)]
)
# 验证结果
results_dict = {idx: val for idx, val in results}
for i in range(5):
assert results_dict[i] == i * 10
def test_pool_apply_async_with_shared_data(self):
"""测试在进程池 apply_async 中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("async_key", "async_value")
with multiprocessing.Pool(processes=2) as process_pool:
result = process_pool.apply_async(_pool_worker_apply, ("async_key",))
assert result.get(timeout=5) == "async_value"
class TestDataVisibility:
"""测试数据可见性和同步时机"""
def test_immediate_visibility_after_put(self):
"""测试写入后立即可见"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("immediate_key", "immediate_value")
# 同一进程内应该立即可见
assert pool.exists("immediate_key")
assert pool.get("immediate_key") == "immediate_value"
def test_cross_process_visibility_with_delay(self):
"""测试跨进程可见性(带延迟)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入
pool.put("delayed_key", "delayed_value")
# 子进程延迟后读取
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "delayed_value"
class TestConcurrentModifications:
"""测试并发修改场景"""
def test_concurrent_counter_increments(self):
"""测试并发计数器递增(非原子操作,预期会有竞争条件)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化计数器
pool.put("counter", 0)
num_processes = 4
iterations_per_process = 10
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_increment_counter,
args=("counter", iterations_per_process),
)
processes.append(p)
p.start()
for p in processes:
p.join()
time.sleep(0.1)
# 由于竞争条件,实际值可能小于期望值
# 这个测试主要是为了验证并发访问不会崩溃
final_count = pool.get("counter")
assert isinstance(final_count, int)
assert 0 <= final_count <= num_processes * iterations_per_process
def test_concurrent_pop_operations(self):
"""测试并发 pop 操作"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化多个 key
num_keys = 5
for i in range(num_keys):
pool.put(f"pop_key_{i}", f"pop_value_{i}")
result_queue = multiprocessing.Queue()
processes = []
for i in range(num_keys):
p = multiprocessing.Process(
target=worker_pop_data, args=(f"pop_key_{i}", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有 pop 的结果
popped_values = []
for _ in range(num_keys):
popped_values.append(result_queue.get())
# 验证所有值都被正确 pop
assert len(popped_values) == num_keys
for i in range(num_keys):
assert f"pop_value_{i}" in popped_values
# 验证所有 key 都被移除
time.sleep(0.1)
for i in range(num_keys):
assert not pool.exists(f"pop_key_{i}")
class TestComplexDataTypes:
"""测试复杂数据类型的多进程共享"""
def test_share_nested_dict(self):
"""测试共享嵌套字典"""
pool = MultiProcessingSharedPool()
pool.clear()
nested_data = {
"level1": {"level2": {"level3": [1, 2, 3]}},
"list_of_dicts": [{"a": 1}, {"b": 2}],
}
pool.put("nested_key", nested_data)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("nested_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == nested_data
def test_share_large_list(self):
"""测试共享大型列表"""
pool = MultiProcessingSharedPool()
pool.clear()
large_list = list(range(10000))
pool.put("large_list_key", large_list)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("large_list_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == large_list