Files
mpsp/test/test_numpy.py
2026-02-21 12:00:47 +08:00

570 lines
16 KiB
Python

"""
NumPy ndarray 支持测试 - 验证数组数据共享
"""
import multiprocessing
import time
import pytest
import numpy as np
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数 ====================
def worker_get_array(key, result_queue):
"""子进程:获取数组并放入结果队列"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(arr)
def worker_modify_array(key, index, value, result_queue):
"""子进程:读取数组、修改特定索引、返回原值"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
old_value = arr[index].copy() if isinstance(index, tuple) else arr[index]
# 注意:此处修改的是副本,因为 get 返回的是数组副本
result_queue.put(old_value)
def worker_sum_array(key, result_queue):
"""子进程:计算数组元素和"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(np.sum(arr))
def worker_check_array_properties(key, expected_shape, expected_dtype, result_queue):
"""子进程:检查数组属性"""
pool = MultiProcessingSharedPool()
arr = pool.get(key)
result_queue.put(
(arr.shape == expected_shape, arr.dtype == expected_dtype, arr.ndim)
)
# ==================== 测试类 ====================
class TestNDBasicOperations:
"""测试 NumPy 数组基本操作"""
def test_1d_array(self):
"""测试一维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1, 2, 3, 4, 5])
pool.put("1d_array", arr)
retrieved = pool.get("1d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.dtype == arr.dtype
def test_2d_array(self):
"""测试二维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
pool.put("2d_array", arr)
retrieved = pool.get("2d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.shape == (3, 3)
def test_3d_array(self):
"""测试三维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(24).reshape(2, 3, 4)
pool.put("3d_array", arr)
retrieved = pool.get("3d_array")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved.shape == (2, 3, 4)
def test_multidimensional_array(self):
"""测试高维数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 4维数组
arr = np.random.rand(2, 3, 4, 5)
pool.put("4d_array", arr)
retrieved = pool.get("4d_array")
np.testing.assert_array_almost_equal(retrieved, arr)
assert retrieved.shape == (2, 3, 4, 5)
class TestNDDataTypes:
"""测试不同数据类型的 NumPy 数组"""
def test_integer_dtypes(self):
"""测试整数类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
]
for dtype in dtypes:
arr = np.array([1, 2, 3], dtype=dtype)
key = f"int_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_equal(retrieved, arr)
def test_float_dtypes(self):
"""测试浮点类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [np.float32, np.float64]
for dtype in dtypes:
arr = np.array([1.1, 2.2, 3.3], dtype=dtype)
key = f"float_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_almost_equal(retrieved, arr)
def test_complex_dtypes(self):
"""测试复数类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
dtypes = [np.complex64, np.complex128]
for dtype in dtypes:
arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=dtype)
key = f"complex_array_{dtype.__name__}"
pool.put(key, arr)
retrieved = pool.get(key)
assert retrieved.dtype == dtype
np.testing.assert_array_almost_equal(retrieved, arr)
def test_bool_dtype(self):
"""测试布尔类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([True, False, True, True, False], dtype=np.bool_)
pool.put("bool_array", arr)
retrieved = pool.get("bool_array")
assert retrieved.dtype == np.bool_
np.testing.assert_array_equal(retrieved, arr)
def test_string_dtype(self):
"""测试字符串类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# Unicode 字符串
arr = np.array(["hello", "world", "mpsp"])
pool.put("string_array", arr)
retrieved = pool.get("string_array")
np.testing.assert_array_equal(retrieved, arr)
def test_object_dtype(self):
"""测试对象类型数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 对象数组可以存储不同类型的数据
arr = np.array([1, "string", 3.14, [1, 2, 3]], dtype=object)
pool.put("object_array", arr)
retrieved = pool.get("object_array")
assert retrieved.dtype == object
np.testing.assert_array_equal(retrieved, arr)
class TestNDCrossProcess:
"""测试 NumPy 数组跨进程共享"""
def test_1d_array_cross_process(self):
"""测试一维数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([10, 20, 30, 40, 50])
pool.put("shared_1d", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("shared_1d", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
def test_2d_array_cross_process(self):
"""测试二维数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2], [3, 4], [5, 6]])
pool.put("shared_2d", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("shared_2d", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
def test_array_properties_cross_process(self):
"""测试数组属性跨进程保持"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(12).reshape(3, 4).astype(np.float32)
pool.put("property_test", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_check_array_properties,
args=("property_test", (3, 4), np.float32, result_queue),
)
p.start()
p.join()
shape_match, dtype_match, ndim = result_queue.get()
assert shape_match
assert dtype_match
assert ndim == 2
def test_array_operations_cross_process(self):
"""测试在子进程中执行数组操作"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
pool.put("sum_test", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_sum_array, args=("sum_test", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 55 # sum of 1..10
class TestNDSpecialArrays:
"""测试特殊类型的 NumPy 数组"""
def test_empty_array(self):
"""测试空数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([])
pool.put("empty_array", arr)
retrieved = pool.get("empty_array")
np.testing.assert_array_equal(retrieved, arr)
assert len(retrieved) == 0
def test_single_element_array(self):
"""测试单元素数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([42])
pool.put("single_element", arr)
retrieved = pool.get("single_element")
np.testing.assert_array_equal(retrieved, arr)
assert retrieved[0] == 42
def test_zeros_array(self):
"""测试零数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.zeros((5, 5))
pool.put("zeros_array", arr)
retrieved = pool.get("zeros_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(retrieved == 0)
def test_ones_array(self):
"""测试全一数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.ones((3, 4))
pool.put("ones_array", arr)
retrieved = pool.get("ones_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(retrieved == 1)
def test_eye_array(self):
"""测试单位矩阵"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.eye(5)
pool.put("eye_array", arr)
retrieved = pool.get("eye_array")
np.testing.assert_array_equal(retrieved, arr)
assert np.all(np.diag(retrieved) == 1)
def test_nan_and_inf_array(self):
"""测试包含 NaN 和 Inf 的数组"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([1.0, np.nan, np.inf, -np.inf, 2.0])
pool.put("special_values", arr)
retrieved = pool.get("special_values")
assert np.isnan(retrieved[1])
assert np.isinf(retrieved[2]) and retrieved[2] > 0
assert np.isinf(retrieved[3]) and retrieved[3] < 0
def test_masked_array(self):
"""测试掩码数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = np.array([1, 2, 3, 4, 5])
mask = np.array([False, True, False, True, False])
arr = np.ma.array(data, mask=mask)
pool.put("masked_array", arr)
retrieved = pool.get("masked_array")
np.testing.assert_array_equal(retrieved.data, data)
np.testing.assert_array_equal(retrieved.mask, mask)
class TestNDLargeArrays:
"""测试大型 NumPy 数组"""
def test_large_1d_array(self):
"""测试大型一维数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 10000 个元素的数组
arr = np.arange(10000)
pool.put("large_1d", arr)
retrieved = pool.get("large_1d")
np.testing.assert_array_equal(retrieved, arr)
def test_large_2d_array(self):
"""测试大型二维数组"""
pool = MultiProcessingSharedPool()
pool.clear()
# 1000x100 的数组
arr = np.random.rand(1000, 100)
pool.put("large_2d", arr)
retrieved = pool.get("large_2d")
np.testing.assert_array_almost_equal(retrieved, arr)
def test_large_array_cross_process(self):
"""测试大型数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.arange(100000).reshape(1000, 100)
pool.put("large_cross", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_sum_array, args=("large_cross", result_queue)
)
p.start()
p.join()
result = result_queue.get()
expected_sum = np.sum(arr)
assert result == expected_sum
class TestNDStructuredArrays:
"""测试结构化数组"""
def test_structured_array(self):
"""测试结构化数组存取"""
pool = MultiProcessingSharedPool()
pool.clear()
dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f4")])
arr = np.array(
[("Alice", 25, 55.5), ("Bob", 30, 85.3), ("Charlie", 35, 75.0)], dtype=dt
)
pool.put("structured_array", arr)
retrieved = pool.get("structured_array")
assert retrieved.dtype == dt
np.testing.assert_array_equal(retrieved, arr)
def test_structured_array_cross_process(self):
"""测试结构化数组跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
dt = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")])
arr = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dt)
pool.put("structured_cross", arr)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_array, args=("structured_cross", result_queue)
)
p.start()
p.join()
result = result_queue.get()
np.testing.assert_array_equal(result, arr)
class TestNDMatrixOperations:
"""测试矩阵操作相关的数组"""
def test_matrix_multiplication(self):
"""测试矩阵乘法用的数组"""
pool = MultiProcessingSharedPool()
pool.clear()
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
pool.put("matrix_A", A)
pool.put("matrix_B", B)
retrieved_A = pool.get("matrix_A")
retrieved_B = pool.get("matrix_B")
result = np.dot(retrieved_A, retrieved_B)
expected = np.array([[19, 22], [43, 50]])
np.testing.assert_array_equal(result, expected)
def test_eigenvalue_computation(self):
"""测试特征值计算"""
pool = MultiProcessingSharedPool()
pool.clear()
# 对称矩阵
arr = np.array([[4, 2], [2, 4]])
pool.put("eigen_matrix", arr)
retrieved = pool.get("eigen_matrix")
eigenvalues, eigenvectors = np.linalg.eig(retrieved)
# 特征值应该是 6 和 2
assert np.allclose(sorted(eigenvalues), [2, 6])
def test_svd_decomposition(self):
"""测试 SVD 分解"""
pool = MultiProcessingSharedPool()
pool.clear()
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
pool.put("svd_matrix", arr)
retrieved = pool.get("svd_matrix")
U, S, Vh = np.linalg.svd(retrieved)
# 验证分解结果
reconstructed = U @ np.diag(S) @ Vh
np.testing.assert_array_almost_equal(reconstructed, arr)
class TestNDBroadcasting:
"""测试广播机制"""
def test_broadcasting_operation(self):
"""测试广播操作"""
pool = MultiProcessingSharedPool()
pool.clear()
arr_2d = np.array([[1, 2, 3], [4, 5, 6]])
arr_1d = np.array([10, 20, 30])
pool.put("array_2d", arr_2d)
pool.put("array_1d", arr_1d)
retrieved_2d = pool.get("array_2d")
retrieved_1d = pool.get("array_1d")
result = retrieved_2d + retrieved_1d
expected = np.array([[11, 22, 33], [14, 25, 36]])
np.testing.assert_array_equal(result, expected)
class TestNDMixedTypes:
"""测试混合数据类型的数组相关操作"""
def test_array_in_dict(self):
"""测试字典中包含数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = {
"matrix": np.array([[1, 2], [3, 4]]),
"vector": np.array([1, 2, 3]),
"scalar": 42,
"name": "test",
}
pool.put("dict_with_arrays", data)
retrieved = pool.get("dict_with_arrays")
np.testing.assert_array_equal(retrieved["matrix"], data["matrix"])
np.testing.assert_array_equal(retrieved["vector"], data["vector"])
assert retrieved["scalar"] == 42
assert retrieved["name"] == "test"
def test_array_in_list(self):
"""测试列表中包含数组"""
pool = MultiProcessingSharedPool()
pool.clear()
data = [np.array([1, 2, 3]), np.array([[4, 5], [6, 7]]), "string", 42]
pool.put("list_with_arrays", data)
retrieved = pool.get("list_with_arrays")
np.testing.assert_array_equal(retrieved[0], data[0])
np.testing.assert_array_equal(retrieved[1], data[1])
assert retrieved[2] == "string"
assert retrieved[3] == 42