Source code for neuroconv.tools.testing.mock_ttl_signals

import math
from pathlib import Path

import numpy as np
from numpy.typing import DTypeLike
from pydantic import DirectoryPath
from pynwb import NWBHDF5IO, H5DataIO, TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

from ..importing import is_package_installed
from ...utils import ArrayType


def _check_parameter_dtype_consistency(
    parameter_name: str,
    parameter_value: int | float,
    generic_dtype: type,  # Literal[np.integer, np.floating]
):
    """Helper for `generate_mock_ttl_signal` to assert consistency between parameters and expected trace dtype."""
    end_format = "an integer" if generic_dtype == np.integer else "a float"
    assert np.issubdtype(type(parameter_value), generic_dtype), (
        f"If specifying the '{parameter_name}' manually, please ensure it matches the 'dtype'! "
        f"Received '{type(parameter_value).__name__}', should be {end_format}."
    )


[docs] def generate_mock_ttl_signal( signal_duration: float = 7.0, ttl_times: ArrayType | None = None, ttl_duration: float = 1.0, sampling_frequency_hz: float = 25_000.0, dtype: DTypeLike = "int16", baseline_mean: int | float | None = None, signal_mean: int | float | None = None, channel_noise: int | float | None = None, random_seed: int | None = 0, ) -> np.ndarray: """ Generate a synthetic signal of TTL pulses similar to those seen in .nidq.bin files using SpikeGLX. Parameters ---------- signal_duration : float, default: 7.0 The number of seconds to simulate. ttl_times : array of floats, optional The times within the `signal_duration` to trigger the TTL pulse. In conjunction with the `ttl_duration`, these must produce disjoint 'on' intervals. The default generates a periodic 1 second on, 1 second off pattern. ttl_duration : float, default: 1.0 How long the TTL pulse stays in the 'on' state when triggered, in seconds. In conjunction with the `ttl_times`, these must produce disjoint 'on' intervals. sampling_frequency_hz : float, default: 25,000.0 The sampling frequency of the signal in Hz. The default is 25000 Hz; similar to that of typical .nidq.bin files. dtype : numpy data type or one of its accepted string input, default: "int16" The data type of the trace. Must match the data type of `baseline_mean`, `signal_mean`, and `channel_noise`, if any of those are specified. Recommended to be int16 for maximum efficiency, but can also be any size float to represent voltage scalings. baseline_mean : integer or float, depending on specified 'dtype', optional The average value for the baseline; usually around 0 Volts. The default is approximately 0.005645752 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. signal_mean : integer or float, optional Type depends on specified 'dtype'. The average value for the signal; usually around 5 Volts. The default is approximately 4.980773925 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. channel_noise : integer or float, optional Type depends on specified 'dtype'. The standard deviation of white noise in the channel. The default is approximately 0.002288818 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. random_seed : int or None, default: 0 The seed to set for the numpy random number generator. Set to None to choose the seed randomly. The default is kept at 0 for generating reproducible outputs. Returns ------- trace: numpy.ndarray The synethic trace representing a channel with TTL pulses. """ dtype = np.dtype(dtype) # Default values estimated from real files baseline_mean_int16_default = 37 signal_mean_int16_default = 32642 channel_noise_int16_default = 15 default_gain_to_volts = 152.58789062 * 1e-6 if np.issubdtype(dtype, np.unsignedinteger): # If data type is an unsigned integer, increment the signed default values by the midpoint of the unsigned range shift = math.floor(np.iinfo(dtype).max / 2) baseline_mean_int16_default += shift signal_mean_int16_default += shift if np.issubdtype(dtype, np.integer): baseline_mean = baseline_mean or baseline_mean_int16_default signal_mean = signal_mean or signal_mean_int16_default channel_noise = channel_noise or channel_noise_int16_default generic_dtype = np.integer else: baseline_mean = baseline_mean or baseline_mean_int16_default * default_gain_to_volts signal_mean = signal_mean or signal_mean_int16_default * default_gain_to_volts channel_noise = channel_noise or channel_noise_int16_default * default_gain_to_volts generic_dtype = np.floating parameters_to_check = dict(baseline_mean=baseline_mean, signal_mean=signal_mean, channel_noise=channel_noise) for parameter_name, parameter_value in parameters_to_check.items(): _check_parameter_dtype_consistency( parameter_name=parameter_name, parameter_value=parameter_value, generic_dtype=generic_dtype ) np.random.seed(seed=random_seed) num_frames = np.ceil(signal_duration * sampling_frequency_hz).astype(int) trace = (np.random.randn(num_frames) * channel_noise + baseline_mean).astype(dtype) if ttl_times is not None: ttl_times = np.array(ttl_times) else: ttl_times = np.arange(start=1.0, stop=signal_duration, step=2.0) assert len(ttl_times) == 1 or not any( # np.diff errors out when len(ttl_times) < 2 np.diff(ttl_times) <= ttl_duration ), "There are overlapping TTL 'on' intervals! Please specify disjoint on/off periods." ttl_start_frames = np.round(ttl_times * sampling_frequency_hz).astype(int) num_frames_ttl_duration = np.round(ttl_duration * sampling_frequency_hz).astype(int) ttl_intervals = (slice(start, start + num_frames_ttl_duration) for start in ttl_start_frames) for ttl_interval in ttl_intervals: trace[ttl_interval] += signal_mean return trace
[docs] def regenerate_test_cases(folder_path: DirectoryPath, regenerate_reference_images: bool = False): # pragma: no cover """ Regenerate the test cases of the file included in the main testing suite, which is frozen between breaking changes. Parameters ---------- folder_path : PathType Folder to save the resulting NWB file in. For use in the testing suite, this must be the '/test_testing/test_mock_ttl/' subfolder adjacent to the 'test_mock_tt.py' file. regenerate_reference_images : bool If true, uses the kaleido package with plotly (you may need to install both) to regenerate the images used as references in the documentation. """ folder_path = Path(folder_path) if regenerate_reference_images: assert is_package_installed("plotly") and is_package_installed("kaleido"), ( "To regenerate the reference images, " "you must install both plotly and kaleido!" ) import plotly.graph_objects as go from plotly.subplots import make_subplots image_file_path = folder_path / "example_ttl_reference.png" nwbfile_path = folder_path / "mock_ttl_examples.nwb" compression_options = dict(compression="gzip", compression_opts=9) unit = "Volts" rate = 1000.0 # For non-default series to produce less data nwbfile = mock_NWBFile() # Test Case 1: Default default_ttl_signal = generate_mock_ttl_signal() nwbfile.add_acquisition( TimeSeries( name="DefaultTTLSignal", unit=unit, rate=25000.0, data=H5DataIO(data=default_ttl_signal, chunks=default_ttl_signal.shape, **compression_options), ) ) non_default_series = dict() # Test Case 2: Irregular short pulses irregular_short_pulses = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate ) non_default_series.update(IrregularShortPulses=irregular_short_pulses) # Test Case 3: Non-default regular non_default_regular = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, ) non_default_series.update(NonDefaultRegular=non_default_regular) # Test Case 4: Non-default regular with adjusted means non_default_regular_adjusted_means = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, baseline_mean=300, signal_mean=20000, ) non_default_series.update(NonDefaultRegularAdjustedMeans=non_default_regular_adjusted_means) # Test Case 5: Irregular short pulses with adjusted noise irregular_short_pulses_adjusted_noise = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate, channel_noise=2, ) non_default_series.update(IrregularShortPulsesAdjustedNoise=irregular_short_pulses_adjusted_noise) # Test Case 6: Non-default regular as floats non_default_regular_as_floats = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="float32", ) non_default_series.update(NonDefaultRegularFloats=non_default_regular_as_floats) # Test Case 7: Non-default regular as floats with adjusted means and noise (which are also, then, floats) non_default_regular_as_floats_adjusted_means_and_noise = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="float32", baseline_mean=1.1, signal_mean=7.2, channel_noise=0.4, ) non_default_series.update(FloatsAdjustedMeansAndNoise=non_default_regular_as_floats_adjusted_means_and_noise) # Test Case 8: Non-default regular as uint16 non_default_regular_as_uint16 = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="uint16", ) non_default_series.update(NonDefaultRegularUInt16=non_default_regular_as_uint16) # Test Case 9: Irregular short pulses with different seed irregular_short_pulses_different_seed = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate, random_seed=1, ) non_default_series.update(IrregularShortPulsesDifferentSeed=irregular_short_pulses_different_seed) if regenerate_reference_images: num_cols = 5 plot_index = 1 subplot_titles = ["Default"] subplot_titles.extend(list(non_default_series)) fig = make_subplots(rows=2, cols=num_cols, subplot_titles=subplot_titles) fig.add_trace(go.Scatter(y=default_ttl_signal, text="Default"), row=1, col=1) for time_series_name, time_series_data in non_default_series.items(): nwbfile.add_acquisition( TimeSeries( name=time_series_name, unit=unit, rate=rate, data=H5DataIO(data=time_series_data, chunks=time_series_data.shape, **compression_options), ) ) if regenerate_reference_images: fig.add_trace( go.Scatter(y=time_series_data, text=time_series_name), row=math.floor(plot_index / num_cols) + 1, col=int(plot_index % num_cols) + 1, ) plot_index += 1 if regenerate_reference_images: fig.update_annotations(font_size=6) fig.update_layout(showlegend=False) fig.update_yaxes(tickfont=dict(size=5)) fig.update_xaxes(showticklabels=False) fig.write_image(file=image_file_path) with NWBHDF5IO(path=nwbfile_path, mode="w") as io: io.write(nwbfile)