570 lines
16 KiB
Python
570 lines
16 KiB
Python
"""
|
|
NumPy ndarray 支持测试 - 验证数组数据共享
|
|
"""
|
|
|
|
import multiprocessing
|
|
import time
|
|
import pytest
|
|
import numpy as np
|
|
from mpsp.mpsp import MultiProcessingSharedPool
|
|
|
|
|
|
# ==================== 辅助函数 ====================
|
|
|
|
|
|
def worker_get_array(key, result_queue):
|
|
"""子进程:获取数组并放入结果队列"""
|
|
pool = MultiProcessingSharedPool()
|
|
arr = pool.get(key)
|
|
result_queue.put(arr)
|
|
|
|
|
|
def worker_modify_array(key, index, value, result_queue):
|
|
"""子进程:读取数组、修改特定索引、返回原值"""
|
|
pool = MultiProcessingSharedPool()
|
|
arr = pool.get(key)
|
|
old_value = arr[index].copy() if isinstance(index, tuple) else arr[index]
|
|
# 注意:此处修改的是副本,因为 get 返回的是数组副本
|
|
result_queue.put(old_value)
|
|
|
|
|
|
def worker_sum_array(key, result_queue):
|
|
"""子进程:计算数组元素和"""
|
|
pool = MultiProcessingSharedPool()
|
|
arr = pool.get(key)
|
|
result_queue.put(np.sum(arr))
|
|
|
|
|
|
def worker_check_array_properties(key, expected_shape, expected_dtype, result_queue):
|
|
"""子进程:检查数组属性"""
|
|
pool = MultiProcessingSharedPool()
|
|
arr = pool.get(key)
|
|
result_queue.put(
|
|
(arr.shape == expected_shape, arr.dtype == expected_dtype, arr.ndim)
|
|
)
|
|
|
|
|
|
# ==================== 测试类 ====================
|
|
|
|
|
|
class TestNDBasicOperations:
|
|
"""测试 NumPy 数组基本操作"""
|
|
|
|
def test_1d_array(self):
|
|
"""测试一维数组存取"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([1, 2, 3, 4, 5])
|
|
pool.put("1d_array", arr)
|
|
|
|
retrieved = pool.get("1d_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert retrieved.dtype == arr.dtype
|
|
|
|
def test_2d_array(self):
|
|
"""测试二维数组存取"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
pool.put("2d_array", arr)
|
|
|
|
retrieved = pool.get("2d_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert retrieved.shape == (3, 3)
|
|
|
|
def test_3d_array(self):
|
|
"""测试三维数组存取"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.arange(24).reshape(2, 3, 4)
|
|
pool.put("3d_array", arr)
|
|
|
|
retrieved = pool.get("3d_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert retrieved.shape == (2, 3, 4)
|
|
|
|
def test_multidimensional_array(self):
|
|
"""测试高维数组存取"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 4维数组
|
|
arr = np.random.rand(2, 3, 4, 5)
|
|
pool.put("4d_array", arr)
|
|
|
|
retrieved = pool.get("4d_array")
|
|
np.testing.assert_array_almost_equal(retrieved, arr)
|
|
assert retrieved.shape == (2, 3, 4, 5)
|
|
|
|
|
|
class TestNDDataTypes:
|
|
"""测试不同数据类型的 NumPy 数组"""
|
|
|
|
def test_integer_dtypes(self):
|
|
"""测试整数类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
dtypes = [
|
|
np.int8,
|
|
np.int16,
|
|
np.int32,
|
|
np.int64,
|
|
np.uint8,
|
|
np.uint16,
|
|
np.uint32,
|
|
np.uint64,
|
|
]
|
|
|
|
for dtype in dtypes:
|
|
arr = np.array([1, 2, 3], dtype=dtype)
|
|
key = f"int_array_{dtype.__name__}"
|
|
pool.put(key, arr)
|
|
|
|
retrieved = pool.get(key)
|
|
assert retrieved.dtype == dtype
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
def test_float_dtypes(self):
|
|
"""测试浮点类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
dtypes = [np.float32, np.float64]
|
|
|
|
for dtype in dtypes:
|
|
arr = np.array([1.1, 2.2, 3.3], dtype=dtype)
|
|
key = f"float_array_{dtype.__name__}"
|
|
pool.put(key, arr)
|
|
|
|
retrieved = pool.get(key)
|
|
assert retrieved.dtype == dtype
|
|
np.testing.assert_array_almost_equal(retrieved, arr)
|
|
|
|
def test_complex_dtypes(self):
|
|
"""测试复数类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
dtypes = [np.complex64, np.complex128]
|
|
|
|
for dtype in dtypes:
|
|
arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=dtype)
|
|
key = f"complex_array_{dtype.__name__}"
|
|
pool.put(key, arr)
|
|
|
|
retrieved = pool.get(key)
|
|
assert retrieved.dtype == dtype
|
|
np.testing.assert_array_almost_equal(retrieved, arr)
|
|
|
|
def test_bool_dtype(self):
|
|
"""测试布尔类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([True, False, True, True, False], dtype=np.bool_)
|
|
pool.put("bool_array", arr)
|
|
|
|
retrieved = pool.get("bool_array")
|
|
assert retrieved.dtype == np.bool_
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
def test_string_dtype(self):
|
|
"""测试字符串类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# Unicode 字符串
|
|
arr = np.array(["hello", "world", "mpsp"])
|
|
pool.put("string_array", arr)
|
|
|
|
retrieved = pool.get("string_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
def test_object_dtype(self):
|
|
"""测试对象类型数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 对象数组可以存储不同类型的数据
|
|
arr = np.array([1, "string", 3.14, [1, 2, 3]], dtype=object)
|
|
pool.put("object_array", arr)
|
|
|
|
retrieved = pool.get("object_array")
|
|
assert retrieved.dtype == object
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
|
|
class TestNDCrossProcess:
|
|
"""测试 NumPy 数组跨进程共享"""
|
|
|
|
def test_1d_array_cross_process(self):
|
|
"""测试一维数组跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([10, 20, 30, 40, 50])
|
|
pool.put("shared_1d", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_get_array, args=("shared_1d", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
np.testing.assert_array_equal(result, arr)
|
|
|
|
def test_2d_array_cross_process(self):
|
|
"""测试二维数组跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([[1, 2], [3, 4], [5, 6]])
|
|
pool.put("shared_2d", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_get_array, args=("shared_2d", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
np.testing.assert_array_equal(result, arr)
|
|
|
|
def test_array_properties_cross_process(self):
|
|
"""测试数组属性跨进程保持"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.arange(12).reshape(3, 4).astype(np.float32)
|
|
pool.put("property_test", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_check_array_properties,
|
|
args=("property_test", (3, 4), np.float32, result_queue),
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
shape_match, dtype_match, ndim = result_queue.get()
|
|
assert shape_match
|
|
assert dtype_match
|
|
assert ndim == 2
|
|
|
|
def test_array_operations_cross_process(self):
|
|
"""测试在子进程中执行数组操作"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
pool.put("sum_test", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_sum_array, args=("sum_test", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 55 # sum of 1..10
|
|
|
|
|
|
class TestNDSpecialArrays:
|
|
"""测试特殊类型的 NumPy 数组"""
|
|
|
|
def test_empty_array(self):
|
|
"""测试空数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([])
|
|
pool.put("empty_array", arr)
|
|
|
|
retrieved = pool.get("empty_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert len(retrieved) == 0
|
|
|
|
def test_single_element_array(self):
|
|
"""测试单元素数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([42])
|
|
pool.put("single_element", arr)
|
|
|
|
retrieved = pool.get("single_element")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert retrieved[0] == 42
|
|
|
|
def test_zeros_array(self):
|
|
"""测试零数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.zeros((5, 5))
|
|
pool.put("zeros_array", arr)
|
|
|
|
retrieved = pool.get("zeros_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert np.all(retrieved == 0)
|
|
|
|
def test_ones_array(self):
|
|
"""测试全一数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.ones((3, 4))
|
|
pool.put("ones_array", arr)
|
|
|
|
retrieved = pool.get("ones_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert np.all(retrieved == 1)
|
|
|
|
def test_eye_array(self):
|
|
"""测试单位矩阵"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.eye(5)
|
|
pool.put("eye_array", arr)
|
|
|
|
retrieved = pool.get("eye_array")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
assert np.all(np.diag(retrieved) == 1)
|
|
|
|
def test_nan_and_inf_array(self):
|
|
"""测试包含 NaN 和 Inf 的数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([1.0, np.nan, np.inf, -np.inf, 2.0])
|
|
pool.put("special_values", arr)
|
|
|
|
retrieved = pool.get("special_values")
|
|
assert np.isnan(retrieved[1])
|
|
assert np.isinf(retrieved[2]) and retrieved[2] > 0
|
|
assert np.isinf(retrieved[3]) and retrieved[3] < 0
|
|
|
|
def test_masked_array(self):
|
|
"""测试掩码数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
data = np.array([1, 2, 3, 4, 5])
|
|
mask = np.array([False, True, False, True, False])
|
|
arr = np.ma.array(data, mask=mask)
|
|
|
|
pool.put("masked_array", arr)
|
|
|
|
retrieved = pool.get("masked_array")
|
|
np.testing.assert_array_equal(retrieved.data, data)
|
|
np.testing.assert_array_equal(retrieved.mask, mask)
|
|
|
|
|
|
class TestNDLargeArrays:
|
|
"""测试大型 NumPy 数组"""
|
|
|
|
def test_large_1d_array(self):
|
|
"""测试大型一维数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 10000 个元素的数组
|
|
arr = np.arange(10000)
|
|
pool.put("large_1d", arr)
|
|
|
|
retrieved = pool.get("large_1d")
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
def test_large_2d_array(self):
|
|
"""测试大型二维数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 1000x100 的数组
|
|
arr = np.random.rand(1000, 100)
|
|
pool.put("large_2d", arr)
|
|
|
|
retrieved = pool.get("large_2d")
|
|
np.testing.assert_array_almost_equal(retrieved, arr)
|
|
|
|
def test_large_array_cross_process(self):
|
|
"""测试大型数组跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.arange(100000).reshape(1000, 100)
|
|
pool.put("large_cross", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_sum_array, args=("large_cross", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
expected_sum = np.sum(arr)
|
|
assert result == expected_sum
|
|
|
|
|
|
class TestNDStructuredArrays:
|
|
"""测试结构化数组"""
|
|
|
|
def test_structured_array(self):
|
|
"""测试结构化数组存取"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f4")])
|
|
arr = np.array(
|
|
[("Alice", 25, 55.5), ("Bob", 30, 85.3), ("Charlie", 35, 75.0)], dtype=dt
|
|
)
|
|
|
|
pool.put("structured_array", arr)
|
|
|
|
retrieved = pool.get("structured_array")
|
|
assert retrieved.dtype == dt
|
|
np.testing.assert_array_equal(retrieved, arr)
|
|
|
|
def test_structured_array_cross_process(self):
|
|
"""测试结构化数组跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
dt = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")])
|
|
arr = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dt)
|
|
|
|
pool.put("structured_cross", arr)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_get_array, args=("structured_cross", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
np.testing.assert_array_equal(result, arr)
|
|
|
|
|
|
class TestNDMatrixOperations:
|
|
"""测试矩阵操作相关的数组"""
|
|
|
|
def test_matrix_multiplication(self):
|
|
"""测试矩阵乘法用的数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
A = np.array([[1, 2], [3, 4]])
|
|
B = np.array([[5, 6], [7, 8]])
|
|
|
|
pool.put("matrix_A", A)
|
|
pool.put("matrix_B", B)
|
|
|
|
retrieved_A = pool.get("matrix_A")
|
|
retrieved_B = pool.get("matrix_B")
|
|
|
|
result = np.dot(retrieved_A, retrieved_B)
|
|
expected = np.array([[19, 22], [43, 50]])
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
def test_eigenvalue_computation(self):
|
|
"""测试特征值计算"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 对称矩阵
|
|
arr = np.array([[4, 2], [2, 4]])
|
|
pool.put("eigen_matrix", arr)
|
|
|
|
retrieved = pool.get("eigen_matrix")
|
|
eigenvalues, eigenvectors = np.linalg.eig(retrieved)
|
|
|
|
# 特征值应该是 6 和 2
|
|
assert np.allclose(sorted(eigenvalues), [2, 6])
|
|
|
|
def test_svd_decomposition(self):
|
|
"""测试 SVD 分解"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
pool.put("svd_matrix", arr)
|
|
|
|
retrieved = pool.get("svd_matrix")
|
|
U, S, Vh = np.linalg.svd(retrieved)
|
|
|
|
# 验证分解结果
|
|
reconstructed = U @ np.diag(S) @ Vh
|
|
np.testing.assert_array_almost_equal(reconstructed, arr)
|
|
|
|
|
|
class TestNDBroadcasting:
|
|
"""测试广播机制"""
|
|
|
|
def test_broadcasting_operation(self):
|
|
"""测试广播操作"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
arr_2d = np.array([[1, 2, 3], [4, 5, 6]])
|
|
arr_1d = np.array([10, 20, 30])
|
|
|
|
pool.put("array_2d", arr_2d)
|
|
pool.put("array_1d", arr_1d)
|
|
|
|
retrieved_2d = pool.get("array_2d")
|
|
retrieved_1d = pool.get("array_1d")
|
|
|
|
result = retrieved_2d + retrieved_1d
|
|
expected = np.array([[11, 22, 33], [14, 25, 36]])
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
|
|
class TestNDMixedTypes:
|
|
"""测试混合数据类型的数组相关操作"""
|
|
|
|
def test_array_in_dict(self):
|
|
"""测试字典中包含数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
data = {
|
|
"matrix": np.array([[1, 2], [3, 4]]),
|
|
"vector": np.array([1, 2, 3]),
|
|
"scalar": 42,
|
|
"name": "test",
|
|
}
|
|
|
|
pool.put("dict_with_arrays", data)
|
|
|
|
retrieved = pool.get("dict_with_arrays")
|
|
np.testing.assert_array_equal(retrieved["matrix"], data["matrix"])
|
|
np.testing.assert_array_equal(retrieved["vector"], data["vector"])
|
|
assert retrieved["scalar"] == 42
|
|
assert retrieved["name"] == "test"
|
|
|
|
def test_array_in_list(self):
|
|
"""测试列表中包含数组"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
data = [np.array([1, 2, 3]), np.array([[4, 5], [6, 7]]), "string", 42]
|
|
|
|
pool.put("list_with_arrays", data)
|
|
|
|
retrieved = pool.get("list_with_arrays")
|
|
np.testing.assert_array_equal(retrieved[0], data[0])
|
|
np.testing.assert_array_equal(retrieved[1], data[1])
|
|
assert retrieved[2] == "string"
|
|
assert retrieved[3] == 42
|