""" 多进程数据共享测试 - 验证跨进程数据同步能力 """ import multiprocessing import time import pytest from mpsp.mpsp import MultiProcessingSharedPool # ==================== 辅助函数(需要在模块级别定义以便多进程使用)==================== def worker_put_data(key, value): """子进程:往共享池写入数据""" pool = MultiProcessingSharedPool() pool.put(key, value) return True def worker_get_data(key, result_queue): """子进程:从共享池读取数据并放入结果队列""" pool = MultiProcessingSharedPool() value = pool.get(key) result_queue.put(value) def worker_check_exists(key, result_queue): """子进程:检查 key 是否存在""" pool = MultiProcessingSharedPool() result_queue.put(pool.exists(key)) def worker_modify_data(key, new_value, result_queue): """子进程:修改数据并返回旧值""" pool = MultiProcessingSharedPool() old_value = pool.get(key) pool.put(key, new_value) result_queue.put(old_value) def worker_wait_and_get(key, wait_time, result_queue): """子进程:等待一段时间后读取数据""" time.sleep(wait_time) pool = MultiProcessingSharedPool() result_queue.put(pool.get(key)) def worker_increment_counter(key, iterations): """子进程:对计数器进行递增""" pool = MultiProcessingSharedPool() for _ in range(iterations): # 注意:这不是原子操作,仅用于测试并发访问 current = pool.get(key, 0) pool.put(key, current + 1) def worker_pop_data(key, result_queue): """子进程:弹出数据""" pool = MultiProcessingSharedPool() value = pool.pop(key) result_queue.put(value) # ==================== 测试类 ==================== class TestParentChildProcess: """测试父子进程间数据传递""" def test_parent_write_child_read(self): """测试父进程写入,子进程读取""" pool = MultiProcessingSharedPool() pool.clear() # 父进程写入数据 pool.put("shared_key", "shared_value") # 子进程读取数据 result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_data, args=("shared_key", result_queue) ) p.start() p.join() result = result_queue.get() assert result == "shared_value" def test_child_write_parent_read(self): """测试子进程写入,父进程读取""" pool = MultiProcessingSharedPool() pool.clear() # 子进程写入数据 p = multiprocessing.Process( target=worker_put_data, args=("child_key", "child_value") ) p.start() p.join() # 父进程读取数据(需要短暂等待以确保数据同步) time.sleep(0.1) result = pool.get("child_key") assert result == "child_value" def test_parent_child_data_isolation(self): """测试父子进程数据隔离 - 验证 manager.dict 的同步机制""" pool = MultiProcessingSharedPool() pool.clear() # 父进程写入初始数据 pool.put("isolation_test", "parent_value") # 子进程读取并修改 result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_modify_data, args=("isolation_test", "child_value", result_queue), ) p.start() p.join() # 验证子进程读取到了父进程的值 child_read = result_queue.get() assert child_read == "parent_value" # 验证父进程可以看到子进程的修改 time.sleep(0.1) parent_read = pool.get("isolation_test") assert parent_read == "child_value" class TestMultipleChildProcesses: """测试多个子进程间的数据共享""" def test_multiple_children_write_same_key(self): """测试多个子进程写入同一个 key(后面的覆盖前面的)""" pool = MultiProcessingSharedPool() pool.clear() processes = [] for i in range(5): p = multiprocessing.Process( target=worker_put_data, args=("shared_key", f"value_{i}") ) processes.append(p) p.start() for p in processes: p.join() # 短暂等待确保所有写入完成 time.sleep(0.1) # 验证有一个值被成功写入 result = pool.get("shared_key") assert result.startswith("value_") def test_multiple_children_write_different_keys(self): """测试多个子进程写入不同的 key""" pool = MultiProcessingSharedPool() pool.clear() num_processes = 5 processes = [] for i in range(num_processes): p = multiprocessing.Process( target=worker_put_data, args=(f"key_{i}", f"value_{i}") ) processes.append(p) p.start() for p in processes: p.join() # 短暂等待确保所有写入完成 time.sleep(0.1) # 验证所有值都被成功写入 for i in range(num_processes): assert pool.get(f"key_{i}") == f"value_{i}" def test_multiple_children_read_same_key(self): """测试多个子进程读取同一个 key""" pool = MultiProcessingSharedPool() pool.clear() # 父进程写入数据 pool.put("shared_key", "shared_value") # 多个子进程读取 num_processes = 5 result_queue = multiprocessing.Queue() processes = [] for _ in range(num_processes): p = multiprocessing.Process( target=worker_get_data, args=("shared_key", result_queue) ) processes.append(p) p.start() for p in processes: p.join() # 收集所有结果 results = [] for _ in range(num_processes): results.append(result_queue.get()) # 所有子进程都应该读取到相同的值 assert all(r == "shared_value" for r in results) def test_concurrent_exists_check(self): """测试并发检查 key 是否存在""" pool = MultiProcessingSharedPool() pool.clear() # 写入一些数据 for i in range(5): pool.put(f"key_{i}", f"value_{i}") # 多个子进程并发检查 result_queue = multiprocessing.Queue() processes = [] for i in range(10): key = f"key_{i % 7}" # 有些 key 存在,有些不存在 p = multiprocessing.Process( target=worker_check_exists, args=(key, result_queue) ) processes.append(p) p.start() for p in processes: p.join() # 收集结果 results = [] for _ in range(10): results.append(result_queue.get()) # 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查) assert sum(results) >= 5 # 至少 5 个存在 # 进程池测试需要在模块级别定义 worker 函数 def _pool_worker_map(args): """进程池 map 操作的 worker""" idx, key = args shared_pool = MultiProcessingSharedPool() value = shared_pool.get(key) return idx, value def _pool_worker_apply(key): """进程池 apply_async 的 worker""" shared_pool = MultiProcessingSharedPool() return shared_pool.get(key) class TestProcessPool: """测试在进程池中使用""" def test_pool_map_with_shared_data(self): """测试在进程池 map 操作中使用共享数据""" pool = MultiProcessingSharedPool() pool.clear() # 写入测试数据 for i in range(5): pool.put(f"input_{i}", i * 10) # 使用进程池 with multiprocessing.Pool(processes=3) as process_pool: results = process_pool.map( _pool_worker_map, [(i, f"input_{i}") for i in range(5)] ) # 验证结果 results_dict = {idx: val for idx, val in results} for i in range(5): assert results_dict[i] == i * 10 def test_pool_apply_async_with_shared_data(self): """测试在进程池 apply_async 中使用共享数据""" pool = MultiProcessingSharedPool() pool.clear() pool.put("async_key", "async_value") with multiprocessing.Pool(processes=2) as process_pool: result = process_pool.apply_async(_pool_worker_apply, ("async_key",)) assert result.get(timeout=5) == "async_value" class TestDataVisibility: """测试数据可见性和同步时机""" def test_immediate_visibility_after_put(self): """测试写入后立即可见""" pool = MultiProcessingSharedPool() pool.clear() pool.put("immediate_key", "immediate_value") # 同一进程内应该立即可见 assert pool.exists("immediate_key") assert pool.get("immediate_key") == "immediate_value" def test_cross_process_visibility_with_delay(self): """测试跨进程可见性(带延迟)""" pool = MultiProcessingSharedPool() pool.clear() # 父进程写入 pool.put("delayed_key", "delayed_value") # 子进程延迟后读取 result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue) ) p.start() p.join() result = result_queue.get() assert result == "delayed_value" class TestConcurrentModifications: """测试并发修改场景""" def test_concurrent_counter_increments(self): """测试并发计数器递增(非原子操作,预期会有竞争条件)""" pool = MultiProcessingSharedPool() pool.clear() # 初始化计数器 pool.put("counter", 0) num_processes = 4 iterations_per_process = 10 processes = [] for _ in range(num_processes): p = multiprocessing.Process( target=worker_increment_counter, args=("counter", iterations_per_process), ) processes.append(p) p.start() for p in processes: p.join() time.sleep(0.1) # 由于竞争条件,实际值可能小于期望值 # 这个测试主要是为了验证并发访问不会崩溃 final_count = pool.get("counter") assert isinstance(final_count, int) assert 0 <= final_count <= num_processes * iterations_per_process def test_concurrent_pop_operations(self): """测试并发 pop 操作""" pool = MultiProcessingSharedPool() pool.clear() # 初始化多个 key num_keys = 5 for i in range(num_keys): pool.put(f"pop_key_{i}", f"pop_value_{i}") result_queue = multiprocessing.Queue() processes = [] for i in range(num_keys): p = multiprocessing.Process( target=worker_pop_data, args=(f"pop_key_{i}", result_queue) ) processes.append(p) p.start() for p in processes: p.join() # 收集所有 pop 的结果 popped_values = [] for _ in range(num_keys): popped_values.append(result_queue.get()) # 验证所有值都被正确 pop assert len(popped_values) == num_keys for i in range(num_keys): assert f"pop_value_{i}" in popped_values # 验证所有 key 都被移除 time.sleep(0.1) for i in range(num_keys): assert not pool.exists(f"pop_key_{i}") class TestComplexDataTypes: """测试复杂数据类型的多进程共享""" def test_share_nested_dict(self): """测试共享嵌套字典""" pool = MultiProcessingSharedPool() pool.clear() nested_data = { "level1": {"level2": {"level3": [1, 2, 3]}}, "list_of_dicts": [{"a": 1}, {"b": 2}], } pool.put("nested_key", nested_data) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_data, args=("nested_key", result_queue) ) p.start() p.join() result = result_queue.get() assert result == nested_data def test_share_large_list(self): """测试共享大型列表""" pool = MultiProcessingSharedPool() pool.clear() large_list = list(range(10000)) pool.put("large_list_key", large_list) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_data, args=("large_list_key", result_queue) ) p.start() p.join() result = result_queue.get() assert result == large_list