Initial commit.

This commit is contained in:
2026-02-21 12:00:47 +08:00
commit cd335c1b3f
14 changed files with 3492 additions and 0 deletions

0
test/__init__.py Normal file
View File

46
test/conftest.py Normal file
View File

@ -0,0 +1,46 @@
"""
Pytest 配置文件
"""
import pytest
import multiprocessing
def pytest_configure(config):
"""配置 pytest"""
# 设置多进程启动方法fork 在 Linux 上更快spawn 在 Windows/macOS 上更稳定)
try:
multiprocessing.set_start_method("fork", force=True)
except RuntimeError:
# 如果已经设置过,忽略错误
pass
@pytest.fixture(scope="function", autouse=True)
def reset_shared_pool():
"""
每个测试函数执行前清理共享池
这是一个自动使用的 fixture确保每个测试都在干净的环境中运行
"""
from mpsp.mpsp import MultiProcessingSharedPool
pool = MultiProcessingSharedPool()
pool.clear()
yield
# 测试结束后也清理
pool.clear()
@pytest.fixture(scope="session")
def shared_pool():
"""
提供共享池实例的 fixture
在整个测试会话中复用同一个实例(单例模式)
"""
from mpsp.mpsp import MultiProcessingSharedPool
return MultiProcessingSharedPool()

324
test/test_basic.py Normal file
View File

@ -0,0 +1,324 @@
"""
基础功能测试 - 验证核心 API 正确性
"""
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
class TestBasicOperations:
"""测试基本操作"""
def test_singleton_pattern(self):
"""测试单例模式 - 多次获取应为同一实例"""
pool1 = MultiProcessingSharedPool()
pool2 = MultiProcessingSharedPool.get_instance()
pool3 = MultiProcessingSharedPool()
assert pool1 is pool2
assert pool2 is pool3
assert pool1 is pool3
def test_put_and_get_basic_types(self):
"""测试基础类型数据的存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 整数
pool.put("int_key", 42)
assert pool.get("int_key") == 42
# 浮点数
pool.put("float_key", 3.14159)
assert abs(pool.get("float_key") - 3.14159) < 1e-10
# 字符串
pool.put("str_key", "hello mpsp")
assert pool.get("str_key") == "hello mpsp"
# 布尔值
pool.put("bool_key", True)
assert pool.get("bool_key") is True
# None
pool.put("none_key", None)
assert pool.get("none_key") is None
def test_put_and_get_collections(self):
"""测试集合类型数据的存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 列表
test_list = [1, 2, 3, "a", "b", "c"]
pool.put("list_key", test_list)
assert pool.get("list_key") == test_list
# 字典
test_dict = {"name": "test", "value": 100, "nested": {"a": 1}}
pool.put("dict_key", test_dict)
assert pool.get("dict_key") == test_dict
# 元组
test_tuple = (1, 2, 3)
pool.put("tuple_key", test_tuple)
assert pool.get("tuple_key") == test_tuple
# 集合
test_set = {1, 2, 3}
pool.put("set_key", test_set)
assert pool.get("set_key") == test_set
def test_exists(self):
"""测试 exists 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert not pool.exists("nonexistent_key")
pool.put("existing_key", "value")
assert pool.exists("existing_key")
pool.remove("existing_key")
assert not pool.exists("existing_key")
def test_remove(self):
"""测试 remove 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
# 移除存在的 key
pool.put("key_to_remove", "value")
assert pool.remove("key_to_remove") is True
assert not pool.exists("key_to_remove")
# 移除不存在的 key
assert pool.remove("nonexistent_key") is False
def test_pop(self):
"""测试 pop 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
# pop 存在的 key
pool.put("key_to_pop", "popped_value")
value = pool.pop("key_to_pop")
assert value == "popped_value"
assert not pool.exists("key_to_pop")
# pop 不存在的 key带默认值
default_value = pool.pop("nonexistent_key", "default")
assert default_value == "default"
# pop 不存在的 key不带默认值
assert pool.pop("nonexistent_key") is None
def test_size(self):
"""测试 size 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.size() == 0
pool.put("key1", "value1")
assert pool.size() == 1
pool.put("key2", "value2")
pool.put("key3", "value3")
assert pool.size() == 3
pool.remove("key1")
assert pool.size() == 2
pool.clear()
assert pool.size() == 0
def test_keys(self):
"""测试 keys 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.keys() == []
pool.put("key1", "value1")
pool.put("key2", "value2")
pool.put("key3", "value3")
keys = pool.keys()
assert len(keys) == 3
assert set(keys) == {"key1", "key2", "key3"}
def test_clear(self):
"""测试 clear 方法"""
pool = MultiProcessingSharedPool()
pool.put("key1", "value1")
pool.put("key2", "value2")
pool.clear()
assert pool.size() == 0
assert pool.keys() == []
assert not pool.exists("key1")
class TestDictInterface:
"""测试字典风格接口"""
def test_getitem(self):
"""测试 __getitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("test_key", "test_value")
assert pool["test_key"] == "test_value"
# 访问不存在的 key 应抛出 KeyError
with pytest.raises(KeyError):
_ = pool["nonexistent_key"]
def test_setitem(self):
"""测试 __setitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool["new_key"] = "new_value"
assert pool.get("new_key") == "new_value"
def test_delitem(self):
"""测试 __delitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key_to_delete", "value")
del pool["key_to_delete"]
assert not pool.exists("key_to_delete")
# 删除不存在的 key 应抛出 KeyError
with pytest.raises(KeyError):
del pool["nonexistent_key"]
def test_contains(self):
"""测试 __contains__ (in 操作符)"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("existing_key", "value")
assert "existing_key" in pool
assert "nonexistent_key" not in pool
class TestContextManager:
"""测试上下文管理器"""
def test_context_manager(self):
"""测试 with 语句支持"""
pool = MultiProcessingSharedPool()
pool.clear()
with MultiProcessingSharedPool() as p:
p.put("ctx_key", "ctx_value")
assert p.get("ctx_key") == "ctx_value"
# 上下文退出后数据应该仍然存在(管理器未关闭)
assert pool.get("ctx_key") == "ctx_value"
class TestErrorHandling:
"""测试错误处理"""
def test_invalid_label_type_put(self):
"""测试 put 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.put(123, "value")
with pytest.raises(TypeError, match="Label must be a string"):
pool.put(None, "value")
def test_invalid_label_type_get(self):
"""测试 get 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.get(123)
def test_invalid_label_type_pop(self):
"""测试 pop 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.pop(123)
def test_get_with_default(self):
"""测试 get 带默认值的场景"""
pool = MultiProcessingSharedPool()
pool.clear()
# key 不存在时返回默认值
assert pool.get("nonexistent", "default") == "default"
assert pool.get("nonexistent", None) is None
assert pool.get("nonexistent") is None
# key 存在时返回实际值
pool.put("existing", "real_value")
assert pool.get("existing", "default") == "real_value"
def test_special_label_names(self):
"""测试特殊 label 名称"""
pool = MultiProcessingSharedPool()
pool.clear()
# 空字符串
pool.put("", "empty_string_key")
assert pool.get("") == "empty_string_key"
# 特殊字符
special_keys = [
"key with spaces",
"key\twith\ttabs",
"key\nwith\nnewlines",
"key/with/slashes",
"key.with.dots",
"key:with:colons",
"UPPERCASE_KEY",
"mixedCase_Key",
"unicode_中文_key",
"emoji_😀_key",
]
for key in special_keys:
pool.put(key, f"value_for_{key}")
assert pool.get(key) == f"value_for_{key}"
class TestOverwrite:
"""测试覆盖写入"""
def test_overwrite_value(self):
"""测试覆盖相同 key 的值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key", "original_value")
assert pool.get("key") == "original_value"
pool.put("key", "new_value")
assert pool.get("key") == "new_value"
# 不同类型覆盖
pool.put("key", 12345)
assert pool.get("key") == 12345
def test_overwrite_with_none(self):
"""测试用 None 覆盖有值的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key", "value")
pool.put("key", None)
assert pool.get("key") is None
assert pool.exists("key") # key 应该仍然存在

652
test/test_edge_cases.py Normal file
View File

@ -0,0 +1,652 @@
"""
边界与异常测试 - 验证鲁棒性
"""
import multiprocessing
import time
import pickle
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数 ====================
def worker_put_empty_key(result_queue):
"""子进程:测试空字符串 key"""
pool = MultiProcessingSharedPool()
try:
pool.put("", "empty_value")
result_queue.put(("success", pool.get("")))
except Exception as e:
result_queue.put(("error", str(e)))
def worker_get_nonexistent(result_queue):
"""子进程:获取不存在的 key"""
pool = MultiProcessingSharedPool()
result = pool.get("definitely_nonexistent_key_12345")
result_queue.put(result)
def worker_put_large_object(key, data, result_queue):
"""子进程:存储大对象"""
pool = MultiProcessingSharedPool()
try:
success = pool.put(key, data)
result_queue.put(("success", success))
except Exception as e:
result_queue.put(("error", str(e)))
# ==================== 测试类 ====================
class TestEmptyAndNoneValues:
"""测试空值和 None 处理"""
def test_put_empty_string_value(self):
"""测试存储空字符串值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_value_key", "")
assert pool.get("empty_value_key") == ""
def test_put_none_value(self):
"""测试存储 None 值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("none_value_key", None)
assert pool.get("none_value_key") is None
def test_put_empty_list(self):
"""测试存储空列表"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_list", [])
assert pool.get("empty_list") == []
def test_put_empty_dict(self):
"""测试存储空字典"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_dict", {})
assert pool.get("empty_dict") == {}
def test_put_empty_tuple(self):
"""测试存储空元组"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_tuple", ())
assert pool.get("empty_tuple") == ()
def test_put_empty_set(self):
"""测试存储空集合"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_set", set())
assert pool.get("empty_set") == set()
class TestSpecialLabelNames:
"""测试特殊 label 名称"""
def test_empty_string_label(self):
"""测试空字符串作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("", "empty_key_value")
assert pool.get("") == "empty_key_value"
assert pool.exists("")
def test_empty_string_label_cross_process(self):
"""测试空字符串 label 跨进程"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("", "parent_empty_value")
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_put_empty_key, args=(result_queue,))
p.start()
p.join()
status, result = result_queue.get()
# 子进程应该可以覆盖空字符串 key
assert status == "success"
assert result == "empty_value"
def test_unicode_label(self):
"""测试 Unicode 字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
unicode_keys = [
"中文键",
"日本語キー",
"한국어키",
"emoji_😀",
"special_©_®_™",
"math_∑_∏_√",
"arrows_→_←_↑_↓",
]
for key in unicode_keys:
pool.put(key, f"value_for_{key}")
for key in unicode_keys:
assert pool.get(key) == f"value_for_{key}"
def test_whitespace_label(self):
"""测试空白字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
whitespace_keys = [
" ", # 单个空格
" ", # 两个空格
"\t", # Tab
"\n", # 换行
"\r\n", # Windows 换行
" key_with_leading_space",
"key_with_trailing_space ",
" key_with_both_spaces ",
"key\twith\ttabs",
]
for key in whitespace_keys:
pool.put(key, f"value_for_repr_{repr(key)}")
for key in whitespace_keys:
assert pool.get(key) == f"value_for_repr_{repr(key)}"
def test_special_chars_label(self):
"""测试特殊字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
special_keys = [
"key.with.dots",
"key/with/slashes",
"key:with:colons",
"key|with|pipes",
"key*with*asterisks",
"key?with?question",
"key<with>brackets",
"key[with]square",
"key{with}curly",
"key+with+plus",
"key=with=equals",
"key!with!exclamation",
"key@with@at",
"key#with#hash",
"key$with$dollar",
"key%with%percent",
"key^with^caret",
"key&with&ampersand",
"key'with'quotes",
'key"with"double',
"key`with`backtick",
"key~with~tilde",
"key-with-hyphens",
"key_with_underscores",
]
for key in special_keys:
pool.put(key, f"value_for_{key}")
for key in special_keys:
assert pool.get(key) == f"value_for_{key}"
def test_very_long_label(self):
"""测试超长 label"""
pool = MultiProcessingSharedPool()
pool.clear()
# 1000 字符的 label
long_key = "a" * 1000
pool.put(long_key, "long_key_value")
assert pool.get(long_key) == "long_key_value"
class TestNonExistentKeys:
"""测试不存在的 key 处理"""
def test_get_nonexistent(self):
"""测试获取不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
result = pool.get("nonexistent_key_12345")
assert result is None
def test_get_nonexistent_with_default(self):
"""测试获取不存在的 key 带默认值"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.get("nonexistent", "default") == "default"
assert pool.get("nonexistent", 0) == 0
assert pool.get("nonexistent", []) == []
assert pool.get("nonexistent", {}) == {}
def test_get_nonexistent_cross_process(self):
"""测试跨进程获取不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_get_nonexistent, args=(result_queue,))
p.start()
p.join()
result = result_queue.get()
assert result is None
def test_remove_nonexistent(self):
"""测试删除不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.remove("nonexistent_key") is False
def test_pop_nonexistent(self):
"""测试弹出不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.pop("nonexistent_key") is None
assert pool.pop("nonexistent_key", "default") == "default"
def test_exists_nonexistent(self):
"""测试检查不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.exists("nonexistent_key") is False
class TestLargeObjects:
"""测试大对象序列化"""
def test_large_list(self):
"""测试大型列表"""
pool = MultiProcessingSharedPool()
pool.clear()
large_list = list(range(100000))
pool.put("large_list", large_list)
retrieved = pool.get("large_list")
assert len(retrieved) == 100000
assert retrieved[0] == 0
assert retrieved[99999] == 99999
def test_large_dict(self):
"""测试大型字典"""
pool = MultiProcessingSharedPool()
pool.clear()
large_dict = {f"key_{i}": f"value_{i}" for i in range(10000)}
pool.put("large_dict", large_dict)
retrieved = pool.get("large_dict")
assert len(retrieved) == 10000
assert retrieved["key_0"] == "value_0"
assert retrieved["key_9999"] == "value_9999"
def test_large_string(self):
"""测试大型字符串"""
pool = MultiProcessingSharedPool()
pool.clear()
large_string = "x" * 1000000 # 1MB 字符串
pool.put("large_string", large_string)
retrieved = pool.get("large_string")
assert len(retrieved) == 1000000
assert retrieved[0] == "x"
assert retrieved[-1] == "x"
def test_deeply_nested_structure(self):
"""测试深度嵌套结构"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建深度嵌套的字典
depth = 50
nested = "bottom"
for i in range(depth):
nested = {"level": i, "nested": nested}
pool.put("deep_nested", nested)
retrieved = pool.get("deep_nested")
# 验证嵌套深度
current = retrieved
for i in range(depth):
assert current["level"] == depth - 1 - i
current = current["nested"]
assert current == "bottom"
def test_large_object_cross_process(self):
"""测试跨进程传递大对象"""
pool = MultiProcessingSharedPool()
pool.clear()
large_data = {"items": list(range(10000)), "name": "large_test"}
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_put_large_object,
args=("large_cross", large_data, result_queue),
)
p.start()
p.join()
status, result = result_queue.get()
assert status == "success"
assert result is True
class TestCircularReferences:
"""测试循环引用"""
def test_circular_list(self):
"""测试列表中的循环引用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建循环引用列表
circular = [1, 2, 3]
circular.append(circular) # 循环引用
pool.put("circular_list", circular)
retrieved = pool.get("circular_list")
assert retrieved[0] == 1
assert retrieved[1] == 2
assert retrieved[2] == 3
# 循环引用应该被正确处理
assert retrieved[3] is not None
def test_circular_dict(self):
"""测试字典中的循环引用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建循环引用字典
circular = {"a": 1, "b": 2}
circular["self"] = circular # 循环引用
pool.put("circular_dict", circular)
retrieved = pool.get("circular_dict")
assert retrieved["a"] == 1
assert retrieved["b"] == 2
# 循环引用应该被正确处理
assert "self" in retrieved
class TestBinaryData:
"""测试二进制数据"""
def test_bytes_data(self):
"""测试字节数据"""
pool = MultiProcessingSharedPool()
pool.clear()
binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
pool.put("binary_data", binary_data)
retrieved = pool.get("binary_data")
assert retrieved == binary_data
def test_large_binary_data(self):
"""测试大型二进制数据"""
pool = MultiProcessingSharedPool()
pool.clear()
binary_data = bytes(range(256)) * 1000 # 256KB
pool.put("large_binary", binary_data)
retrieved = pool.get("large_binary")
assert retrieved == binary_data
def test_bytearray_data(self):
"""测试 bytearray 数据"""
pool = MultiProcessingSharedPool()
pool.clear()
ba = bytearray(b"\x00\x01\x02\x03")
pool.put("bytearray_data", ba)
retrieved = pool.get("bytearray_data")
assert retrieved == ba
class TestMixedTypes:
"""测试混合类型数据"""
def test_heterogeneous_list(self):
"""测试异构列表"""
pool = MultiProcessingSharedPool()
pool.clear()
mixed_list = [
1, # int
3.14, # float
"string", # str
True, # bool
None, # NoneType
[1, 2, 3], # list
{"a": 1}, # dict
(1, 2), # tuple
{1, 2, 3}, # set
b"binary", # bytes
]
pool.put("mixed_list", mixed_list)
retrieved = pool.get("mixed_list")
assert retrieved[0] == 1
assert abs(retrieved[1] - 3.14) < 1e-10
assert retrieved[2] == "string"
assert retrieved[3] is True
assert retrieved[4] is None
assert retrieved[5] == [1, 2, 3]
assert retrieved[6] == {"a": 1}
assert retrieved[7] == (1, 2)
assert retrieved[8] == {1, 2, 3}
assert retrieved[9] == b"binary"
def test_heterogeneous_dict(self):
"""测试异构字典"""
pool = MultiProcessingSharedPool()
pool.clear()
mixed_dict = {
"int_key": 42,
"float_key": 3.14,
"str_key": "hello",
"bool_key": True,
"none_key": None,
"list_key": [1, 2, 3],
"dict_key": {"nested": "value"},
"tuple_key": (1, 2, 3),
}
pool.put("mixed_dict", mixed_dict)
retrieved = pool.get("mixed_dict")
for key, value in mixed_dict.items():
if isinstance(value, float):
assert abs(retrieved[key] - value) < 1e-10
else:
assert retrieved[key] == value
class TestConcurrentAccess:
"""测试并发访问稳定性"""
def worker_stress_test(key_prefix, iterations, result_queue):
"""子进程:压力测试"""
pool = MultiProcessingSharedPool()
errors = []
for i in range(iterations):
try:
key = f"{key_prefix}_{i}"
pool.put(key, f"value_{i}")
value = pool.get(key)
if value != f"value_{i}":
errors.append(f"Value mismatch at {key}")
pool.remove(key)
except Exception as e:
errors.append(str(e))
result_queue.put(errors)
class TestConcurrentAccess:
"""测试并发访问稳定性"""
def test_stress_concurrent_writes(self):
"""压力测试:并发写入"""
pool = MultiProcessingSharedPool()
pool.clear()
num_processes = 4
iterations = 100
result_queue = multiprocessing.Queue()
processes = []
for i in range(num_processes):
p = multiprocessing.Process(
target=worker_stress_test,
args=(f"stress_{i}", iterations, result_queue),
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有错误
all_errors = []
for _ in range(num_processes):
all_errors.extend(result_queue.get())
# 应该没有错误
assert len(all_errors) == 0, f"Errors occurred: {all_errors}"
def test_rapid_put_get_cycle(self):
"""测试快速 put-get 循环"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(1000):
pool.put("rapid_key", f"value_{i}")
value = pool.get("rapid_key")
assert value == f"value_{i}"
def test_rapid_key_creation_deletion(self):
"""测试快速创建和删除 key"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(100):
key = f"temp_key_{i}"
pool.put(key, f"temp_value_{i}")
assert pool.exists(key)
pool.remove(key)
assert not pool.exists(key)
class TestErrorRecovery:
"""测试错误恢复能力"""
def test_put_after_error(self):
"""测试错误后可以继续 put"""
pool = MultiProcessingSharedPool()
pool.clear()
# 尝试使用非法 key 类型
try:
pool.put(123, "value") # 应该抛出 TypeError
except TypeError:
pass
# 应该可以继续正常使用
pool.put("valid_key", "valid_value")
assert pool.get("valid_key") == "valid_value"
def test_get_after_nonexistent(self):
"""测试获取不存在的 key 后可以继续使用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 获取不存在的 key
result = pool.get("nonexistent")
assert result is None
# 应该可以继续正常使用
pool.put("new_key", "new_value")
assert pool.get("new_key") == "new_value"
def test_multiple_singleton_access(self):
"""测试多次获取单例后访问"""
pool1 = MultiProcessingSharedPool()
pool1.put("key1", "value1")
pool2 = MultiProcessingSharedPool()
pool2.put("key2", "value2")
pool3 = MultiProcessingSharedPool.get_instance()
pool3.put("key3", "value3")
# 所有实例应该看到相同的数据
assert pool1.get("key1") == "value1"
assert pool1.get("key2") == "value2"
assert pool1.get("key3") == "value3"
class TestCleanup:
"""测试清理功能"""
def test_clear_after_multiple_puts(self):
"""测试多次 put 后 clear"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(100):
pool.put(f"key_{i}", f"value_{i}")
assert pool.size() == 100
pool.clear()
assert pool.size() == 0
assert pool.keys() == []
def test_remove_all_keys_one_by_one(self):
"""测试逐个删除所有 key"""
pool = MultiProcessingSharedPool()
pool.clear()
keys = [f"key_{i}" for i in range(50)]
for key in keys:
pool.put(key, f"value_for_{key}")
for key in keys:
assert pool.remove(key) is True
assert pool.size() == 0

523
test/test_functions.py Normal file
View File

@ -0,0 +1,523 @@
"""
函数序列化测试 - 验证 cloudpickle 集成
"""
import multiprocessing
import time
import math
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 模块级别的普通函数 ====================
def simple_function(x):
"""简单的加法函数"""
return x + 1
def multiply_function(a, b):
"""乘法函数"""
return a * b
def function_with_default(x, y=10):
"""带默认参数的函数"""
return x + y
def function_with_kwargs(*args, **kwargs):
"""带可变参数的函数"""
return sum(args) + sum(kwargs.values())
def recursive_function(n):
"""递归函数"""
if n <= 1:
return 1
return n * recursive_function(n - 1)
def closure_factory(base):
"""闭包工厂函数"""
def inner(x):
return x + base
return inner
# ==================== 辅助函数(模块级别定义)====================
def worker_execute_function(key, result_queue):
"""子进程:获取函数并执行"""
pool = MultiProcessingSharedPool()
func = pool.get(key)
if func is None:
result_queue.put(None)
return
# 执行函数(根据不同的测试函数传入不同参数)
try:
if key == "simple_func":
result = func(5)
elif key == "multiply_func":
result = func(3, 4)
elif key == "default_func":
result = func(5)
elif key == "kwargs_func":
result = func(1, 2, 3, a=4, b=5)
elif key == "recursive_func":
result = func(5)
elif key == "closure_func":
result = func(10)
elif key == "lambda_func":
result = func(7)
elif key == "lambda_with_capture":
result = func()
elif key == "nested_func":
result = func(3)
else:
result = func()
result_queue.put(result)
except Exception as e:
result_queue.put(f"ERROR: {e}")
def worker_execute_lambda_with_arg(key, arg, result_queue):
"""子进程:获取 lambda 并执行,传入参数"""
pool = MultiProcessingSharedPool()
func = pool.get(key)
if func is None:
result_queue.put(None)
return
result_queue.put(func(arg))
def get_lambda_description(func):
"""获取 lambda 函数的描述字符串"""
try:
return func.__name__
except AttributeError:
return str(func)
# ==================== 测试类 ====================
class TestNormalFunctions:
"""测试普通函数的序列化和反序列化"""
def test_simple_function(self):
"""测试简单函数的传递"""
pool = MultiProcessingSharedPool()
pool.clear()
# 存储函数
result = pool.put("simple_func", simple_function)
assert result is True
# 当前进程验证
retrieved = pool.get("simple_func")
assert retrieved is not None
assert retrieved(5) == 6
assert retrieved(10) == 11
def test_simple_function_cross_process(self):
"""测试简单函数的跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程存储函数
pool.put("simple_func", simple_function)
# 子进程获取并执行
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("simple_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 6 # simple_function(5) = 5 + 1
def test_function_with_multiple_args(self):
"""测试多参数函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("multiply_func", multiply_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("multiply_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 12 # multiply_function(3, 4) = 12
def test_function_with_default_args(self):
"""测试带默认参数的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("default_func", function_with_default)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("default_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 15 # function_with_default(5) = 5 + 10
def test_function_with_kwargs(self):
"""测试带 **kwargs 的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("kwargs_func", function_with_kwargs)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("kwargs_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 15 # sum(1,2,3) + sum(4,5) = 6 + 9 = 15
def test_recursive_function(self):
"""测试递归函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("recursive_func", recursive_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("recursive_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 120 # 5! = 120
class TestLambdaFunctions:
"""测试 Lambda 函数的序列化和反序列化"""
def test_simple_lambda(self):
"""测试简单 lambda 函数"""
pool = MultiProcessingSharedPool()
pool.clear()
simple_lambda = lambda x: x * 2
result = pool.put("lambda_func", simple_lambda)
assert result is True
# 当前进程验证
retrieved = pool.get("lambda_func")
assert retrieved(5) == 10
assert retrieved(7) == 14
def test_simple_lambda_cross_process(self):
"""测试简单 lambda 的跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
simple_lambda = lambda x: x * 3
pool.put("lambda_func", simple_lambda)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("lambda_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 21 # lambda(7) = 7 * 3 = 21
def test_lambda_with_capture(self):
"""测试捕获外部变量的 lambda"""
pool = MultiProcessingSharedPool()
pool.clear()
captured_value = 100
capturing_lambda = lambda: captured_value + 1
pool.put("lambda_with_capture", capturing_lambda)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("lambda_with_capture", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 101 # captured_value + 1 = 101
def test_lambda_in_list_comprehension(self):
"""测试在列表推导式中创建的 lambda"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建多个 lambda
lambdas = [(lambda x, i=i: x + i) for i in range(5)]
for i, lam in enumerate(lambdas):
pool.put(f"lambda_{i}", lam)
# 验证每个 lambda 都能正确捕获各自的 i
for i in range(5):
retrieved = pool.get(f"lambda_{i}")
assert retrieved(10) == 10 + i
def test_complex_lambda(self):
"""测试复杂的 lambda 表达式"""
pool = MultiProcessingSharedPool()
pool.clear()
complex_lambda = lambda x, y: (x**2 + y**2) ** 0.5
pool.put("complex_lambda", complex_lambda)
# 子进程验证
def worker_execute_complex_lambda(key, x, y, result_queue):
pool = MultiProcessingSharedPool()
func = pool.get(key)
result_queue.put(func(x, y))
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_complex_lambda,
args=("complex_lambda", 3, 4, result_queue),
)
p.start()
p.join()
result = result_queue.get()
assert abs(result - 5.0) < 1e-10 # sqrt(3^2 + 4^2) = 5
class TestNestedFunctions:
"""测试嵌套函数(在函数内部定义的函数)"""
def test_nested_function(self):
"""测试嵌套函数"""
pool = MultiProcessingSharedPool()
pool.clear()
def outer_function(x):
def inner_function(y):
return y * y
return inner_function(x) + x
pool.put("nested_func", outer_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("nested_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 12 # outer_function(3) = 3*3 + 3 = 12
def test_closure_function(self):
"""测试闭包函数"""
pool = MultiProcessingSharedPool()
pool.clear()
closure_func = closure_factory(100)
pool.put("closure_func", closure_func)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("closure_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 110 # closure_func(10) = 10 + 100
def test_multiple_closures(self):
"""测试多个闭包"""
pool = MultiProcessingSharedPool()
pool.clear()
closures = [closure_factory(i) for i in range(5)]
for i, closure in enumerate(closures):
pool.put(f"closure_{i}", closure)
# 验证每个闭包捕获的值不同
for i in range(5):
retrieved = pool.get(f"closure_{i}")
assert retrieved(10) == 10 + i
class TestClassMethods:
"""测试类方法的序列化"""
def test_static_method(self):
"""测试静态方法"""
pool = MultiProcessingSharedPool()
pool.clear()
class Calculator:
@staticmethod
def add(x, y):
return x + y
@staticmethod
def multiply(x, y):
return x * y
pool.put("static_add", Calculator.add)
pool.put("static_multiply", Calculator.multiply)
# 验证静态方法
add_func = pool.get("static_add")
multiply_func = pool.get("static_multiply")
assert add_func(2, 3) == 5
assert multiply_func(2, 3) == 6
def test_class_method(self):
"""测试类方法"""
pool = MultiProcessingSharedPool()
pool.clear()
class Counter:
count = 0
@classmethod
def increment(cls):
cls.count += 1
return cls.count
# 注意:类方法通常不能被 cloudpickle 正确序列化
# 因为它依赖于类定义
result = pool.put("class_method", Counter.increment)
# 如果成功存储,尝试执行
if result:
try:
method = pool.get("class_method")
# 类方法在反序列化后可能无法正常工作
# 这取决于 cloudpickle 的实现
except Exception:
pass # 预期可能失败
class TestBuiltInFunctions:
"""测试内置函数的序列化"""
def test_builtin_functions(self):
"""测试 Python 内置函数"""
pool = MultiProcessingSharedPool()
pool.clear()
# 大多数内置函数可以用标准 pickle 序列化
pool.put("builtin_sum", sum)
pool.put("builtin_max", max)
pool.put("builtin_min", min)
pool.put("builtin_len", len)
# 验证
assert pool.get("builtin_sum")([1, 2, 3]) == 6
assert pool.get("builtin_max")([1, 2, 3]) == 3
assert pool.get("builtin_min")([1, 2, 3]) == 1
assert pool.get("builtin_len")([1, 2, 3]) == 3
def test_math_functions(self):
"""测试 math 模块函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("math_sqrt", math.sqrt)
pool.put("math_sin", math.sin)
pool.put("math_cos", math.cos)
# 验证
assert abs(pool.get("math_sqrt")(16) - 4.0) < 1e-10
assert abs(pool.get("math_sin")(0) - 0.0) < 1e-10
assert abs(pool.get("math_cos")(0) - 1.0) < 1e-10
class TestFunctionReturnValues:
"""测试函数作为返回值"""
def test_function_returned_from_function(self):
"""测试返回函数的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
def create_multiplier(factor):
return lambda x: x * factor
pool.put("create_multiplier", create_multiplier)
# 在子进程中获取并使用
def worker_get_multiplier(result_queue):
pool = MultiProcessingSharedPool()
factory = pool.get("create_multiplier")
multiplier_func = factory(5) # 创建一个乘以 5 的函数
result_queue.put(multiplier_func(10))
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_get_multiplier, args=(result_queue,))
p.start()
p.join()
result = result_queue.get()
assert result == 50 # 10 * 5 = 50
class TestErrorHandling:
"""测试函数序列化的错误处理"""
def test_unpicklable_function_fallback(self):
"""测试无法序列化的函数回退到 cloudpickle"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建局部函数(无法被标准 pickle 序列化)
def local_function(x):
return x**2
# 应该通过 cloudpickle 成功存储
result = pool.put("local_func", local_function)
assert result is True
# 验证可以正确执行
retrieved = pool.get("local_func")
assert retrieved(5) == 25
def test_function_with_unpicklable_capture(self):
"""测试捕获不可序列化对象的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
# 捕获文件对象(不可序列化)
try:
with open(__file__, "r") as f:
file_capturing_lambda = lambda: f.read()
# 尝试存储应该失败
result = pool.put("file_lambda", file_capturing_lambda)
# 如果 cloudpickle 支持,验证是否能正确失败
except Exception:
pass # 预期可能失败

446
test/test_multiprocess.py Normal file
View File

@ -0,0 +1,446 @@
"""
多进程数据共享测试 - 验证跨进程数据同步能力
"""
import multiprocessing
import time
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数(需要在模块级别定义以便多进程使用)====================
def worker_put_data(key, value):
"""子进程:往共享池写入数据"""
pool = MultiProcessingSharedPool()
pool.put(key, value)
return True
def worker_get_data(key, result_queue):
"""子进程:从共享池读取数据并放入结果队列"""
pool = MultiProcessingSharedPool()
value = pool.get(key)
result_queue.put(value)
def worker_check_exists(key, result_queue):
"""子进程:检查 key 是否存在"""
pool = MultiProcessingSharedPool()
result_queue.put(pool.exists(key))
def worker_modify_data(key, new_value, result_queue):
"""子进程:修改数据并返回旧值"""
pool = MultiProcessingSharedPool()
old_value = pool.get(key)
pool.put(key, new_value)
result_queue.put(old_value)
def worker_wait_and_get(key, wait_time, result_queue):
"""子进程:等待一段时间后读取数据"""
time.sleep(wait_time)
pool = MultiProcessingSharedPool()
result_queue.put(pool.get(key))
def worker_increment_counter(key, iterations):
"""子进程:对计数器进行递增"""
pool = MultiProcessingSharedPool()
for _ in range(iterations):
# 注意:这不是原子操作,仅用于测试并发访问
current = pool.get(key, 0)
pool.put(key, current + 1)
def worker_pop_data(key, result_queue):
"""子进程:弹出数据"""
pool = MultiProcessingSharedPool()
value = pool.pop(key)
result_queue.put(value)
# ==================== 测试类 ====================
class TestParentChildProcess:
"""测试父子进程间数据传递"""
def test_parent_write_child_read(self):
"""测试父进程写入,子进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 子进程读取数据
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "shared_value"
def test_child_write_parent_read(self):
"""测试子进程写入,父进程读取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 子进程写入数据
p = multiprocessing.Process(
target=worker_put_data, args=("child_key", "child_value")
)
p.start()
p.join()
# 父进程读取数据(需要短暂等待以确保数据同步)
time.sleep(0.1)
result = pool.get("child_key")
assert result == "child_value"
def test_parent_child_data_isolation(self):
"""测试父子进程数据隔离 - 验证 manager.dict 的同步机制"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入初始数据
pool.put("isolation_test", "parent_value")
# 子进程读取并修改
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_modify_data,
args=("isolation_test", "child_value", result_queue),
)
p.start()
p.join()
# 验证子进程读取到了父进程的值
child_read = result_queue.get()
assert child_read == "parent_value"
# 验证父进程可以看到子进程的修改
time.sleep(0.1)
parent_read = pool.get("isolation_test")
assert parent_read == "child_value"
class TestMultipleChildProcesses:
"""测试多个子进程间的数据共享"""
def test_multiple_children_write_same_key(self):
"""测试多个子进程写入同一个 key后面的覆盖前面的"""
pool = MultiProcessingSharedPool()
pool.clear()
processes = []
for i in range(5):
p = multiprocessing.Process(
target=worker_put_data, args=("shared_key", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证有一个值被成功写入
result = pool.get("shared_key")
assert result.startswith("value_")
def test_multiple_children_write_different_keys(self):
"""测试多个子进程写入不同的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
num_processes = 5
processes = []
for i in range(num_processes):
p = multiprocessing.Process(
target=worker_put_data, args=(f"key_{i}", f"value_{i}")
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 短暂等待确保所有写入完成
time.sleep(0.1)
# 验证所有值都被成功写入
for i in range(num_processes):
assert pool.get(f"key_{i}") == f"value_{i}"
def test_multiple_children_read_same_key(self):
"""测试多个子进程读取同一个 key"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入数据
pool.put("shared_key", "shared_value")
# 多个子进程读取
num_processes = 5
result_queue = multiprocessing.Queue()
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有结果
results = []
for _ in range(num_processes):
results.append(result_queue.get())
# 所有子进程都应该读取到相同的值
assert all(r == "shared_value" for r in results)
def test_concurrent_exists_check(self):
"""测试并发检查 key 是否存在"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入一些数据
for i in range(5):
pool.put(f"key_{i}", f"value_{i}")
# 多个子进程并发检查
result_queue = multiprocessing.Queue()
processes = []
for i in range(10):
key = f"key_{i % 7}" # 有些 key 存在,有些不存在
p = multiprocessing.Process(
target=worker_check_exists, args=(key, result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集结果
results = []
for _ in range(10):
results.append(result_queue.get())
# 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查)
assert sum(results) >= 5 # 至少 5 个存在
# 进程池测试需要在模块级别定义 worker 函数
def _pool_worker_map(args):
"""进程池 map 操作的 worker"""
idx, key = args
shared_pool = MultiProcessingSharedPool()
value = shared_pool.get(key)
return idx, value
def _pool_worker_apply(key):
"""进程池 apply_async 的 worker"""
shared_pool = MultiProcessingSharedPool()
return shared_pool.get(key)
class TestProcessPool:
"""测试在进程池中使用"""
def test_pool_map_with_shared_data(self):
"""测试在进程池 map 操作中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
# 写入测试数据
for i in range(5):
pool.put(f"input_{i}", i * 10)
# 使用进程池
with multiprocessing.Pool(processes=3) as process_pool:
results = process_pool.map(
_pool_worker_map, [(i, f"input_{i}") for i in range(5)]
)
# 验证结果
results_dict = {idx: val for idx, val in results}
for i in range(5):
assert results_dict[i] == i * 10
def test_pool_apply_async_with_shared_data(self):
"""测试在进程池 apply_async 中使用共享数据"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("async_key", "async_value")
with multiprocessing.Pool(processes=2) as process_pool:
result = process_pool.apply_async(_pool_worker_apply, ("async_key",))
assert result.get(timeout=5) == "async_value"
class TestDataVisibility:
"""测试数据可见性和同步时机"""
def test_immediate_visibility_after_put(self):
"""测试写入后立即可见"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("immediate_key", "immediate_value")
# 同一进程内应该立即可见
assert pool.exists("immediate_key")
assert pool.get("immediate_key") == "immediate_value"
def test_cross_process_visibility_with_delay(self):
"""测试跨进程可见性(带延迟)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程写入
pool.put("delayed_key", "delayed_value")
# 子进程延迟后读取
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "delayed_value"
class TestConcurrentModifications:
"""测试并发修改场景"""
def test_concurrent_counter_increments(self):
"""测试并发计数器递增(非原子操作,预期会有竞争条件)"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化计数器
pool.put("counter", 0)
num_processes = 4
iterations_per_process = 10
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(
target=worker_increment_counter,
args=("counter", iterations_per_process),
)
processes.append(p)
p.start()
for p in processes:
p.join()
time.sleep(0.1)
# 由于竞争条件,实际值可能小于期望值
# 这个测试主要是为了验证并发访问不会崩溃
final_count = pool.get("counter")
assert isinstance(final_count, int)
assert 0 <= final_count <= num_processes * iterations_per_process
def test_concurrent_pop_operations(self):
"""测试并发 pop 操作"""
pool = MultiProcessingSharedPool()
pool.clear()
# 初始化多个 key
num_keys = 5
for i in range(num_keys):
pool.put(f"pop_key_{i}", f"pop_value_{i}")
result_queue = multiprocessing.Queue()
processes = []
for i in range(num_keys):
p = multiprocessing.Process(
target=worker_pop_data, args=(f"pop_key_{i}", result_queue)
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有 pop 的结果
popped_values = []
for _ in range(num_keys):
popped_values.append(result_queue.get())
# 验证所有值都被正确 pop
assert len(popped_values) == num_keys
for i in range(num_keys):
assert f"pop_value_{i}" in popped_values
# 验证所有 key 都被移除
time.sleep(0.1)
for i in range(num_keys):
assert not pool.exists(f"pop_key_{i}")
class TestComplexDataTypes:
"""测试复杂数据类型的多进程共享"""
def test_share_nested_dict(self):
"""测试共享嵌套字典"""
pool = MultiProcessingSharedPool()
pool.clear()
nested_data = {
"level1": {"level2": {"level3": [1, 2, 3]}},
"list_of_dicts": [{"a": 1}, {"b": 2}],
}
pool.put("nested_key", nested_data)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("nested_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == nested_data
def test_share_large_list(self):
"""测试共享大型列表"""
pool = MultiProcessingSharedPool()
pool.clear()
large_list = list(range(10000))
pool.put("large_list_key", large_list)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("large_list_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == large_list

569
test/test_numpy.py Normal file
View File

@ -0,0 +1,569 @@
"""
NumPy ndarray 支持测试 - 验证数组数据共享
"""
import multiprocessing
import time
import pytest
import numpy as np
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数 ====================
def worker_get_array(key, result_queue):
"""子进程:获取数组并放入结果队列"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(arr)
def worker_modify_array(key, index, value, result_queue):
"""子进程:读取数组、修改特定索引、返回原值"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
old_value = arr[index].copy() if isinstance(index, tuple) else arr[index]
# 注意:此处修改的是副本,因为 get 返回的是数组副本
result_queue.put(old_value)
def worker_sum_array(key, result_queue):
"""子进程:计算数组元素和"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(np.sum(arr))
def worker_check_array_properties(key, expected_shape, expected_dtype, result_queue):
"""子进程:检查数组属性"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(
(arr.shape == expected_shape, arr.dtype == expected_dtype, arr.ndim)
)
# ==================== 测试类 ====================
class TestNDBasicOperations:
"""测试 NumPy 数组基本操作"""
def test_1d_array(self):
"""测试一维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1, 2, 3, 4, 5])
pool.put("1d_array", arr)
retrieved = pool.get("1d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.dtype == arr.dtype
def test_2d_array(self):
"""测试二维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
pool.put("2d_array", arr)
retrieved = pool.get("2d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.shape == (3, 3)
def test_3d_array(self):
"""测试三维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(24).reshape(2, 3, 4)
pool.put("3d_array", arr)
retrieved = pool.get("3d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.shape == (2, 3, 4)
def test_multidimensional_array(self):
"""测试高维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 4维数组
arr = np.random.rand(2, 3, 4, 5)
pool.put("4d_array", arr)
retrieved = pool.get("4d_array")
np.testing.assert_array_almost_equal(retrieved, arr)
assert retrieved.shape == (2, 3, 4, 5)
class TestNDDataTypes:
"""测试不同数据类型的 NumPy 数组"""
def test_integer_dtypes(self):
"""测试整数类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]
for dtype in dtypes:
arr = np.array([1, 2, 3], dtype=dtype)
key = f"int_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_equal(retrieved, arr)
def test_float_dtypes(self):
"""测试浮点类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [np.float32, np.float64]
for dtype in dtypes:
arr = np.array([1.1, 2.2, 3.3], dtype=dtype)
key = f"float_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_almost_equal(retrieved, arr)
def test_complex_dtypes(self):
"""测试复数类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [np.complex64, np.complex128]
for dtype in dtypes:
arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=dtype)
key = f"complex_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_almost_equal(retrieved, arr)
def test_bool_dtype(self):
"""测试布尔类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([True, False, True, True, False], dtype=np.bool_)
pool.put("bool_array", arr)
retrieved = pool.get("bool_array")
assert retrieved.dtype == np.bool_
np.testing.assert_array_equal(retrieved, arr)
def test_string_dtype(self):
"""测试字符串类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# Unicode 字符串
arr = np.array(["hello", "world", "mpsp"])
pool.put("string_array", arr)
retrieved = pool.get("string_array")
np.testing.assert_array_equal(retrieved, arr)
def test_object_dtype(self):
"""测试对象类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 对象数组可以存储不同类型的数据
arr = np.array([1, "string", 3.14, [1, 2, 3]], dtype=object)
pool.put("object_array", arr)
retrieved = pool.get("object_array")
assert retrieved.dtype == object
np.testing.assert_array_equal(retrieved, arr)
class TestNDCrossProcess:
"""测试 NumPy 数组跨进程共享"""
def test_1d_array_cross_process(self):
"""测试一维数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([10, 20, 30, 40, 50])
pool.put("shared_1d", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("shared_1d", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
def test_2d_array_cross_process(self):
"""测试二维数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2], [3, 4], [5, 6]])
pool.put("shared_2d", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("shared_2d", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
def test_array_properties_cross_process(self):
"""测试数组属性跨进程保持"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(12).reshape(3, 4).astype(np.float32)
pool.put("property_test", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_check_array_properties,
args=("property_test", (3, 4), np.float32, result_queue),
)
p.start()
p.join()
shape_match, dtype_match, ndim = result_queue.get()
assert shape_match
assert dtype_match
assert ndim == 2
def test_array_operations_cross_process(self):
"""测试在子进程中执行数组操作"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
pool.put("sum_test", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_sum_array, args=("sum_test", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 55 # sum of 1..10
class TestNDSpecialArrays:
"""测试特殊类型的 NumPy 数组"""
def test_empty_array(self):
"""测试空数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([])
pool.put("empty_array", arr)
retrieved = pool.get("empty_array")
np.testing.assert_array_equal(retrieved, arr)
assert len(retrieved) == 0
def test_single_element_array(self):
"""测试单元素数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([42])
pool.put("single_element", arr)
retrieved = pool.get("single_element")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved[0] == 42
def test_zeros_array(self):
"""测试零数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.zeros((5, 5))
pool.put("zeros_array", arr)
retrieved = pool.get("zeros_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(retrieved == 0)
def test_ones_array(self):
"""测试全一数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.ones((3, 4))
pool.put("ones_array", arr)
retrieved = pool.get("ones_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(retrieved == 1)
def test_eye_array(self):
"""测试单位矩阵"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.eye(5)
pool.put("eye_array", arr)
retrieved = pool.get("eye_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(np.diag(retrieved) == 1)
def test_nan_and_inf_array(self):
"""测试包含 NaN 和 Inf 的数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1.0, np.nan, np.inf, -np.inf, 2.0])
pool.put("special_values", arr)
retrieved = pool.get("special_values")
assert np.isnan(retrieved[1])
assert np.isinf(retrieved[2]) and retrieved[2] > 0
assert np.isinf(retrieved[3]) and retrieved[3] < 0
def test_masked_array(self):
"""测试掩码数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = np.array([1, 2, 3, 4, 5])
mask = np.array([False, True, False, True, False])
arr = np.ma.array(data, mask=mask)
pool.put("masked_array", arr)
retrieved = pool.get("masked_array")
np.testing.assert_array_equal(retrieved.data, data)
np.testing.assert_array_equal(retrieved.mask, mask)
class TestNDLargeArrays:
"""测试大型 NumPy 数组"""
def test_large_1d_array(self):
"""测试大型一维数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 10000 个元素的数组
arr = np.arange(10000)
pool.put("large_1d", arr)
retrieved = pool.get("large_1d")
np.testing.assert_array_equal(retrieved, arr)
def test_large_2d_array(self):
"""测试大型二维数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 1000x100 的数组
arr = np.random.rand(1000, 100)
pool.put("large_2d", arr)
retrieved = pool.get("large_2d")
np.testing.assert_array_almost_equal(retrieved, arr)
def test_large_array_cross_process(self):
"""测试大型数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(100000).reshape(1000, 100)
pool.put("large_cross", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_sum_array, args=("large_cross", result_queue)
)
p.start()
p.join()
result = result_queue.get()
expected_sum = np.sum(arr)
assert result == expected_sum
class TestNDStructuredArrays:
"""测试结构化数组"""
def test_structured_array(self):
"""测试结构化数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f4")])
arr = np.array(
[("Alice", 25, 55.5), ("Bob", 30, 85.3), ("Charlie", 35, 75.0)], dtype=dt
)
pool.put("structured_array", arr)
retrieved = pool.get("structured_array")
assert retrieved.dtype == dt
np.testing.assert_array_equal(retrieved, arr)
def test_structured_array_cross_process(self):
"""测试结构化数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
dt = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")])
arr = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dt)
pool.put("structured_cross", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("structured_cross", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
class TestNDMatrixOperations:
"""测试矩阵操作相关的数组"""
def test_matrix_multiplication(self):
"""测试矩阵乘法用的数组"""
pool = MultiProcessingSharedPool()
pool.clear()
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
pool.put("matrix_A", A)
pool.put("matrix_B", B)
retrieved_A = pool.get("matrix_A")
retrieved_B = pool.get("matrix_B")
result = np.dot(retrieved_A, retrieved_B)
expected = np.array([[19, 22], [43, 50]])
np.testing.assert_array_equal(result, expected)
def test_eigenvalue_computation(self):
"""测试特征值计算"""
pool = MultiProcessingSharedPool()
pool.clear()
# 对称矩阵
arr = np.array([[4, 2], [2, 4]])
pool.put("eigen_matrix", arr)
retrieved = pool.get("eigen_matrix")
eigenvalues, eigenvectors = np.linalg.eig(retrieved)
# 特征值应该是 6 和 2
assert np.allclose(sorted(eigenvalues), [2, 6])
def test_svd_decomposition(self):
"""测试 SVD 分解"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
pool.put("svd_matrix", arr)
retrieved = pool.get("svd_matrix")
U, S, Vh = np.linalg.svd(retrieved)
# 验证分解结果
reconstructed = U @ np.diag(S) @ Vh
np.testing.assert_array_almost_equal(reconstructed, arr)
class TestNDBroadcasting:
"""测试广播机制"""
def test_broadcasting_operation(self):
"""测试广播操作"""
pool = MultiProcessingSharedPool()
pool.clear()
arr_2d = np.array([[1, 2, 3], [4, 5, 6]])
arr_1d = np.array([10, 20, 30])
pool.put("array_2d", arr_2d)
pool.put("array_1d", arr_1d)
retrieved_2d = pool.get("array_2d")
retrieved_1d = pool.get("array_1d")
result = retrieved_2d + retrieved_1d
expected = np.array([[11, 22, 33], [14, 25, 36]])
np.testing.assert_array_equal(result, expected)
class TestNDMixedTypes:
"""测试混合数据类型的数组相关操作"""
def test_array_in_dict(self):
"""测试字典中包含数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = {
"matrix": np.array([[1, 2], [3, 4]]),
"vector": np.array([1, 2, 3]),
"scalar": 42,
"name": "test",
}
pool.put("dict_with_arrays", data)
retrieved = pool.get("dict_with_arrays")
np.testing.assert_array_equal(retrieved["matrix"], data["matrix"])
np.testing.assert_array_equal(retrieved["vector"], data["vector"])
assert retrieved["scalar"] == 42
assert retrieved["name"] == "test"
def test_array_in_list(self):
"""测试列表中包含数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = [np.array([1, 2, 3]), np.array([[4, 5], [6, 7]]), "string", 42]
pool.put("list_with_arrays", data)
retrieved = pool.get("list_with_arrays")
np.testing.assert_array_equal(retrieved[0], data[0])
np.testing.assert_array_equal(retrieved[1], data[1])
assert retrieved[2] == "string"
assert retrieved[3] == 42