""" 边界与异常测试 - 验证鲁棒性 """ import multiprocessing import time import pickle import pytest from mpsp.mpsp import MultiProcessingSharedPool # ==================== 辅助函数 ==================== def worker_put_empty_key(result_queue): """子进程:测试空字符串 key""" pool = MultiProcessingSharedPool() try: pool.put("", "empty_value") result_queue.put(("success", pool.get(""))) except Exception as e: result_queue.put(("error", str(e))) def worker_get_nonexistent(result_queue): """子进程:获取不存在的 key""" pool = MultiProcessingSharedPool() result = pool.get("definitely_nonexistent_key_12345") result_queue.put(result) def worker_put_large_object(key, data, result_queue): """子进程:存储大对象""" pool = MultiProcessingSharedPool() try: success = pool.put(key, data) result_queue.put(("success", success)) except Exception as e: result_queue.put(("error", str(e))) # ==================== 测试类 ==================== class TestEmptyAndNoneValues: """测试空值和 None 处理""" def test_put_empty_string_value(self): """测试存储空字符串值""" pool = MultiProcessingSharedPool() pool.clear() pool.put("empty_value_key", "") assert pool.get("empty_value_key") == "" def test_put_none_value(self): """测试存储 None 值""" pool = MultiProcessingSharedPool() pool.clear() pool.put("none_value_key", None) assert pool.get("none_value_key") is None def test_put_empty_list(self): """测试存储空列表""" pool = MultiProcessingSharedPool() pool.clear() pool.put("empty_list", []) assert pool.get("empty_list") == [] def test_put_empty_dict(self): """测试存储空字典""" pool = MultiProcessingSharedPool() pool.clear() pool.put("empty_dict", {}) assert pool.get("empty_dict") == {} def test_put_empty_tuple(self): """测试存储空元组""" pool = MultiProcessingSharedPool() pool.clear() pool.put("empty_tuple", ()) assert pool.get("empty_tuple") == () def test_put_empty_set(self): """测试存储空集合""" pool = MultiProcessingSharedPool() pool.clear() pool.put("empty_set", set()) assert pool.get("empty_set") == set() class TestSpecialLabelNames: """测试特殊 label 名称""" def test_empty_string_label(self): """测试空字符串作为 label""" pool = MultiProcessingSharedPool() pool.clear() pool.put("", "empty_key_value") assert pool.get("") == "empty_key_value" assert pool.exists("") def test_empty_string_label_cross_process(self): """测试空字符串 label 跨进程""" pool = MultiProcessingSharedPool() pool.clear() pool.put("", "parent_empty_value") result_queue = multiprocessing.Queue() p = multiprocessing.Process(target=worker_put_empty_key, args=(result_queue,)) p.start() p.join() status, result = result_queue.get() # 子进程应该可以覆盖空字符串 key assert status == "success" assert result == "empty_value" def test_unicode_label(self): """测试 Unicode 字符作为 label""" pool = MultiProcessingSharedPool() pool.clear() unicode_keys = [ "中文键", "日本語キー", "한국어키", "emoji_😀", "special_©_®_™", "math_∑_∏_√", "arrows_→_←_↑_↓", ] for key in unicode_keys: pool.put(key, f"value_for_{key}") for key in unicode_keys: assert pool.get(key) == f"value_for_{key}" def test_whitespace_label(self): """测试空白字符作为 label""" pool = MultiProcessingSharedPool() pool.clear() whitespace_keys = [ " ", # 单个空格 " ", # 两个空格 "\t", # Tab "\n", # 换行 "\r\n", # Windows 换行 " key_with_leading_space", "key_with_trailing_space ", " key_with_both_spaces ", "key\twith\ttabs", ] for key in whitespace_keys: pool.put(key, f"value_for_repr_{repr(key)}") for key in whitespace_keys: assert pool.get(key) == f"value_for_repr_{repr(key)}" def test_special_chars_label(self): """测试特殊字符作为 label""" pool = MultiProcessingSharedPool() pool.clear() special_keys = [ "key.with.dots", "key/with/slashes", "key:with:colons", "key|with|pipes", "key*with*asterisks", "key?with?question", "keybrackets", "key[with]square", "key{with}curly", "key+with+plus", "key=with=equals", "key!with!exclamation", "key@with@at", "key#with#hash", "key$with$dollar", "key%with%percent", "key^with^caret", "key&with&ersand", "key'with'quotes", 'key"with"double', "key`with`backtick", "key~with~tilde", "key-with-hyphens", "key_with_underscores", ] for key in special_keys: pool.put(key, f"value_for_{key}") for key in special_keys: assert pool.get(key) == f"value_for_{key}" def test_very_long_label(self): """测试超长 label""" pool = MultiProcessingSharedPool() pool.clear() # 1000 字符的 label long_key = "a" * 1000 pool.put(long_key, "long_key_value") assert pool.get(long_key) == "long_key_value" class TestNonExistentKeys: """测试不存在的 key 处理""" def test_get_nonexistent(self): """测试获取不存在的 key""" pool = MultiProcessingSharedPool() pool.clear() result = pool.get("nonexistent_key_12345") assert result is None def test_get_nonexistent_with_default(self): """测试获取不存在的 key 带默认值""" pool = MultiProcessingSharedPool() pool.clear() assert pool.get("nonexistent", "default") == "default" assert pool.get("nonexistent", 0) == 0 assert pool.get("nonexistent", []) == [] assert pool.get("nonexistent", {}) == {} def test_get_nonexistent_cross_process(self): """测试跨进程获取不存在的 key""" pool = MultiProcessingSharedPool() pool.clear() result_queue = multiprocessing.Queue() p = multiprocessing.Process(target=worker_get_nonexistent, args=(result_queue,)) p.start() p.join() result = result_queue.get() assert result is None def test_remove_nonexistent(self): """测试删除不存在的 key""" pool = MultiProcessingSharedPool() pool.clear() assert pool.remove("nonexistent_key") is False def test_pop_nonexistent(self): """测试弹出不存在的 key""" pool = MultiProcessingSharedPool() pool.clear() assert pool.pop("nonexistent_key") is None assert pool.pop("nonexistent_key", "default") == "default" def test_exists_nonexistent(self): """测试检查不存在的 key""" pool = MultiProcessingSharedPool() pool.clear() assert pool.exists("nonexistent_key") is False class TestLargeObjects: """测试大对象序列化""" def test_large_list(self): """测试大型列表""" pool = MultiProcessingSharedPool() pool.clear() large_list = list(range(100000)) pool.put("large_list", large_list) retrieved = pool.get("large_list") assert len(retrieved) == 100000 assert retrieved[0] == 0 assert retrieved[99999] == 99999 def test_large_dict(self): """测试大型字典""" pool = MultiProcessingSharedPool() pool.clear() large_dict = {f"key_{i}": f"value_{i}" for i in range(10000)} pool.put("large_dict", large_dict) retrieved = pool.get("large_dict") assert len(retrieved) == 10000 assert retrieved["key_0"] == "value_0" assert retrieved["key_9999"] == "value_9999" def test_large_string(self): """测试大型字符串""" pool = MultiProcessingSharedPool() pool.clear() large_string = "x" * 1000000 # 1MB 字符串 pool.put("large_string", large_string) retrieved = pool.get("large_string") assert len(retrieved) == 1000000 assert retrieved[0] == "x" assert retrieved[-1] == "x" def test_deeply_nested_structure(self): """测试深度嵌套结构""" pool = MultiProcessingSharedPool() pool.clear() # 创建深度嵌套的字典 depth = 50 nested = "bottom" for i in range(depth): nested = {"level": i, "nested": nested} pool.put("deep_nested", nested) retrieved = pool.get("deep_nested") # 验证嵌套深度 current = retrieved for i in range(depth): assert current["level"] == depth - 1 - i current = current["nested"] assert current == "bottom" def test_large_object_cross_process(self): """测试跨进程传递大对象""" pool = MultiProcessingSharedPool() pool.clear() large_data = {"items": list(range(10000)), "name": "large_test"} result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_put_large_object, args=("large_cross", large_data, result_queue), ) p.start() p.join() status, result = result_queue.get() assert status == "success" assert result is True class TestCircularReferences: """测试循环引用""" def test_circular_list(self): """测试列表中的循环引用""" pool = MultiProcessingSharedPool() pool.clear() # 创建循环引用列表 circular = [1, 2, 3] circular.append(circular) # 循环引用 pool.put("circular_list", circular) retrieved = pool.get("circular_list") assert retrieved[0] == 1 assert retrieved[1] == 2 assert retrieved[2] == 3 # 循环引用应该被正确处理 assert retrieved[3] is not None def test_circular_dict(self): """测试字典中的循环引用""" pool = MultiProcessingSharedPool() pool.clear() # 创建循环引用字典 circular = {"a": 1, "b": 2} circular["self"] = circular # 循环引用 pool.put("circular_dict", circular) retrieved = pool.get("circular_dict") assert retrieved["a"] == 1 assert retrieved["b"] == 2 # 循环引用应该被正确处理 assert "self" in retrieved class TestBinaryData: """测试二进制数据""" def test_bytes_data(self): """测试字节数据""" pool = MultiProcessingSharedPool() pool.clear() binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc" pool.put("binary_data", binary_data) retrieved = pool.get("binary_data") assert retrieved == binary_data def test_large_binary_data(self): """测试大型二进制数据""" pool = MultiProcessingSharedPool() pool.clear() binary_data = bytes(range(256)) * 1000 # 256KB pool.put("large_binary", binary_data) retrieved = pool.get("large_binary") assert retrieved == binary_data def test_bytearray_data(self): """测试 bytearray 数据""" pool = MultiProcessingSharedPool() pool.clear() ba = bytearray(b"\x00\x01\x02\x03") pool.put("bytearray_data", ba) retrieved = pool.get("bytearray_data") assert retrieved == ba class TestMixedTypes: """测试混合类型数据""" def test_heterogeneous_list(self): """测试异构列表""" pool = MultiProcessingSharedPool() pool.clear() mixed_list = [ 1, # int 3.14, # float "string", # str True, # bool None, # NoneType [1, 2, 3], # list {"a": 1}, # dict (1, 2), # tuple {1, 2, 3}, # set b"binary", # bytes ] pool.put("mixed_list", mixed_list) retrieved = pool.get("mixed_list") assert retrieved[0] == 1 assert abs(retrieved[1] - 3.14) < 1e-10 assert retrieved[2] == "string" assert retrieved[3] is True assert retrieved[4] is None assert retrieved[5] == [1, 2, 3] assert retrieved[6] == {"a": 1} assert retrieved[7] == (1, 2) assert retrieved[8] == {1, 2, 3} assert retrieved[9] == b"binary" def test_heterogeneous_dict(self): """测试异构字典""" pool = MultiProcessingSharedPool() pool.clear() mixed_dict = { "int_key": 42, "float_key": 3.14, "str_key": "hello", "bool_key": True, "none_key": None, "list_key": [1, 2, 3], "dict_key": {"nested": "value"}, "tuple_key": (1, 2, 3), } pool.put("mixed_dict", mixed_dict) retrieved = pool.get("mixed_dict") for key, value in mixed_dict.items(): if isinstance(value, float): assert abs(retrieved[key] - value) < 1e-10 else: assert retrieved[key] == value class TestConcurrentAccess: """测试并发访问稳定性""" def worker_stress_test(key_prefix, iterations, result_queue): """子进程:压力测试""" pool = MultiProcessingSharedPool() errors = [] for i in range(iterations): try: key = f"{key_prefix}_{i}" pool.put(key, f"value_{i}") value = pool.get(key) if value != f"value_{i}": errors.append(f"Value mismatch at {key}") pool.remove(key) except Exception as e: errors.append(str(e)) result_queue.put(errors) class TestConcurrentAccess: """测试并发访问稳定性""" def test_stress_concurrent_writes(self): """压力测试:并发写入""" pool = MultiProcessingSharedPool() pool.clear() num_processes = 4 iterations = 100 result_queue = multiprocessing.Queue() processes = [] for i in range(num_processes): p = multiprocessing.Process( target=worker_stress_test, args=(f"stress_{i}", iterations, result_queue), ) processes.append(p) p.start() for p in processes: p.join() # 收集所有错误 all_errors = [] for _ in range(num_processes): all_errors.extend(result_queue.get()) # 应该没有错误 assert len(all_errors) == 0, f"Errors occurred: {all_errors}" def test_rapid_put_get_cycle(self): """测试快速 put-get 循环""" pool = MultiProcessingSharedPool() pool.clear() for i in range(1000): pool.put("rapid_key", f"value_{i}") value = pool.get("rapid_key") assert value == f"value_{i}" def test_rapid_key_creation_deletion(self): """测试快速创建和删除 key""" pool = MultiProcessingSharedPool() pool.clear() for i in range(100): key = f"temp_key_{i}" pool.put(key, f"temp_value_{i}") assert pool.exists(key) pool.remove(key) assert not pool.exists(key) class TestErrorRecovery: """测试错误恢复能力""" def test_put_after_error(self): """测试错误后可以继续 put""" pool = MultiProcessingSharedPool() pool.clear() # 尝试使用非法 key 类型 try: pool.put(123, "value") # 应该抛出 TypeError except TypeError: pass # 应该可以继续正常使用 pool.put("valid_key", "valid_value") assert pool.get("valid_key") == "valid_value" def test_get_after_nonexistent(self): """测试获取不存在的 key 后可以继续使用""" pool = MultiProcessingSharedPool() pool.clear() # 获取不存在的 key result = pool.get("nonexistent") assert result is None # 应该可以继续正常使用 pool.put("new_key", "new_value") assert pool.get("new_key") == "new_value" def test_multiple_singleton_access(self): """测试多次获取单例后访问""" pool1 = MultiProcessingSharedPool() pool1.put("key1", "value1") pool2 = MultiProcessingSharedPool() pool2.put("key2", "value2") pool3 = MultiProcessingSharedPool.get_instance() pool3.put("key3", "value3") # 所有实例应该看到相同的数据 assert pool1.get("key1") == "value1" assert pool1.get("key2") == "value2" assert pool1.get("key3") == "value3" class TestCleanup: """测试清理功能""" def test_clear_after_multiple_puts(self): """测试多次 put 后 clear""" pool = MultiProcessingSharedPool() pool.clear() for i in range(100): pool.put(f"key_{i}", f"value_{i}") assert pool.size() == 100 pool.clear() assert pool.size() == 0 assert pool.keys() == [] def test_remove_all_keys_one_by_one(self): """测试逐个删除所有 key""" pool = MultiProcessingSharedPool() pool.clear() keys = [f"key_{i}" for i in range(50)] for key in keys: pool.put(key, f"value_for_{key}") for key in keys: assert pool.remove(key) is True assert pool.size() == 0