import json
import tempfile
from abc import abstractmethod
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Literal
import numpy as np
import pytest
from hdmf_zarr import NWBZarrIO
from jsonschema.validators import Draft7Validator, validate
from numpy.testing import assert_array_equal
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile
from neuroconv import BaseDataInterface, NWBConverter
from neuroconv.datainterfaces.ecephys.baserecordingextractorinterface import (
BaseRecordingExtractorInterface,
)
from neuroconv.datainterfaces.ecephys.basesortingextractorinterface import (
BaseSortingExtractorInterface,
)
from neuroconv.datainterfaces.ophys.baseimagingextractorinterface import (
BaseImagingExtractorInterface,
)
from neuroconv.datainterfaces.ophys.basesegmentationextractorinterface import (
BaseSegmentationExtractorInterface,
)
from neuroconv.tools.nwb_helpers import (
configure_backend,
get_default_backend_configuration,
)
from neuroconv.utils.json_schema import _NWBMetaDataEncoder
[docs]
class DataInterfaceTestMixin:
"""
Generic class for testing DataInterfaces.
Several of these tests are required to be run in a specific order. In this case,
there is a `test_conversion_as_lone_interface` that calls the `check` functions in
the appropriate order, after the `interface` has been created. Normally, you might
expect the `interface` to be simply created in the `setUp` method, but this class
allows you to specify multiple interface_kwargs.
Class Attributes
----------------
data_interface_cls : DataInterface
class, not instance
interface_kwargs : dict or list
When it is a dictionary, take these as arguments to the constructor of the
interface. When it is a list, each element of the list is a dictionary of
arguments to the constructor. Each dictionary will be tested one at a time.
save_directory : Path, optional
Directory where test files should be saved.
"""
data_interface_cls: type[BaseDataInterface]
interface_kwargs: dict
save_directory: Path = Path(tempfile.mkdtemp())
conversion_options: dict | None = None
maxDiff = None
[docs]
@pytest.fixture
def setup_interface(self, request):
"""Add this as a fixture when you want freshly created interface in the test."""
self.test_name: str = ""
self.interface = self.data_interface_cls(**self.interface_kwargs)
return self.interface, self.test_name
[docs]
@pytest.fixture(scope="class", autouse=True)
def setup_default_conversion_options(self, request):
cls = request.cls
cls.conversion_options = cls.conversion_options or dict()
return cls.conversion_options
[docs]
def test_source_schema_valid(self):
schema = self.data_interface_cls.get_source_schema()
Draft7Validator.check_schema(schema=schema)
[docs]
def test_conversion_options_schema_valid(self, setup_interface):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)
[docs]
@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend(self, setup_interface, tmp_path, backend):
nwbfile_path = str(tmp_path / f"conversion_with_backend{backend}-{self.test_name}.nwb")
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)
if backend == "zarr":
with NWBZarrIO(path=nwbfile_path, mode="r") as io:
io.read()
[docs]
@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend_configuration(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())
nwbfile_path = str(tmp_path / f"conversion_with_backend_configuration{backend}-{self.test_name}.nwb")
nwbfile = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
backend_configuration = self.interface.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend_configuration=backend_configuration,
**self.conversion_options,
)
[docs]
def test_all_conversion_checks(self, setup_interface, tmp_path):
interface, test_name = setup_interface
# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")
self.nwbfile_path = nwbfile_path
self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_read_nwb(nwbfile_path=nwbfile_path)
# Any extra custom checks to run
self.run_custom_checks()
[docs]
@abstractmethod
def check_read_nwb(self, nwbfile_path: str):
"""Read the produced NWB file and compare it to the interface."""
pass
[docs]
def run_custom_checks(self):
"""Override this in child classes to inject additional custom checks."""
pass
[docs]
def check_run_conversion_in_nwbconverter_with_backend(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
source_data = dict(Test=self.interface_kwargs)
converter = TestNWBConverter(source_data=source_data)
metadata = converter.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())
conversion_options = dict(Test=self.conversion_options)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
conversion_options=conversion_options,
)
[docs]
def check_run_conversion_in_nwbconverter_with_backend_configuration(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
source_data = dict(Test=self.interface_kwargs)
converter = TestNWBConverter(source_data=source_data)
metadata = converter.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())
conversion_options = dict(Test=self.conversion_options)
nwbfile = converter.create_nwbfile(metadata=metadata, conversion_options=conversion_options)
backend_configuration = converter.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend_configuration=backend_configuration,
conversion_options=conversion_options,
)
[docs]
class TemporalAlignmentMixin:
"""
Generic class for testing temporal alignment methods.
"""
data_interface_cls: type[BaseDataInterface]
interface_kwargs: dict
save_directory: Path = Path(tempfile.mkdtemp())
conversion_options: dict | None = None
maxDiff = None
[docs]
@pytest.fixture
def setup_interface(self, request):
self.test_name: str = ""
self.interface = self.data_interface_cls(**self.interface_kwargs)
return self.interface, self.test_name
[docs]
@pytest.fixture(scope="class", autouse=True)
def setup_default_conversion_options(self, request):
cls = request.cls
cls.conversion_options = cls.conversion_options or dict()
return cls.conversion_options
[docs]
def setUpFreshInterface(self):
"""Protocol for creating a fresh instance of the interface."""
self.interface = self.data_interface_cls(**self.interface_kwargs)
[docs]
def check_interface_get_original_timestamps(self):
"""
Just to ensure each interface can call .get_original_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
original_timestamps = self.interface.get_original_timestamps()
assert len(original_timestamps) != 0
[docs]
def check_interface_get_timestamps(self):
"""
Just to ensure each interface can call .get_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
timestamps = self.interface.get_timestamps()
assert len(timestamps) != 0
[docs]
def check_interface_set_aligned_timestamps(self):
"""Ensure that internal mechanisms for the timestamps getter/setter work as expected."""
self.setUpFreshInterface()
unaligned_timestamps = self.interface.get_timestamps()
random_number_generator = np.random.default_rng(seed=0)
aligned_timestamps = (
unaligned_timestamps + 1.23 + random_number_generator.random(size=unaligned_timestamps.shape)
)
self.interface.set_aligned_timestamps(aligned_timestamps=aligned_timestamps)
retrieved_aligned_timestamps = self.interface.get_timestamps()
assert_array_equal(retrieved_aligned_timestamps, aligned_timestamps)
[docs]
def check_shift_timestamps_by_start_time(self):
"""Ensure that internal mechanisms for shifting timestamps by a starting time work as expected."""
self.setUpFreshInterface()
unaligned_timestamps = self.interface.get_timestamps()
aligned_starting_time = 1.23
self.interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time)
aligned_timestamps = self.interface.get_timestamps()
expected_timestamps = unaligned_timestamps + aligned_starting_time
assert_array_equal(aligned_timestamps, expected_timestamps)
[docs]
def check_interface_original_timestamps_inmutability(self):
"""Check aligning the timestamps for the interface does not change the value of .get_original_timestamps()."""
self.setUpFreshInterface()
pre_alignment_original_timestamps = self.interface.get_original_timestamps()
aligned_timestamps = pre_alignment_original_timestamps + 1.23
self.interface.set_aligned_timestamps(aligned_timestamps=aligned_timestamps)
post_alignment_original_timestamps = self.interface.get_original_timestamps()
assert_array_equal(post_alignment_original_timestamps, pre_alignment_original_timestamps)
[docs]
def check_nwbfile_temporal_alignment(self):
"""Check the temporally aligned timing information makes it into the NWB file."""
pass # TODO: will be easier to add when interface have 'add' methods separate from .run_conversion()
[docs]
def test_interface_alignment(self, setup_interface):
interface, test_name = setup_interface
self.check_interface_get_original_timestamps()
self.check_interface_get_timestamps()
self.check_interface_set_aligned_timestamps()
self.check_shift_timestamps_by_start_time()
self.check_interface_original_timestamps_inmutability()
self.check_nwbfile_temporal_alignment()
[docs]
class AudioInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""
A mixin for testing Audio interfaces.
"""
# Currently asserted in the downstream testing suite; could be refactored in future PR
[docs]
def check_read_nwb(self, nwbfile_path: str):
pass
# Currently asserted in the downstream testing suite
[docs]
def test_interface_alignment(self):
pass
[docs]
class VideoInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""
A mixin for testing Video interfaces.
"""
[docs]
def check_read_nwb(self, nwbfile_path: str):
with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io:
nwbfile = io.read()
video_type = Path(self.interface_kwargs["file_paths"][0]).suffix[1:]
assert f"Video video_{video_type}" in nwbfile.acquisition
[docs]
def check_interface_set_aligned_timestamps(self):
all_unaligned_timestamps = self.interface.get_original_timestamps()
random_number_generator = np.random.default_rng(seed=0)
aligned_timestamps = [
unaligned_timestamps + 1.23 + random_number_generator.random(size=unaligned_timestamps.shape)
for unaligned_timestamps in all_unaligned_timestamps
]
self.interface.set_aligned_timestamps(aligned_timestamps=aligned_timestamps)
retrieved_aligned_timestamps = self.interface.get_timestamps()
assert_array_equal(retrieved_aligned_timestamps, aligned_timestamps)
[docs]
def check_shift_timestamps_by_start_time(self):
self.setUpFreshInterface()
aligned_starting_time = 1.23
self.interface.set_aligned_timestamps(aligned_timestamps=self.interface.get_original_timestamps())
self.interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time)
all_aligned_timestamps = self.interface.get_timestamps()
unaligned_timestamps = self.interface.get_original_timestamps()
all_expected_timestamps = [timestamps + aligned_starting_time for timestamps in unaligned_timestamps]
[
assert_array_equal(aligned_timestamps, expected_timestamps)
for aligned_timestamps, expected_timestamps in zip(all_aligned_timestamps, all_expected_timestamps)
]
[docs]
def check_set_aligned_segment_starting_times(self):
self.setUpFreshInterface()
aligned_segment_starting_times = [
1.23 * file_path_index for file_path_index in range(len(self.interface_kwargs))
]
self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=aligned_segment_starting_times)
all_aligned_timestamps = self.interface.get_timestamps()
unaligned_timestamps = self.interface.get_original_timestamps()
all_expected_timestamps = [
timestamps + segment_starting_time
for timestamps, segment_starting_time in zip(unaligned_timestamps, aligned_segment_starting_times)
]
for aligned_timestamps, expected_timestamps in zip(all_aligned_timestamps, all_expected_timestamps):
assert_array_equal(aligned_timestamps, expected_timestamps)
[docs]
def check_interface_original_timestamps_inmutability(self):
self.setUpFreshInterface()
all_pre_alignment_original_timestamps = self.interface.get_original_timestamps()
all_aligned_timestamps = [
pre_alignment_original_timestamps + 1.23
for pre_alignment_original_timestamps in all_pre_alignment_original_timestamps
]
self.interface.set_aligned_timestamps(aligned_timestamps=all_aligned_timestamps)
all_post_alignment_original_timestamps = self.interface.get_original_timestamps()
for post_alignment_original_timestamps, pre_alignment_original_timestamps in zip(
all_post_alignment_original_timestamps, all_pre_alignment_original_timestamps
):
assert_array_equal(post_alignment_original_timestamps, pre_alignment_original_timestamps)
[docs]
class MedPCInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""
A mixin for testing MedPC interfaces.
"""
[docs]
def test_conversion_options_schema_valid(self):
pass
[docs]
def test_run_conversion_with_backend(self):
pass
[docs]
def test_run_conversion_with_backend_configuration(self):
pass
[docs]
def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)
[docs]
def check_run_conversion_with_backend(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)
[docs]
def check_run_conversion_with_backend_configuration(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
nwbfile = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
backend_configuration = self.interface.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend_configuration=backend_configuration,
**self.conversion_options,
)
[docs]
def check_run_conversion_in_nwbconverter_with_backend(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs
source_data = dict(Test=test_kwargs)
converter = TestNWBConverter(source_data=source_data)
conversion_options = dict(Test=self.conversion_options)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
conversion_options=conversion_options,
)
[docs]
def check_run_conversion_in_nwbconverter_with_backend_configuration(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs
source_data = dict(Test=test_kwargs)
converter = TestNWBConverter(source_data=source_data)
conversion_options = dict(Test=self.conversion_options)
nwbfile = converter.create_nwbfile(metadata=metadata, conversion_options=conversion_options)
backend_configuration = converter.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend_configuration=backend_configuration,
conversion_options=conversion_options,
)
[docs]
def test_all_conversion_checks(self, metadata: dict):
interface_kwargs = self.interface_kwargs
if isinstance(interface_kwargs, dict):
interface_kwargs = [interface_kwargs]
for num, kwargs in enumerate(interface_kwargs):
with self.subTest(str(num)):
self.case = num
self.test_kwargs = kwargs
self.interface = self.data_interface_cls(**self.test_kwargs)
self.check_metadata_schema_valid()
self.check_conversion_options_schema_valid()
self.check_metadata()
self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb")
self.check_no_metadata_mutation(metadata=metadata)
self.check_configure_backend_for_equivalent_nwbfiles(metadata=metadata)
self.check_run_conversion_in_nwbconverter_with_backend(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_in_nwbconverter_with_backend_configuration(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_with_backend(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_with_backend_configuration(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_read_nwb(nwbfile_path=self.nwbfile_path)
# TODO: enable when all H5DataIO prewraps are gone
# self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb.zarr")
# self.check_run_conversion(nwbfile_path=self.nwbfile_path, backend="zarr")
# self.check_run_conversion_custom_backend(nwbfile_path=self.nwbfile_path, backend="zarr")
# self.check_basic_zarr_read(nwbfile_path=self.nwbfile_path)
# Any extra custom checks to run
self.run_custom_checks()
[docs]
def check_interface_get_original_timestamps(self, medpc_name_to_info_dict: dict):
"""
Just to ensure each interface can call .get_original_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
original_timestamps_dict = self.interface.get_original_timestamps(
medpc_name_to_info_dict=medpc_name_to_info_dict
)
for name in self.interface.source_data["aligned_timestamp_names"]:
original_timestamps = original_timestamps_dict[name]
assert len(original_timestamps) != 0, f"Timestamps for {name} are empty."
[docs]
def check_interface_get_timestamps(self):
"""
Just to ensure each interface can call .get_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
timestamps_dict = self.interface.get_timestamps()
for timestamps in timestamps_dict.values():
assert len(timestamps) != 0
[docs]
def check_interface_set_aligned_timestamps(self, medpc_name_to_info_dict: dict):
"""Ensure that internal mechanisms for the timestamps getter/setter work as expected."""
self.setUpFreshInterface()
unaligned_timestamps_dict = self.interface.get_original_timestamps(
medpc_name_to_info_dict=medpc_name_to_info_dict
)
random_number_generator = np.random.default_rng(seed=0)
aligned_timestamps_dict = {}
for name, unaligned_timestamps in unaligned_timestamps_dict.items():
aligned_timestamps = (
unaligned_timestamps + 1.23 + random_number_generator.random(size=unaligned_timestamps.shape)
)
aligned_timestamps_dict[name] = aligned_timestamps
self.interface.set_aligned_timestamps(aligned_timestamps_dict=aligned_timestamps_dict)
retrieved_aligned_timestamps = self.interface.get_timestamps()
for name, aligned_timestamps in aligned_timestamps_dict.items():
assert_array_equal(retrieved_aligned_timestamps[name], aligned_timestamps)
[docs]
def check_shift_timestamps_by_start_time(self, medpc_name_to_info_dict: dict):
"""Ensure that internal mechanisms for shifting timestamps by a starting time work as expected."""
self.setUpFreshInterface()
unaligned_timestamps_dict = self.interface.get_original_timestamps(
medpc_name_to_info_dict=medpc_name_to_info_dict
)
aligned_starting_time = 1.23
self.interface.set_aligned_starting_time(
aligned_starting_time=aligned_starting_time,
medpc_name_to_info_dict=medpc_name_to_info_dict,
)
aligned_timestamps = self.interface.get_timestamps()
expected_timestamps_dict = {
name: unaligned_timestamps + aligned_starting_time
for name, unaligned_timestamps in unaligned_timestamps_dict.items()
}
for name, expected_timestamps in expected_timestamps_dict.items():
assert_array_equal(aligned_timestamps[name], expected_timestamps)
[docs]
def check_interface_original_timestamps_inmutability(self, medpc_name_to_info_dict: dict):
"""Check aligning the timestamps for the interface does not change the value of .get_original_timestamps()."""
self.setUpFreshInterface()
pre_alignment_original_timestamps_dict = self.interface.get_original_timestamps(
medpc_name_to_info_dict=medpc_name_to_info_dict
)
aligned_timestamps_dict = {
name: pre_alignment_og_timestamps + 1.23
for name, pre_alignment_og_timestamps in pre_alignment_original_timestamps_dict.items()
}
self.interface.set_aligned_timestamps(aligned_timestamps_dict=aligned_timestamps_dict)
post_alignment_original_timestamps_dict = self.interface.get_original_timestamps(
medpc_name_to_info_dict=medpc_name_to_info_dict
)
for name, post_alignment_original_timestamps_dict in post_alignment_original_timestamps_dict.items():
assert_array_equal(post_alignment_original_timestamps_dict, pre_alignment_original_timestamps_dict[name])
[docs]
def test_interface_alignment(self, medpc_name_to_info_dict: dict):
interface_kwargs = self.interface_kwargs
if isinstance(interface_kwargs, dict):
interface_kwargs = [interface_kwargs]
for num, kwargs in enumerate(interface_kwargs):
with self.subTest(str(num)):
self.case = num
self.test_kwargs = kwargs
self.check_interface_get_original_timestamps(medpc_name_to_info_dict=medpc_name_to_info_dict)
self.check_interface_get_timestamps()
self.check_interface_set_aligned_timestamps(medpc_name_to_info_dict=medpc_name_to_info_dict)
self.check_shift_timestamps_by_start_time(medpc_name_to_info_dict=medpc_name_to_info_dict)
self.check_interface_original_timestamps_inmutability(medpc_name_to_info_dict=medpc_name_to_info_dict)
self.check_nwbfile_temporal_alignment()
[docs]
class MiniscopeImagingInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""
A mixin for testing Miniscope Imaging interfaces.
"""
[docs]
def check_read_nwb(self, nwbfile_path: str):
from ndx_miniscope import Miniscope
with NWBHDF5IO(nwbfile_path, "r") as io:
nwbfile = io.read()
assert self.device_name in nwbfile.devices
device = nwbfile.devices[self.device_name]
assert isinstance(device, Miniscope)
imaging_plane = nwbfile.imaging_planes[self.imaging_plane_name]
assert imaging_plane.device.name == self.device_name
# Check OnePhotonSeries
assert self.photon_series_name in nwbfile.acquisition
one_photon_series = nwbfile.acquisition[self.photon_series_name]
assert one_photon_series.unit == "px"
assert one_photon_series.data.shape == (15, 752, 480)
assert one_photon_series.data.dtype == np.uint8
assert one_photon_series.rate is None
assert one_photon_series.starting_frame is None
assert one_photon_series.timestamps.shape == (15,)
interface_times = self.interface.get_original_timestamps()
assert_array_equal(one_photon_series.timestamps, interface_times)
[docs]
class TDTFiberPhotometryInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""Mixin for testing TDT Fiber Photometry interfaces."""
def test_no_metadata_mutation(self):
pass
[docs]
def test_conversion_options_schema_valid(self):
pass
[docs]
def test_run_conversion_with_backend(self):
pass
[docs]
def test_run_conversion_with_backend_configuration(self):
pass
[docs]
def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)
[docs]
def check_run_conversion_with_backend(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)
[docs]
def check_run_conversion_with_backend_configuration(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
nwbfile = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
backend_configuration = self.interface.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
self.interface.run_conversion(
nwbfile_path=nwbfile_path,
metadata=metadata,
overwrite=True,
backend_configuration=backend_configuration,
**self.conversion_options,
)
[docs]
def check_run_conversion_in_nwbconverter_with_backend(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs
source_data = dict(Test=test_kwargs)
converter = TestNWBConverter(source_data=source_data)
conversion_options = dict(Test=self.conversion_options)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
conversion_options=conversion_options,
)
[docs]
def check_run_conversion_in_nwbconverter_with_backend_configuration(
self, nwbfile_path: str, metadata: dict, backend: Literal["hdf5", "zarr"] = "hdf5"
):
class TestNWBConverter(NWBConverter):
data_interface_classes = dict(Test=type(self.interface))
test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs
source_data = dict(Test=test_kwargs)
converter = TestNWBConverter(source_data=source_data)
conversion_options = dict(Test=self.conversion_options)
nwbfile = converter.create_nwbfile(metadata=metadata, conversion_options=conversion_options)
backend_configuration = converter.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
converter.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend_configuration=backend_configuration,
conversion_options=conversion_options,
)
[docs]
def test_all_conversion_checks(self, metadata: dict):
interface_kwargs = self.interface_kwargs
if isinstance(interface_kwargs, dict):
interface_kwargs = [interface_kwargs]
for num, kwargs in enumerate(interface_kwargs):
with self.subTest(str(num)):
self.case = num
self.test_kwargs = kwargs
self.interface = self.data_interface_cls(**self.test_kwargs)
self.check_metadata_schema_valid()
self.check_conversion_options_schema_valid()
self.check_metadata()
self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb")
self.check_no_metadata_mutation(metadata=metadata)
self.check_configure_backend_for_equivalent_nwbfiles(metadata=metadata)
self.check_run_conversion_in_nwbconverter_with_backend(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_in_nwbconverter_with_backend_configuration(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_with_backend(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_run_conversion_with_backend_configuration(
nwbfile_path=self.nwbfile_path, metadata=metadata, backend="hdf5"
)
self.check_read_nwb(nwbfile_path=self.nwbfile_path)
# TODO: enable when all H5DataIO prewraps are gone
# self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb.zarr")
# self.check_run_conversion(nwbfile_path=self.nwbfile_path, backend="zarr")
# self.check_run_conversion_custom_backend(nwbfile_path=self.nwbfile_path, backend="zarr")
# self.check_basic_zarr_read(nwbfile_path=self.nwbfile_path)
# Any extra custom checks to run
self.run_custom_checks()
[docs]
def check_interface_get_original_timestamps(self):
"""
Just to ensure each interface can call .get_original_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
t1 = self.conversion_options.get("t1", 0.0)
t2 = self.conversion_options.get("t2", 0.0)
stream_name_to_timestamps = self.interface.get_original_timestamps(t1=t1, t2=t2)
for stream_name, timestamps in stream_name_to_timestamps.items():
assert len(timestamps) != 0, f"Timestamps for {stream_name} are empty."
[docs]
def check_interface_get_timestamps(self):
"""
Just to ensure each interface can call .get_timestamps() without an error raising.
Also, that it always returns non-empty.
"""
self.setUpFreshInterface()
t1 = self.conversion_options.get("t1", 0.0)
t2 = self.conversion_options.get("t2", 0.0)
stream_name_to_timestamps = self.interface.get_timestamps(t1=t1, t2=t2)
for stream_name, timestamps in stream_name_to_timestamps.items():
assert len(timestamps) != 0, f"Timestamps for {stream_name} are empty."
[docs]
def check_interface_set_aligned_timestamps(self):
"""Ensure that internal mechanisms for the timestamps getter/setter work as expected."""
t1 = self.conversion_options.get("t1", 0.0)
t2 = self.conversion_options.get("t2", 0.0)
self.setUpFreshInterface()
unaligned_stream_name_to_timestamps = self.interface.get_original_timestamps(t1=t1, t2=t2)
random_number_generator = np.random.default_rng(seed=0)
aligned_stream_name_to_timestamps = {}
for stream_name, unaligned_timestamps in unaligned_stream_name_to_timestamps.items():
aligned_timestamps = (
unaligned_timestamps + 1.23 + random_number_generator.random(size=unaligned_timestamps.shape)
)
aligned_stream_name_to_timestamps[stream_name] = aligned_timestamps
self.interface.set_aligned_timestamps(stream_name_to_aligned_timestamps=aligned_stream_name_to_timestamps)
t1 += 1.23 if t1 != 0.0 else 0.0
t2 += 2.23 if t2 != 0.0 else 0.0
retrieved_aligned_stream_name_to_timestamps = self.interface.get_timestamps(t1=t1, t2=t2)
for stream_name, aligned_timestamps in aligned_stream_name_to_timestamps.items():
retrieved_aligned_timestamps = retrieved_aligned_stream_name_to_timestamps[stream_name]
assert_array_equal(retrieved_aligned_timestamps, aligned_timestamps)
[docs]
def check_shift_timestamps_by_start_time(self):
"""Ensure that internal mechanisms for shifting timestamps by a starting time work as expected."""
t1 = self.conversion_options.get("t1", 0.0)
t2 = self.conversion_options.get("t2", 0.0)
self.setUpFreshInterface()
unaligned_stream_name_to_timestamps = self.interface.get_original_timestamps(t1=t1, t2=t2)
aligned_starting_time = 1.23
self.interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time, t1=t1, t2=t2)
t1 += aligned_starting_time if t1 != 0.0 else 0.0
t2 += aligned_starting_time if t2 != 0.0 else 0.0
aligned_stream_name_to_timestamps = self.interface.get_timestamps(t1=t1, t2=t2)
expected_timestamps_dict = {
name: unaligned_timestamps + aligned_starting_time
for name, unaligned_timestamps in unaligned_stream_name_to_timestamps.items()
}
for name, expected_timestamps in expected_timestamps_dict.items():
timestamps = aligned_stream_name_to_timestamps[name]
assert_array_equal(timestamps, expected_timestamps)
[docs]
def check_interface_original_timestamps_inmutability(self):
"""Check aligning the timestamps for the interface does not change the value of .get_original_timestamps()."""
t1 = self.conversion_options.get("t1", 0.0)
t2 = self.conversion_options.get("t2", 0.0)
self.setUpFreshInterface()
pre_alignment_stream_name_to_timestamps = self.interface.get_original_timestamps(t1=t1, t2=t2)
aligned_stream_name_to_timestamps = {
name: pre_alignment_timestamps + 1.23
for name, pre_alignment_timestamps in pre_alignment_stream_name_to_timestamps.items()
}
self.interface.set_aligned_timestamps(stream_name_to_aligned_timestamps=aligned_stream_name_to_timestamps)
post_alignment_stream_name_to_timestamps = self.interface.get_original_timestamps(t1=t1, t2=t2)
for name, post_alignment_timestamps in post_alignment_stream_name_to_timestamps.items():
pre_alignment_timestamps = pre_alignment_stream_name_to_timestamps[name]
assert_array_equal(post_alignment_timestamps, pre_alignment_timestamps)
[docs]
def test_interface_alignment(self):
interface_kwargs = self.interface_kwargs
if isinstance(interface_kwargs, dict):
interface_kwargs = [interface_kwargs]
for num, kwargs in enumerate(interface_kwargs):
with self.subTest(str(num)):
self.case = num
self.test_kwargs = kwargs
self.check_interface_get_original_timestamps()
self.check_interface_get_timestamps()
self.check_interface_set_aligned_timestamps()
self.check_shift_timestamps_by_start_time()
self.check_interface_original_timestamps_inmutability()
self.check_nwbfile_temporal_alignment()
[docs]
class PoseEstimationInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""
Generic class for testing any pose estimation interface.
"""
[docs]
def check_read_nwb(self, nwbfile_path: str):
"""Check that pose estimation data can be read back from NWB file."""
with NWBHDF5IO(nwbfile_path, "r") as io:
nwbfile = io.read()
# Check that behavior module exists
assert "behavior" in nwbfile.processing
behavior_module = nwbfile.processing["behavior"]
# Check for pose estimation container (this may vary by interface)
# Most interfaces will have some pose estimation container in behavior
pose_containers = [
data_interface
for name, data_interface in behavior_module.data_interfaces.items()
if hasattr(data_interface, "pose_estimation_series")
]
assert len(pose_containers) > 0, "No pose estimation containers found in behavior module"
# Check that pose estimation series exist
pose_container = pose_containers[0]
assert hasattr(pose_container, "pose_estimation_series")
assert len(pose_container.pose_estimation_series) > 0
# Check that timestamps are properly written
for series_name, series in pose_container.pose_estimation_series.items():
assert hasattr(series, "timestamps")
assert len(series.timestamps) > 0
assert hasattr(series, "data")
assert len(series.data) > 0
# Check data dimensions (should be 2D: time x spatial_dims)
assert len(series.data.shape) == 2
assert series.data.shape[0] == len(series.timestamps)