Initial commit.

This commit is contained in:
2026-02-21 12:00:47 +08:00
commit cd335c1b3f
14 changed files with 3492 additions and 0 deletions

446
test/test_multiprocess.py Normal file
View File

@ -0,0 +1,446 @@
"""
多进程数据共享测试 - 验证跨进程数据同步能力
"""
import multiprocessing
import time
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数(需要在模块级别定义以便多进程使用)====================
def worker_put_data(key, value):
"""子进程:往共享池写入数据"""
pool = MultiProcessingSharedPool()
pool.put(key, value)
return True
def worker_get_data(key, result_queue):
"""子进程:从共享池读取数据并放入结果队列"""
pool = MultiProcessingSharedPool()
value = pool.get(key)
result_queue.put(value)
def worker_check_exists(key, result_queue):
"""子进程:检查 key 是否存在"""
pool = MultiProcessingSharedPool()
result_queue.put(pool.exists(key))
def worker_modify_data(key, new_value, result_queue):
"""子进程:修改数据并返回旧值"""
pool = MultiProcessingSharedPool()
old_value = pool.get(key)
pool.put(key, new_value)
result_queue.put(old_value)
def worker_wait_and_get(key, wait_time, result_queue):
"""子进程:等待一段时间后读取数据"""
time.sleep(wait_time)
pool = MultiProcessingSharedPool()
result_queue.put(pool.get(key))
def worker_increment_counter(key, iterations):
"""子进程:对计数器进行递增"""
pool = MultiProcessingSharedPool()
for _ in range(iterations):
# 注意:这不是原子操作,仅用于测试并发访问
current = pool.get(key, 0)
pool.put(key, current + 1)
def worker_pop_data(key, result_queue):
"""子进程:弹出数据"""
pool = MultiProcessingSharedPool()
value = pool.pop(key)
result_queue.put(value)
# ==================== 测试类 ====================
class TestParentChildProcess:
"""测试父子进程间数据传递"""
def test_parent_write_child_read(self):
"""测试父进程写入,子进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 子进程读取数据
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "shared_value"
def test_child_write_parent_read(self):
"""测试子进程写入,父进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 子进程写入数据
p = multiprocessing.Process(
target=worker_put_data, args=("child_key", "child_value")
)
p.start()
p.join()
# 父进程读取数据(需要短暂等待以确保数据同步)
time.sleep(0.1)
result = pool.get("child_key")
assert result == "child_value"
def test_parent_child_data_isolation(self):
"""测试父子进程数据隔离 - 验证 manager.dict 的同步机制"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入初始数据
pool.put("isolation_test", "parent_value")
# 子进程读取并修改
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_modify_data,
args=("isolation_test", "child_value", result_queue),
)
p.start()
p.join()
# 验证子进程读取到了父进程的值
child_read = result_queue.get()
assert child_read == "parent_value"
# 验证父进程可以看到子进程的修改
time.sleep(0.1)
parent_read = pool.get("isolation_test")
assert parent_read == "child_value"
class TestMultipleChildProcesses:
"""测试多个子进程间的数据共享"""
def test_multiple_children_write_same_key(self):
"""测试多个子进程写入同一个 key后面的覆盖前面的"""
pool = MultiProcessingSharedPool()
pool.clear()
processes = []
for i in range(5):
p = multiprocessing.Process(
target=worker_put_data, args=("shared_key", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证有一个值被成功写入
result = pool.get("shared_key")
assert result.startswith("value_")
def test_multiple_children_write_different_keys(self):
"""测试多个子进程写入不同的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
num_processes = 5
processes = []
for i in range(num_processes):
p = multiprocessing.Process(
target=worker_put_data, args=(f"key_{i}", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证所有值都被成功写入
for i in range(num_processes):
assert pool.get(f"key_{i}") == f"value_{i}"
def test_multiple_children_read_same_key(self):
"""测试多个子进程读取同一个 key"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 多个子进程读取
num_processes = 5
result_queue = multiprocessing.Queue()
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有结果
results = []
for _ in range(num_processes):
results.append(result_queue.get())
# 所有子进程都应该读取到相同的值
assert all(r == "shared_value" for r in results)
def test_concurrent_exists_check(self):
"""测试并发检查 key 是否存在"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入一些数据
for i in range(5):
pool.put(f"key_{i}", f"value_{i}")
# 多个子进程并发检查
result_queue = multiprocessing.Queue()
processes = []
for i in range(10):
key = f"key_{i % 7}" # 有些 key 存在,有些不存在
p = multiprocessing.Process(
target=worker_check_exists, args=(key, result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集结果
results = []
for _ in range(10):
results.append(result_queue.get())
# 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查)
assert sum(results) >= 5 # 至少 5 个存在
# 进程池测试需要在模块级别定义 worker 函数
def _pool_worker_map(args):
"""进程池 map 操作的 worker"""
idx, key = args
shared_pool = MultiProcessingSharedPool()
value = shared_pool.get(key)
return idx, value
def _pool_worker_apply(key):
"""进程池 apply_async 的 worker"""
shared_pool = MultiProcessingSharedPool()
return shared_pool.get(key)
class TestProcessPool:
"""测试在进程池中使用"""
def test_pool_map_with_shared_data(self):
"""测试在进程池 map 操作中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入测试数据
for i in range(5):
pool.put(f"input_{i}", i * 10)
# 使用进程池
with multiprocessing.Pool(processes=3) as process_pool:
results = process_pool.map(
_pool_worker_map, [(i, f"input_{i}") for i in range(5)]
)
# 验证结果
results_dict = {idx: val for idx, val in results}
for i in range(5):
assert results_dict[i] == i * 10
def test_pool_apply_async_with_shared_data(self):
"""测试在进程池 apply_async 中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("async_key", "async_value")
with multiprocessing.Pool(processes=2) as process_pool:
result = process_pool.apply_async(_pool_worker_apply, ("async_key",))
assert result.get(timeout=5) == "async_value"
class TestDataVisibility:
"""测试数据可见性和同步时机"""
def test_immediate_visibility_after_put(self):
"""测试写入后立即可见"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("immediate_key", "immediate_value")
# 同一进程内应该立即可见
assert pool.exists("immediate_key")
assert pool.get("immediate_key") == "immediate_value"
def test_cross_process_visibility_with_delay(self):
"""测试跨进程可见性(带延迟)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入
pool.put("delayed_key", "delayed_value")
# 子进程延迟后读取
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "delayed_value"
class TestConcurrentModifications:
"""测试并发修改场景"""
def test_concurrent_counter_increments(self):
"""测试并发计数器递增(非原子操作,预期会有竞争条件)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化计数器
pool.put("counter", 0)
num_processes = 4
iterations_per_process = 10
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_increment_counter,
args=("counter", iterations_per_process),
)
processes.append(p)
p.start()
for p in processes:
p.join()
time.sleep(0.1)
# 由于竞争条件,实际值可能小于期望值
# 这个测试主要是为了验证并发访问不会崩溃
final_count = pool.get("counter")
assert isinstance(final_count, int)
assert 0 <= final_count <= num_processes * iterations_per_process
def test_concurrent_pop_operations(self):
"""测试并发 pop 操作"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化多个 key
num_keys = 5
for i in range(num_keys):
pool.put(f"pop_key_{i}", f"pop_value_{i}")
result_queue = multiprocessing.Queue()
processes = []
for i in range(num_keys):
p = multiprocessing.Process(
target=worker_pop_data, args=(f"pop_key_{i}", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有 pop 的结果
popped_values = []
for _ in range(num_keys):
popped_values.append(result_queue.get())
# 验证所有值都被正确 pop
assert len(popped_values) == num_keys
for i in range(num_keys):
assert f"pop_value_{i}" in popped_values
# 验证所有 key 都被移除
time.sleep(0.1)
for i in range(num_keys):
assert not pool.exists(f"pop_key_{i}")
class TestComplexDataTypes:
"""测试复杂数据类型的多进程共享"""
def test_share_nested_dict(self):
"""测试共享嵌套字典"""
pool = MultiProcessingSharedPool()
pool.clear()
nested_data = {
"level1": {"level2": {"level3": [1, 2, 3]}},
"list_of_dicts": [{"a": 1}, {"b": 2}],
}
pool.put("nested_key", nested_data)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("nested_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == nested_data
def test_share_large_list(self):
"""测试共享大型列表"""
pool = MultiProcessingSharedPool()
pool.clear()
large_list = list(range(10000))
pool.put("large_list_key", large_list)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("large_list_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == large_list