653 lines
18 KiB
Python
653 lines
18 KiB
Python
"""
|
|
边界与异常测试 - 验证鲁棒性
|
|
"""
|
|
|
|
import multiprocessing
|
|
import time
|
|
import pickle
|
|
import pytest
|
|
from mpsp.mpsp import MultiProcessingSharedPool
|
|
|
|
|
|
# ==================== 辅助函数 ====================
|
|
|
|
|
|
def worker_put_empty_key(result_queue):
|
|
"""子进程:测试空字符串 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
try:
|
|
pool.put("", "empty_value")
|
|
result_queue.put(("success", pool.get("")))
|
|
except Exception as e:
|
|
result_queue.put(("error", str(e)))
|
|
|
|
|
|
def worker_get_nonexistent(result_queue):
|
|
"""子进程:获取不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
result = pool.get("definitely_nonexistent_key_12345")
|
|
result_queue.put(result)
|
|
|
|
|
|
def worker_put_large_object(key, data, result_queue):
|
|
"""子进程:存储大对象"""
|
|
pool = MultiProcessingSharedPool()
|
|
try:
|
|
success = pool.put(key, data)
|
|
result_queue.put(("success", success))
|
|
except Exception as e:
|
|
result_queue.put(("error", str(e)))
|
|
|
|
|
|
# ==================== 测试类 ====================
|
|
|
|
|
|
class TestEmptyAndNoneValues:
|
|
"""测试空值和 None 处理"""
|
|
|
|
def test_put_empty_string_value(self):
|
|
"""测试存储空字符串值"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("empty_value_key", "")
|
|
assert pool.get("empty_value_key") == ""
|
|
|
|
def test_put_none_value(self):
|
|
"""测试存储 None 值"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("none_value_key", None)
|
|
assert pool.get("none_value_key") is None
|
|
|
|
def test_put_empty_list(self):
|
|
"""测试存储空列表"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("empty_list", [])
|
|
assert pool.get("empty_list") == []
|
|
|
|
def test_put_empty_dict(self):
|
|
"""测试存储空字典"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("empty_dict", {})
|
|
assert pool.get("empty_dict") == {}
|
|
|
|
def test_put_empty_tuple(self):
|
|
"""测试存储空元组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("empty_tuple", ())
|
|
assert pool.get("empty_tuple") == ()
|
|
|
|
def test_put_empty_set(self):
|
|
"""测试存储空集合"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("empty_set", set())
|
|
assert pool.get("empty_set") == set()
|
|
|
|
|
|
class TestSpecialLabelNames:
|
|
"""测试特殊 label 名称"""
|
|
|
|
def test_empty_string_label(self):
|
|
"""测试空字符串作为 label"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("", "empty_key_value")
|
|
assert pool.get("") == "empty_key_value"
|
|
assert pool.exists("")
|
|
|
|
def test_empty_string_label_cross_process(self):
|
|
"""测试空字符串 label 跨进程"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("", "parent_empty_value")
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(target=worker_put_empty_key, args=(result_queue,))
|
|
p.start()
|
|
p.join()
|
|
|
|
status, result = result_queue.get()
|
|
# 子进程应该可以覆盖空字符串 key
|
|
assert status == "success"
|
|
assert result == "empty_value"
|
|
|
|
def test_unicode_label(self):
|
|
"""测试 Unicode 字符作为 label"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
unicode_keys = [
|
|
"中文键",
|
|
"日本語キー",
|
|
"한국어키",
|
|
"emoji_😀",
|
|
"special_©_®_™",
|
|
"math_∑_∏_√",
|
|
"arrows_→_←_↑_↓",
|
|
]
|
|
|
|
for key in unicode_keys:
|
|
pool.put(key, f"value_for_{key}")
|
|
|
|
for key in unicode_keys:
|
|
assert pool.get(key) == f"value_for_{key}"
|
|
|
|
def test_whitespace_label(self):
|
|
"""测试空白字符作为 label"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
whitespace_keys = [
|
|
" ", # 单个空格
|
|
" ", # 两个空格
|
|
"\t", # Tab
|
|
"\n", # 换行
|
|
"\r\n", # Windows 换行
|
|
" key_with_leading_space",
|
|
"key_with_trailing_space ",
|
|
" key_with_both_spaces ",
|
|
"key\twith\ttabs",
|
|
]
|
|
|
|
for key in whitespace_keys:
|
|
pool.put(key, f"value_for_repr_{repr(key)}")
|
|
|
|
for key in whitespace_keys:
|
|
assert pool.get(key) == f"value_for_repr_{repr(key)}"
|
|
|
|
def test_special_chars_label(self):
|
|
"""测试特殊字符作为 label"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
special_keys = [
|
|
"key.with.dots",
|
|
"key/with/slashes",
|
|
"key:with:colons",
|
|
"key|with|pipes",
|
|
"key*with*asterisks",
|
|
"key?with?question",
|
|
"key<with>brackets",
|
|
"key[with]square",
|
|
"key{with}curly",
|
|
"key+with+plus",
|
|
"key=with=equals",
|
|
"key!with!exclamation",
|
|
"key@with@at",
|
|
"key#with#hash",
|
|
"key$with$dollar",
|
|
"key%with%percent",
|
|
"key^with^caret",
|
|
"key&with&ersand",
|
|
"key'with'quotes",
|
|
'key"with"double',
|
|
"key`with`backtick",
|
|
"key~with~tilde",
|
|
"key-with-hyphens",
|
|
"key_with_underscores",
|
|
]
|
|
|
|
for key in special_keys:
|
|
pool.put(key, f"value_for_{key}")
|
|
|
|
for key in special_keys:
|
|
assert pool.get(key) == f"value_for_{key}"
|
|
|
|
def test_very_long_label(self):
|
|
"""测试超长 label"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 1000 字符的 label
|
|
long_key = "a" * 1000
|
|
pool.put(long_key, "long_key_value")
|
|
assert pool.get(long_key) == "long_key_value"
|
|
|
|
|
|
class TestNonExistentKeys:
|
|
"""测试不存在的 key 处理"""
|
|
|
|
def test_get_nonexistent(self):
|
|
"""测试获取不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
result = pool.get("nonexistent_key_12345")
|
|
assert result is None
|
|
|
|
def test_get_nonexistent_with_default(self):
|
|
"""测试获取不存在的 key 带默认值"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
assert pool.get("nonexistent", "default") == "default"
|
|
assert pool.get("nonexistent", 0) == 0
|
|
assert pool.get("nonexistent", []) == []
|
|
assert pool.get("nonexistent", {}) == {}
|
|
|
|
def test_get_nonexistent_cross_process(self):
|
|
"""测试跨进程获取不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(target=worker_get_nonexistent, args=(result_queue,))
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result is None
|
|
|
|
def test_remove_nonexistent(self):
|
|
"""测试删除不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
assert pool.remove("nonexistent_key") is False
|
|
|
|
def test_pop_nonexistent(self):
|
|
"""测试弹出不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
assert pool.pop("nonexistent_key") is None
|
|
assert pool.pop("nonexistent_key", "default") == "default"
|
|
|
|
def test_exists_nonexistent(self):
|
|
"""测试检查不存在的 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
assert pool.exists("nonexistent_key") is False
|
|
|
|
|
|
class TestLargeObjects:
|
|
"""测试大对象序列化"""
|
|
|
|
def test_large_list(self):
|
|
"""测试大型列表"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
large_list = list(range(100000))
|
|
pool.put("large_list", large_list)
|
|
|
|
retrieved = pool.get("large_list")
|
|
assert len(retrieved) == 100000
|
|
assert retrieved[0] == 0
|
|
assert retrieved[99999] == 99999
|
|
|
|
def test_large_dict(self):
|
|
"""测试大型字典"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
large_dict = {f"key_{i}": f"value_{i}" for i in range(10000)}
|
|
pool.put("large_dict", large_dict)
|
|
|
|
retrieved = pool.get("large_dict")
|
|
assert len(retrieved) == 10000
|
|
assert retrieved["key_0"] == "value_0"
|
|
assert retrieved["key_9999"] == "value_9999"
|
|
|
|
def test_large_string(self):
|
|
"""测试大型字符串"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
large_string = "x" * 1000000 # 1MB 字符串
|
|
pool.put("large_string", large_string)
|
|
|
|
retrieved = pool.get("large_string")
|
|
assert len(retrieved) == 1000000
|
|
assert retrieved[0] == "x"
|
|
assert retrieved[-1] == "x"
|
|
|
|
def test_deeply_nested_structure(self):
|
|
"""测试深度嵌套结构"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 创建深度嵌套的字典
|
|
depth = 50
|
|
nested = "bottom"
|
|
for i in range(depth):
|
|
nested = {"level": i, "nested": nested}
|
|
|
|
pool.put("deep_nested", nested)
|
|
|
|
retrieved = pool.get("deep_nested")
|
|
# 验证嵌套深度
|
|
current = retrieved
|
|
for i in range(depth):
|
|
assert current["level"] == depth - 1 - i
|
|
current = current["nested"]
|
|
assert current == "bottom"
|
|
|
|
def test_large_object_cross_process(self):
|
|
"""测试跨进程传递大对象"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
large_data = {"items": list(range(10000)), "name": "large_test"}
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_put_large_object,
|
|
args=("large_cross", large_data, result_queue),
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
status, result = result_queue.get()
|
|
assert status == "success"
|
|
assert result is True
|
|
|
|
|
|
class TestCircularReferences:
|
|
"""测试循环引用"""
|
|
|
|
def test_circular_list(self):
|
|
"""测试列表中的循环引用"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 创建循环引用列表
|
|
circular = [1, 2, 3]
|
|
circular.append(circular) # 循环引用
|
|
|
|
pool.put("circular_list", circular)
|
|
|
|
retrieved = pool.get("circular_list")
|
|
assert retrieved[0] == 1
|
|
assert retrieved[1] == 2
|
|
assert retrieved[2] == 3
|
|
# 循环引用应该被正确处理
|
|
assert retrieved[3] is not None
|
|
|
|
def test_circular_dict(self):
|
|
"""测试字典中的循环引用"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 创建循环引用字典
|
|
circular = {"a": 1, "b": 2}
|
|
circular["self"] = circular # 循环引用
|
|
|
|
pool.put("circular_dict", circular)
|
|
|
|
retrieved = pool.get("circular_dict")
|
|
assert retrieved["a"] == 1
|
|
assert retrieved["b"] == 2
|
|
# 循环引用应该被正确处理
|
|
assert "self" in retrieved
|
|
|
|
|
|
class TestBinaryData:
|
|
"""测试二进制数据"""
|
|
|
|
def test_bytes_data(self):
|
|
"""测试字节数据"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
|
|
pool.put("binary_data", binary_data)
|
|
|
|
retrieved = pool.get("binary_data")
|
|
assert retrieved == binary_data
|
|
|
|
def test_large_binary_data(self):
|
|
"""测试大型二进制数据"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
binary_data = bytes(range(256)) * 1000 # 256KB
|
|
pool.put("large_binary", binary_data)
|
|
|
|
retrieved = pool.get("large_binary")
|
|
assert retrieved == binary_data
|
|
|
|
def test_bytearray_data(self):
|
|
"""测试 bytearray 数据"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
ba = bytearray(b"\x00\x01\x02\x03")
|
|
pool.put("bytearray_data", ba)
|
|
|
|
retrieved = pool.get("bytearray_data")
|
|
assert retrieved == ba
|
|
|
|
|
|
class TestMixedTypes:
|
|
"""测试混合类型数据"""
|
|
|
|
def test_heterogeneous_list(self):
|
|
"""测试异构列表"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
mixed_list = [
|
|
1, # int
|
|
3.14, # float
|
|
"string", # str
|
|
True, # bool
|
|
None, # NoneType
|
|
[1, 2, 3], # list
|
|
{"a": 1}, # dict
|
|
(1, 2), # tuple
|
|
{1, 2, 3}, # set
|
|
b"binary", # bytes
|
|
]
|
|
|
|
pool.put("mixed_list", mixed_list)
|
|
|
|
retrieved = pool.get("mixed_list")
|
|
assert retrieved[0] == 1
|
|
assert abs(retrieved[1] - 3.14) < 1e-10
|
|
assert retrieved[2] == "string"
|
|
assert retrieved[3] is True
|
|
assert retrieved[4] is None
|
|
assert retrieved[5] == [1, 2, 3]
|
|
assert retrieved[6] == {"a": 1}
|
|
assert retrieved[7] == (1, 2)
|
|
assert retrieved[8] == {1, 2, 3}
|
|
assert retrieved[9] == b"binary"
|
|
|
|
def test_heterogeneous_dict(self):
|
|
"""测试异构字典"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
mixed_dict = {
|
|
"int_key": 42,
|
|
"float_key": 3.14,
|
|
"str_key": "hello",
|
|
"bool_key": True,
|
|
"none_key": None,
|
|
"list_key": [1, 2, 3],
|
|
"dict_key": {"nested": "value"},
|
|
"tuple_key": (1, 2, 3),
|
|
}
|
|
|
|
pool.put("mixed_dict", mixed_dict)
|
|
|
|
retrieved = pool.get("mixed_dict")
|
|
for key, value in mixed_dict.items():
|
|
if isinstance(value, float):
|
|
assert abs(retrieved[key] - value) < 1e-10
|
|
else:
|
|
assert retrieved[key] == value
|
|
|
|
|
|
class TestConcurrentAccess:
|
|
"""测试并发访问稳定性"""
|
|
|
|
|
|
def worker_stress_test(key_prefix, iterations, result_queue):
|
|
"""子进程:压力测试"""
|
|
pool = MultiProcessingSharedPool()
|
|
errors = []
|
|
|
|
for i in range(iterations):
|
|
try:
|
|
key = f"{key_prefix}_{i}"
|
|
pool.put(key, f"value_{i}")
|
|
value = pool.get(key)
|
|
if value != f"value_{i}":
|
|
errors.append(f"Value mismatch at {key}")
|
|
pool.remove(key)
|
|
except Exception as e:
|
|
errors.append(str(e))
|
|
|
|
result_queue.put(errors)
|
|
|
|
class TestConcurrentAccess:
|
|
"""测试并发访问稳定性"""
|
|
|
|
def test_stress_concurrent_writes(self):
|
|
"""压力测试:并发写入"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
num_processes = 4
|
|
iterations = 100
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
processes = []
|
|
|
|
for i in range(num_processes):
|
|
p = multiprocessing.Process(
|
|
target=worker_stress_test,
|
|
args=(f"stress_{i}", iterations, result_queue),
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
# 收集所有错误
|
|
all_errors = []
|
|
for _ in range(num_processes):
|
|
all_errors.extend(result_queue.get())
|
|
|
|
# 应该没有错误
|
|
assert len(all_errors) == 0, f"Errors occurred: {all_errors}"
|
|
|
|
def test_rapid_put_get_cycle(self):
|
|
"""测试快速 put-get 循环"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
for i in range(1000):
|
|
pool.put("rapid_key", f"value_{i}")
|
|
value = pool.get("rapid_key")
|
|
assert value == f"value_{i}"
|
|
|
|
def test_rapid_key_creation_deletion(self):
|
|
"""测试快速创建和删除 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
for i in range(100):
|
|
key = f"temp_key_{i}"
|
|
pool.put(key, f"temp_value_{i}")
|
|
assert pool.exists(key)
|
|
pool.remove(key)
|
|
assert not pool.exists(key)
|
|
|
|
|
|
class TestErrorRecovery:
|
|
"""测试错误恢复能力"""
|
|
|
|
def test_put_after_error(self):
|
|
"""测试错误后可以继续 put"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 尝试使用非法 key 类型
|
|
try:
|
|
pool.put(123, "value") # 应该抛出 TypeError
|
|
except TypeError:
|
|
pass
|
|
|
|
# 应该可以继续正常使用
|
|
pool.put("valid_key", "valid_value")
|
|
assert pool.get("valid_key") == "valid_value"
|
|
|
|
def test_get_after_nonexistent(self):
|
|
"""测试获取不存在的 key 后可以继续使用"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 获取不存在的 key
|
|
result = pool.get("nonexistent")
|
|
assert result is None
|
|
|
|
# 应该可以继续正常使用
|
|
pool.put("new_key", "new_value")
|
|
assert pool.get("new_key") == "new_value"
|
|
|
|
def test_multiple_singleton_access(self):
|
|
"""测试多次获取单例后访问"""
|
|
pool1 = MultiProcessingSharedPool()
|
|
pool1.put("key1", "value1")
|
|
|
|
pool2 = MultiProcessingSharedPool()
|
|
pool2.put("key2", "value2")
|
|
|
|
pool3 = MultiProcessingSharedPool.get_instance()
|
|
pool3.put("key3", "value3")
|
|
|
|
# 所有实例应该看到相同的数据
|
|
assert pool1.get("key1") == "value1"
|
|
assert pool1.get("key2") == "value2"
|
|
assert pool1.get("key3") == "value3"
|
|
|
|
|
|
class TestCleanup:
|
|
"""测试清理功能"""
|
|
|
|
def test_clear_after_multiple_puts(self):
|
|
"""测试多次 put 后 clear"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
for i in range(100):
|
|
pool.put(f"key_{i}", f"value_{i}")
|
|
|
|
assert pool.size() == 100
|
|
|
|
pool.clear()
|
|
|
|
assert pool.size() == 0
|
|
assert pool.keys() == []
|
|
|
|
def test_remove_all_keys_one_by_one(self):
|
|
"""测试逐个删除所有 key"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
keys = [f"key_{i}" for i in range(50)]
|
|
for key in keys:
|
|
pool.put(key, f"value_for_{key}")
|
|
|
|
for key in keys:
|
|
assert pool.remove(key) is True
|
|
|
|
assert pool.size() == 0
|