284 lines
9.8 KiB
Python
284 lines
9.8 KiB
Python
"""
|
|
Unit tests for the MultiProcessingSharedPool core functionality.
|
|
|
|
These tests verify the singleton pattern, basic CRUD operations,
|
|
error handling, and dictionary-like interface.
|
|
"""
|
|
|
|
import pytest
|
|
import multiprocessing as mp
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
from easy_mp_share import MultiProcessingSharedPool, shared_pool
|
|
|
|
|
|
class TestMultiProcessingSharedPool:
|
|
"""Test cases for MultiProcessingSharedPool class."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures before each test method."""
|
|
# Get a fresh instance for each test
|
|
self.pool = MultiProcessingSharedPool()
|
|
self.pool.clear() # Clear any existing data
|
|
|
|
def test_singleton_pattern(self):
|
|
"""Test that MultiProcessingSharedPool follows singleton pattern."""
|
|
pool1 = MultiProcessingSharedPool()
|
|
pool2 = MultiProcessingSharedPool()
|
|
|
|
# Should be the same instance
|
|
assert pool1 is pool2
|
|
|
|
# get_instance should also return the same instance
|
|
pool3 = MultiProcessingSharedPool.get_instance()
|
|
assert pool1 is pool3
|
|
|
|
def test_basic_put_and_get(self):
|
|
"""Test basic put and get operations."""
|
|
# Test string data
|
|
assert self.pool.put("test_key", "test_value")
|
|
assert self.pool.get("test_key") == "test_value"
|
|
|
|
# Test numeric data
|
|
assert self.pool.put("number", 42)
|
|
assert self.pool.get("number") == 42
|
|
|
|
# Test list data
|
|
test_list = [1, 2, 3, "four"]
|
|
assert self.pool.put("my_list", test_list)
|
|
assert self.pool.get("my_list") == test_list
|
|
|
|
# Test dict data
|
|
test_dict = {"key1": "value1", "key2": 123}
|
|
assert self.pool.put("my_dict", test_dict)
|
|
assert self.pool.get("my_dict") == test_dict
|
|
|
|
def test_get_with_default(self):
|
|
"""Test get operation with default values."""
|
|
# Non-existent key should return None by default
|
|
assert self.pool.get("non_existent") is None
|
|
|
|
# Non-existent key should return provided default
|
|
assert self.pool.get("non_existent", "default_value") == "default_value"
|
|
assert self.pool.get("non_existent", 42) == 42
|
|
|
|
# Existent key should return actual value, not default
|
|
self.pool.put("existing", "actual_value")
|
|
assert self.pool.get("existing", "default") == "actual_value"
|
|
|
|
def test_exists_method(self):
|
|
"""Test the exists method."""
|
|
assert not self.pool.exists("non_existent")
|
|
|
|
self.pool.put("existing_key", "value")
|
|
assert self.pool.exists("existing_key")
|
|
|
|
# Test with empty string key
|
|
self.pool.put("", "empty_key_value")
|
|
assert self.pool.exists("")
|
|
|
|
# Test with special characters
|
|
self.pool.put("key!@#$%", "special_chars")
|
|
assert self.pool.exists("key!@#$%")
|
|
|
|
def test_remove_method(self):
|
|
"""Test the remove method."""
|
|
# Remove non-existent key should return False
|
|
assert not self.pool.remove("non_existent")
|
|
|
|
# Remove existing key should return True
|
|
self.pool.put("to_remove", "value")
|
|
assert self.pool.remove("to_remove")
|
|
assert not self.pool.exists("to_remove")
|
|
|
|
# Should not be able to get removed key
|
|
assert self.pool.get("to_remove") is None
|
|
|
|
def test_pop_method(self):
|
|
"""Test the pop method."""
|
|
# Pop non-existent key should return default
|
|
assert self.pool.pop("non_existent") is None
|
|
assert self.pool.pop("non_existent", "default") == "default"
|
|
|
|
# Pop existing key should return value and remove it
|
|
self.pool.put("to_pop", "popped_value")
|
|
result = self.pool.pop("to_pop")
|
|
assert result == "popped_value"
|
|
assert not self.pool.exists("to_pop")
|
|
|
|
# Second pop should return default
|
|
assert self.pool.pop("to_pop", "gone") == "gone"
|
|
|
|
def test_size_method(self):
|
|
"""Test the size method."""
|
|
assert self.pool.size() == 0
|
|
|
|
self.pool.put("key1", "value1")
|
|
assert self.pool.size() == 1
|
|
|
|
self.pool.put("key2", "value2")
|
|
assert self.pool.size() == 2
|
|
|
|
self.pool.remove("key1")
|
|
assert self.pool.size() == 1
|
|
|
|
self.pool.clear()
|
|
assert self.pool.size() == 0
|
|
|
|
def test_keys_method(self):
|
|
"""Test the keys method."""
|
|
assert self.pool.keys() == []
|
|
|
|
# Add some keys
|
|
self.pool.put("first", "value1")
|
|
self.pool.put("second", "value2")
|
|
self.pool.put("third", "value3")
|
|
|
|
keys = self.pool.keys()
|
|
assert len(keys) == 3
|
|
assert "first" in keys
|
|
assert "second" in keys
|
|
assert "third" in keys
|
|
|
|
# Remove a key
|
|
self.pool.remove("second")
|
|
keys = self.pool.keys()
|
|
assert len(keys) == 2
|
|
assert "second" not in keys
|
|
|
|
def test_clear_method(self):
|
|
"""Test the clear method."""
|
|
self.pool.put("key1", "value1")
|
|
self.pool.put("key2", "value2")
|
|
assert self.pool.size() == 2
|
|
|
|
self.pool.clear()
|
|
assert self.pool.size() == 0
|
|
assert self.pool.keys() == []
|
|
assert self.pool.get("key1") is None
|
|
assert self.pool.get("key2") is None
|
|
|
|
def test_dictionary_interface(self):
|
|
"""Test dictionary-like interface."""
|
|
# Test __setitem__ and __getitem__
|
|
self.pool["dict_key"] = "dict_value"
|
|
assert self.pool["dict_key"] == "dict_value"
|
|
|
|
# Test __contains__
|
|
assert "dict_key" in self.pool
|
|
assert "non_existent" not in self.pool
|
|
|
|
# Test __delitem__
|
|
del self.pool["dict_key"]
|
|
assert "dict_key" not in self.pool
|
|
|
|
# Test KeyError for non-existent key
|
|
with pytest.raises(KeyError):
|
|
_ = self.pool["non_existent"]
|
|
|
|
with pytest.raises(KeyError):
|
|
del self.pool["non_existent"]
|
|
|
|
def test_type_validation(self):
|
|
"""Test type validation for keys."""
|
|
# Non-string keys should raise TypeError for put, get, and pop
|
|
with pytest.raises(TypeError):
|
|
self.pool.put(123, "value")
|
|
|
|
with pytest.raises(TypeError):
|
|
self.pool.get(123)
|
|
|
|
with pytest.raises(TypeError):
|
|
self.pool.pop(123)
|
|
|
|
# Note: __contains__ doesn't raise TypeError as it delegates to exists()
|
|
# which doesn't have type validation for performance reasons
|
|
# This is the current behavior of the implementation
|
|
result = 123 in self.pool # Should return False, not raise
|
|
assert result is False
|
|
|
|
def test_context_manager(self):
|
|
"""Test context manager interface."""
|
|
with MultiProcessingSharedPool() as pool:
|
|
pool.put("context_key", "context_value")
|
|
assert pool.get("context_key") == "context_value"
|
|
|
|
# Pool should still be usable after context
|
|
pool = MultiProcessingSharedPool()
|
|
assert pool.get("context_key") == "context_value"
|
|
|
|
def test_shared_pool_convenience_instance(self):
|
|
"""Test the shared_pool convenience instance."""
|
|
# Should be the same as MultiProcessingSharedPool.get_instance()
|
|
pool = MultiProcessingSharedPool.get_instance()
|
|
assert shared_pool is pool
|
|
|
|
# Should work like any other instance
|
|
shared_pool.put("convenience_key", "convenience_value")
|
|
assert shared_pool.get("convenience_key") == "convenience_value"
|
|
|
|
# Should be accessible from other instances
|
|
pool = MultiProcessingSharedPool()
|
|
assert pool.get("convenience_key") == "convenience_value"
|
|
|
|
@pytest.mark.skipif(
|
|
not hasattr(mp, 'Process') or mp.get_start_method() != 'spawn',
|
|
reason="Multiprocessing test requires spawn method"
|
|
)
|
|
def test_multiprocess_access(self):
|
|
"""Test that the pool works across multiple processes."""
|
|
def worker_function(results_queue: mp.Queue, worker_id: int):
|
|
"""Worker function for multiprocess testing."""
|
|
pool = MultiProcessingSharedPool()
|
|
|
|
# Put some data
|
|
pool.put(f"worker_{worker_id}_data", f"value_from_worker_{worker_id}")
|
|
|
|
# Get data from main process
|
|
main_data = pool.get("main_data")
|
|
|
|
# Get data from previous workers
|
|
if worker_id > 1:
|
|
prev_data = pool.get(f"worker_{worker_id-1}_data")
|
|
results_queue.put(prev_data)
|
|
|
|
results_queue.put(main_data)
|
|
results_queue.put(f"worker_{worker_id}_data")
|
|
|
|
# Put initial data from main process
|
|
self.pool.put("main_data", "value_from_main")
|
|
|
|
results_queue = mp.Queue()
|
|
processes = []
|
|
|
|
# Start multiple worker processes
|
|
for i in range(1, 4):
|
|
p = mp.Process(target=worker_function, args=(results_queue, i))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
# Wait for all processes to complete
|
|
for p in processes:
|
|
p.join()
|
|
|
|
# Collect results
|
|
results = []
|
|
while not results_queue.empty():
|
|
results.append(results_queue.get())
|
|
|
|
# Verify data sharing worked
|
|
assert "value_from_main" in results
|
|
assert "worker_1_data" in results
|
|
assert "worker_2_data" in results
|
|
assert "worker_3_data" in results
|
|
|
|
# Verify data is accessible from main process
|
|
assert self.pool.get("worker_1_data") == "value_from_worker_1"
|
|
assert self.pool.get("worker_2_data") == "value_from_worker_2"
|
|
assert self.pool.get("worker_3_data") == "value_from_worker_3"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|