Skip to content

Datasets

Classes for loading and processing datasets.

Note

📝 Device: All datasets return samples with tensors on the CPU (even when the synthetic data generation is done on the GPU). This is due to restriction on the GPU usage in the multiprocessing settings, where GPU memory cannot be easily shared between processes.

📝 Dataloader: When using torch.utils.data.DataLoader ensure that you pass multiprocessing_context="spawn" argument to the dataloader object when using FetalSynthDataset to ensure that the spawned processes have access to the GPU.

FetalDataset

Abstract class defining a dataset for loading fetal data.

Source code in fetalsyngen/data/datasets.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class FetalDataset:
    """Abstract class defining a dataset for loading fetal data."""

    def __init__(
        self,
        bids_path: str,
        sub_list: list[str] | None,
    ) -> dict:
        """
        Args:
            bids_path: Path to the bids folder with the data.
            sub_list: List of the subjects to use. If None, all subjects are used.


        """
        super().__init__()

        self.bids_path = Path(bids_path)
        self.subjects = self.find_subjects(sub_list)
        if self.subjects is None:
            self.subjects = [x.name for x in self.bids_path.glob("sub-*")]
        self.sub_ses = [
            (x, y) for x in self.subjects for y in self._get_ses(self.bids_path, x)
        ]
        self.loader = SimpleITKReader()
        self.scaler = ScaleIntensity(minv=0, maxv=1)

        self.img_paths = self._load_bids_path(self.bids_path, "T2w")
        self.segm_paths = self._load_bids_path(self.bids_path, "dseg")

    def find_subjects(self, sub_list):
        subj_found = [x.name for x in Path(self.bids_path).glob("sub-*")]
        return list(set(subj_found) & set(sub_list)) if sub_list is not None else None

    def _sub_ses_string(self, sub, ses):
        return f"{sub}_{ses}" if ses is not None else sub

    def _sub_ses_idx(self, idx):
        sub, ses = self.sub_ses[idx]
        return self._sub_ses_string(sub, ses)

    def _get_ses(self, bids_path, sub):
        """Get the session names for the subject."""
        sub_path = bids_path / sub
        ses_dir = [x for x in sub_path.iterdir() if x.is_dir()]
        ses = []
        for s in ses_dir:
            if "anat" in s.name:
                ses.append(None)
            else:
                ses.append(s.name)

        return sorted(ses, key=lambda x: x or "")

    def _get_pattern(self, sub, ses, suffix, extension=".nii.gz"):
        """Get the pattern for the file name."""
        if ses is None:
            return f"{sub}/anat/{sub}*_{suffix}{extension}"
        else:
            return f"{sub}/{ses}/anat/{sub}_{ses}*_{suffix}{extension}"

    def _load_bids_path(self, path, suffix):
        """
        "Check that for a given path, all subjects have a file with the provided suffix
        """
        files_paths = []
        for sub, ses in self.sub_ses:
            pattern = self._get_pattern(sub, ses, suffix)
            files = list(path.glob(pattern))
            if len(files) == 0:
                raise FileNotFoundError(
                    f"No files found for requested subject {sub} in {path} "
                    f"({pattern} returned nothing)"
                )
            elif len(files) > 1:
                raise RuntimeError(
                    f"Multiple files found for requested subject {sub} in {path} "
                    f"({pattern} returned {files})"
                )
            files_paths.append(files[0])

        return files_paths

    def __len__(self):
        return len(self.subjects)

    def __getitem__(self, idx):
        raise NotImplementedError(
            "This method should be implemented in the child class."
        )

__init__(bids_path, sub_list)

Parameters:

Name Type Description Default
bids_path str

Path to the bids folder with the data.

required
sub_list list[str] | None

List of the subjects to use. If None, all subjects are used.

required
Source code in fetalsyngen/data/datasets.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    bids_path: str,
    sub_list: list[str] | None,
) -> dict:
    """
    Args:
        bids_path: Path to the bids folder with the data.
        sub_list: List of the subjects to use. If None, all subjects are used.


    """
    super().__init__()

    self.bids_path = Path(bids_path)
    self.subjects = self.find_subjects(sub_list)
    if self.subjects is None:
        self.subjects = [x.name for x in self.bids_path.glob("sub-*")]
    self.sub_ses = [
        (x, y) for x in self.subjects for y in self._get_ses(self.bids_path, x)
    ]
    self.loader = SimpleITKReader()
    self.scaler = ScaleIntensity(minv=0, maxv=1)

    self.img_paths = self._load_bids_path(self.bids_path, "T2w")
    self.segm_paths = self._load_bids_path(self.bids_path, "dseg")

FetalTestDataset

Bases: FetalDataset

Dataset class for loading fetal images offline. Used to load test/validation data.

Use the transforms argument to pass additional processing steps (scaling, resampling, cropping, etc.).

Source code in fetalsyngen/data/datasets.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class FetalTestDataset(FetalDataset):
    """Dataset class for loading fetal images offline.
    Used to load test/validation data.

    Use the `transforms` argument to pass additional processing steps
    (scaling, resampling, cropping, etc.).
    """

    def __init__(
        self,
        bids_path: str,
        sub_list: list[str] | None,
        transforms: Compose | None = None,
    ):
        """
        Args:
            bids_path: Path to the bids folder with the data.
            sub_list: List of the subjects to use. If None, all subjects are used.
            transforms: Compose object with the transformations to apply.
                Default is None, no transformations are applied.

        !!! Note
            We highle recommend using the `transforms` arguments with at
            least the re-oriented transform to RAS and the intensity scaling
            to `[0, 1]` to ensure the data consistency.

            See [inference.yaml](https://github.com/Medical-Image-Analysis-Laboratory/fetalsyngen/blob/dev/configs/dataset/transforms/inference.yaml) for an example of the transforms configuration.
        """
        super().__init__(bids_path, sub_list)
        self.transforms = transforms

    def _load_data(self, idx):
        # load the image and segmentation
        image = self.loader(self.img_paths[idx], interp="linear")
        segm = self.loader(self.segm_paths[idx], interp="nearest")

        if len(image.shape) == 3:
            # add channel dimension
            image = image.unsqueeze(0)
            segm = segm.unsqueeze(0)
        elif len(image.shape) != 4:
            raise ValueError(f"Expected 3D or 4D image, got {len(image.shape)}D image.")

        # transform name into a single string otherwise collate fails
        name = self.sub_ses[idx]
        name = self._sub_ses_string(name[0], ses=name[1])

        return {"image": image, "label": segm.long(), "name": name}

    def __getitem__(self, idx) -> dict:
        """
        Returns:
            Dictionary with the `image` , `label` and the `name`
                keys. `image` and `label` are  `torch.float32`
                [`monai.data.meta_tensor.MetaTensor`](https://docs.monai.io/en/stable/data.html#metatensor)
                instances  with dimensions `(1, H, W, D)` and `name` is a string
                of a format `sub_ses` where `sub` is the subject name
                and `ses` is the session name.


        """
        data = self._load_data(idx)

        if self.transforms:
            data = self.transforms(data)
        data["label"] = data["label"].long()
        return data

    def reverse_transform(self, data: dict) -> dict:
        """Reverse the transformations applied to the data.

        Args:
            data: Dictionary with the `image` and `label` keys,
                like the one returned by the `__getitem__` method.

        Returns:
            Dictionary with the `image` and `label` keys where
                the transformations are reversed.
        """
        if self.transforms:
            data = self.transforms.inverse(data)
        return data

__init__(bids_path, sub_list, transforms=None)

Parameters:

Name Type Description Default
bids_path str

Path to the bids folder with the data.

required
sub_list list[str] | None

List of the subjects to use. If None, all subjects are used.

required
transforms Compose | None

Compose object with the transformations to apply. Default is None, no transformations are applied.

None

Note

We highle recommend using the transforms arguments with at least the re-oriented transform to RAS and the intensity scaling to [0, 1] to ensure the data consistency.

See inference.yaml for an example of the transforms configuration.

Source code in fetalsyngen/data/datasets.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    bids_path: str,
    sub_list: list[str] | None,
    transforms: Compose | None = None,
):
    """
    Args:
        bids_path: Path to the bids folder with the data.
        sub_list: List of the subjects to use. If None, all subjects are used.
        transforms: Compose object with the transformations to apply.
            Default is None, no transformations are applied.

    !!! Note
        We highle recommend using the `transforms` arguments with at
        least the re-oriented transform to RAS and the intensity scaling
        to `[0, 1]` to ensure the data consistency.

        See [inference.yaml](https://github.com/Medical-Image-Analysis-Laboratory/fetalsyngen/blob/dev/configs/dataset/transforms/inference.yaml) for an example of the transforms configuration.
    """
    super().__init__(bids_path, sub_list)
    self.transforms = transforms

__getitem__(idx)

Returns:

Type Description
dict

Dictionary with the image , label and the name keys. image and label are torch.float32 monai.data.meta_tensor.MetaTensor instances with dimensions (1, H, W, D) and name is a string of a format sub_ses where sub is the subject name and ses is the session name.

Source code in fetalsyngen/data/datasets.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def __getitem__(self, idx) -> dict:
    """
    Returns:
        Dictionary with the `image` , `label` and the `name`
            keys. `image` and `label` are  `torch.float32`
            [`monai.data.meta_tensor.MetaTensor`](https://docs.monai.io/en/stable/data.html#metatensor)
            instances  with dimensions `(1, H, W, D)` and `name` is a string
            of a format `sub_ses` where `sub` is the subject name
            and `ses` is the session name.


    """
    data = self._load_data(idx)

    if self.transforms:
        data = self.transforms(data)
    data["label"] = data["label"].long()
    return data

reverse_transform(data)

Reverse the transformations applied to the data.

Parameters:

Name Type Description Default
data dict

Dictionary with the image and label keys, like the one returned by the __getitem__ method.

required

Returns:

Type Description
dict

Dictionary with the image and label keys where the transformations are reversed.

Source code in fetalsyngen/data/datasets.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def reverse_transform(self, data: dict) -> dict:
    """Reverse the transformations applied to the data.

    Args:
        data: Dictionary with the `image` and `label` keys,
            like the one returned by the `__getitem__` method.

    Returns:
        Dictionary with the `image` and `label` keys where
            the transformations are reversed.
    """
    if self.transforms:
        data = self.transforms.inverse(data)
    return data

FetalSynthDataset

Bases: FetalDataset

Dataset class for generating/augmenting on-the-fly fetal images"

Source code in fetalsyngen/data/datasets.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class FetalSynthDataset(FetalDataset):
    """Dataset class for generating/augmenting on-the-fly fetal images" """

    def __init__(
        self,
        bids_path: str,
        generator: FetalSynthGen,
        seed_path: str | None,
        sub_list: list[str] | None,
        load_image: bool = False,
        image_as_intensity: bool = False,
        transforms: Compose | None = None,
    ):
        """

        Args:
            bids_path: Path to the bids-formatted folder with the data.
            seed_path: Path to the folder with the seeds to use for
                intensity sampling. See `scripts/seed_generation.py`
                for details on the data formatting. If seed_path is None,
                the intensity  sampling step is skipped and the output image
                intensities will be based on the input image.
            generator: a class object defining a generator to use.
            sub_list: List of the subjects to use. If None, all subjects are used.
            load_image: If **True**, the image is loaded and passed to the generator,
                where it can be used as the intensity prior instead of a random
                intensity sampling or spatially deformed with the same transformation
                field as segmentation and the syntehtic image. Default is **False**.
            image_as_intensity: If **True**, the image is used as the intensity prior,
                instead of sampling the intensities from the seeds. Default is **False**.
        """
        super().__init__(bids_path, sub_list)
        self.seed_path = Path(seed_path) if isinstance(seed_path, str) else None
        self.load_image = load_image
        self.generator = generator
        self.image_as_intensity = image_as_intensity
        self.transforms = transforms
        # parse seeds paths
        if not self.image_as_intensity and isinstance(self.seed_path, Path):
            if not self.seed_path.exists():
                raise FileNotFoundError(
                    f"Provided seed path {self.seed_path} does not exist."
                )
            else:
                self._load_seed_path()

    def _load_seed_path(self):
        """Load the seeds for the subjects."""
        self.seed_paths = {
            self._sub_ses_string(sub, ses): defaultdict(dict)
            for (sub, ses) in self.sub_ses
        }
        avail_seeds = [
            int(x.name.replace("subclasses_", ""))
            for x in self.seed_path.glob("subclasses_*")
        ]
        min_seeds_available = min(avail_seeds)
        max_seeds_available = max(avail_seeds)
        for n_sub in range(
            min_seeds_available,
            max_seeds_available + 1,
        ):
            seed_path = self.seed_path / f"subclasses_{n_sub}"
            if not seed_path.exists():
                raise FileNotFoundError(
                    f"Provided seed path {seed_path} does not exist."
                )
            # load the seeds for the subjects for each meta label 1-4
            for i in self.generator.intensity_generator.meta_labels:
                files = self._load_bids_path(seed_path, f"mlabel_{i}")
                for (sub, ses), file in zip(self.sub_ses, files):
                    sub_ses_str = self._sub_ses_string(sub, ses)
                    self.seed_paths[sub_ses_str][n_sub][i] = file

    def sample(self, idx, genparams: dict = {}) -> tuple[dict, dict]:
        """
        Retrieve a single item from the dataset at the specified index.

        Args:
            idx (int): The index of the item to retrieve.
            genparams (dict): Dictionary with generation parameters.
                Used for fixed generation. Should follow exactly the same structure
                and be of the same type as the returned generation parameters.
                Can be used to replicate the augmentations (power)
                used for the generation of a specific sample.
        Returns:
            Dictionaries with the generated data and the generation parameters.
                First dictionary contains the `image`, `label` and the `name` keys.
                The second dictionary contains the parameters used for the generation.

        !!! Note
            The `image` is scaled to `[0, 1]` and oriented with the `label` to **RAS**
            and returned on the device  specified in the `generator` initialization.
        """
        # use generation_params to track the parameters used for the generation
        generation_params = {}

        image = (
            self.loader(
                self.img_paths[idx], interp="linear", spatial_size=192, resolution=1.0
            )
            if self.load_image
            else None
        )
        segm = self.loader(
            self.segm_paths[idx], interp="nearest", spatial_size=192, resolution=1.0
        )

        # RANDOM re-orient TODO: MAKE SWTICHABLE!
        # LR = "L" if torch.rand(1) > 0.5 else "R"
        # AP = "A" if torch.rand(1) > 0.5 else "P"
        # IS = "I" if torch.rand(1) > 0.5 else "S"
        # orient = np.array([LR, AP, IS])
        # rand_orient_oreder = torch.randperm(3)
        axcodes = "RAS"  # = orient[rand_orient_oreder]
        orientation = Orientation(axcodes=axcodes)
        image = orientation(image.unsqueeze(0)).squeeze(0) if self.load_image else None
        segm = orientation(segm.unsqueeze(0)).squeeze(0)

        # transform name into a single string otherwise collate fails
        name = self.sub_ses[idx]
        name = self._sub_ses_string(name[0], ses=name[1])

        # initialize seeds as dictionary
        # with paths to the seeds volumes
        # or None if image is to be used as intensity prior
        if self.seed_path is not None:
            seeds = self.seed_paths[name]
        if self.image_as_intensity:
            seeds = None

        # log input data
        generation_params["idx"] = idx
        generation_params["img_paths"] = str(self.img_paths[idx])
        generation_params["segm_paths"] = str(self.img_paths[idx])
        generation_params["seeds"] = str(self.seed_path)
        generation_time_start = time.time()

        # generate the synthetic data
        gen_output, segmentation, image, synth_params = self.generator.sample(
            image=image,
            segmentation=segm,
            seeds=seeds,
            genparams=genparams,
            orientation=orientation,
        )

        # scale the images to [0, 1]
        gen_output = self.scaler(gen_output)
        image = self.scaler(image) if image is not None else None

        # ensure image and segmentation are on the cpu
        gen_output = gen_output.cpu()
        segmentation = segmentation.cpu()
        image = image.cpu() if image is not None else None

        generation_params = {**generation_params, **synth_params}
        generation_params["generation_time"] = time.time() - generation_time_start
        data_out = {
            "image": gen_output.unsqueeze(0),
            "label": segmentation.unsqueeze(0).long(),
        }

        if self.transforms:
            data_out = self.transforms(data_out)
        data_out["name"] = name

        return data_out, generation_params

    def __getitem__(self, idx) -> dict:
        """
        Retrieve a single item from the dataset at the specified index.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            Dictionary with the `image`, `label` and the `name` keys.
                `image` and `label` are `torch.float32`
                [`monai.data.meta_tensor.MetaTensor`](https://docs.monai.io/en/stable/data.html#metatensor)
                and `name` is a string of a format `sub_ses` where `sub` is the subject name
                and `ses` is the session name.

        !!!Note
            The `image` is scaled to `[0, 1]` and oriented to **RAS** and returned on the device
            specified in the `generator` initialization.
        """
        data_out, generation_params = self.sample(idx)
        self.generation_params = generation_params
        return data_out

    def sample_with_meta(self, idx: int, genparams: dict = {}) -> dict:
        """
        Retrieve a sample along with its generation parameters
        and store them in the same dictionary.

        Args:
            idx: The index of the sample to retrieve.
            genparams: Dictionary with generation parameters.
                Used for fixed generation. Should follow exactly the same structure
                and be of the same type as the returned generation parameters from the `sample()` method.
                Can be used to replicate the augmentations (power)
                used for the generation of a specific sample.

        Returns:
            A dictionary with `image`, `label`, `name` and `generation_params` keys.
        """

        data, generation_params = self.sample(idx, genparams=genparams)
        data["generation_params"] = generation_params
        return data

__init__(bids_path, generator, seed_path, sub_list, load_image=False, image_as_intensity=False, transforms=None)

Parameters:

Name Type Description Default
bids_path str

Path to the bids-formatted folder with the data.

required
seed_path str | None

Path to the folder with the seeds to use for intensity sampling. See scripts/seed_generation.py for details on the data formatting. If seed_path is None, the intensity sampling step is skipped and the output image intensities will be based on the input image.

required
generator FetalSynthGen

a class object defining a generator to use.

required
sub_list list[str] | None

List of the subjects to use. If None, all subjects are used.

required
load_image bool

If True, the image is loaded and passed to the generator, where it can be used as the intensity prior instead of a random intensity sampling or spatially deformed with the same transformation field as segmentation and the syntehtic image. Default is False.

False
image_as_intensity bool

If True, the image is used as the intensity prior, instead of sampling the intensities from the seeds. Default is False.

False
Source code in fetalsyngen/data/datasets.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(
    self,
    bids_path: str,
    generator: FetalSynthGen,
    seed_path: str | None,
    sub_list: list[str] | None,
    load_image: bool = False,
    image_as_intensity: bool = False,
    transforms: Compose | None = None,
):
    """

    Args:
        bids_path: Path to the bids-formatted folder with the data.
        seed_path: Path to the folder with the seeds to use for
            intensity sampling. See `scripts/seed_generation.py`
            for details on the data formatting. If seed_path is None,
            the intensity  sampling step is skipped and the output image
            intensities will be based on the input image.
        generator: a class object defining a generator to use.
        sub_list: List of the subjects to use. If None, all subjects are used.
        load_image: If **True**, the image is loaded and passed to the generator,
            where it can be used as the intensity prior instead of a random
            intensity sampling or spatially deformed with the same transformation
            field as segmentation and the syntehtic image. Default is **False**.
        image_as_intensity: If **True**, the image is used as the intensity prior,
            instead of sampling the intensities from the seeds. Default is **False**.
    """
    super().__init__(bids_path, sub_list)
    self.seed_path = Path(seed_path) if isinstance(seed_path, str) else None
    self.load_image = load_image
    self.generator = generator
    self.image_as_intensity = image_as_intensity
    self.transforms = transforms
    # parse seeds paths
    if not self.image_as_intensity and isinstance(self.seed_path, Path):
        if not self.seed_path.exists():
            raise FileNotFoundError(
                f"Provided seed path {self.seed_path} does not exist."
            )
        else:
            self._load_seed_path()

sample(idx, genparams={})

Retrieve a single item from the dataset at the specified index.

Parameters:

Name Type Description Default
idx int

The index of the item to retrieve.

required
genparams dict

Dictionary with generation parameters. Used for fixed generation. Should follow exactly the same structure and be of the same type as the returned generation parameters. Can be used to replicate the augmentations (power) used for the generation of a specific sample.

{}

Returns: Dictionaries with the generated data and the generation parameters. First dictionary contains the image, label and the name keys. The second dictionary contains the parameters used for the generation.

Note

The image is scaled to [0, 1] and oriented with the label to RAS and returned on the device specified in the generator initialization.

Source code in fetalsyngen/data/datasets.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def sample(self, idx, genparams: dict = {}) -> tuple[dict, dict]:
    """
    Retrieve a single item from the dataset at the specified index.

    Args:
        idx (int): The index of the item to retrieve.
        genparams (dict): Dictionary with generation parameters.
            Used for fixed generation. Should follow exactly the same structure
            and be of the same type as the returned generation parameters.
            Can be used to replicate the augmentations (power)
            used for the generation of a specific sample.
    Returns:
        Dictionaries with the generated data and the generation parameters.
            First dictionary contains the `image`, `label` and the `name` keys.
            The second dictionary contains the parameters used for the generation.

    !!! Note
        The `image` is scaled to `[0, 1]` and oriented with the `label` to **RAS**
        and returned on the device  specified in the `generator` initialization.
    """
    # use generation_params to track the parameters used for the generation
    generation_params = {}

    image = (
        self.loader(
            self.img_paths[idx], interp="linear", spatial_size=192, resolution=1.0
        )
        if self.load_image
        else None
    )
    segm = self.loader(
        self.segm_paths[idx], interp="nearest", spatial_size=192, resolution=1.0
    )

    # RANDOM re-orient TODO: MAKE SWTICHABLE!
    # LR = "L" if torch.rand(1) > 0.5 else "R"
    # AP = "A" if torch.rand(1) > 0.5 else "P"
    # IS = "I" if torch.rand(1) > 0.5 else "S"
    # orient = np.array([LR, AP, IS])
    # rand_orient_oreder = torch.randperm(3)
    axcodes = "RAS"  # = orient[rand_orient_oreder]
    orientation = Orientation(axcodes=axcodes)
    image = orientation(image.unsqueeze(0)).squeeze(0) if self.load_image else None
    segm = orientation(segm.unsqueeze(0)).squeeze(0)

    # transform name into a single string otherwise collate fails
    name = self.sub_ses[idx]
    name = self._sub_ses_string(name[0], ses=name[1])

    # initialize seeds as dictionary
    # with paths to the seeds volumes
    # or None if image is to be used as intensity prior
    if self.seed_path is not None:
        seeds = self.seed_paths[name]
    if self.image_as_intensity:
        seeds = None

    # log input data
    generation_params["idx"] = idx
    generation_params["img_paths"] = str(self.img_paths[idx])
    generation_params["segm_paths"] = str(self.img_paths[idx])
    generation_params["seeds"] = str(self.seed_path)
    generation_time_start = time.time()

    # generate the synthetic data
    gen_output, segmentation, image, synth_params = self.generator.sample(
        image=image,
        segmentation=segm,
        seeds=seeds,
        genparams=genparams,
        orientation=orientation,
    )

    # scale the images to [0, 1]
    gen_output = self.scaler(gen_output)
    image = self.scaler(image) if image is not None else None

    # ensure image and segmentation are on the cpu
    gen_output = gen_output.cpu()
    segmentation = segmentation.cpu()
    image = image.cpu() if image is not None else None

    generation_params = {**generation_params, **synth_params}
    generation_params["generation_time"] = time.time() - generation_time_start
    data_out = {
        "image": gen_output.unsqueeze(0),
        "label": segmentation.unsqueeze(0).long(),
    }

    if self.transforms:
        data_out = self.transforms(data_out)
    data_out["name"] = name

    return data_out, generation_params

__getitem__(idx)

Retrieve a single item from the dataset at the specified index.

Parameters:

Name Type Description Default
idx int

The index of the item to retrieve.

required

Returns:

Type Description
dict

Dictionary with the image, label and the name keys. image and label are torch.float32 monai.data.meta_tensor.MetaTensor and name is a string of a format sub_ses where sub is the subject name and ses is the session name.

Note

The image is scaled to [0, 1] and oriented to RAS and returned on the device specified in the generator initialization.

Source code in fetalsyngen/data/datasets.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def __getitem__(self, idx) -> dict:
    """
    Retrieve a single item from the dataset at the specified index.

    Args:
        idx (int): The index of the item to retrieve.

    Returns:
        Dictionary with the `image`, `label` and the `name` keys.
            `image` and `label` are `torch.float32`
            [`monai.data.meta_tensor.MetaTensor`](https://docs.monai.io/en/stable/data.html#metatensor)
            and `name` is a string of a format `sub_ses` where `sub` is the subject name
            and `ses` is the session name.

    !!!Note
        The `image` is scaled to `[0, 1]` and oriented to **RAS** and returned on the device
        specified in the `generator` initialization.
    """
    data_out, generation_params = self.sample(idx)
    self.generation_params = generation_params
    return data_out

sample_with_meta(idx, genparams={})

Retrieve a sample along with its generation parameters and store them in the same dictionary.

Parameters:

Name Type Description Default
idx int

The index of the sample to retrieve.

required
genparams dict

Dictionary with generation parameters. Used for fixed generation. Should follow exactly the same structure and be of the same type as the returned generation parameters from the sample() method. Can be used to replicate the augmentations (power) used for the generation of a specific sample.

{}

Returns:

Type Description
dict

A dictionary with image, label, name and generation_params keys.

Source code in fetalsyngen/data/datasets.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def sample_with_meta(self, idx: int, genparams: dict = {}) -> dict:
    """
    Retrieve a sample along with its generation parameters
    and store them in the same dictionary.

    Args:
        idx: The index of the sample to retrieve.
        genparams: Dictionary with generation parameters.
            Used for fixed generation. Should follow exactly the same structure
            and be of the same type as the returned generation parameters from the `sample()` method.
            Can be used to replicate the augmentations (power)
            used for the generation of a specific sample.

    Returns:
        A dictionary with `image`, `label`, `name` and `generation_params` keys.
    """

    data, generation_params = self.sample(idx, genparams=genparams)
    data["generation_params"] = generation_params
    return data

Fixed Image Generation

It is possible to generate synthetic images of the same 'augmentation' power as any given synthetic image. This is done by passing the genparams dictionary to the sample_with_meta (or sample) method of the FetalSynthDataset class. The generation_params dictionary is a dictionary of the parameters used to generate the image. The method will then use these parameters to generate a new image with the same augmentation power as the original image.

This genparams dictionary can be obtained, for example, from the dictionary returned by the FetalSynthDataset.sample_with_meta method. It then can be directly used to fix (some or all) generation parameters for the new image.

See example below:

# initialize the dataset class
# see the Examples page for more details
dataset = FetalSynthDataset(...)

# first sample a synthetic image from the dataset
sample = dataset.sample_with_meta(0)
# then we sample a synthetic image with the same augmentation power as the first image
sample_copy = dataset.sample_with_meta(0, genparams=sample["generation_params"])

For example, generation parameters of the first image can be like this:

{'idx': 0,
 'img_paths': PosixPath('../data/sub-sta38/anat/sub-sta38_rec-irtk_T2w.nii.gz'),
 'segm_paths': PosixPath('../data/sub-sta38/anat/sub-sta38_rec-irtk_T2w.nii.gz'),
 'seeds': defaultdict(dict,
             {1: {1: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'),
               2: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'),
               3: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'),
               4: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')},
              2: {1: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'),
               2: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'),
               3: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'),
               4: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')},
              3: {1: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'),
               2: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'),
               3: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'),
               4: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')}}),
 'selected_seeds': {'mlabel2subclusters': {1: 2, 2: 1, 3: 3, 4: 1}},
 'seed_intensities': {'mus': tensor([109.6722, 220.9658, 100.9801,  38.6364, 125.5148, 108.1950, 216.1060,
          190.5462,  55.3930,  59.2667,  72.0628,  68.8775,  76.5113,  84.6639,
           90.0124,  94.1701,  67.0610,  25.9465,  31.5438,  21.0375, 192.4223,
          173.7434, 139.9284, 121.3904, 145.4289, 158.1318, 157.4630, 150.0894,
          183.9047, 181.7129, 114.8939,   9.5253,  29.0257,  97.9543, 122.0798,
           72.2969,  26.3086,  81.8050,  67.7463,  72.3737, 129.8539, 113.3900,
          141.8177, 225.0000,  35.3458, 173.7635,  29.5101, 135.9482, 188.2391,
          225.0000], device='cuda:0'),
  'sigmas': tensor([ 9.2432, 23.1060, 16.4965,  6.4289, 24.7862, 23.7996, 15.2424, 20.2845,
          12.6833,  6.9079,  6.1214, 22.1317,  9.7907,  5.5302, 14.3288, 11.1454,
          16.0453, 20.9057, 24.2358, 13.4785, 22.7258, 11.2053, 12.9420, 13.4270,
          14.8660, 22.4874,  5.6251,  9.8794,  8.8749, 19.0294,  9.7164,  6.2293,
          13.6376, 11.7447, 14.1414,  6.4362, 20.4575, 14.6729,  8.4719, 14.2926,
           6.9458, 11.5346, 14.6113,  6.6516, 22.1767,  8.3793, 20.1699,  6.3299,
           5.3340, 21.8027], device='cuda:0')},
 'deform_params': {'affine': {'rotations': array([ 0.0008224 ,  0.03067143, -0.0151502 ]),
   'shears': array([-0.01735838,  0.00744726,  0.00012507]),
   'scalings': array([1.09345725, 0.91695532, 0.98194215])},
  'non_rigid': {'nonlin_scale': array([0.05686841]),
   'nonlin_std': 1.048839010036788,
   'size_F_small': [15, 15, 15]},
  'flip': False},
 'gamma_params': {'gamma': 0.960299468352801},
 'bf_params': {'bf_scale': None, 'bf_std': None, 'bf_size': None},
 'resample_params': {'spacing': array([0.65685245, 0.65685245, 0.65685245])},
 'noise_params': {'noise_std': None},
 'generation_time': 0.5615839958190918}


If the key:value pair exists in the passed genparams dictionary, the sample method will use directly the value from the genparams dictionary. If the key:value pair does not exist in the genparams dictionary or it is None, sample method will generate the value randomly, using the corresponding class attributes.

See how the keys bf_scale, bf_std, bf_size and noise_std have not been defined in the genparams dictionary above. This means that the sample method will generate these values randomly. The same could have been achieved by not passing them at all.

{'idx': 0, 'img_paths': PosixPath('../data/sub-sta38/anat/sub-sta38_rec-irtk_T2w.nii.gz'), 'segm_paths': PosixPath('../data/sub-sta38/anat/sub-sta38_rec-irtk_T2w.nii.gz'), 'seeds': defaultdict(dict, {1: {1: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'), 2: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'), 3: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'), 4: PosixPath('../data/derivatives/seeds/subclasses_1/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')}, 2: {1: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'), 2: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'), 3: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'), 4: PosixPath('../data/derivatives/seeds/subclasses_2/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')}, 3: {1: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_1.nii.gz'), 2: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_2.nii.gz'), 3: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_3.nii.gz'), 4: PosixPath('../data/derivatives/seeds/subclasses_3/sub-sta38/anat/sub-sta38_rec-irtk_T2w_dseg_mlabel_4.nii.gz')}}), 'selected_seeds': {'mlabel2subclusters': {1: 2, 2: 1, 3: 3, 4: 1}}, 'seed_intensities': {'mus': tensor([109.6722, 220.9658, 100.9801, 38.6364, 125.5148, 108.1950, 216.1060, 190.5462, 55.3930, 59.2667, 72.0628, 68.8775, 76.5113, 84.6639, 90.0124, 94.1701, 67.0610, 25.9465, 31.5438, 21.0375, 192.4223, 173.7434, 139.9284, 121.3904, 145.4289, 158.1318, 157.4630, 150.0894, 183.9047, 181.7129, 114.8939, 9.5253, 29.0257, 97.9543, 122.0798, 72.2969, 26.3086, 81.8050, 67.7463, 72.3737, 129.8539, 113.3900, 141.8177, 225.0000, 35.3458, 173.7635, 29.5101, 135.9482, 188.2391, 225.0000], device='cuda:0'), 'sigmas': tensor([ 9.2432, 23.1060, 16.4965, 6.4289, 24.7862, 23.7996, 15.2424, 20.2845, 12.6833, 6.9079, 6.1214, 22.1317, 9.7907, 5.5302, 14.3288, 11.1454, 16.0453, 20.9057, 24.2358, 13.4785, 22.7258, 11.2053, 12.9420, 13.4270, 14.8660, 22.4874, 5.6251, 9.8794, 8.8749, 19.0294, 9.7164, 6.2293, 13.6376, 11.7447, 14.1414, 6.4362, 20.4575, 14.6729, 8.4719, 14.2926, 6.9458, 11.5346, 14.6113, 6.6516, 22.1767, 8.3793, 20.1699, 6.3299, 5.3340, 21.8027], device='cuda:0')}, 'deform_params': {'affine': {'rotations': array([ 0.0008224 , 0.03067143, -0.0151502 ]), 'shears': array([-0.01735838, 0.00744726, 0.00012507]), 'scalings': array([1.09345725, 0.91695532, 0.98194215])}, 'non_rigid': {'nonlin_scale': array([0.05686841]), 'nonlin_std': 1.048839010036788, 'size_F_small': [15, 15, 15]}, 'flip': False}, 'gamma_params': {'gamma': 0.960299468352801}, 'bf_params': {'bf_scale': array([0.00797334]), 'bf_std': array([0.21896995]), 'bf_size': [2, 2, 2]}, 'resample_params': {'spacing': array([0.65685245, 0.65685245, 0.65685245])}, 'noise_params': {'noise_std': None}, 'generation_time': 0.6192283630371094} ```


Note

  • If a specific parameter is passed in genparams it means that the probability of its application is 100%. The internal prob is not used as the parameter is fixed.

  • If using custom values for the parameters, ensure that the values are within the range of the parameters defined in the class attributes (especially for the spatial deformation parameters, as the grid is pre-defined at class initialization). Furthermore, ensure that the device location and parameter type is consistent with the one in the returned generation_parameters dictionary.