""" NumPy ndarray 支持测试 - 验证数组数据共享 """ import multiprocessing import time import pytest import numpy as np from mpsp.mpsp import MultiProcessingSharedPool # ==================== 辅助函数 ==================== def worker_get_array(key, result_queue): """子进程:获取数组并放入结果队列""" pool = MultiProcessingSharedPool() arr = pool.get(key) result_queue.put(arr) def worker_modify_array(key, index, value, result_queue): """子进程:读取数组、修改特定索引、返回原值""" pool = MultiProcessingSharedPool() arr = pool.get(key) old_value = arr[index].copy() if isinstance(index, tuple) else arr[index] # 注意:此处修改的是副本,因为 get 返回的是数组副本 result_queue.put(old_value) def worker_sum_array(key, result_queue): """子进程:计算数组元素和""" pool = MultiProcessingSharedPool() arr = pool.get(key) result_queue.put(np.sum(arr)) def worker_check_array_properties(key, expected_shape, expected_dtype, result_queue): """子进程:检查数组属性""" pool = MultiProcessingSharedPool() arr = pool.get(key) result_queue.put( (arr.shape == expected_shape, arr.dtype == expected_dtype, arr.ndim) ) # ==================== 测试类 ==================== class TestNDBasicOperations: """测试 NumPy 数组基本操作""" def test_1d_array(self): """测试一维数组存取""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([1, 2, 3, 4, 5]) pool.put("1d_array", arr) retrieved = pool.get("1d_array") np.testing.assert_array_equal(retrieved, arr) assert retrieved.dtype == arr.dtype def test_2d_array(self): """测试二维数组存取""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pool.put("2d_array", arr) retrieved = pool.get("2d_array") np.testing.assert_array_equal(retrieved, arr) assert retrieved.shape == (3, 3) def test_3d_array(self): """测试三维数组存取""" pool = MultiProcessingSharedPool() pool.clear() arr = np.arange(24).reshape(2, 3, 4) pool.put("3d_array", arr) retrieved = pool.get("3d_array") np.testing.assert_array_equal(retrieved, arr) assert retrieved.shape == (2, 3, 4) def test_multidimensional_array(self): """测试高维数组存取""" pool = MultiProcessingSharedPool() pool.clear() # 4维数组 arr = np.random.rand(2, 3, 4, 5) pool.put("4d_array", arr) retrieved = pool.get("4d_array") np.testing.assert_array_almost_equal(retrieved, arr) assert retrieved.shape == (2, 3, 4, 5) class TestNDDataTypes: """测试不同数据类型的 NumPy 数组""" def test_integer_dtypes(self): """测试整数类型数组""" pool = MultiProcessingSharedPool() pool.clear() dtypes = [ np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, ] for dtype in dtypes: arr = np.array([1, 2, 3], dtype=dtype) key = f"int_array_{dtype.__name__}" pool.put(key, arr) retrieved = pool.get(key) assert retrieved.dtype == dtype np.testing.assert_array_equal(retrieved, arr) def test_float_dtypes(self): """测试浮点类型数组""" pool = MultiProcessingSharedPool() pool.clear() dtypes = [np.float32, np.float64] for dtype in dtypes: arr = np.array([1.1, 2.2, 3.3], dtype=dtype) key = f"float_array_{dtype.__name__}" pool.put(key, arr) retrieved = pool.get(key) assert retrieved.dtype == dtype np.testing.assert_array_almost_equal(retrieved, arr) def test_complex_dtypes(self): """测试复数类型数组""" pool = MultiProcessingSharedPool() pool.clear() dtypes = [np.complex64, np.complex128] for dtype in dtypes: arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=dtype) key = f"complex_array_{dtype.__name__}" pool.put(key, arr) retrieved = pool.get(key) assert retrieved.dtype == dtype np.testing.assert_array_almost_equal(retrieved, arr) def test_bool_dtype(self): """测试布尔类型数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([True, False, True, True, False], dtype=np.bool_) pool.put("bool_array", arr) retrieved = pool.get("bool_array") assert retrieved.dtype == np.bool_ np.testing.assert_array_equal(retrieved, arr) def test_string_dtype(self): """测试字符串类型数组""" pool = MultiProcessingSharedPool() pool.clear() # Unicode 字符串 arr = np.array(["hello", "world", "mpsp"]) pool.put("string_array", arr) retrieved = pool.get("string_array") np.testing.assert_array_equal(retrieved, arr) def test_object_dtype(self): """测试对象类型数组""" pool = MultiProcessingSharedPool() pool.clear() # 对象数组可以存储不同类型的数据 arr = np.array([1, "string", 3.14, [1, 2, 3]], dtype=object) pool.put("object_array", arr) retrieved = pool.get("object_array") assert retrieved.dtype == object np.testing.assert_array_equal(retrieved, arr) class TestNDCrossProcess: """测试 NumPy 数组跨进程共享""" def test_1d_array_cross_process(self): """测试一维数组跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([10, 20, 30, 40, 50]) pool.put("shared_1d", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_array, args=("shared_1d", result_queue) ) p.start() p.join() result = result_queue.get() np.testing.assert_array_equal(result, arr) def test_2d_array_cross_process(self): """测试二维数组跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([[1, 2], [3, 4], [5, 6]]) pool.put("shared_2d", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_array, args=("shared_2d", result_queue) ) p.start() p.join() result = result_queue.get() np.testing.assert_array_equal(result, arr) def test_array_properties_cross_process(self): """测试数组属性跨进程保持""" pool = MultiProcessingSharedPool() pool.clear() arr = np.arange(12).reshape(3, 4).astype(np.float32) pool.put("property_test", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_check_array_properties, args=("property_test", (3, 4), np.float32, result_queue), ) p.start() p.join() shape_match, dtype_match, ndim = result_queue.get() assert shape_match assert dtype_match assert ndim == 2 def test_array_operations_cross_process(self): """测试在子进程中执行数组操作""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) pool.put("sum_test", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_sum_array, args=("sum_test", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 55 # sum of 1..10 class TestNDSpecialArrays: """测试特殊类型的 NumPy 数组""" def test_empty_array(self): """测试空数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([]) pool.put("empty_array", arr) retrieved = pool.get("empty_array") np.testing.assert_array_equal(retrieved, arr) assert len(retrieved) == 0 def test_single_element_array(self): """测试单元素数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([42]) pool.put("single_element", arr) retrieved = pool.get("single_element") np.testing.assert_array_equal(retrieved, arr) assert retrieved[0] == 42 def test_zeros_array(self): """测试零数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.zeros((5, 5)) pool.put("zeros_array", arr) retrieved = pool.get("zeros_array") np.testing.assert_array_equal(retrieved, arr) assert np.all(retrieved == 0) def test_ones_array(self): """测试全一数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.ones((3, 4)) pool.put("ones_array", arr) retrieved = pool.get("ones_array") np.testing.assert_array_equal(retrieved, arr) assert np.all(retrieved == 1) def test_eye_array(self): """测试单位矩阵""" pool = MultiProcessingSharedPool() pool.clear() arr = np.eye(5) pool.put("eye_array", arr) retrieved = pool.get("eye_array") np.testing.assert_array_equal(retrieved, arr) assert np.all(np.diag(retrieved) == 1) def test_nan_and_inf_array(self): """测试包含 NaN 和 Inf 的数组""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([1.0, np.nan, np.inf, -np.inf, 2.0]) pool.put("special_values", arr) retrieved = pool.get("special_values") assert np.isnan(retrieved[1]) assert np.isinf(retrieved[2]) and retrieved[2] > 0 assert np.isinf(retrieved[3]) and retrieved[3] < 0 def test_masked_array(self): """测试掩码数组""" pool = MultiProcessingSharedPool() pool.clear() data = np.array([1, 2, 3, 4, 5]) mask = np.array([False, True, False, True, False]) arr = np.ma.array(data, mask=mask) pool.put("masked_array", arr) retrieved = pool.get("masked_array") np.testing.assert_array_equal(retrieved.data, data) np.testing.assert_array_equal(retrieved.mask, mask) class TestNDLargeArrays: """测试大型 NumPy 数组""" def test_large_1d_array(self): """测试大型一维数组""" pool = MultiProcessingSharedPool() pool.clear() # 10000 个元素的数组 arr = np.arange(10000) pool.put("large_1d", arr) retrieved = pool.get("large_1d") np.testing.assert_array_equal(retrieved, arr) def test_large_2d_array(self): """测试大型二维数组""" pool = MultiProcessingSharedPool() pool.clear() # 1000x100 的数组 arr = np.random.rand(1000, 100) pool.put("large_2d", arr) retrieved = pool.get("large_2d") np.testing.assert_array_almost_equal(retrieved, arr) def test_large_array_cross_process(self): """测试大型数组跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() arr = np.arange(100000).reshape(1000, 100) pool.put("large_cross", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_sum_array, args=("large_cross", result_queue) ) p.start() p.join() result = result_queue.get() expected_sum = np.sum(arr) assert result == expected_sum class TestNDStructuredArrays: """测试结构化数组""" def test_structured_array(self): """测试结构化数组存取""" pool = MultiProcessingSharedPool() pool.clear() dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f4")]) arr = np.array( [("Alice", 25, 55.5), ("Bob", 30, 85.3), ("Charlie", 35, 75.0)], dtype=dt ) pool.put("structured_array", arr) retrieved = pool.get("structured_array") assert retrieved.dtype == dt np.testing.assert_array_equal(retrieved, arr) def test_structured_array_cross_process(self): """测试结构化数组跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() dt = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")]) arr = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dt) pool.put("structured_cross", arr) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_get_array, args=("structured_cross", result_queue) ) p.start() p.join() result = result_queue.get() np.testing.assert_array_equal(result, arr) class TestNDMatrixOperations: """测试矩阵操作相关的数组""" def test_matrix_multiplication(self): """测试矩阵乘法用的数组""" pool = MultiProcessingSharedPool() pool.clear() A = np.array([[1, 2], [3, 4]]) B = np.array([[5, 6], [7, 8]]) pool.put("matrix_A", A) pool.put("matrix_B", B) retrieved_A = pool.get("matrix_A") retrieved_B = pool.get("matrix_B") result = np.dot(retrieved_A, retrieved_B) expected = np.array([[19, 22], [43, 50]]) np.testing.assert_array_equal(result, expected) def test_eigenvalue_computation(self): """测试特征值计算""" pool = MultiProcessingSharedPool() pool.clear() # 对称矩阵 arr = np.array([[4, 2], [2, 4]]) pool.put("eigen_matrix", arr) retrieved = pool.get("eigen_matrix") eigenvalues, eigenvectors = np.linalg.eig(retrieved) # 特征值应该是 6 和 2 assert np.allclose(sorted(eigenvalues), [2, 6]) def test_svd_decomposition(self): """测试 SVD 分解""" pool = MultiProcessingSharedPool() pool.clear() arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pool.put("svd_matrix", arr) retrieved = pool.get("svd_matrix") U, S, Vh = np.linalg.svd(retrieved) # 验证分解结果 reconstructed = U @ np.diag(S) @ Vh np.testing.assert_array_almost_equal(reconstructed, arr) class TestNDBroadcasting: """测试广播机制""" def test_broadcasting_operation(self): """测试广播操作""" pool = MultiProcessingSharedPool() pool.clear() arr_2d = np.array([[1, 2, 3], [4, 5, 6]]) arr_1d = np.array([10, 20, 30]) pool.put("array_2d", arr_2d) pool.put("array_1d", arr_1d) retrieved_2d = pool.get("array_2d") retrieved_1d = pool.get("array_1d") result = retrieved_2d + retrieved_1d expected = np.array([[11, 22, 33], [14, 25, 36]]) np.testing.assert_array_equal(result, expected) class TestNDMixedTypes: """测试混合数据类型的数组相关操作""" def test_array_in_dict(self): """测试字典中包含数组""" pool = MultiProcessingSharedPool() pool.clear() data = { "matrix": np.array([[1, 2], [3, 4]]), "vector": np.array([1, 2, 3]), "scalar": 42, "name": "test", } pool.put("dict_with_arrays", data) retrieved = pool.get("dict_with_arrays") np.testing.assert_array_equal(retrieved["matrix"], data["matrix"]) np.testing.assert_array_equal(retrieved["vector"], data["vector"]) assert retrieved["scalar"] == 42 assert retrieved["name"] == "test" def test_array_in_list(self): """测试列表中包含数组""" pool = MultiProcessingSharedPool() pool.clear() data = [np.array([1, 2, 3]), np.array([[4, 5], [6, 7]]), "string", 42] pool.put("list_with_arrays", data) retrieved = pool.get("list_with_arrays") np.testing.assert_array_equal(retrieved[0], data[0]) np.testing.assert_array_equal(retrieved[1], data[1]) assert retrieved[2] == "string" assert retrieved[3] == 42