447 lines
13 KiB
Python
447 lines
13 KiB
Python
"""
|
||
多进程数据共享测试 - 验证跨进程数据同步能力
|
||
"""
|
||
|
||
import multiprocessing
|
||
import time
|
||
import pytest
|
||
from mpsp.mpsp import MultiProcessingSharedPool
|
||
|
||
|
||
# ==================== 辅助函数(需要在模块级别定义以便多进程使用)====================
|
||
|
||
|
||
def worker_put_data(key, value):
|
||
"""子进程:往共享池写入数据"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.put(key, value)
|
||
return True
|
||
|
||
|
||
def worker_get_data(key, result_queue):
|
||
"""子进程:从共享池读取数据并放入结果队列"""
|
||
pool = MultiProcessingSharedPool()
|
||
value = pool.get(key)
|
||
result_queue.put(value)
|
||
|
||
|
||
def worker_check_exists(key, result_queue):
|
||
"""子进程:检查 key 是否存在"""
|
||
pool = MultiProcessingSharedPool()
|
||
result_queue.put(pool.exists(key))
|
||
|
||
|
||
def worker_modify_data(key, new_value, result_queue):
|
||
"""子进程:修改数据并返回旧值"""
|
||
pool = MultiProcessingSharedPool()
|
||
old_value = pool.get(key)
|
||
pool.put(key, new_value)
|
||
result_queue.put(old_value)
|
||
|
||
|
||
def worker_wait_and_get(key, wait_time, result_queue):
|
||
"""子进程:等待一段时间后读取数据"""
|
||
time.sleep(wait_time)
|
||
pool = MultiProcessingSharedPool()
|
||
result_queue.put(pool.get(key))
|
||
|
||
|
||
def worker_increment_counter(key, iterations):
|
||
"""子进程:对计数器进行递增"""
|
||
pool = MultiProcessingSharedPool()
|
||
for _ in range(iterations):
|
||
# 注意:这不是原子操作,仅用于测试并发访问
|
||
current = pool.get(key, 0)
|
||
pool.put(key, current + 1)
|
||
|
||
|
||
def worker_pop_data(key, result_queue):
|
||
"""子进程:弹出数据"""
|
||
pool = MultiProcessingSharedPool()
|
||
value = pool.pop(key)
|
||
result_queue.put(value)
|
||
|
||
|
||
# ==================== 测试类 ====================
|
||
|
||
|
||
class TestParentChildProcess:
|
||
"""测试父子进程间数据传递"""
|
||
|
||
def test_parent_write_child_read(self):
|
||
"""测试父进程写入,子进程读取"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 父进程写入数据
|
||
pool.put("shared_key", "shared_value")
|
||
|
||
# 子进程读取数据
|
||
result_queue = multiprocessing.Queue()
|
||
p = multiprocessing.Process(
|
||
target=worker_get_data, args=("shared_key", result_queue)
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
result = result_queue.get()
|
||
assert result == "shared_value"
|
||
|
||
def test_child_write_parent_read(self):
|
||
"""测试子进程写入,父进程读取"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 子进程写入数据
|
||
p = multiprocessing.Process(
|
||
target=worker_put_data, args=("child_key", "child_value")
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
# 父进程读取数据(需要短暂等待以确保数据同步)
|
||
time.sleep(0.1)
|
||
result = pool.get("child_key")
|
||
assert result == "child_value"
|
||
|
||
def test_parent_child_data_isolation(self):
|
||
"""测试父子进程数据隔离 - 验证 manager.dict 的同步机制"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 父进程写入初始数据
|
||
pool.put("isolation_test", "parent_value")
|
||
|
||
# 子进程读取并修改
|
||
result_queue = multiprocessing.Queue()
|
||
p = multiprocessing.Process(
|
||
target=worker_modify_data,
|
||
args=("isolation_test", "child_value", result_queue),
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
# 验证子进程读取到了父进程的值
|
||
child_read = result_queue.get()
|
||
assert child_read == "parent_value"
|
||
|
||
# 验证父进程可以看到子进程的修改
|
||
time.sleep(0.1)
|
||
parent_read = pool.get("isolation_test")
|
||
assert parent_read == "child_value"
|
||
|
||
|
||
class TestMultipleChildProcesses:
|
||
"""测试多个子进程间的数据共享"""
|
||
|
||
def test_multiple_children_write_same_key(self):
|
||
"""测试多个子进程写入同一个 key(后面的覆盖前面的)"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
processes = []
|
||
for i in range(5):
|
||
p = multiprocessing.Process(
|
||
target=worker_put_data, args=("shared_key", f"value_{i}")
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
# 短暂等待确保所有写入完成
|
||
time.sleep(0.1)
|
||
|
||
# 验证有一个值被成功写入
|
||
result = pool.get("shared_key")
|
||
assert result.startswith("value_")
|
||
|
||
def test_multiple_children_write_different_keys(self):
|
||
"""测试多个子进程写入不同的 key"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
num_processes = 5
|
||
processes = []
|
||
|
||
for i in range(num_processes):
|
||
p = multiprocessing.Process(
|
||
target=worker_put_data, args=(f"key_{i}", f"value_{i}")
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
# 短暂等待确保所有写入完成
|
||
time.sleep(0.1)
|
||
|
||
# 验证所有值都被成功写入
|
||
for i in range(num_processes):
|
||
assert pool.get(f"key_{i}") == f"value_{i}"
|
||
|
||
def test_multiple_children_read_same_key(self):
|
||
"""测试多个子进程读取同一个 key"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 父进程写入数据
|
||
pool.put("shared_key", "shared_value")
|
||
|
||
# 多个子进程读取
|
||
num_processes = 5
|
||
result_queue = multiprocessing.Queue()
|
||
processes = []
|
||
|
||
for _ in range(num_processes):
|
||
p = multiprocessing.Process(
|
||
target=worker_get_data, args=("shared_key", result_queue)
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
# 收集所有结果
|
||
results = []
|
||
for _ in range(num_processes):
|
||
results.append(result_queue.get())
|
||
|
||
# 所有子进程都应该读取到相同的值
|
||
assert all(r == "shared_value" for r in results)
|
||
|
||
def test_concurrent_exists_check(self):
|
||
"""测试并发检查 key 是否存在"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 写入一些数据
|
||
for i in range(5):
|
||
pool.put(f"key_{i}", f"value_{i}")
|
||
|
||
# 多个子进程并发检查
|
||
result_queue = multiprocessing.Queue()
|
||
processes = []
|
||
|
||
for i in range(10):
|
||
key = f"key_{i % 7}" # 有些 key 存在,有些不存在
|
||
p = multiprocessing.Process(
|
||
target=worker_check_exists, args=(key, result_queue)
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
# 收集结果
|
||
results = []
|
||
for _ in range(10):
|
||
results.append(result_queue.get())
|
||
|
||
# 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查)
|
||
assert sum(results) >= 5 # 至少 5 个存在
|
||
|
||
|
||
# 进程池测试需要在模块级别定义 worker 函数
|
||
def _pool_worker_map(args):
|
||
"""进程池 map 操作的 worker"""
|
||
idx, key = args
|
||
shared_pool = MultiProcessingSharedPool()
|
||
value = shared_pool.get(key)
|
||
return idx, value
|
||
|
||
|
||
def _pool_worker_apply(key):
|
||
"""进程池 apply_async 的 worker"""
|
||
shared_pool = MultiProcessingSharedPool()
|
||
return shared_pool.get(key)
|
||
|
||
|
||
class TestProcessPool:
|
||
"""测试在进程池中使用"""
|
||
|
||
def test_pool_map_with_shared_data(self):
|
||
"""测试在进程池 map 操作中使用共享数据"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 写入测试数据
|
||
for i in range(5):
|
||
pool.put(f"input_{i}", i * 10)
|
||
|
||
# 使用进程池
|
||
with multiprocessing.Pool(processes=3) as process_pool:
|
||
results = process_pool.map(
|
||
_pool_worker_map, [(i, f"input_{i}") for i in range(5)]
|
||
)
|
||
|
||
# 验证结果
|
||
results_dict = {idx: val for idx, val in results}
|
||
for i in range(5):
|
||
assert results_dict[i] == i * 10
|
||
|
||
def test_pool_apply_async_with_shared_data(self):
|
||
"""测试在进程池 apply_async 中使用共享数据"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("async_key", "async_value")
|
||
|
||
with multiprocessing.Pool(processes=2) as process_pool:
|
||
result = process_pool.apply_async(_pool_worker_apply, ("async_key",))
|
||
assert result.get(timeout=5) == "async_value"
|
||
|
||
|
||
class TestDataVisibility:
|
||
"""测试数据可见性和同步时机"""
|
||
|
||
def test_immediate_visibility_after_put(self):
|
||
"""测试写入后立即可见"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("immediate_key", "immediate_value")
|
||
# 同一进程内应该立即可见
|
||
assert pool.exists("immediate_key")
|
||
assert pool.get("immediate_key") == "immediate_value"
|
||
|
||
def test_cross_process_visibility_with_delay(self):
|
||
"""测试跨进程可见性(带延迟)"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 父进程写入
|
||
pool.put("delayed_key", "delayed_value")
|
||
|
||
# 子进程延迟后读取
|
||
result_queue = multiprocessing.Queue()
|
||
p = multiprocessing.Process(
|
||
target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue)
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
result = result_queue.get()
|
||
assert result == "delayed_value"
|
||
|
||
|
||
class TestConcurrentModifications:
|
||
"""测试并发修改场景"""
|
||
|
||
def test_concurrent_counter_increments(self):
|
||
"""测试并发计数器递增(非原子操作,预期会有竞争条件)"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 初始化计数器
|
||
pool.put("counter", 0)
|
||
|
||
num_processes = 4
|
||
iterations_per_process = 10
|
||
|
||
processes = []
|
||
for _ in range(num_processes):
|
||
p = multiprocessing.Process(
|
||
target=worker_increment_counter,
|
||
args=("counter", iterations_per_process),
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
time.sleep(0.1)
|
||
|
||
# 由于竞争条件,实际值可能小于期望值
|
||
# 这个测试主要是为了验证并发访问不会崩溃
|
||
final_count = pool.get("counter")
|
||
assert isinstance(final_count, int)
|
||
assert 0 <= final_count <= num_processes * iterations_per_process
|
||
|
||
def test_concurrent_pop_operations(self):
|
||
"""测试并发 pop 操作"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 初始化多个 key
|
||
num_keys = 5
|
||
for i in range(num_keys):
|
||
pool.put(f"pop_key_{i}", f"pop_value_{i}")
|
||
|
||
result_queue = multiprocessing.Queue()
|
||
processes = []
|
||
|
||
for i in range(num_keys):
|
||
p = multiprocessing.Process(
|
||
target=worker_pop_data, args=(f"pop_key_{i}", result_queue)
|
||
)
|
||
processes.append(p)
|
||
p.start()
|
||
|
||
for p in processes:
|
||
p.join()
|
||
|
||
# 收集所有 pop 的结果
|
||
popped_values = []
|
||
for _ in range(num_keys):
|
||
popped_values.append(result_queue.get())
|
||
|
||
# 验证所有值都被正确 pop
|
||
assert len(popped_values) == num_keys
|
||
for i in range(num_keys):
|
||
assert f"pop_value_{i}" in popped_values
|
||
|
||
# 验证所有 key 都被移除
|
||
time.sleep(0.1)
|
||
for i in range(num_keys):
|
||
assert not pool.exists(f"pop_key_{i}")
|
||
|
||
|
||
class TestComplexDataTypes:
|
||
"""测试复杂数据类型的多进程共享"""
|
||
|
||
def test_share_nested_dict(self):
|
||
"""测试共享嵌套字典"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
nested_data = {
|
||
"level1": {"level2": {"level3": [1, 2, 3]}},
|
||
"list_of_dicts": [{"a": 1}, {"b": 2}],
|
||
}
|
||
|
||
pool.put("nested_key", nested_data)
|
||
|
||
result_queue = multiprocessing.Queue()
|
||
p = multiprocessing.Process(
|
||
target=worker_get_data, args=("nested_key", result_queue)
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
result = result_queue.get()
|
||
assert result == nested_data
|
||
|
||
def test_share_large_list(self):
|
||
"""测试共享大型列表"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
large_list = list(range(10000))
|
||
pool.put("large_list_key", large_list)
|
||
|
||
result_queue = multiprocessing.Queue()
|
||
p = multiprocessing.Process(
|
||
target=worker_get_data, args=("large_list_key", result_queue)
|
||
)
|
||
p.start()
|
||
p.join()
|
||
|
||
result = result_queue.get()
|
||
assert result == large_list
|