"""
========================================
The :mod:`array_split.split_test` Module
========================================
.. currentmodule:: array_split.split_test
Module defining :mod:`array_split.split` unit-tests.
Execute as::
python -m array_split.split_tests
Classes
=======
.. autosummary::
:toctree: generated/
SplitTest - :obj:`unittest.TestCase` for :mod:`array_split.split` functions.
"""
from __future__ import absolute_import
from .license import license as _license, copyright as _copyright
import array_split.unittest as _unittest
import array_split.logging as _logging
import array_split as _array_split
import numpy as _np
from .split import ShapeSplitter, array_split, shape_split
from .split import calculate_num_slices_per_axis, shape_factors
from .split import calculate_tile_shape_for_max_bytes
from .split import ARRAY_BOUNDS, NO_BOUNDS
__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _array_split.__version__
[docs]class SplitTest(_unittest.TestCase):
"""
:obj:`unittest.TestCase` for :mod:`array_split.split` functions.
"""
#: Class attribute for :obj:`logging.Logger` logging.
logger = _logging.getLogger(__name__ + ".SplitTest")
[docs] def test_shape_factors(self):
"""
Tests for :func:`array_split.split.shape_factors`.
"""
f = shape_factors(4, 2)
self.assertTrue(_np.all(f == 2))
f = shape_factors(4, 1)
self.assertTrue(_np.all(f == 4))
f = shape_factors(5, 2)
self.assertTrue(_np.all(f == [1, 5]))
f = shape_factors(6, 2)
self.assertTrue(_np.all(f == [2, 3]))
f = shape_factors(6, 3)
self.assertTrue(_np.all(f == [1, 2, 3]))
[docs] def test_calculate_num_slices_per_axis(self):
"""
Tests for :func:`array_split.split.calculate_num_slices_per_axis`.
"""
spa = calculate_num_slices_per_axis([0, ], 5)
self.assertEqual(1, len(spa))
self.assertTrue(_np.all(spa == 5))
spa = calculate_num_slices_per_axis([2, 0], 4)
self.assertEqual(2, len(spa))
self.assertTrue(_np.all(spa == 2))
spa = calculate_num_slices_per_axis([0, 2], 4)
self.assertEqual(2, len(spa))
self.assertTrue(_np.all(spa == 2))
spa = calculate_num_slices_per_axis([0, 0], 4)
self.assertEqual(2, len(spa))
self.assertTrue(_np.all(spa == 2))
spa = calculate_num_slices_per_axis([0, 0], 16)
self.assertEqual(2, len(spa))
self.assertTrue(_np.all(spa == 4))
spa = calculate_num_slices_per_axis([0, 0, 0], 8)
self.assertEqual(3, len(spa))
self.assertTrue(_np.all(spa == 2))
spa = calculate_num_slices_per_axis([0, 1, 0], 8)
self.assertEqual(3, len(spa))
self.assertTrue(_np.all(spa == [4, 1, 2]))
spa = calculate_num_slices_per_axis([0, 1, 0], 17)
self.assertEqual(3, len(spa))
self.assertTrue(_np.all(spa == [17, 1, 1]))
spa = calculate_num_slices_per_axis([0, 1, 0], 15, [1, _np.inf, _np.inf])
self.assertEqual(3, len(spa))
self.assertTrue(_np.all(spa == [1, 1, 15]))
[docs] def test_calculate_tile_shape_for_max_bytes_1d(self):
"""
Test case for :func:`array_split.split.calculate_tile_shape_for_max_bytes`,
where :samp:`array_shape` parameter is 1D, i.e. of the form :samp:`(N,)`.
"""
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=1024
)
self.assertSequenceEqual((512,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=1024,
sub_tile_shape=[64, ]
)
self.assertSequenceEqual((512,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=1024,
sub_tile_shape=[26, ]
)
self.assertSequenceEqual((260,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=512
)
self.assertSequenceEqual((512,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=2,
max_tile_bytes=512,
)
self.assertSequenceEqual((256,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=2,
max_tile_bytes=512,
sub_tile_shape=[32, ]
)
self.assertSequenceEqual((256,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=2,
max_tile_bytes=512,
sub_tile_shape=[60, ]
)
self.assertSequenceEqual((180,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=512,
halo=1
)
self.assertSequenceEqual((256,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=1,
max_tile_bytes=514,
halo=1
)
self.assertSequenceEqual((512,), tile_shape)
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512,),
array_itemsize=2,
max_tile_bytes=511
)
self.assertSequenceEqual((171,), tile_shape)
[docs] def test_calculate_tile_shape_for_max_bytes_2d(self):
"""
Test case for :func:`array_split.split.calculate_tile_shape_for_max_bytes`,
where :samp:`array_shape` parameter is 2D, i.e. of the form :samp:`(H,W)`.
"""
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=1,
max_tile_bytes=512**2
)
self.assertSequenceEqual((512, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=1,
max_tile_bytes=512**2 - 1
)
self.assertSequenceEqual((256, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(513, 512),
array_itemsize=1,
max_tile_bytes=512**2 - 1
)
self.assertSequenceEqual((257, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=1,
max_tile_bytes=512**2 // 2
)
self.assertSequenceEqual((256, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=2,
max_tile_bytes=512**2 // 2
)
self.assertSequenceEqual((128, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=2,
max_tile_bytes=512**2 // 2,
sub_tile_shape=(32, 64)
)
self.assertSequenceEqual((128, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=1,
max_tile_bytes=512**2 // 2,
sub_tile_shape=(30, 64)
)
self.assertSequenceEqual((180, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 512),
array_itemsize=2,
max_tile_bytes=512**2 // 2,
sub_tile_shape=(30, 64)
)
self.assertSequenceEqual((120, 512), tile_shape.tolist())
tile_shape = \
calculate_tile_shape_for_max_bytes(
array_shape=(512, 1024),
array_itemsize=1,
max_tile_bytes=512**2,
sub_tile_shape=(30, 60)
)
self.assertSequenceEqual((180, 540), tile_shape.tolist())
[docs] def test_array_split(self):
"""
Test for case for :func:`array_split.split.array_split`.
"""
x = _np.arange(9.0)
self.assertArraySplitEqual(
_np.array_split(x, 3),
array_split(x, 3)
)
self.assertArraySplitEqual(
_np.array_split(x, 4),
array_split(x, 4)
)
idx = [2, 3, 5, ]
self.assertArraySplitEqual(
_np.array_split(x, idx),
array_split(x, idx)
)
x = _np.arange(32)
x = x.reshape((4, 8))
self.logger.info("_np.array_split(x, 3, axis=0) = \n%s", _np.array_split(x, 3, axis=0))
self.logger.info(
"array_split.split.array_split(x, 3, axis=0) = \n%s", array_split(x, 3, axis=0)
)
self.assertArraySplitEqual(
_np.array_split(x, 3, axis=0),
array_split(x, 3, axis=0)
)
self.logger.info("_np.array_split(x, 3, axis=1) = \n%s", _np.array_split(x, 3, axis=1))
self.logger.info(
"array_split.split.array_split(x, 3, axis=1) = \n%s", array_split(x, 3, axis=1)
)
self.assertArraySplitEqual(
_np.array_split(x, 3, axis=1),
array_split(x, 3, axis=1)
)
self.logger.info("_np.array_split(x, 8, axis=0) = \n%s", _np.array_split(x, 8, axis=0))
self.assertArraySplitEqual(
_np.array_split(x, 8, axis=0),
array_split(x, 8, axis=0)
)
x = _np.arange(0, 64)
x = x.reshape((4, 16))
self.assertArraySplitEqual(
_np.array_split(x, [3, 8, 12], axis=1),
array_split(x, [3, 8, 12], axis=1)
)
x = _np.arange(0, 512, dtype="int16")
self.assertArraySplitEqual(
[_np.arange(0, 256), _np.arange(256, 512)],
array_split(x, max_tile_bytes=512)
)
[docs] def test_split_by_per_axis_indices(self):
"""
Test for case for splitting by specified
indices::
ShapeSplitter(array_shape=(10, 4), indices_or_sections=[[2, 6, 8], ]).calculate_split()
"""
splitter = ShapeSplitter((10, 4), [[2, 6, 8], ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [4, 1]))
self.assertEqual(slice(0, 2), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(2, 6), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(6, 8), split[2, 0][0]) # axis 0 slice
self.assertEqual(slice(8, 10), split[3, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 4), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(0, 4), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(0, 4), split[2, 0][1]) # axis 1 slice
self.assertEqual(slice(0, 4), split[3, 0][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), [None, [2, 5, 8], ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [1, 4]))
self.assertEqual(slice(0, 10), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 2][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 3][0]) # axis 0 slice
self.assertEqual(slice(0, 2), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(2, 5), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(5, 8), split[0, 2][1]) # axis 1 slice
self.assertEqual(slice(8, 13), split[0, 3][1]) # axis 1 slice
splitter = ShapeSplitter((10, 4), [[2, 6], [2, ]])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [3, 2]))
self.assertEqual(slice(0, 2), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(2, 6), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(6, 10), split[2, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 2), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(2, 6), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(6, 10), split[2, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 2), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(0, 2), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(0, 2), split[2, 0][1]) # axis 1 slice
self.assertEqual(slice(2, 4), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(2, 4), split[1, 1][1]) # axis 1 slice
self.assertEqual(slice(2, 4), split[2, 1][1]) # axis 1 slice
splitter = ShapeSplitter((10,), [[2, 6, 8], ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [4, ]))
self.assertEqual(slice(0, 2), split[0][0]) # axis 0 slice
self.assertEqual(slice(2, 6), split[1][0]) # axis 0 slice
self.assertEqual(slice(6, 8), split[2][0]) # axis 0 slice
self.assertEqual(slice(8, 10), split[3][0]) # axis 0 slice
[docs] def test_split_by_num_slices(self):
"""
Test for case for splitting by number of
slice elements::
ShapeSplitter(array_shape=(10, 13), indices_or_sections=3).calculate_split()
ShapeSplitter(array_shape=(10, 13), axis=[2, 3]).calculate_split()
"""
splitter = ShapeSplitter((10,), 3)
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [3, ]))
self.assertEqual(slice(0, 4), split[0][0]) # axis 0 slice
self.assertEqual(slice(4, 7), split[1][0]) # axis 0 slice
self.assertEqual(slice(7, 10), split[2][0]) # axis 0 slice
splitter = ShapeSplitter((10,), axis=[3, ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [3, ]))
self.assertEqual(slice(0, 4), split[0][0]) # axis 0 slice
self.assertEqual(slice(4, 7), split[1][0]) # axis 0 slice
self.assertEqual(slice(7, 10), split[2][0]) # axis 0 slice
splitter = ShapeSplitter((10,), 3, axis=[3, ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [3, ]))
self.assertEqual(slice(0, 4), split[0][0]) # axis 0 slice
self.assertEqual(slice(4, 7), split[1][0]) # axis 0 slice
self.assertEqual(slice(7, 10), split[2][0]) # axis 0 slice
splitter = ShapeSplitter((10,), 3, axis=[0, ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [3, ]))
self.assertEqual(slice(0, 4), split[0][0]) # axis 0 slice
self.assertEqual(slice(4, 7), split[1][0]) # axis 0 slice
self.assertEqual(slice(7, 10), split[2][0]) # axis 0 slice
splitter = ShapeSplitter((10,), 2, axis=[0, ])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, ]))
self.assertEqual(slice(0, 5), split[0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1][0]) # axis 0 slice
splitter = ShapeSplitter((10, 13), 4, axis=[1, 0])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [1, 4]))
self.assertEqual(slice(0, 10), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 2][0]) # axis 0 slice
self.assertEqual(slice(0, 10), split[0, 3][0]) # axis 0 slice
self.assertEqual(slice(0, 4), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(4, 7), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(7, 10), split[0, 2][1]) # axis 1 slice
self.assertEqual(slice(10, 13), split[0, 3][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), axis=[2, 2])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, 2]))
self.assertEqual(slice(0, 5), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 5), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 7), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(0, 7), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[1, 1][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), 4, axis=[2, 2])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, 2]))
self.assertEqual(slice(0, 5), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 5), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 7), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(0, 7), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[1, 1][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), 4, axis=[0, 2])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, 2]))
self.assertEqual(slice(0, 5), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 5), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 7), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(0, 7), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[1, 1][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), 4, axis=[2, 0])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, 2]))
self.assertEqual(slice(0, 5), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 5), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 7), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(0, 7), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[1, 1][1]) # axis 1 slice
splitter = ShapeSplitter((10, 13), 4, axis=[0, 0])
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertTrue(_np.all(_np.array(split.shape) == [2, 2]))
self.assertEqual(slice(0, 5), split[0, 0][0]) # axis 0 slice
self.assertEqual(slice(0, 5), split[0, 1][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 0][0]) # axis 0 slice
self.assertEqual(slice(5, 10), split[1, 1][0]) # axis 0 slice
self.assertEqual(slice(0, 7), split[0, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[0, 1][1]) # axis 1 slice
self.assertEqual(slice(0, 7), split[1, 0][1]) # axis 1 slice
self.assertEqual(slice(7, 13), split[1, 1][1]) # axis 1 slice
[docs] def test_calculate_split_by_tile_shape_1d(self):
splitter = ShapeSplitter((10, ), tile_shape=(3,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((4,), split.shape)
self.assertSequenceEqual(
[(slice(0, 3),), (slice(3, 6),), (slice(6, 9),), (slice(9, 10),)],
split.tolist()
)
splitter = ShapeSplitter((10, ), tile_shape=(4,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((3,), split.shape)
self.assertSequenceEqual(
[(slice(0, 4),), (slice(4, 8),), (slice(8, 10),)],
split.tolist()
)
splitter = ShapeSplitter((10, ), tile_shape=(5,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((2,), split.shape)
self.assertSequenceEqual(
[(slice(0, 5),), (slice(5, 10),)],
split.tolist()
)
splitter = ShapeSplitter((10, ), tile_shape=(10,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((1,), split.shape)
self.assertSequenceEqual(
[(slice(0, 10),)],
split.tolist()
)
splitter = ShapeSplitter((10, ), tile_shape=(11,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((1,), split.shape)
self.assertSequenceEqual(
[(slice(0, 10),)],
split.tolist()
)
[docs] def test_calculate_split_by_tile_shape_2d(self):
splitter = ShapeSplitter((10, 17), tile_shape=(3, 8))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((4, 3), split.shape)
self.assertSequenceEqual(
shape_split(splitter.array_shape, [[3, 6, 9], [8, 16]]).flatten().tolist(),
split.flatten().tolist()
)
splitter = ShapeSplitter((10, 17), tile_shape=(2, 9))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((5, 2), split.shape)
self.assertSequenceEqual(
shape_split(splitter.array_shape, [[2, 4, 6, 8], [9, ]]).flatten().tolist(),
split.flatten().tolist()
)
[docs] def test_calculate_split_by_tile_max_bytes_1d(self):
splitter = ShapeSplitter((512, ), max_tile_bytes=256, array_itemsize=1)
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((2,), split.shape)
self.assertSequenceEqual(
[(slice(0, 256),), (slice(256, 512),)],
split.tolist()
)
splitter = ShapeSplitter((512, ), max_tile_bytes=256, array_itemsize=2)
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((4,), split.shape)
self.assertSequenceEqual(
[(slice(0, 128),), (slice(128, 256),), (slice(256, 384),), (slice(384, 512),)],
split.tolist()
)
splitter = ShapeSplitter((512, ), max_tile_bytes=511, array_itemsize=2)
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((3,), split.shape)
self.assertSequenceEqual(
[(slice(0, 171),), (slice(171, 342),), (slice(342, 512),)],
split.tolist()
)
splitter = ShapeSplitter((512, ), max_tile_bytes=256, array_itemsize=1, halo=1)
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((3,), split.shape)
self.assertSequenceEqual(
[(slice(0, 172),), (slice(170, 343),), (slice(341, 512),)],
split.tolist()
)
splitter = \
ShapeSplitter((512, ), max_tile_bytes=256, array_itemsize=1, max_tile_shape=(128,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((4,), split.shape)
self.assertSequenceEqual(
[(slice(0, 128),), (slice(128, 256),), (slice(256, 384),), (slice(384, 512),)],
split.tolist()
)
splitter = \
ShapeSplitter((512, ), max_tile_bytes=256, array_itemsize=1, sub_tile_shape=(130,))
split = splitter.calculate_split()
self.logger.info("split.shape = %s", split.shape)
self.logger.info("split =\n%s", split)
self.assertSequenceEqual((4,), split.shape)
self.assertSequenceEqual(
[(slice(0, 130),), (slice(130, 260),), (slice(260, 390),), (slice(390, 512),)],
split.tolist()
)
[docs] def test_calculate_split_with_array_start_1d(self):
split = shape_split((10,), 2, array_start=(0,))
self.assertSequenceEqual(
[(slice(0, 5),), (slice(5, 10),)],
split.tolist()
)
split = shape_split((10,), 2, array_start=(32,))
self.assertSequenceEqual(
[(slice(32, 37),), (slice(37, 42),)],
split.tolist()
)
[docs] def test_calculate_split_with_array_start_2d(self):
split = shape_split((10, 12), axis=(2, 2), array_start=(0, 0))
self.assertSequenceEqual(
[
[(slice(0, 5), slice(0, 6)), (slice(0, 5), slice(6, 12))],
[(slice(5, 10), slice(0, 6)), (slice(5, 10), slice(6, 12))]
],
split.tolist()
)
split = shape_split((10, 12), axis=(2, 2), array_start=(32, 16))
self.assertSequenceEqual(
[
[(slice(32, 37), slice(16, 22)), (slice(32, 37), slice(22, 28))],
[(slice(37, 42), slice(16, 22)), (slice(37, 42), slice(22, 28))]
],
split.tolist()
)
[docs] def test_calculate_split_with_halo_1d(self):
split = shape_split((10,), 3, halo=(0,))
self.assertSequenceEqual(
[(slice(0, 4),), (slice(4, 7),), (slice(7, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=(0, 0))
self.assertSequenceEqual(
[(slice(0, 4),), (slice(4, 7),), (slice(7, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=(1, 0))
self.assertSequenceEqual(
[(slice(0, 4),), (slice(3, 7),), (slice(6, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=(0, 1))
self.assertSequenceEqual(
[(slice(0, 5),), (slice(4, 8),), (slice(7, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=(1, 1))
self.assertSequenceEqual(
[(slice(0, 5),), (slice(3, 8),), (slice(6, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=[(1, 2), ])
self.assertSequenceEqual(
[(slice(0, 6),), (slice(3, 9),), (slice(6, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=1)
self.assertSequenceEqual(
[(slice(0, 5),), (slice(3, 8),), (slice(6, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=1, tile_bounds_policy=ARRAY_BOUNDS)
self.assertSequenceEqual(
[(slice(0, 5),), (slice(3, 8),), (slice(6, 10),)],
split.tolist()
)
split = shape_split((10,), 3, halo=1, tile_bounds_policy=NO_BOUNDS)
self.assertSequenceEqual(
[(slice(-1, 5),), (slice(3, 8),), (slice(6, 11),)],
split.tolist()
)
split = shape_split((10,), 3, halo=((2, 3),), tile_bounds_policy=NO_BOUNDS)
self.assertSequenceEqual(
[(slice(-2, 7),), (slice(2, 10),), (slice(5, 13),)],
split.tolist()
)
split = shape_split((10,), 3, halo=(2, 3), tile_bounds_policy=NO_BOUNDS)
self.assertSequenceEqual(
[(slice(-2, 7),), (slice(2, 10),), (slice(5, 13),)],
split.tolist()
)
[docs] def test_calculate_split_with_halo_2d(self):
split = shape_split((15, 13), axis=[3, 3], halo=0)
self.assertSequenceEqual(
[
[
(slice(0, 5), slice(0, 5)),
(slice(0, 5), slice(5, 9)),
(slice(0, 5), slice(9, 13))
],
[
(slice(5, 10), slice(0, 5)),
(slice(5, 10), slice(5, 9)),
(slice(5, 10), slice(9, 13))
],
[
(slice(10, 15), slice(0, 5)),
(slice(10, 15), slice(5, 9)),
(slice(10, 15), slice(9, 13))
],
],
split.tolist()
)
split = shape_split((15, 13), axis=[3, 3], halo=(0, 0))
self.assertSequenceEqual(
[
[
(slice(0, 5), slice(0, 5)),
(slice(0, 5), slice(5, 9)),
(slice(0, 5), slice(9, 13))
],
[
(slice(5, 10), slice(0, 5)),
(slice(5, 10), slice(5, 9)),
(slice(5, 10), slice(9, 13))
],
[
(slice(10, 15), slice(0, 5)),
(slice(10, 15), slice(5, 9)),
(slice(10, 15), slice(9, 13))
],
],
split.tolist()
)
split = shape_split((15, 13), axis=[3, 3], halo=[[0, 0], [0, 0]])
self.assertSequenceEqual(
[
[
(slice(0, 5), slice(0, 5)),
(slice(0, 5), slice(5, 9)),
(slice(0, 5), slice(9, 13))
],
[
(slice(5, 10), slice(0, 5)),
(slice(5, 10), slice(5, 9)),
(slice(5, 10), slice(9, 13))
],
[
(slice(10, 15), slice(0, 5)),
(slice(10, 15), slice(5, 9)),
(slice(10, 15), slice(9, 13))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=[[0, 0], [0, 0]],
tile_bounds_policy=ARRAY_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(0, 5), slice(0, 5)),
(slice(0, 5), slice(5, 9)),
(slice(0, 5), slice(9, 13))
],
[
(slice(5, 10), slice(0, 5)),
(slice(5, 10), slice(5, 9)),
(slice(5, 10), slice(9, 13))
],
[
(slice(10, 15), slice(0, 5)),
(slice(10, 15), slice(5, 9)),
(slice(10, 15), slice(9, 13))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=[[0, 0], [0, 0]],
tile_bounds_policy=NO_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(0, 5), slice(0, 5)),
(slice(0, 5), slice(5, 9)),
(slice(0, 5), slice(9, 13))
],
[
(slice(5, 10), slice(0, 5)),
(slice(5, 10), slice(5, 9)),
(slice(5, 10), slice(9, 13))
],
[
(slice(10, 15), slice(0, 5)),
(slice(10, 15), slice(5, 9)),
(slice(10, 15), slice(9, 13))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=1,
tile_bounds_policy=ARRAY_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(0, 6), slice(0, 6)),
(slice(0, 6), slice(4, 10)),
(slice(0, 6), slice(8, 13))
],
[
(slice(4, 11), slice(0, 6)),
(slice(4, 11), slice(4, 10)),
(slice(4, 11), slice(8, 13))
],
[
(slice(9, 15), slice(0, 6)),
(slice(9, 15), slice(4, 10)),
(slice(9, 15), slice(8, 13))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=(2, 3),
tile_bounds_policy=ARRAY_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(0, 7), slice(0, 8)),
(slice(0, 7), slice(2, 12)),
(slice(0, 7), slice(6, 13))
],
[
(slice(3, 12), slice(0, 8)),
(slice(3, 12), slice(2, 12)),
(slice(3, 12), slice(6, 13))
],
[
(slice(8, 15), slice(0, 8)),
(slice(8, 15), slice(2, 12)),
(slice(8, 15), slice(6, 13))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=[[1, 2], [2, 3]],
tile_bounds_policy=ARRAY_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(0, 7), slice(0, 8)),
(slice(0, 7), slice(3, 12)),
(slice(0, 7), slice(7, 13))
],
[
(slice(4, 12), slice(0, 8)),
(slice(4, 12), slice(3, 12)),
(slice(4, 12), slice(7, 13))
],
[
(slice(9, 15), slice(0, 8)),
(slice(9, 15), slice(3, 12)),
(slice(9, 15), slice(7, 13))
],
],
split.tolist()
)
# NO_BOUNDS
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=1,
tile_bounds_policy=NO_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(-1, 6), slice(-1, 6)),
(slice(-1, 6), slice(4, 10)),
(slice(-1, 6), slice(8, 14))
],
[
(slice(4, 11), slice(-1, 6)),
(slice(4, 11), slice(4, 10)),
(slice(4, 11), slice(8, 14))
],
[
(slice(9, 16), slice(-1, 6)),
(slice(9, 16), slice(4, 10)),
(slice(9, 16), slice(8, 14))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=(2, 3),
tile_bounds_policy=NO_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(-2, 7), slice(-3, 8)),
(slice(-2, 7), slice(2, 12)),
(slice(-2, 7), slice(6, 16))
],
[
(slice(3, 12), slice(-3, 8)),
(slice(3, 12), slice(2, 12)),
(slice(3, 12), slice(6, 16))
],
[
(slice(8, 17), slice(-3, 8)),
(slice(8, 17), slice(2, 12)),
(slice(8, 17), slice(6, 16))
],
],
split.tolist()
)
split = \
shape_split(
(15, 13),
axis=[3, 3],
halo=[[1, 2], [2, 3]],
tile_bounds_policy=NO_BOUNDS
)
self.assertSequenceEqual(
[
[
(slice(-1, 7), slice(-2, 8)),
(slice(-1, 7), slice(3, 12)),
(slice(-1, 7), slice(7, 16))
],
[
(slice(4, 12), slice(-2, 8)),
(slice(4, 12), slice(3, 12)),
(slice(4, 12), slice(7, 16))
],
[
(slice(9, 17), slice(-2, 8)),
(slice(9, 17), slice(3, 12)),
(slice(9, 17), slice(7, 16))
],
],
split.tolist()
)
__all__ = [s for s in dir() if not s.startswith('_')]
_unittest.main(__name__)