Initial commit.
This commit is contained in:
283
tests/test_core.py
Normal file
283
tests/test_core.py
Normal file
@ -0,0 +1,283 @@
|
||||
"""
|
||||
Unit tests for the MultiProcessingSharedPool core functionality.
|
||||
|
||||
These tests verify the singleton pattern, basic CRUD operations,
|
||||
error handling, and dictionary-like interface.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from easy_mp_share import MultiProcessingSharedPool, shared_pool
|
||||
|
||||
|
||||
class TestMultiProcessingSharedPool:
|
||||
"""Test cases for MultiProcessingSharedPool class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
# Get a fresh instance for each test
|
||||
self.pool = MultiProcessingSharedPool()
|
||||
self.pool.clear() # Clear any existing data
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""Test that MultiProcessingSharedPool follows singleton pattern."""
|
||||
pool1 = MultiProcessingSharedPool()
|
||||
pool2 = MultiProcessingSharedPool()
|
||||
|
||||
# Should be the same instance
|
||||
assert pool1 is pool2
|
||||
|
||||
# get_instance should also return the same instance
|
||||
pool3 = MultiProcessingSharedPool.get_instance()
|
||||
assert pool1 is pool3
|
||||
|
||||
def test_basic_put_and_get(self):
|
||||
"""Test basic put and get operations."""
|
||||
# Test string data
|
||||
assert self.pool.put("test_key", "test_value")
|
||||
assert self.pool.get("test_key") == "test_value"
|
||||
|
||||
# Test numeric data
|
||||
assert self.pool.put("number", 42)
|
||||
assert self.pool.get("number") == 42
|
||||
|
||||
# Test list data
|
||||
test_list = [1, 2, 3, "four"]
|
||||
assert self.pool.put("my_list", test_list)
|
||||
assert self.pool.get("my_list") == test_list
|
||||
|
||||
# Test dict data
|
||||
test_dict = {"key1": "value1", "key2": 123}
|
||||
assert self.pool.put("my_dict", test_dict)
|
||||
assert self.pool.get("my_dict") == test_dict
|
||||
|
||||
def test_get_with_default(self):
|
||||
"""Test get operation with default values."""
|
||||
# Non-existent key should return None by default
|
||||
assert self.pool.get("non_existent") is None
|
||||
|
||||
# Non-existent key should return provided default
|
||||
assert self.pool.get("non_existent", "default_value") == "default_value"
|
||||
assert self.pool.get("non_existent", 42) == 42
|
||||
|
||||
# Existent key should return actual value, not default
|
||||
self.pool.put("existing", "actual_value")
|
||||
assert self.pool.get("existing", "default") == "actual_value"
|
||||
|
||||
def test_exists_method(self):
|
||||
"""Test the exists method."""
|
||||
assert not self.pool.exists("non_existent")
|
||||
|
||||
self.pool.put("existing_key", "value")
|
||||
assert self.pool.exists("existing_key")
|
||||
|
||||
# Test with empty string key
|
||||
self.pool.put("", "empty_key_value")
|
||||
assert self.pool.exists("")
|
||||
|
||||
# Test with special characters
|
||||
self.pool.put("key!@#$%", "special_chars")
|
||||
assert self.pool.exists("key!@#$%")
|
||||
|
||||
def test_remove_method(self):
|
||||
"""Test the remove method."""
|
||||
# Remove non-existent key should return False
|
||||
assert not self.pool.remove("non_existent")
|
||||
|
||||
# Remove existing key should return True
|
||||
self.pool.put("to_remove", "value")
|
||||
assert self.pool.remove("to_remove")
|
||||
assert not self.pool.exists("to_remove")
|
||||
|
||||
# Should not be able to get removed key
|
||||
assert self.pool.get("to_remove") is None
|
||||
|
||||
def test_pop_method(self):
|
||||
"""Test the pop method."""
|
||||
# Pop non-existent key should return default
|
||||
assert self.pool.pop("non_existent") is None
|
||||
assert self.pool.pop("non_existent", "default") == "default"
|
||||
|
||||
# Pop existing key should return value and remove it
|
||||
self.pool.put("to_pop", "popped_value")
|
||||
result = self.pool.pop("to_pop")
|
||||
assert result == "popped_value"
|
||||
assert not self.pool.exists("to_pop")
|
||||
|
||||
# Second pop should return default
|
||||
assert self.pool.pop("to_pop", "gone") == "gone"
|
||||
|
||||
def test_size_method(self):
|
||||
"""Test the size method."""
|
||||
assert self.pool.size() == 0
|
||||
|
||||
self.pool.put("key1", "value1")
|
||||
assert self.pool.size() == 1
|
||||
|
||||
self.pool.put("key2", "value2")
|
||||
assert self.pool.size() == 2
|
||||
|
||||
self.pool.remove("key1")
|
||||
assert self.pool.size() == 1
|
||||
|
||||
self.pool.clear()
|
||||
assert self.pool.size() == 0
|
||||
|
||||
def test_keys_method(self):
|
||||
"""Test the keys method."""
|
||||
assert self.pool.keys() == []
|
||||
|
||||
# Add some keys
|
||||
self.pool.put("first", "value1")
|
||||
self.pool.put("second", "value2")
|
||||
self.pool.put("third", "value3")
|
||||
|
||||
keys = self.pool.keys()
|
||||
assert len(keys) == 3
|
||||
assert "first" in keys
|
||||
assert "second" in keys
|
||||
assert "third" in keys
|
||||
|
||||
# Remove a key
|
||||
self.pool.remove("second")
|
||||
keys = self.pool.keys()
|
||||
assert len(keys) == 2
|
||||
assert "second" not in keys
|
||||
|
||||
def test_clear_method(self):
|
||||
"""Test the clear method."""
|
||||
self.pool.put("key1", "value1")
|
||||
self.pool.put("key2", "value2")
|
||||
assert self.pool.size() == 2
|
||||
|
||||
self.pool.clear()
|
||||
assert self.pool.size() == 0
|
||||
assert self.pool.keys() == []
|
||||
assert self.pool.get("key1") is None
|
||||
assert self.pool.get("key2") is None
|
||||
|
||||
def test_dictionary_interface(self):
|
||||
"""Test dictionary-like interface."""
|
||||
# Test __setitem__ and __getitem__
|
||||
self.pool["dict_key"] = "dict_value"
|
||||
assert self.pool["dict_key"] == "dict_value"
|
||||
|
||||
# Test __contains__
|
||||
assert "dict_key" in self.pool
|
||||
assert "non_existent" not in self.pool
|
||||
|
||||
# Test __delitem__
|
||||
del self.pool["dict_key"]
|
||||
assert "dict_key" not in self.pool
|
||||
|
||||
# Test KeyError for non-existent key
|
||||
with pytest.raises(KeyError):
|
||||
_ = self.pool["non_existent"]
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
del self.pool["non_existent"]
|
||||
|
||||
def test_type_validation(self):
|
||||
"""Test type validation for keys."""
|
||||
# Non-string keys should raise TypeError for put, get, and pop
|
||||
with pytest.raises(TypeError):
|
||||
self.pool.put(123, "value")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
self.pool.get(123)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
self.pool.pop(123)
|
||||
|
||||
# Note: __contains__ doesn't raise TypeError as it delegates to exists()
|
||||
# which doesn't have type validation for performance reasons
|
||||
# This is the current behavior of the implementation
|
||||
result = 123 in self.pool # Should return False, not raise
|
||||
assert result is False
|
||||
|
||||
def test_context_manager(self):
|
||||
"""Test context manager interface."""
|
||||
with MultiProcessingSharedPool() as pool:
|
||||
pool.put("context_key", "context_value")
|
||||
assert pool.get("context_key") == "context_value"
|
||||
|
||||
# Pool should still be usable after context
|
||||
pool = MultiProcessingSharedPool()
|
||||
assert pool.get("context_key") == "context_value"
|
||||
|
||||
def test_shared_pool_convenience_instance(self):
|
||||
"""Test the shared_pool convenience instance."""
|
||||
# Should be the same as MultiProcessingSharedPool.get_instance()
|
||||
pool = MultiProcessingSharedPool.get_instance()
|
||||
assert shared_pool is pool
|
||||
|
||||
# Should work like any other instance
|
||||
shared_pool.put("convenience_key", "convenience_value")
|
||||
assert shared_pool.get("convenience_key") == "convenience_value"
|
||||
|
||||
# Should be accessible from other instances
|
||||
pool = MultiProcessingSharedPool()
|
||||
assert pool.get("convenience_key") == "convenience_value"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(mp, 'Process') or mp.get_start_method() != 'spawn',
|
||||
reason="Multiprocessing test requires spawn method"
|
||||
)
|
||||
def test_multiprocess_access(self):
|
||||
"""Test that the pool works across multiple processes."""
|
||||
def worker_function(results_queue: mp.Queue, worker_id: int):
|
||||
"""Worker function for multiprocess testing."""
|
||||
pool = MultiProcessingSharedPool()
|
||||
|
||||
# Put some data
|
||||
pool.put(f"worker_{worker_id}_data", f"value_from_worker_{worker_id}")
|
||||
|
||||
# Get data from main process
|
||||
main_data = pool.get("main_data")
|
||||
|
||||
# Get data from previous workers
|
||||
if worker_id > 1:
|
||||
prev_data = pool.get(f"worker_{worker_id-1}_data")
|
||||
results_queue.put(prev_data)
|
||||
|
||||
results_queue.put(main_data)
|
||||
results_queue.put(f"worker_{worker_id}_data")
|
||||
|
||||
# Put initial data from main process
|
||||
self.pool.put("main_data", "value_from_main")
|
||||
|
||||
results_queue = mp.Queue()
|
||||
processes = []
|
||||
|
||||
# Start multiple worker processes
|
||||
for i in range(1, 4):
|
||||
p = mp.Process(target=worker_function, args=(results_queue, i))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Collect results
|
||||
results = []
|
||||
while not results_queue.empty():
|
||||
results.append(results_queue.get())
|
||||
|
||||
# Verify data sharing worked
|
||||
assert "value_from_main" in results
|
||||
assert "worker_1_data" in results
|
||||
assert "worker_2_data" in results
|
||||
assert "worker_3_data" in results
|
||||
|
||||
# Verify data is accessible from main process
|
||||
assert self.pool.get("worker_1_data") == "value_from_worker_1"
|
||||
assert self.pool.get("worker_2_data") == "value_from_worker_2"
|
||||
assert self.pool.get("worker_3_data") == "value_from_worker_3"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user