import math
from pathlib import Path
import numpy as np
from numpy.typing import DTypeLike
from pydantic import DirectoryPath
from pynwb import NWBHDF5IO, H5DataIO, TimeSeries
from pynwb.testing.mock.file import mock_NWBFile
from ..importing import is_package_installed
from ...utils import ArrayType
def _check_parameter_dtype_consistency(
parameter_name: str,
parameter_value: int | float,
generic_dtype: type, # Literal[np.integer, np.floating]
):
"""Helper for `generate_mock_ttl_signal` to assert consistency between parameters and expected trace dtype."""
end_format = "an integer" if generic_dtype == np.integer else "a float"
assert np.issubdtype(type(parameter_value), generic_dtype), (
f"If specifying the '{parameter_name}' manually, please ensure it matches the 'dtype'! "
f"Received '{type(parameter_value).__name__}', should be {end_format}."
)
[docs]
def generate_mock_ttl_signal(
signal_duration: float = 7.0,
ttl_times: ArrayType | None = None,
ttl_duration: float = 1.0,
sampling_frequency_hz: float = 25_000.0,
dtype: DTypeLike = "int16",
baseline_mean: int | float | None = None,
signal_mean: int | float | None = None,
channel_noise: int | float | None = None,
random_seed: int | None = 0,
) -> np.ndarray:
"""
Generate a synthetic signal of TTL pulses similar to those seen in .nidq.bin files using SpikeGLX.
Parameters
----------
signal_duration : float, default: 7.0
The number of seconds to simulate.
ttl_times : array of floats, optional
The times within the `signal_duration` to trigger the TTL pulse.
In conjunction with the `ttl_duration`, these must produce disjoint 'on' intervals.
The default generates a periodic 1 second on, 1 second off pattern.
ttl_duration : float, default: 1.0
How long the TTL pulse stays in the 'on' state when triggered, in seconds.
In conjunction with the `ttl_times`, these must produce disjoint 'on' intervals.
sampling_frequency_hz : float, default: 25,000.0
The sampling frequency of the signal in Hz.
The default is 25000 Hz; similar to that of typical .nidq.bin files.
dtype : numpy data type or one of its accepted string input, default: "int16"
The data type of the trace.
Must match the data type of `baseline_mean`, `signal_mean`, and `channel_noise`, if any of those are specified.
Recommended to be int16 for maximum efficiency, but can also be any size float to represent voltage scalings.
baseline_mean : integer or float, depending on specified 'dtype', optional
The average value for the baseline; usually around 0 Volts.
The default is approximately 0.005645752 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file.
signal_mean : integer or float, optional
Type depends on specified 'dtype'. The average value for the signal; usually around 5 Volts.
The default is approximately 4.980773925 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file.
channel_noise : integer or float, optional
Type depends on specified 'dtype'. The standard deviation of white noise in the channel.
The default is approximately 0.002288818 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file.
random_seed : int or None, default: 0
The seed to set for the numpy random number generator.
Set to None to choose the seed randomly.
The default is kept at 0 for generating reproducible outputs.
Returns
-------
trace: numpy.ndarray
The synethic trace representing a channel with TTL pulses.
"""
dtype = np.dtype(dtype)
# Default values estimated from real files
baseline_mean_int16_default = 37
signal_mean_int16_default = 32642
channel_noise_int16_default = 15
default_gain_to_volts = 152.58789062 * 1e-6
if np.issubdtype(dtype, np.unsignedinteger):
# If data type is an unsigned integer, increment the signed default values by the midpoint of the unsigned range
shift = math.floor(np.iinfo(dtype).max / 2)
baseline_mean_int16_default += shift
signal_mean_int16_default += shift
if np.issubdtype(dtype, np.integer):
baseline_mean = baseline_mean or baseline_mean_int16_default
signal_mean = signal_mean or signal_mean_int16_default
channel_noise = channel_noise or channel_noise_int16_default
generic_dtype = np.integer
else:
baseline_mean = baseline_mean or baseline_mean_int16_default * default_gain_to_volts
signal_mean = signal_mean or signal_mean_int16_default * default_gain_to_volts
channel_noise = channel_noise or channel_noise_int16_default * default_gain_to_volts
generic_dtype = np.floating
parameters_to_check = dict(baseline_mean=baseline_mean, signal_mean=signal_mean, channel_noise=channel_noise)
for parameter_name, parameter_value in parameters_to_check.items():
_check_parameter_dtype_consistency(
parameter_name=parameter_name, parameter_value=parameter_value, generic_dtype=generic_dtype
)
np.random.seed(seed=random_seed)
num_frames = np.ceil(signal_duration * sampling_frequency_hz).astype(int)
trace = (np.random.randn(num_frames) * channel_noise + baseline_mean).astype(dtype)
if ttl_times is not None:
ttl_times = np.array(ttl_times)
else:
ttl_times = np.arange(start=1.0, stop=signal_duration, step=2.0)
assert len(ttl_times) == 1 or not any( # np.diff errors out when len(ttl_times) < 2
np.diff(ttl_times) <= ttl_duration
), "There are overlapping TTL 'on' intervals! Please specify disjoint on/off periods."
ttl_start_frames = np.round(ttl_times * sampling_frequency_hz).astype(int)
num_frames_ttl_duration = np.round(ttl_duration * sampling_frequency_hz).astype(int)
ttl_intervals = (slice(start, start + num_frames_ttl_duration) for start in ttl_start_frames)
for ttl_interval in ttl_intervals:
trace[ttl_interval] += signal_mean
return trace
[docs]
def regenerate_test_cases(folder_path: DirectoryPath, regenerate_reference_images: bool = False): # pragma: no cover
"""
Regenerate the test cases of the file included in the main testing suite, which is frozen between breaking changes.
Parameters
----------
folder_path : PathType
Folder to save the resulting NWB file in. For use in the testing suite, this must be the
'/test_testing/test_mock_ttl/' subfolder adjacent to the 'test_mock_tt.py' file.
regenerate_reference_images : bool
If true, uses the kaleido package with plotly (you may need to install both) to regenerate the images used
as references in the documentation.
"""
folder_path = Path(folder_path)
if regenerate_reference_images:
assert is_package_installed("plotly") and is_package_installed("kaleido"), (
"To regenerate the reference images, " "you must install both plotly and kaleido!"
)
import plotly.graph_objects as go
from plotly.subplots import make_subplots
image_file_path = folder_path / "example_ttl_reference.png"
nwbfile_path = folder_path / "mock_ttl_examples.nwb"
compression_options = dict(compression="gzip", compression_opts=9)
unit = "Volts"
rate = 1000.0 # For non-default series to produce less data
nwbfile = mock_NWBFile()
# Test Case 1: Default
default_ttl_signal = generate_mock_ttl_signal()
nwbfile.add_acquisition(
TimeSeries(
name="DefaultTTLSignal",
unit=unit,
rate=25000.0,
data=H5DataIO(data=default_ttl_signal, chunks=default_ttl_signal.shape, **compression_options),
)
)
non_default_series = dict()
# Test Case 2: Irregular short pulses
irregular_short_pulses = generate_mock_ttl_signal(
signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate
)
non_default_series.update(IrregularShortPulses=irregular_short_pulses)
# Test Case 3: Non-default regular
non_default_regular = generate_mock_ttl_signal(
signal_duration=2.7,
ttl_times=[0.2, 1.2, 2.2],
ttl_duration=0.3,
sampling_frequency_hz=rate,
)
non_default_series.update(NonDefaultRegular=non_default_regular)
# Test Case 4: Non-default regular with adjusted means
non_default_regular_adjusted_means = generate_mock_ttl_signal(
signal_duration=2.7,
ttl_times=[0.2, 1.2, 2.2],
ttl_duration=0.3,
sampling_frequency_hz=rate,
baseline_mean=300,
signal_mean=20000,
)
non_default_series.update(NonDefaultRegularAdjustedMeans=non_default_regular_adjusted_means)
# Test Case 5: Irregular short pulses with adjusted noise
irregular_short_pulses_adjusted_noise = generate_mock_ttl_signal(
signal_duration=2.5,
ttl_times=[0.22, 1.37],
ttl_duration=0.25,
sampling_frequency_hz=rate,
channel_noise=2,
)
non_default_series.update(IrregularShortPulsesAdjustedNoise=irregular_short_pulses_adjusted_noise)
# Test Case 6: Non-default regular as floats
non_default_regular_as_floats = generate_mock_ttl_signal(
signal_duration=2.7,
ttl_times=[0.2, 1.2, 2.2],
ttl_duration=0.3,
sampling_frequency_hz=rate,
dtype="float32",
)
non_default_series.update(NonDefaultRegularFloats=non_default_regular_as_floats)
# Test Case 7: Non-default regular as floats with adjusted means and noise (which are also, then, floats)
non_default_regular_as_floats_adjusted_means_and_noise = generate_mock_ttl_signal(
signal_duration=2.7,
ttl_times=[0.2, 1.2, 2.2],
ttl_duration=0.3,
sampling_frequency_hz=rate,
dtype="float32",
baseline_mean=1.1,
signal_mean=7.2,
channel_noise=0.4,
)
non_default_series.update(FloatsAdjustedMeansAndNoise=non_default_regular_as_floats_adjusted_means_and_noise)
# Test Case 8: Non-default regular as uint16
non_default_regular_as_uint16 = generate_mock_ttl_signal(
signal_duration=2.7,
ttl_times=[0.2, 1.2, 2.2],
ttl_duration=0.3,
sampling_frequency_hz=rate,
dtype="uint16",
)
non_default_series.update(NonDefaultRegularUInt16=non_default_regular_as_uint16)
# Test Case 9: Irregular short pulses with different seed
irregular_short_pulses_different_seed = generate_mock_ttl_signal(
signal_duration=2.5,
ttl_times=[0.22, 1.37],
ttl_duration=0.25,
sampling_frequency_hz=rate,
random_seed=1,
)
non_default_series.update(IrregularShortPulsesDifferentSeed=irregular_short_pulses_different_seed)
if regenerate_reference_images:
num_cols = 5
plot_index = 1
subplot_titles = ["Default"]
subplot_titles.extend(list(non_default_series))
fig = make_subplots(rows=2, cols=num_cols, subplot_titles=subplot_titles)
fig.add_trace(go.Scatter(y=default_ttl_signal, text="Default"), row=1, col=1)
for time_series_name, time_series_data in non_default_series.items():
nwbfile.add_acquisition(
TimeSeries(
name=time_series_name,
unit=unit,
rate=rate,
data=H5DataIO(data=time_series_data, chunks=time_series_data.shape, **compression_options),
)
)
if regenerate_reference_images:
fig.add_trace(
go.Scatter(y=time_series_data, text=time_series_name),
row=math.floor(plot_index / num_cols) + 1,
col=int(plot_index % num_cols) + 1,
)
plot_index += 1
if regenerate_reference_images:
fig.update_annotations(font_size=6)
fig.update_layout(showlegend=False)
fig.update_yaxes(tickfont=dict(size=5))
fig.update_xaxes(showticklabels=False)
fig.write_image(file=image_file_path)
with NWBHDF5IO(path=nwbfile_path, mode="w") as io:
io.write(nwbfile)