Source code for neuroconv.datainterfaces.ecephys.baserecordingextractorinterface
import warnings
from typing import Literal
import numpy as np
from pynwb import NWBFile
from pynwb.device import Device
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup
from ...baseextractorinterface import BaseExtractorInterface
from ...utils import (
DeepDict,
get_base_schema,
get_schema_from_hdmf_class,
)
[docs]
class BaseRecordingExtractorInterface(BaseExtractorInterface):
"""Parent class for all RecordingExtractorInterfaces."""
keywords = ("extracellular electrophysiology", "voltage", "recording")
def _initialize_extractor(self, interface_kwargs: dict):
"""
Initialize and return the extractor instance for recording interfaces.
Extends the base implementation to also remove the 'es_key' parameter
which is specific to the recording interface, not the extractor.
Also adds 'all_annotations=True' to ensure all metadata is loaded.
Parameters
----------
interface_kwargs : dict
The source data parameters passed to the interface constructor.
Returns
-------
extractor_instance
An initialized recording extractor instance.
"""
self.extractor_kwargs = interface_kwargs.copy()
self.extractor_kwargs.pop("verbose", None)
self.extractor_kwargs.pop("es_key", None)
self.extractor_kwargs["all_annotations"] = True
extractor_class = self.get_extractor_class()
extractor_instance = extractor_class(**self.extractor_kwargs)
return extractor_instance
def __init__(self, verbose: bool = False, es_key: str = "ElectricalSeries", **source_data):
"""
Parameters
----------
verbose : bool, default: False
If True, will print out additional information.
es_key : str, default: "ElectricalSeries"
The key of this ElectricalSeries in the metadata dictionary.
source_data : dict
The key-value pairs of extractor-specific arguments.
"""
super().__init__(**source_data)
self.recording_extractor = self._extractor_instance
property_names = self.recording_extractor.get_property_keys()
# TODO remove this and go and change all the uses of channel_name once spikeinterface > 0.101.0 is released
if "channel_name" not in property_names and "channel_names" in property_names:
channel_names = self.recording_extractor.get_property("channel_names")
self.recording_extractor.set_property("channel_name", channel_names)
self.recording_extractor.delete_property("channel_names")
self.verbose = verbose
self.es_key = es_key
self._number_of_segments = self.recording_extractor.get_num_segments()
[docs]
def get_metadata_schema(self) -> dict:
"""
Compile metadata schema for the RecordingExtractor.
Returns
-------
dict
The metadata schema dictionary containing definitions for Device, ElectrodeGroup,
Electrodes, and optionally ElectricalSeries.
"""
metadata_schema = super().get_metadata_schema()
metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys")
metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"]
metadata_schema["properties"]["Ecephys"]["properties"] = dict(
Device=dict(type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/Device"}),
ElectrodeGroup=dict(
type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/ElectrodeGroup"}
),
Electrodes=dict(
type="array",
minItems=0,
renderForm=False,
items={"$ref": "#/properties/Ecephys/definitions/Electrodes"},
),
)
# Schema definition for arrays
metadata_schema["properties"]["Ecephys"]["definitions"] = dict(
Device=get_schema_from_hdmf_class(Device),
ElectrodeGroup=get_schema_from_hdmf_class(ElectrodeGroup),
Electrodes=dict(
type="object",
additionalProperties=False,
required=["name"],
properties=dict(
name=dict(type="string", description="name of this electrodes column"),
description=dict(type="string", description="description of this electrodes column"),
),
),
)
if self.es_key is not None:
metadata_schema["properties"]["Ecephys"]["properties"].update(
{self.es_key: get_schema_from_hdmf_class(ElectricalSeries)}
)
return metadata_schema
[docs]
def get_metadata(self) -> DeepDict:
metadata = super().get_metadata()
from ...tools.spikeinterface.spikeinterface import _get_group_name
channel_groups_array = _get_group_name(recording=self.recording_extractor)
unique_channel_groups = set(channel_groups_array) if channel_groups_array is not None else ["ElectrodeGroup"]
electrode_metadata = [
dict(name=str(group_id), description="no description", location="unknown", device="DeviceEcephys")
for group_id in unique_channel_groups
]
metadata["Ecephys"] = dict(
Device=[dict(name="DeviceEcephys", description="no description")],
ElectrodeGroup=electrode_metadata,
)
if self.es_key is not None:
metadata["Ecephys"][self.es_key] = dict(
name=self.es_key, description=f"Acquisition traces for the {self.es_key}."
)
return metadata
@property
def channel_ids(self):
"Gets the channel ids of the data."
return self.recording_extractor.get_channel_ids()
[docs]
def get_original_timestamps(self) -> np.ndarray | list[np.ndarray]:
"""
Retrieve the original unaltered timestamps for the data in this interface.
This function should retrieve the data on-demand by re-initializing the IO.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
new_recording = self._initialize_extractor(self.source_data)
if self._number_of_segments == 1:
return new_recording.get_times()
else:
return [
new_recording.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs]
def get_timestamps(self) -> np.ndarray | list[np.ndarray]:
"""
Retrieve the timestamps for the data in this interface.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
if self._number_of_segments == 1:
return self.recording_extractor.get_times()
else:
return [
self.recording_extractor.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs]
def set_aligned_timestamps(self, aligned_timestamps: np.ndarray):
assert (
self._number_of_segments == 1
), "This recording has multiple segments; please use 'align_segment_timestamps' instead."
self.recording_extractor.set_times(times=aligned_timestamps, with_warning=False)
[docs]
def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]):
"""
Replace all timestamps for all segments in this interface with those aligned to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_timestamps : list of numpy.ndarray
The synchronized timestamps for segment of data in this interface.
"""
assert isinstance(
aligned_segment_timestamps, list
), "Recording has multiple segment! Please pass a list of timestamps to align each segment."
assert (
len(aligned_segment_timestamps) == self._number_of_segments
), f"The number of timestamp vectors ({len(aligned_segment_timestamps)}) does not match the number of segments ({self._number_of_segments})!"
for segment_index in range(self._number_of_segments):
self.recording_extractor.set_times(
times=aligned_segment_timestamps[segment_index],
segment_index=segment_index,
with_warning=False,
)
[docs]
def set_aligned_starting_time(self, aligned_starting_time: float):
if self._number_of_segments == 1:
self.set_aligned_timestamps(aligned_timestamps=self.get_timestamps() + aligned_starting_time)
else:
self.set_aligned_segment_timestamps(
aligned_segment_timestamps=[
segment_timestamps + aligned_starting_time for segment_timestamps in self.get_timestamps()
]
)
[docs]
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]):
"""
Align the starting time for each segment in this interface relative to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_starting_times : list of floats
The starting time for each segment of data in this interface.
"""
assert len(aligned_segment_starting_times) == self._number_of_segments, (
f"The length of the starting_times ({len(aligned_segment_starting_times)}) does not match the "
"number of segments ({self._number_of_segments})!"
)
if self._number_of_segments == 1:
self.set_aligned_starting_time(aligned_starting_time=aligned_segment_starting_times[0])
else:
aligned_segment_timestamps = [
segment_timestamps + aligned_segment_starting_time
for segment_timestamps, aligned_segment_starting_time in zip(
self.get_timestamps(), aligned_segment_starting_times
)
]
self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)
[docs]
def set_probe(self, probe: "Probe | ProbeGroup", group_mode: Literal["by_shank", "by_probe"]):
"""
Set the probe information via a ProbeInterface object.
Parameters
----------
probe : probeinterface.Probe or probeinterface.ProbeGroup
The probe object(s). Can be a single Probe or a ProbeGroup containing multiple probes.
group_mode : {'by_shank', 'by_probe'}
How to group the channels for electrode group assignment in the NWB file:
- 'by_probe': Each probe becomes a separate electrode group. For a ProbeGroup with
multiple probes, each probe gets its own group (group 0, 1, 2, etc.). For a single
probe, all channels are assigned to group 0.
- 'by_shank': Each unique combination of probe and shank becomes a separate electrode
group. Requires that shank_ids are defined for all probes. Groups are assigned
sequentially for each unique (probe_index, shank_id) pair.
The resulting groups determine how electrode groups and electrodes are organized
in the NWB file, with each group corresponding to one ElectrodeGroup.
"""
# Set the probe to the recording extractor
self.recording_extractor._set_probes(
probe,
in_place=True,
group_mode=group_mode,
)
# Spike interface sets the "group" property
# But neuroconv allows "group_name" property to override spike interface "group" value
# So we re-set this here to avoid a conflict
self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))
[docs]
def has_probe(self) -> bool:
"""
Check if the recording extractor has probe information.
Returns
-------
bool
True if the recording extractor has probe information, False otherwise.
"""
return self.recording_extractor.has_probe()
[docs]
def align_by_interpolation(
self,
unaligned_timestamps: np.ndarray,
aligned_timestamps: np.ndarray,
):
if self._number_of_segments == 1:
self.set_aligned_timestamps(
aligned_timestamps=np.interp(x=self.get_timestamps(), xp=unaligned_timestamps, fp=aligned_timestamps)
)
else:
raise NotImplementedError("Multi-segment support for aligning by interpolation has not been added yet.")
[docs]
def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: dict | None = None,
*,
stub_test: bool = False,
write_as: Literal["raw", "lfp", "processed"] = "raw",
write_electrical_series: bool = True,
iterator_type: str | None = "v2",
iterator_options: dict | None = None,
iterator_opts: dict | None = None,
always_write_timestamps: bool = False,
):
"""
Primary function for converting raw (unprocessed) RecordingExtractor data to the NWB standard.
Parameters
----------
nwbfile : NWBFile
NWBFile to which the recording information is to be added
metadata : dict, optional
metadata info for constructing the NWB file.
Should be of the format::
metadata['Ecephys']['ElectricalSeries'] = dict(name=my_name, description=my_description)
stub_test : bool, default: False
If True, will truncate the data to run the conversion faster and take up less memory.
write_as : {'raw', 'processed', 'lfp'}, default='raw'
Specifies how to save the trace data in the NWB file. Options are:
- 'raw': Save the data in the acquisition group.
- 'processed': Save the data as FilteredEphys in a processing module.
- 'lfp': Save the data as LFP in a processing module.
write_electrical_series : bool, default: True
Electrical series are written in acquisition. If False, only device, electrode_groups,
and electrodes are written to NWB.
iterator_type : {'v2', None}, default: 'v2'
The type of iterator for chunked data writing.
'v2': Uses iterative write with control over chunking and progress bars.
None: Loads all data into memory before writing (not recommended for large datasets).
iterator_options : dict, optional
Options for controlling iterative write when iterator_type='v2'.
See the `pynwb tutorial on iterative write
<https://pynwb.readthedocs.io/en/stable/tutorials/advanced_io/plot_iterative_write.html#sphx-glr-tutorials-advanced-io-plot-iterative-write-py>`_
for more information on chunked data writing.
Available options:
* buffer_gb : float, default: 1.0
RAM to use for buffering data chunks in GB. Recommended to be as much free RAM as available.
* buffer_shape : tuple, optional
Manual specification of buffer shape. Must be a multiple of chunk_shape along each axis.
Cannot be set if buffer_gb is specified.
* display_progress : bool, default: False
Enable tqdm progress bar during data write.
* progress_bar_options : dict, optional
Additional options passed to tqdm progress bar.
See https://github.com/tqdm/tqdm#parameters for all tqdm options.
Note: To configure chunk size and compression, use the backend configuration system
via ``get_default_backend_configuration()`` and ``configure_backend()`` after calling
this method. See the backend configuration documentation for details.
iterator_opts : dict, optional
Deprecated. Use 'iterator_options' instead.
always_write_timestamps : bool, default: False
Set to True to always write timestamps.
By default (False), the function checks if the timestamps are uniformly sampled, and if so, stores the data
using a regular sampling rate instead of explicit timestamps. If set to True, timestamps will be written
explicitly, regardless of whether the sampling rate is uniform.
"""
from ...tools.spikeinterface import (
_stub_recording,
add_recording_metadata_to_nwbfile,
add_recording_to_nwbfile,
)
# Handle deprecated iterator_opts parameter
if iterator_opts is not None:
warnings.warn(
"The 'iterator_opts' parameter is deprecated and will be removed on or after March 2026. "
"Use 'iterator_options' instead.",
FutureWarning,
stacklevel=2,
)
if iterator_options is not None:
raise ValueError("Cannot specify both 'iterator_opts' and 'iterator_options'. Use 'iterator_options'.")
iterator_options = iterator_opts
recording = self.recording_extractor
if stub_test:
recording = _stub_recording(recording=recording)
metadata = metadata or self.get_metadata()
if write_electrical_series:
add_recording_to_nwbfile(
recording=recording,
nwbfile=nwbfile,
metadata=metadata,
write_as=write_as,
es_key=self.es_key,
iterator_type=iterator_type,
iterator_options=iterator_options,
always_write_timestamps=always_write_timestamps,
)
else:
add_recording_metadata_to_nwbfile(
recording=recording,
nwbfile=nwbfile,
metadata=metadata,
)