Force explicit allowing initialization in child process.
Change: in current implementation, pools created in child processes COULD NOT share data with main process. Thus an exception will be raised by default when attemp to create pool in a child process, which can be suppressed with `allow_child_init` parameter.
This commit is contained in:
123
test/test_child_process_creation.py
Normal file
123
test/test_child_process_creation.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
子进程创建 pool 的行为控制测试
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# 子进程创建 pool 相关的测试辅助函数
|
||||
def worker_create_pool_in_child(result_queue):
|
||||
"""子进程:尝试创建 pool(应该失败)"""
|
||||
try:
|
||||
# 重置单例以模拟子进程首次创建
|
||||
MultiProcessingSharedPool._instance = None
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.put("child_key", "child_value")
|
||||
result_queue.put(("success", "created"))
|
||||
except RuntimeError as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
def worker_create_pool_with_allow(result_queue):
|
||||
"""子进程:使用 allow_child_init=True 创建 pool"""
|
||||
try:
|
||||
# 重置单例以模拟子进程首次创建
|
||||
MultiProcessingSharedPool._instance = None
|
||||
pool = MultiProcessingSharedPool(allow_child_init=True)
|
||||
pool.put("child_key", "child_value")
|
||||
result_queue.put(("success", "created_with_allow"))
|
||||
except RuntimeError as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
def worker_get_data(key, result_queue):
|
||||
"""子进程:从共享池读取数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
value = pool.get(key)
|
||||
result_queue.put(value)
|
||||
|
||||
|
||||
class TestChildProcessCreation:
|
||||
"""测试子进程中创建 pool 的行为控制"""
|
||||
|
||||
def test_child_process_creation_blocked_by_default(self):
|
||||
"""测试默认情况下子进程创建 pool 会被阻止"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_create_pool_in_child, args=(result_queue,)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, message = result_queue.get()
|
||||
assert status == "error"
|
||||
assert (
|
||||
"Cannot initialize MultiProcessingSharedPool in a child process" in message
|
||||
)
|
||||
|
||||
def test_child_process_creation_allowed_with_flag(self):
|
||||
"""测试使用 allow_child_init=True 允许子进程创建 pool"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_create_pool_with_allow, args=(result_queue,)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, message = result_queue.get()
|
||||
assert status == "success"
|
||||
assert message == "created_with_allow"
|
||||
|
||||
def test_allow_child_init_parameter(self):
|
||||
"""测试 allow_child_init 参数可以被正确设置"""
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
|
||||
# 默认情况下不允许
|
||||
pool1 = MultiProcessingSharedPool()
|
||||
assert MultiProcessingSharedPool._allow_child_init is False
|
||||
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
# 显式设置为允许
|
||||
pool2 = MultiProcessingSharedPool(allow_child_init=True)
|
||||
assert MultiProcessingSharedPool._allow_child_init is True
|
||||
|
||||
# 清理
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
def test_best_practice_main_process_init(self):
|
||||
"""测试最佳实践:主进程先初始化,子进程后使用"""
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
# 主进程先初始化
|
||||
main_pool = MultiProcessingSharedPool()
|
||||
main_pool._ensure_initialized()
|
||||
main_pool.put("shared_key", "shared_value")
|
||||
|
||||
# 子进程使用(不创建新的)
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("shared_key", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == "shared_value"
|
||||
|
||||
# 清理
|
||||
MultiProcessingSharedPool._instance = None
|
||||
Reference in New Issue
Block a user