Force explicit allowing initialization in child process.
Change: in current implementation, pools created in child processes COULD NOT share data with main process. Thus an exception will be raised by default when attemp to create pool in a child process, which can be suppressed with `allow_child_init` parameter.
This commit is contained in:
@ -52,15 +52,26 @@ class MultiProcessingSharedPool:
|
||||
The singleton instance.
|
||||
_lock : threading.Lock
|
||||
Lock for thread-safe singleton instantiation.
|
||||
_allow_child_init : bool
|
||||
Whether to allow initialization in child processes.
|
||||
"""
|
||||
|
||||
_instance: Optional["MultiProcessingSharedPool"] = None
|
||||
_lock = threading.Lock()
|
||||
_allow_child_init: bool = False
|
||||
|
||||
def __new__(cls) -> "MultiProcessingSharedPool":
|
||||
def __new__(cls, allow_child_init: bool = False) -> "MultiProcessingSharedPool":
|
||||
"""
|
||||
Ensure singleton instance creation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
allow_child_init : bool, optional
|
||||
If True, allows the pool to be initialized in a child process.
|
||||
Default is False because data created in child processes won't be
|
||||
shared with the parent process (each process has its own Manager).
|
||||
Only set to True if you understand this limitation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
MultiProcessingSharedPool
|
||||
@ -71,13 +82,20 @@ class MultiProcessingSharedPool:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(MultiProcessingSharedPool, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
cls._allow_child_init = allow_child_init
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, allow_child_init: bool = False):
|
||||
"""
|
||||
Initialize the MultiProcessingSharedPool instance.
|
||||
|
||||
Uses a flag to ensure initialization only happens once.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
allow_child_init : bool, optional
|
||||
If True, allows the pool to be initialized in a child process.
|
||||
Note: This parameter only takes effect on the first instantiation.
|
||||
"""
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
@ -103,6 +121,16 @@ class MultiProcessingSharedPool:
|
||||
"""
|
||||
return cls()
|
||||
|
||||
def _is_child_process(self) -> bool:
|
||||
"""Check if the current process is a child process."""
|
||||
try:
|
||||
# Python 3.8+ has parent_process()
|
||||
return multiprocessing.parent_process() is not None
|
||||
except AttributeError:
|
||||
# Fallback for older Python versions
|
||||
# Check if current pid differs from owner pid
|
||||
return os.getpid() != self._owner_pid
|
||||
|
||||
def _ensure_initialized(self):
|
||||
"""
|
||||
Lazy initialization of the multiprocessing Manager and shared dictionary.
|
||||
@ -111,10 +139,24 @@ class MultiProcessingSharedPool:
|
||||
------
|
||||
RuntimeError
|
||||
If the multiprocessing Manager fails to start.
|
||||
RuntimeError
|
||||
If attempting to initialize in a child process without allow_child_init=True.
|
||||
"""
|
||||
if self._shared_dict is None:
|
||||
with self._init_lock:
|
||||
if self._shared_dict is None:
|
||||
# Check if we are in a child process
|
||||
if self._is_child_process() and not self._allow_child_init:
|
||||
raise RuntimeError(
|
||||
"Cannot initialize MultiProcessingSharedPool in a child process. "
|
||||
"Data created in child processes won't be shared with the parent process "
|
||||
"because each process has its own Manager instance. "
|
||||
"To suppress this error and proceed anyway, create the pool with "
|
||||
"MultiProcessingSharedPool(allow_child_init=True). "
|
||||
"Best practice: Initialize the pool in the main process before "
|
||||
"starting any child processes."
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the default context for manager
|
||||
self._manager = multiprocessing.Manager()
|
||||
|
||||
123
test/test_child_process_creation.py
Normal file
123
test/test_child_process_creation.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
子进程创建 pool 的行为控制测试
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# 子进程创建 pool 相关的测试辅助函数
|
||||
def worker_create_pool_in_child(result_queue):
|
||||
"""子进程:尝试创建 pool(应该失败)"""
|
||||
try:
|
||||
# 重置单例以模拟子进程首次创建
|
||||
MultiProcessingSharedPool._instance = None
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.put("child_key", "child_value")
|
||||
result_queue.put(("success", "created"))
|
||||
except RuntimeError as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
def worker_create_pool_with_allow(result_queue):
|
||||
"""子进程:使用 allow_child_init=True 创建 pool"""
|
||||
try:
|
||||
# 重置单例以模拟子进程首次创建
|
||||
MultiProcessingSharedPool._instance = None
|
||||
pool = MultiProcessingSharedPool(allow_child_init=True)
|
||||
pool.put("child_key", "child_value")
|
||||
result_queue.put(("success", "created_with_allow"))
|
||||
except RuntimeError as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
def worker_get_data(key, result_queue):
|
||||
"""子进程:从共享池读取数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
value = pool.get(key)
|
||||
result_queue.put(value)
|
||||
|
||||
|
||||
class TestChildProcessCreation:
|
||||
"""测试子进程中创建 pool 的行为控制"""
|
||||
|
||||
def test_child_process_creation_blocked_by_default(self):
|
||||
"""测试默认情况下子进程创建 pool 会被阻止"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_create_pool_in_child, args=(result_queue,)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, message = result_queue.get()
|
||||
assert status == "error"
|
||||
assert (
|
||||
"Cannot initialize MultiProcessingSharedPool in a child process" in message
|
||||
)
|
||||
|
||||
def test_child_process_creation_allowed_with_flag(self):
|
||||
"""测试使用 allow_child_init=True 允许子进程创建 pool"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_create_pool_with_allow, args=(result_queue,)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, message = result_queue.get()
|
||||
assert status == "success"
|
||||
assert message == "created_with_allow"
|
||||
|
||||
def test_allow_child_init_parameter(self):
|
||||
"""测试 allow_child_init 参数可以被正确设置"""
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
|
||||
# 默认情况下不允许
|
||||
pool1 = MultiProcessingSharedPool()
|
||||
assert MultiProcessingSharedPool._allow_child_init is False
|
||||
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
# 显式设置为允许
|
||||
pool2 = MultiProcessingSharedPool(allow_child_init=True)
|
||||
assert MultiProcessingSharedPool._allow_child_init is True
|
||||
|
||||
# 清理
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
def test_best_practice_main_process_init(self):
|
||||
"""测试最佳实践:主进程先初始化,子进程后使用"""
|
||||
# 重置单例
|
||||
MultiProcessingSharedPool._instance = None
|
||||
MultiProcessingSharedPool._allow_child_init = False
|
||||
|
||||
# 主进程先初始化
|
||||
main_pool = MultiProcessingSharedPool()
|
||||
main_pool._ensure_initialized()
|
||||
main_pool.put("shared_key", "shared_value")
|
||||
|
||||
# 子进程使用(不创建新的)
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("shared_key", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == "shared_value"
|
||||
|
||||
# 清理
|
||||
MultiProcessingSharedPool._instance = None
|
||||
Reference in New Issue
Block a user