Skip to content

Generation

API

FetalSynthGen

Source code in fetalsyngen/generator/model.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class FetalSynthGen:

    def __init__(
        self,
        shape: Iterable[int],
        resolution: Iterable[float],
        device: str,
        intensity_generator: ImageFromSeeds,
        spatial_deform: SpatialDeformation,
        resampler: RandResample,
        bias_field: RandBiasField,
        noise: RandNoise,
        gamma: RandGamma,
        # optional SR artifacts
        blur_cortex: BlurCortex | None = None,
        struct_noise: StructNoise | None = None,
        simulate_motion: SimulateMotion | None = None,
        boundaries: SimulatedBoundaries | None = None,
        # optional
        stack_sampler: StackSampler | None = None,
    ):
        """
        Initialize the model with the given parameters.

        !!!Note
            Augmentations related to SR artifacts are optional and can be set to None
            if not needed.

        Args:
            shape: Shape of the output image.
            resolution: Resolution of the output image.
            device: Device to use for computation.
            intensity_generator: Intensity generator.
            spatial_deform: Spatial deformation generator.
            resampler: Resampler.
            bias_field: Bias field generator.
            noise: Noise generator.
            gamma: Gamma correction generator.
            blur_cortex: Cortex blurring generator.
            struct_noise: Structural noise generator.
            simulate_motion: Motion simulation generator.
            boundaries: Boundaries generator
            stack_sampler: Stack sampler.

        """
        self.shape = shape
        self.resolution = resolution
        self.intensity_generator = intensity_generator
        self.spatial_deform = spatial_deform
        self.resampled = resampler
        self.biasfield = bias_field
        self.gamma = gamma
        self.noise = noise

        self.artifacts = {
            "blur_cortex": blur_cortex,
            "struct_noise": struct_noise,
            "simulate_motion": simulate_motion,
            "boundaries": boundaries,
        }
        self.stack_sampler = stack_sampler
        self.device = device

    def _validated_genparams(self, d: dict) -> dict:
        """Recursively removes all the keys with None values as they are not fixed in the generation."""
        if not isinstance(d, dict):
            return d  # Return non-dictionaries as-is

        return {
            key: self._validated_genparams(value)
            for key, value in d.items()
            if value is not None
        }

    def sample(
        self,
        orientation,
        image: torch.Tensor | None,
        segmentation: torch.Tensor,
        seeds: torch.Tensor | None,
        genparams: dict = {},
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        """
        Generate a synthetic image from the input data.
        Supports both random generation and from a fixed genparams dictionary.

        Args:
            image: Image to use as intensity prior if required.
            segmentation: Segmentation to use as spatial prior.
            seeds: Seeds to use for intensity generation.
            genparams: Dictionary with generation parameters.
                Used for fixed generation.
                Should follow the structure and be of the same type as
                the returned generation parameters.

        Returns:
            The synthetic image, the segmentation, the original image, and the generation parameters.

        """
        if genparams:
            genparams = self._validated_genparams(genparams)

        # 1. Generate intensity output.
        if seeds is not None:
            seeds, selected_seeds = self.intensity_generator.load_seeds(
                seeds=seeds,
                genparams=genparams.get("selected_seeds", {}),
                orientation=orientation,
            )
            output, seed_intensities = self.intensity_generator.sample_intensities(
                seeds=seeds,
                device=self.device,
                genparams=genparams.get("seed_intensities", {}),
            )
        else:
            if image is None:
                raise ValueError(
                    "If no seeds are passed, an image must be loaded to be used as intensity prior!"
                )
            # normalize the image from 0 to 255 to
            # match the intensity generator
            output = (image - image.min()) / (image.max() - image.min()) * 255
            selected_seeds = {}
            seed_intensities = {}

        # ensure that tensors are on the same device
        output = output.to(self.device)
        segmentation = segmentation.to(self.device)
        image = image.to(self.device) if image is not None else None

        # 2. Spatially deform the data
        image, segmentation, output, deform_params = self.spatial_deform.deform(
            image=image,
            segmentation=segmentation,
            output=output,
            genparams=genparams.get("deform_params", {}),
        )

        # 3. Gamma contrast transformation
        output, gamma_params = self.gamma(
            output, self.device, genparams=genparams.get("gamma_params", {})
        )

        # 4. Bias field corruption
        output, bf_params = self.biasfield(
            output, self.device, genparams=genparams.get("bf_params", {})
        )

        # 5. Downsample to simulate lower reconstruction resolution
        output, factors, resample_params = self.resampled(
            output,
            np.array(self.resolution),
            self.device,
            genparams=genparams.get("resample_params", {}),
        )

        # 6. Noise corruption
        output, noise_params = self.noise(
            output, self.device, genparams=genparams.get("noise_params", {})
        )

        # 7. Up-sample back to the original resolution/shape
        output = self.resampled.resize_back(output, factors)

        # 8. Induce SR-artifacts
        artifacts = {}
        for name, artifact in self.artifacts.items():
            if artifact is not None:
                output, metadata = artifact(
                    output,
                    segmentation,
                    self.device,
                    genparams.get("artifact_params", {}),
                    resolution=self.resolution,
                )
                artifacts[name] = metadata

        # 9. Apply stack sampler if available
        if self.stack_sampler is not None:
            output, segmentation, meta = self.stack_sampler(
                output, segmentation, device=self.device
            )

            # unsqueeze the image to match the expected shape
            output = output.squeeze(0)
            segmentation = segmentation.squeeze(0)

        # 10. Aggregete the synth params
        synth_params = {
            "selected_seeds": selected_seeds,
            "seed_intensities": seed_intensities,
            "deform_params": deform_params,
            "gamma_params": gamma_params,
            "bf_params": bf_params,
            "resample_params": resample_params,
            "noise_params": noise_params,
            "artifacts": artifacts,
            "stack_sampler": meta if self.stack_sampler is not None else None,
        }

        return output, segmentation, image, synth_params

__init__(shape, resolution, device, intensity_generator, spatial_deform, resampler, bias_field, noise, gamma, blur_cortex=None, struct_noise=None, simulate_motion=None, boundaries=None, stack_sampler=None)

Initialize the model with the given parameters.

Note

Augmentations related to SR artifacts are optional and can be set to None if not needed.

Parameters:

Name Type Description Default
shape Iterable[int]

Shape of the output image.

required
resolution Iterable[float]

Resolution of the output image.

required
device str

Device to use for computation.

required
intensity_generator ImageFromSeeds

Intensity generator.

required
spatial_deform SpatialDeformation

Spatial deformation generator.

required
resampler RandResample

Resampler.

required
bias_field RandBiasField

Bias field generator.

required
noise RandNoise

Noise generator.

required
gamma RandGamma

Gamma correction generator.

required
blur_cortex BlurCortex | None

Cortex blurring generator.

None
struct_noise StructNoise | None

Structural noise generator.

None
simulate_motion SimulateMotion | None

Motion simulation generator.

None
boundaries SimulatedBoundaries | None

Boundaries generator

None
stack_sampler StackSampler | None

Stack sampler.

None
Source code in fetalsyngen/generator/model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    shape: Iterable[int],
    resolution: Iterable[float],
    device: str,
    intensity_generator: ImageFromSeeds,
    spatial_deform: SpatialDeformation,
    resampler: RandResample,
    bias_field: RandBiasField,
    noise: RandNoise,
    gamma: RandGamma,
    # optional SR artifacts
    blur_cortex: BlurCortex | None = None,
    struct_noise: StructNoise | None = None,
    simulate_motion: SimulateMotion | None = None,
    boundaries: SimulatedBoundaries | None = None,
    # optional
    stack_sampler: StackSampler | None = None,
):
    """
    Initialize the model with the given parameters.

    !!!Note
        Augmentations related to SR artifacts are optional and can be set to None
        if not needed.

    Args:
        shape: Shape of the output image.
        resolution: Resolution of the output image.
        device: Device to use for computation.
        intensity_generator: Intensity generator.
        spatial_deform: Spatial deformation generator.
        resampler: Resampler.
        bias_field: Bias field generator.
        noise: Noise generator.
        gamma: Gamma correction generator.
        blur_cortex: Cortex blurring generator.
        struct_noise: Structural noise generator.
        simulate_motion: Motion simulation generator.
        boundaries: Boundaries generator
        stack_sampler: Stack sampler.

    """
    self.shape = shape
    self.resolution = resolution
    self.intensity_generator = intensity_generator
    self.spatial_deform = spatial_deform
    self.resampled = resampler
    self.biasfield = bias_field
    self.gamma = gamma
    self.noise = noise

    self.artifacts = {
        "blur_cortex": blur_cortex,
        "struct_noise": struct_noise,
        "simulate_motion": simulate_motion,
        "boundaries": boundaries,
    }
    self.stack_sampler = stack_sampler
    self.device = device

sample(orientation, image, segmentation, seeds, genparams={})

Generate a synthetic image from the input data. Supports both random generation and from a fixed genparams dictionary.

Parameters:

Name Type Description Default
image Tensor | None

Image to use as intensity prior if required.

required
segmentation Tensor

Segmentation to use as spatial prior.

required
seeds Tensor | None

Seeds to use for intensity generation.

required
genparams dict

Dictionary with generation parameters. Used for fixed generation. Should follow the structure and be of the same type as the returned generation parameters.

{}

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, dict]

The synthetic image, the segmentation, the original image, and the generation parameters.

Source code in fetalsyngen/generator/model.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def sample(
    self,
    orientation,
    image: torch.Tensor | None,
    segmentation: torch.Tensor,
    seeds: torch.Tensor | None,
    genparams: dict = {},
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
    """
    Generate a synthetic image from the input data.
    Supports both random generation and from a fixed genparams dictionary.

    Args:
        image: Image to use as intensity prior if required.
        segmentation: Segmentation to use as spatial prior.
        seeds: Seeds to use for intensity generation.
        genparams: Dictionary with generation parameters.
            Used for fixed generation.
            Should follow the structure and be of the same type as
            the returned generation parameters.

    Returns:
        The synthetic image, the segmentation, the original image, and the generation parameters.

    """
    if genparams:
        genparams = self._validated_genparams(genparams)

    # 1. Generate intensity output.
    if seeds is not None:
        seeds, selected_seeds = self.intensity_generator.load_seeds(
            seeds=seeds,
            genparams=genparams.get("selected_seeds", {}),
            orientation=orientation,
        )
        output, seed_intensities = self.intensity_generator.sample_intensities(
            seeds=seeds,
            device=self.device,
            genparams=genparams.get("seed_intensities", {}),
        )
    else:
        if image is None:
            raise ValueError(
                "If no seeds are passed, an image must be loaded to be used as intensity prior!"
            )
        # normalize the image from 0 to 255 to
        # match the intensity generator
        output = (image - image.min()) / (image.max() - image.min()) * 255
        selected_seeds = {}
        seed_intensities = {}

    # ensure that tensors are on the same device
    output = output.to(self.device)
    segmentation = segmentation.to(self.device)
    image = image.to(self.device) if image is not None else None

    # 2. Spatially deform the data
    image, segmentation, output, deform_params = self.spatial_deform.deform(
        image=image,
        segmentation=segmentation,
        output=output,
        genparams=genparams.get("deform_params", {}),
    )

    # 3. Gamma contrast transformation
    output, gamma_params = self.gamma(
        output, self.device, genparams=genparams.get("gamma_params", {})
    )

    # 4. Bias field corruption
    output, bf_params = self.biasfield(
        output, self.device, genparams=genparams.get("bf_params", {})
    )

    # 5. Downsample to simulate lower reconstruction resolution
    output, factors, resample_params = self.resampled(
        output,
        np.array(self.resolution),
        self.device,
        genparams=genparams.get("resample_params", {}),
    )

    # 6. Noise corruption
    output, noise_params = self.noise(
        output, self.device, genparams=genparams.get("noise_params", {})
    )

    # 7. Up-sample back to the original resolution/shape
    output = self.resampled.resize_back(output, factors)

    # 8. Induce SR-artifacts
    artifacts = {}
    for name, artifact in self.artifacts.items():
        if artifact is not None:
            output, metadata = artifact(
                output,
                segmentation,
                self.device,
                genparams.get("artifact_params", {}),
                resolution=self.resolution,
            )
            artifacts[name] = metadata

    # 9. Apply stack sampler if available
    if self.stack_sampler is not None:
        output, segmentation, meta = self.stack_sampler(
            output, segmentation, device=self.device
        )

        # unsqueeze the image to match the expected shape
        output = output.squeeze(0)
        segmentation = segmentation.squeeze(0)

    # 10. Aggregete the synth params
    synth_params = {
        "selected_seeds": selected_seeds,
        "seed_intensities": seed_intensities,
        "deform_params": deform_params,
        "gamma_params": gamma_params,
        "bf_params": bf_params,
        "resample_params": resample_params,
        "noise_params": noise_params,
        "artifacts": artifacts,
        "stack_sampler": meta if self.stack_sampler is not None else None,
    }

    return output, segmentation, image, synth_params

ImageFromSeeds

Source code in fetalsyngen/generator/intensity/rand_gmm.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class ImageFromSeeds:

    def __init__(
        self,
        min_subclusters: int,
        max_subclusters: int,
        seed_labels: Iterable[int],
        generation_classes: Iterable[int],
        meta_labels: list[int] = [1, 2, 3, 4],
    ):
        """

        Args:
            min_subclusters (int): Minimum number of subclusters to use.
            max_subclusters (int): Maximum number of subclusters to use.
            seed_labels (Iterable[int]): Iterable with all possible labels
                that can occur in the loaded seeds. Should be a unique set of
                integers starting from [0, ...]. 0 is reserved for the background,
                that will not have any intensity generated.
            generation_classes (Iterable[int]): Classes to use for generation.
                Seeds with the same generation calss will be generated with
                the same GMM. Should be the same length as seed_labels.
            meta_labels (int, optional): Number of meta-labels used. Defaults to 4.
        """
        self.min_subclusters = min_subclusters
        self.max_subclusters = max_subclusters
        try:
            assert len(set(seed_labels)) == len(seed_labels)
        except AssertionError:
            raise ValueError("Parameter seed_labels should have unique values.")
        try:
            assert len(seed_labels) == len(generation_classes)
        except AssertionError:
            raise ValueError(
                "Parameters seed_labels and generation_classes should have the same lengths."
            )
        self.seed_labels = seed_labels
        self.generation_classes = generation_classes
        self.meta_labels = meta_labels
        self.loader = SimpleITKReader()

    def load_seeds(
        self,
        seeds: dict[int : dict[int:Path]],
        mlabel2subclusters: dict[int:int] | None = None,
        genparams: dict = {},
        orientation: Orientation = Orientation("RAS"),
    ) -> torch.Tensor:
        """Generate an intensity image from seeds.
        If seed_mapping is provided, it is used to
        select the number of subclusters to use for
        each meta label. Otherwise, the number of subclusters
        is randomly selected from a uniform discrete distribution
        between `min_subclusters` and `max_subclusters` (both inclusive).

        Args:

            seeds: Dictionary with the mapping `subcluster_number: {meta_label: seed_path}`.
            mlabel2subclusters: Mapping to use when defining how many subclusters to
                use for each meta-label. Defaults to None.
            genparams: Dictionary with generation parameters. Defaults to {}.
                Should contain the key "mlabel2subclusters" if the mapping is to be fixed.
            orientation: Orientation to use. Defaults to Orientation("RAS").


        Returns:
            torch.Tensor: Intensity image with the same shape as the seeds.
                Tensor dimensions are **(H, W, D)**. Values inside the tensor
                correspond to the subclusters, and are grouped by meta-label.
                `1-19: CSF, 20-29: GM, 30-39: WM, 40-49: Extra-cerebral`.
        """
        # if no mapping is provided, randomly select the number of subclusters
        # to use for each meta-label in the format {mlabel: n_subclusters}
        if mlabel2subclusters is None:
            mlabel2subclusters = {
                meta_label: np.random.randint(
                    self.min_subclusters, self.max_subclusters + 1
                )
                for meta_label in self.meta_labels
            }
        if "mlabel2subclusters" in genparams.keys():
            mlabel2subclusters = genparams["mlabel2subclusters"]

        # load the first seed as the one corresponding to mlabel 1
        first_mlab = list(mlabel2subclusters.keys())[0]
        first_subcls = list(seeds[mlabel2subclusters[first_mlab]].keys())[0]

        seed = self.loader(
            seeds[mlabel2subclusters[first_mlab]][first_subcls],
            interp="nearest",
            spatial_size=192,
            resolution=1.0,
        )
        seed = orientation(seed.unsqueeze(0))
        #
        # re-orient seeds to RAS
        for mlabel in self.meta_labels:
            if mlabel == first_mlab:
                continue
            new_seed = self.loader(
                seeds[mlabel2subclusters[mlabel]][mlabel],
                interp="nearest",
                spatial_size=192,
                resolution=1.0,
            )
            new_seed = orientation(new_seed.unsqueeze(0))
            seed += new_seed

        return seed.long().squeeze(0), {"mlabel2subclusters": mlabel2subclusters}

    def sample_intensities(
        self, seeds: torch.Tensor, device: str, genparams: dict = {}
    ) -> torch.Tensor:
        """Sample the intensities from the seeds.

        Args:
            seeds (torch.Tensor): Tensor with the seeds.
            device (str): Device to use. Should be "cuda" or "cpu".
            genparams (dict, optional): Dictionary with generation parameters.
                Defaults to {}. Should contain the keys "mus" and "sigmas" if
                the GMM parameters are to be fixed.

        Returns:
            torch.Tensor: Tensor with the intensities.
        """
        nlabels = max(self.seed_labels) + 1
        nsamp = len(self.seed_labels)

        # # Sample GMMs means and stds
        mus = (
            25 + 200 * torch.rand(nlabels, dtype=torch.float, device=device)
            if "mus" not in genparams.keys()
            else genparams["mus"]
        )
        sigmas = (
            5
            + 20
            * torch.rand(
                nlabels,
                dtype=torch.float,
                device=device,
            )
            if "sigmas" not in genparams.keys()
            else genparams["sigmas"]
        )

        # if there are seed labels from the same generation class
        # set their mean to be the same with some random perturbation
        if self.generation_classes != self.seed_labels:
            # Ensure that seeds are within valid range
            if (seeds < 0).any() or (seeds >= mus.size(0)).any():
                raise ValueError(
                    f"Invalid seed indices detected: min={seeds.min().item()}, max={seeds.max().item()} (expected range: 0-{mus.size(0)-1})"
                )

            mus[self.seed_labels] = torch.clamp(
                mus[self.generation_classes]
                + 25 * torch.randn(nsamp, dtype=torch.float, device=device),
                0,
                225,
            )
        intensity_image = mus[seeds] + sigmas[seeds] * torch.randn(
            seeds.shape, dtype=torch.float, device=device
        )
        intensity_image[intensity_image < 0] = 0

        return intensity_image, {
            "mus": mus,
            "sigmas": sigmas,
        }

__init__(min_subclusters, max_subclusters, seed_labels, generation_classes, meta_labels=[1, 2, 3, 4])

Parameters:

Name Type Description Default
min_subclusters int

Minimum number of subclusters to use.

required
max_subclusters int

Maximum number of subclusters to use.

required
seed_labels Iterable[int]

Iterable with all possible labels that can occur in the loaded seeds. Should be a unique set of integers starting from [0, ...]. 0 is reserved for the background, that will not have any intensity generated.

required
generation_classes Iterable[int]

Classes to use for generation. Seeds with the same generation calss will be generated with the same GMM. Should be the same length as seed_labels.

required
meta_labels int

Number of meta-labels used. Defaults to 4.

[1, 2, 3, 4]
Source code in fetalsyngen/generator/intensity/rand_gmm.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    min_subclusters: int,
    max_subclusters: int,
    seed_labels: Iterable[int],
    generation_classes: Iterable[int],
    meta_labels: list[int] = [1, 2, 3, 4],
):
    """

    Args:
        min_subclusters (int): Minimum number of subclusters to use.
        max_subclusters (int): Maximum number of subclusters to use.
        seed_labels (Iterable[int]): Iterable with all possible labels
            that can occur in the loaded seeds. Should be a unique set of
            integers starting from [0, ...]. 0 is reserved for the background,
            that will not have any intensity generated.
        generation_classes (Iterable[int]): Classes to use for generation.
            Seeds with the same generation calss will be generated with
            the same GMM. Should be the same length as seed_labels.
        meta_labels (int, optional): Number of meta-labels used. Defaults to 4.
    """
    self.min_subclusters = min_subclusters
    self.max_subclusters = max_subclusters
    try:
        assert len(set(seed_labels)) == len(seed_labels)
    except AssertionError:
        raise ValueError("Parameter seed_labels should have unique values.")
    try:
        assert len(seed_labels) == len(generation_classes)
    except AssertionError:
        raise ValueError(
            "Parameters seed_labels and generation_classes should have the same lengths."
        )
    self.seed_labels = seed_labels
    self.generation_classes = generation_classes
    self.meta_labels = meta_labels
    self.loader = SimpleITKReader()

load_seeds(seeds, mlabel2subclusters=None, genparams={}, orientation=Orientation('RAS'))

Generate an intensity image from seeds. If seed_mapping is provided, it is used to select the number of subclusters to use for each meta label. Otherwise, the number of subclusters is randomly selected from a uniform discrete distribution between min_subclusters and max_subclusters (both inclusive).

Args:

seeds: Dictionary with the mapping `subcluster_number: {meta_label: seed_path}`.
mlabel2subclusters: Mapping to use when defining how many subclusters to
    use for each meta-label. Defaults to None.
genparams: Dictionary with generation parameters. Defaults to {}.
    Should contain the key "mlabel2subclusters" if the mapping is to be fixed.
orientation: Orientation to use. Defaults to Orientation("RAS").

Returns:

Type Description
Tensor

torch.Tensor: Intensity image with the same shape as the seeds. Tensor dimensions are (H, W, D). Values inside the tensor correspond to the subclusters, and are grouped by meta-label. 1-19: CSF, 20-29: GM, 30-39: WM, 40-49: Extra-cerebral.

Source code in fetalsyngen/generator/intensity/rand_gmm.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def load_seeds(
    self,
    seeds: dict[int : dict[int:Path]],
    mlabel2subclusters: dict[int:int] | None = None,
    genparams: dict = {},
    orientation: Orientation = Orientation("RAS"),
) -> torch.Tensor:
    """Generate an intensity image from seeds.
    If seed_mapping is provided, it is used to
    select the number of subclusters to use for
    each meta label. Otherwise, the number of subclusters
    is randomly selected from a uniform discrete distribution
    between `min_subclusters` and `max_subclusters` (both inclusive).

    Args:

        seeds: Dictionary with the mapping `subcluster_number: {meta_label: seed_path}`.
        mlabel2subclusters: Mapping to use when defining how many subclusters to
            use for each meta-label. Defaults to None.
        genparams: Dictionary with generation parameters. Defaults to {}.
            Should contain the key "mlabel2subclusters" if the mapping is to be fixed.
        orientation: Orientation to use. Defaults to Orientation("RAS").


    Returns:
        torch.Tensor: Intensity image with the same shape as the seeds.
            Tensor dimensions are **(H, W, D)**. Values inside the tensor
            correspond to the subclusters, and are grouped by meta-label.
            `1-19: CSF, 20-29: GM, 30-39: WM, 40-49: Extra-cerebral`.
    """
    # if no mapping is provided, randomly select the number of subclusters
    # to use for each meta-label in the format {mlabel: n_subclusters}
    if mlabel2subclusters is None:
        mlabel2subclusters = {
            meta_label: np.random.randint(
                self.min_subclusters, self.max_subclusters + 1
            )
            for meta_label in self.meta_labels
        }
    if "mlabel2subclusters" in genparams.keys():
        mlabel2subclusters = genparams["mlabel2subclusters"]

    # load the first seed as the one corresponding to mlabel 1
    first_mlab = list(mlabel2subclusters.keys())[0]
    first_subcls = list(seeds[mlabel2subclusters[first_mlab]].keys())[0]

    seed = self.loader(
        seeds[mlabel2subclusters[first_mlab]][first_subcls],
        interp="nearest",
        spatial_size=192,
        resolution=1.0,
    )
    seed = orientation(seed.unsqueeze(0))
    #
    # re-orient seeds to RAS
    for mlabel in self.meta_labels:
        if mlabel == first_mlab:
            continue
        new_seed = self.loader(
            seeds[mlabel2subclusters[mlabel]][mlabel],
            interp="nearest",
            spatial_size=192,
            resolution=1.0,
        )
        new_seed = orientation(new_seed.unsqueeze(0))
        seed += new_seed

    return seed.long().squeeze(0), {"mlabel2subclusters": mlabel2subclusters}

sample_intensities(seeds, device, genparams={})

Sample the intensities from the seeds.

Parameters:

Name Type Description Default
seeds Tensor

Tensor with the seeds.

required
device str

Device to use. Should be "cuda" or "cpu".

required
genparams dict

Dictionary with generation parameters. Defaults to {}. Should contain the keys "mus" and "sigmas" if the GMM parameters are to be fixed.

{}

Returns:

Type Description
Tensor

torch.Tensor: Tensor with the intensities.

Source code in fetalsyngen/generator/intensity/rand_gmm.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def sample_intensities(
    self, seeds: torch.Tensor, device: str, genparams: dict = {}
) -> torch.Tensor:
    """Sample the intensities from the seeds.

    Args:
        seeds (torch.Tensor): Tensor with the seeds.
        device (str): Device to use. Should be "cuda" or "cpu".
        genparams (dict, optional): Dictionary with generation parameters.
            Defaults to {}. Should contain the keys "mus" and "sigmas" if
            the GMM parameters are to be fixed.

    Returns:
        torch.Tensor: Tensor with the intensities.
    """
    nlabels = max(self.seed_labels) + 1
    nsamp = len(self.seed_labels)

    # # Sample GMMs means and stds
    mus = (
        25 + 200 * torch.rand(nlabels, dtype=torch.float, device=device)
        if "mus" not in genparams.keys()
        else genparams["mus"]
    )
    sigmas = (
        5
        + 20
        * torch.rand(
            nlabels,
            dtype=torch.float,
            device=device,
        )
        if "sigmas" not in genparams.keys()
        else genparams["sigmas"]
    )

    # if there are seed labels from the same generation class
    # set their mean to be the same with some random perturbation
    if self.generation_classes != self.seed_labels:
        # Ensure that seeds are within valid range
        if (seeds < 0).any() or (seeds >= mus.size(0)).any():
            raise ValueError(
                f"Invalid seed indices detected: min={seeds.min().item()}, max={seeds.max().item()} (expected range: 0-{mus.size(0)-1})"
            )

        mus[self.seed_labels] = torch.clamp(
            mus[self.generation_classes]
            + 25 * torch.randn(nsamp, dtype=torch.float, device=device),
            0,
            225,
        )
    intensity_image = mus[seeds] + sigmas[seeds] * torch.randn(
        seeds.shape, dtype=torch.float, device=device
    )
    intensity_image[intensity_image < 0] = 0

    return intensity_image, {
        "mus": mus,
        "sigmas": sigmas,
    }

SpatialDeformation

Class defining the spatial deformation of the image. Combines both random affine and nonlinear transformations to deform the image.

Source code in fetalsyngen/generator/deformation/affine_nonrigid.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class SpatialDeformation:
    """
    Class defining the spatial deformation of the image.
    Combines both random affine and nonlinear transformations to deform the image.
    """

    def __init__(
        self,
        max_rotation: float,
        max_shear: float,
        max_scaling: float,
        size: Iterable[int],
        prob: float,
        nonlinear_transform: bool,
        nonlin_scale_min: float,
        nonlin_scale_max: float,
        nonlin_std_max: float,
        flip_prb: float,
        device: str,
    ):
        """Initialize the spatial deformation.

        Args:
            max_rotation (float): Maximum rotation in degrees.
            max_shear (float): Maximum shear.
            max_scaling (float): Maximum scaling.
            size (Iterable[int]): Size of the output image.
            prob (float): Probability of applying the deformation.
            nonlinear_transform (bool): Whether to apply nonlinear transformation.
            nonlin_scale_min (float): Minimum scale for the nonlinear transformation.
            nonlin_scale_max (float): Maximum scale for the nonlinear transformation.
            nonlin_std_max (float): Maximum standard deviation for the nonlinear transformation.
            flip_prb (float): Probability of flipping the image.
            device (str): Device to use for computation. Either "cuda" or "cpu".
        """
        self.size = size  # 256, 256, 256
        self.prob = prob
        self.flip_prb = flip_prb

        # randaffine parameters
        self.max_rotation = max_rotation
        self.max_shear = max_shear
        self.max_scaling = max_scaling

        # nonlinear transform parameters
        self.nonlinear_transform = nonlinear_transform
        self.nonlin_scale_min = nonlin_scale_min
        self.nonlin_scale_max = nonlin_scale_max
        self.nonlin_std_max = nonlin_std_max

        self.device = device

        self._prepare_grid()

    def _prepare_grid(self):

        xx, yy, zz = np.meshgrid(
            range(self.size[0]),
            range(self.size[1]),
            range(self.size[2]),
            sparse=False,
            indexing="ij",
        )
        self.xx = torch.tensor(xx, dtype=torch.float, device=self.device)
        self.yy = torch.tensor(yy, dtype=torch.float, device=self.device)
        self.zz = torch.tensor(zz, dtype=torch.float, device=self.device)
        self.c = torch.tensor(
            (np.array(self.size) - 1) / 2,
            dtype=torch.float,
            device=self.device,
        )
        self.xc = self.xx - self.c[0]
        self.yc = self.yy - self.c[1]
        self.zc = self.zz - self.c[2]

    def deform(
        self, image, segmentation, output, genparams: dict = {}
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        """Deform the image, segmentation and output.

        Args:
            image (torch.Tensor): Image to deform.
            segmentation (torch.Tensor): Segmentation to deform.
            output (torch.Tensor): Output to deform.
            genparams (dict, optional): Dictionary with generation parameters. Defaults to {}.
                Should contain the keys "affine" and "non_rigid" if the parameters are fixed.
                Affine parameters should contain the keys "rotations", "shears" and "scalings".
                Non-rigid parameters should contain the keys "nonlin_scale", "nonlin_std" and "size_F_small".

        Returns:
            Deformed image, segmentation, output and deformation parameters.
        """
        deform_params = {}
        if np.random.rand() < self.prob or len(genparams.keys()) > 0:
            image_shape = output.shape
            flip = (
                np.random.rand() < self.flip_prb
                if "flip" not in genparams.keys()
                else genparams["flip"]
            )
            xx2, yy2, zz2, x1, y1, z1, x2, y2, z2, deform_params = (
                self.generate_deformation(
                    image_shape, random_shift=True, genparams=genparams
                )
            )
            # flip the image if nessesary
            if flip:
                segmentation = torch.flip(segmentation, [0])
                output = torch.flip(output, [0])
                image = torch.flip(image, [0]) if image is not None else None

            output = fast_3D_interp_torch(output, xx2, yy2, zz2, "linear")
            segmentation = fast_3D_interp_torch(
                segmentation.to(self.device), xx2, yy2, zz2, "nearest"
            )
            if image is not None:
                image = fast_3D_interp_torch(
                    image.to(self.device), xx2, yy2, zz2, "linear"
                )

            deform_params["flip"] = flip

        return image, segmentation, output, deform_params

    def generate_deformation(self, image_shape, random_shift=True, genparams={}):

        # sample affine deformation
        A, c2, aff_params = self.random_affine_transform(
            shp=image_shape,
            max_rotation=self.max_rotation,
            max_shear=self.max_shear,
            max_scaling=self.max_scaling,
            random_shift=random_shift,
            genparams=genparams.get("affine", {}),
        )

        # sample nonlinear deformation
        if self.nonlinear_transform:
            F, non_rigid_params = self.random_nonlinear_transform(
                nonlin_scale_min=self.nonlin_scale_min,
                nonlin_scale_max=self.nonlin_scale_max,
                nonlin_std_max=self.nonlin_std_max,
                genparams=genparams.get("non_rigid", {}),
            )
        else:
            F = None
            non_rigid_params = {}

        # deform the images
        xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_image(image_shape, A, c2, F)

        return (
            xx2,
            yy2,
            zz2,
            x1,
            y1,
            z1,
            x2,
            y2,
            z2,
            {
                "affine": aff_params,
                "non_rigid": non_rigid_params,
            },
        )

    def random_affine_transform(
        self,
        shp,
        max_rotation,
        max_shear,
        max_scaling,
        random_shift=True,
        genparams={},
    ):
        rotations = (
            ((2 * max_rotation * np.random.rand(3) - max_rotation) / 180.0 * np.pi)
            if "rotations" not in genparams.keys()
            else genparams["rotations"]
        )

        shears = (
            2 * max_shear * np.random.rand(3) - max_shear
            if "shears" not in genparams.keys()
            else genparams["shears"]
        )
        scalings = (
            1 + (2 * max_scaling * np.random.rand(3) - max_scaling)
            if "scalings" not in genparams.keys()
            else genparams["scalings"]
        )
        # we divide distance maps by this, not perfect, but better than nothing
        A = torch.tensor(
            make_affine_matrix(rotations, shears, scalings),
            dtype=torch.float,
            device=self.device,
        )
        # sample center
        if random_shift:
            max_shift = (
                torch.tensor(
                    np.array(shp[0:3]) - self.size,
                    dtype=torch.float,
                    device=self.device,
                )
            ) / 2
            max_shift[max_shift < 0] = 0
            c2 = torch.tensor(
                (np.array(shp[0:3]) - 1) / 2,
                dtype=torch.float,
                device=self.device,
            ) + (
                2 * (max_shift * torch.rand(3, dtype=float, device=self.device))
                - max_shift
            )
        else:
            c2 = torch.tensor(
                (np.array(shp[0:3]) - 1) / 2,
                dtype=torch.float,
                device=self.device,
            )
        affine_params = {
            "rotations": rotations,
            "shears": shears,
            "scalings": scalings,
        }

        return A, c2, affine_params

    def random_nonlinear_transform(
        self, nonlin_scale_min, nonlin_scale_max, nonlin_std_max, genparams={}
    ):

        nonlin_scale = (
            nonlin_scale_min + np.random.rand(1) * (nonlin_scale_max - nonlin_scale_min)
            if "nonlin_scale" not in genparams.keys()
            else genparams["nonlin_scale"]
        )
        size_F_small = (
            np.round(nonlin_scale * np.array(self.size)).astype(int).tolist()
            if "size_F_small" not in genparams.keys()
            else genparams["size_F_small"]
        )
        nonlin_std = (
            nonlin_std_max * np.random.rand()
            if "nonlin_std" not in genparams.keys()
            else genparams["nonlin_std"]
        )
        Fsmall = nonlin_std * torch.randn(
            [*size_F_small, 3], dtype=torch.float, device=self.device
        )
        F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small)

        return F, {
            "nonlin_scale": nonlin_scale,
            "nonlin_std": nonlin_std,
            "size_F_small": size_F_small,
        }

    def deform_image(self, shp, A, c2, F):
        if F is not None:
            # deform the images (we do nonlinear "first" ie after so we can do heavy coronal deformations in photo mode)
            xx1 = self.xc + F[:, :, :, 0]
            yy1 = self.yc + F[:, :, :, 1]
            zz1 = self.zc + F[:, :, :, 2]
        else:
            xx1 = self.xc
            yy1 = self.yc
            zz1 = self.zc

        xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0]
        yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1]
        zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2]
        xx2[xx2 < 0] = 0
        yy2[yy2 < 0] = 0
        zz2[zz2 < 0] = 0
        xx2[xx2 > (shp[0] - 1)] = shp[0] - 1
        yy2[yy2 > (shp[1] - 1)] = shp[1] - 1
        zz2[zz2 > (shp[2] - 1)] = shp[2] - 1

        # Get the margins for reading images
        x1 = torch.floor(torch.min(xx2))
        y1 = torch.floor(torch.min(yy2))
        z1 = torch.floor(torch.min(zz2))
        x2 = 1 + torch.ceil(torch.max(xx2))
        y2 = 1 + torch.ceil(torch.max(yy2))
        z2 = 1 + torch.ceil(torch.max(zz2))
        xx2 -= x1
        yy2 -= y1
        zz2 -= z1

        x1 = x1.cpu().numpy().astype(int)
        y1 = y1.cpu().numpy().astype(int)
        z1 = z1.cpu().numpy().astype(int)
        x2 = x2.cpu().numpy().astype(int)
        y2 = y2.cpu().numpy().astype(int)
        z2 = z2.cpu().numpy().astype(int)
        return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2

__init__(max_rotation, max_shear, max_scaling, size, prob, nonlinear_transform, nonlin_scale_min, nonlin_scale_max, nonlin_std_max, flip_prb, device)

Initialize the spatial deformation.

Parameters:

Name Type Description Default
max_rotation float

Maximum rotation in degrees.

required
max_shear float

Maximum shear.

required
max_scaling float

Maximum scaling.

required
size Iterable[int]

Size of the output image.

required
prob float

Probability of applying the deformation.

required
nonlinear_transform bool

Whether to apply nonlinear transformation.

required
nonlin_scale_min float

Minimum scale for the nonlinear transformation.

required
nonlin_scale_max float

Maximum scale for the nonlinear transformation.

required
nonlin_std_max float

Maximum standard deviation for the nonlinear transformation.

required
flip_prb float

Probability of flipping the image.

required
device str

Device to use for computation. Either "cuda" or "cpu".

required
Source code in fetalsyngen/generator/deformation/affine_nonrigid.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    max_rotation: float,
    max_shear: float,
    max_scaling: float,
    size: Iterable[int],
    prob: float,
    nonlinear_transform: bool,
    nonlin_scale_min: float,
    nonlin_scale_max: float,
    nonlin_std_max: float,
    flip_prb: float,
    device: str,
):
    """Initialize the spatial deformation.

    Args:
        max_rotation (float): Maximum rotation in degrees.
        max_shear (float): Maximum shear.
        max_scaling (float): Maximum scaling.
        size (Iterable[int]): Size of the output image.
        prob (float): Probability of applying the deformation.
        nonlinear_transform (bool): Whether to apply nonlinear transformation.
        nonlin_scale_min (float): Minimum scale for the nonlinear transformation.
        nonlin_scale_max (float): Maximum scale for the nonlinear transformation.
        nonlin_std_max (float): Maximum standard deviation for the nonlinear transformation.
        flip_prb (float): Probability of flipping the image.
        device (str): Device to use for computation. Either "cuda" or "cpu".
    """
    self.size = size  # 256, 256, 256
    self.prob = prob
    self.flip_prb = flip_prb

    # randaffine parameters
    self.max_rotation = max_rotation
    self.max_shear = max_shear
    self.max_scaling = max_scaling

    # nonlinear transform parameters
    self.nonlinear_transform = nonlinear_transform
    self.nonlin_scale_min = nonlin_scale_min
    self.nonlin_scale_max = nonlin_scale_max
    self.nonlin_std_max = nonlin_std_max

    self.device = device

    self._prepare_grid()

deform(image, segmentation, output, genparams={})

Deform the image, segmentation and output.

Parameters:

Name Type Description Default
image Tensor

Image to deform.

required
segmentation Tensor

Segmentation to deform.

required
output Tensor

Output to deform.

required
genparams dict

Dictionary with generation parameters. Defaults to {}. Should contain the keys "affine" and "non_rigid" if the parameters are fixed. Affine parameters should contain the keys "rotations", "shears" and "scalings". Non-rigid parameters should contain the keys "nonlin_scale", "nonlin_std" and "size_F_small".

{}

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, dict]

Deformed image, segmentation, output and deformation parameters.

Source code in fetalsyngen/generator/deformation/affine_nonrigid.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def deform(
    self, image, segmentation, output, genparams: dict = {}
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
    """Deform the image, segmentation and output.

    Args:
        image (torch.Tensor): Image to deform.
        segmentation (torch.Tensor): Segmentation to deform.
        output (torch.Tensor): Output to deform.
        genparams (dict, optional): Dictionary with generation parameters. Defaults to {}.
            Should contain the keys "affine" and "non_rigid" if the parameters are fixed.
            Affine parameters should contain the keys "rotations", "shears" and "scalings".
            Non-rigid parameters should contain the keys "nonlin_scale", "nonlin_std" and "size_F_small".

    Returns:
        Deformed image, segmentation, output and deformation parameters.
    """
    deform_params = {}
    if np.random.rand() < self.prob or len(genparams.keys()) > 0:
        image_shape = output.shape
        flip = (
            np.random.rand() < self.flip_prb
            if "flip" not in genparams.keys()
            else genparams["flip"]
        )
        xx2, yy2, zz2, x1, y1, z1, x2, y2, z2, deform_params = (
            self.generate_deformation(
                image_shape, random_shift=True, genparams=genparams
            )
        )
        # flip the image if nessesary
        if flip:
            segmentation = torch.flip(segmentation, [0])
            output = torch.flip(output, [0])
            image = torch.flip(image, [0]) if image is not None else None

        output = fast_3D_interp_torch(output, xx2, yy2, zz2, "linear")
        segmentation = fast_3D_interp_torch(
            segmentation.to(self.device), xx2, yy2, zz2, "nearest"
        )
        if image is not None:
            image = fast_3D_interp_torch(
                image.to(self.device), xx2, yy2, zz2, "linear"
            )

        deform_params["flip"] = flip

    return image, segmentation, output, deform_params

RandResample

Bases: RandTransform

Resample the input image to a random resolution sampled uniformly between min_resolution and max_resolution with a probability of prob.

If the resolution is smaller than the input resolution, no resampling is performed.

Source code in fetalsyngen/generator/augmentation/synthseg.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class RandResample(RandTransform):
    """Resample the input image to a random resolution sampled uniformly between
    `min_resolution` and `max_resolution` with a probability of `prob`.

    If the resolution is smaller than the input resolution, no resampling is performed.
    """

    def __init__(
        self,
        prob: float,
        min_resolution: float,
        max_resolution: float,
    ):
        """
        Initialize the augmentation parameters.

        Args:
            prob (float): Probability of applying the augmentation.
            min_resolution (float): Minimum resolution for the augmentation (in mm).
            max_resolution (float): Maximum resolution for the augmentation.
        """
        self.prob = prob
        self.min_resolution = min_resolution
        self.max_resolution = max_resolution

    def __call__(
        self, output, input_resolution, device, genparams: dict = {}
    ) -> torch.Tensor:
        """Apply the resampling to the input image.

        Args:
            output (torch.Tensor): Input image to resample.
            input_resolution (np.array): Resolution of the input image.
            device (str): Device to use for computation.
            genparams (dict): Generation parameters.
                Default: {}. Should contain the key "spacing" if the spacing is fixed.

        Returns:
            Resampled image.
        """
        if np.random.rand() < self.prob or "spacing" in genparams.keys():
            input_size = np.array(output.shape)
            spacing = (
                np.array([1.0, 1.0, 1.0])
                * self.random_uniform(self.min_resolution, self.max_resolution)
                if "spacing" not in genparams.keys()
                else genparams["spacing"]
            )
            # Ensure spacing and input_resolution are numpy arrays
            spacing = np.array(spacing)
            input_resolution = np.array(input_resolution)

            # calculate stds of gaussian kernels
            # used for blurring to simulate resampling
            # the data to different resolutions
            stds = (
                (0.85 + 0.3 * np.random.rand())
                * np.log(5)
                / np.pi
                * spacing
                / input_resolution
            )
            # no blur if thickness is equal or smaller to the resolution of the training data
            stds[spacing <= input_resolution] = 0.0
            output_blurred = gaussian_blur_3d(output, stds, device)

            # resize the blurred output to the new resolution
            new_size = (np.array(input_size) * input_resolution / spacing).astype(int)

            # calculate the factors for the interpolation
            factors = np.array(new_size) / np.array(input_size)
            # delta is the offset for the interpolation
            delta = (1.0 - factors) / (2.0 * factors)
            vx = np.arange(
                delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0]
            )[: new_size[0]]
            vy = np.arange(
                delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1]
            )[: new_size[1]]
            vz = np.arange(
                delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2]
            )[: new_size[2]]
            II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing="ij")
            II = torch.tensor(II, dtype=torch.float, device=device)
            JJ = torch.tensor(JJ, dtype=torch.float, device=device)
            KK = torch.tensor(KK, dtype=torch.float, device=device)

            output_resized = fast_3D_interp_torch(output_blurred, II, JJ, KK, "linear")
            return output_resized, factors, {"spacing": spacing.tolist()}
        else:
            return output, None, {"spacing": None}

    def resize_back(self, output_resized, factors):
        if factors is not None:
            output_resized = myzoom_torch(output_resized, 1 / factors)
            return output_resized / torch.max(output_resized)
        else:
            return output_resized

__init__(prob, min_resolution, max_resolution)

Initialize the augmentation parameters.

Parameters:

Name Type Description Default
prob float

Probability of applying the augmentation.

required
min_resolution float

Minimum resolution for the augmentation (in mm).

required
max_resolution float

Maximum resolution for the augmentation.

required
Source code in fetalsyngen/generator/augmentation/synthseg.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    prob: float,
    min_resolution: float,
    max_resolution: float,
):
    """
    Initialize the augmentation parameters.

    Args:
        prob (float): Probability of applying the augmentation.
        min_resolution (float): Minimum resolution for the augmentation (in mm).
        max_resolution (float): Maximum resolution for the augmentation.
    """
    self.prob = prob
    self.min_resolution = min_resolution
    self.max_resolution = max_resolution

__call__(output, input_resolution, device, genparams={})

Apply the resampling to the input image.

Parameters:

Name Type Description Default
output Tensor

Input image to resample.

required
input_resolution array

Resolution of the input image.

required
device str

Device to use for computation.

required
genparams dict

Generation parameters. Default: {}. Should contain the key "spacing" if the spacing is fixed.

{}

Returns:

Type Description
Tensor

Resampled image.

Source code in fetalsyngen/generator/augmentation/synthseg.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __call__(
    self, output, input_resolution, device, genparams: dict = {}
) -> torch.Tensor:
    """Apply the resampling to the input image.

    Args:
        output (torch.Tensor): Input image to resample.
        input_resolution (np.array): Resolution of the input image.
        device (str): Device to use for computation.
        genparams (dict): Generation parameters.
            Default: {}. Should contain the key "spacing" if the spacing is fixed.

    Returns:
        Resampled image.
    """
    if np.random.rand() < self.prob or "spacing" in genparams.keys():
        input_size = np.array(output.shape)
        spacing = (
            np.array([1.0, 1.0, 1.0])
            * self.random_uniform(self.min_resolution, self.max_resolution)
            if "spacing" not in genparams.keys()
            else genparams["spacing"]
        )
        # Ensure spacing and input_resolution are numpy arrays
        spacing = np.array(spacing)
        input_resolution = np.array(input_resolution)

        # calculate stds of gaussian kernels
        # used for blurring to simulate resampling
        # the data to different resolutions
        stds = (
            (0.85 + 0.3 * np.random.rand())
            * np.log(5)
            / np.pi
            * spacing
            / input_resolution
        )
        # no blur if thickness is equal or smaller to the resolution of the training data
        stds[spacing <= input_resolution] = 0.0
        output_blurred = gaussian_blur_3d(output, stds, device)

        # resize the blurred output to the new resolution
        new_size = (np.array(input_size) * input_resolution / spacing).astype(int)

        # calculate the factors for the interpolation
        factors = np.array(new_size) / np.array(input_size)
        # delta is the offset for the interpolation
        delta = (1.0 - factors) / (2.0 * factors)
        vx = np.arange(
            delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0]
        )[: new_size[0]]
        vy = np.arange(
            delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1]
        )[: new_size[1]]
        vz = np.arange(
            delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2]
        )[: new_size[2]]
        II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing="ij")
        II = torch.tensor(II, dtype=torch.float, device=device)
        JJ = torch.tensor(JJ, dtype=torch.float, device=device)
        KK = torch.tensor(KK, dtype=torch.float, device=device)

        output_resized = fast_3D_interp_torch(output_blurred, II, JJ, KK, "linear")
        return output_resized, factors, {"spacing": spacing.tolist()}
    else:
        return output, None, {"spacing": None}

RandBiasField

Bases: RandTransform

Add a random bias field to the input image with a probability of prob.

Source code in fetalsyngen/generator/augmentation/synthseg.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class RandBiasField(RandTransform):
    """Add a random bias field to the input image with a probability of `prob`."""

    def __init__(
        self,
        prob: float,
        scale_min: float,
        scale_max: float,
        std_min: float,
        std_max: float,
    ):
        """

        Args:
            prob: Probability of applying the augmentation.
            scale_min: Minimum scale of the bias field.
            scale_max: Maximum scale of the bias field.
            std_min: Minimum standard deviation of the bias field.
            std_max: Maximum standard deviation of the bias.
        """

        self.prob = prob
        self.scale_min = scale_min
        self.scale_max = scale_max
        self.std_min = std_min
        self.std_max = std_max

    def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
        """Apply the bias field to the input image.

        Args:
            output (torch.Tensor): Input image to apply the bias field.
            device (str): Device to use for computation.
            genparams (dict): Generation parameters.
                Default: {}. Should contain the keys "bf_scale", "bf_std" and "bf_size" if
                the bias field parameters are fixed.

        Returns:
            Image with the bias field applied.
        """
        if np.random.rand() < self.prob or len(genparams.keys()) > 0:
            image_size = output.shape
            bf_scale = (
                self.scale_min + np.random.rand(1) * (self.scale_max - self.scale_min)
                if "bf_scale" not in genparams.keys()
                else genparams["bf_scale"]
            )
            bf_size = np.round(bf_scale * np.array(image_size)).astype(int).tolist()
            bf_std = (
                self.std_min + (self.std_max - self.std_min) * np.random.rand(1)
                if "bf_std" not in genparams.keys()
                else genparams["bf_std"]
            )

            bf_low_scale = torch.tensor(
                bf_std,
                dtype=torch.float,
                device=device,
            ) * torch.randn(bf_size, dtype=torch.float, device=device)
            bf_interp = myzoom_torch(bf_low_scale, np.array(image_size) / bf_size)
            bf = torch.exp(bf_interp)

            return output * bf, {
                "bf_scale": bf_scale,
                "bf_std": bf_std,
                "bf_size": bf_size,
            }
        else:
            return output, {"bf_scale": None, "bf_std": None, "bf_size": None}

__init__(prob, scale_min, scale_max, std_min, std_max)

Parameters:

Name Type Description Default
prob float

Probability of applying the augmentation.

required
scale_min float

Minimum scale of the bias field.

required
scale_max float

Maximum scale of the bias field.

required
std_min float

Minimum standard deviation of the bias field.

required
std_max float

Maximum standard deviation of the bias.

required
Source code in fetalsyngen/generator/augmentation/synthseg.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def __init__(
    self,
    prob: float,
    scale_min: float,
    scale_max: float,
    std_min: float,
    std_max: float,
):
    """

    Args:
        prob: Probability of applying the augmentation.
        scale_min: Minimum scale of the bias field.
        scale_max: Maximum scale of the bias field.
        std_min: Minimum standard deviation of the bias field.
        std_max: Maximum standard deviation of the bias.
    """

    self.prob = prob
    self.scale_min = scale_min
    self.scale_max = scale_max
    self.std_min = std_min
    self.std_max = std_max

__call__(output, device, genparams={})

Apply the bias field to the input image.

Parameters:

Name Type Description Default
output Tensor

Input image to apply the bias field.

required
device str

Device to use for computation.

required
genparams dict

Generation parameters. Default: {}. Should contain the keys "bf_scale", "bf_std" and "bf_size" if the bias field parameters are fixed.

{}

Returns:

Type Description
Tensor

Image with the bias field applied.

Source code in fetalsyngen/generator/augmentation/synthseg.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
    """Apply the bias field to the input image.

    Args:
        output (torch.Tensor): Input image to apply the bias field.
        device (str): Device to use for computation.
        genparams (dict): Generation parameters.
            Default: {}. Should contain the keys "bf_scale", "bf_std" and "bf_size" if
            the bias field parameters are fixed.

    Returns:
        Image with the bias field applied.
    """
    if np.random.rand() < self.prob or len(genparams.keys()) > 0:
        image_size = output.shape
        bf_scale = (
            self.scale_min + np.random.rand(1) * (self.scale_max - self.scale_min)
            if "bf_scale" not in genparams.keys()
            else genparams["bf_scale"]
        )
        bf_size = np.round(bf_scale * np.array(image_size)).astype(int).tolist()
        bf_std = (
            self.std_min + (self.std_max - self.std_min) * np.random.rand(1)
            if "bf_std" not in genparams.keys()
            else genparams["bf_std"]
        )

        bf_low_scale = torch.tensor(
            bf_std,
            dtype=torch.float,
            device=device,
        ) * torch.randn(bf_size, dtype=torch.float, device=device)
        bf_interp = myzoom_torch(bf_low_scale, np.array(image_size) / bf_size)
        bf = torch.exp(bf_interp)

        return output * bf, {
            "bf_scale": bf_scale,
            "bf_std": bf_std,
            "bf_size": bf_size,
        }
    else:
        return output, {"bf_scale": None, "bf_std": None, "bf_size": None}

RandNoise

Bases: RandTransform

Add random Gaussian noise to the input image with a probability of prob.

Source code in fetalsyngen/generator/augmentation/synthseg.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class RandNoise(RandTransform):
    """Add random Gaussian noise to the input image with a probability of `prob`."""

    def __init__(self, prob: float, std_min: float, std_max: float):
        """
        The image scale is 0-255 so the noise is added in the same scale.
        Args:
            prob: Probability of applying the augmentation.
            std_min: Minimum standard deviation of the noise.
            std_max: Maximum standard deviation of the noise
        """
        self.prob = prob
        self.std_min = std_min
        self.std_max = std_max

    def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
        """Apply the noise to the input image.

        Args:
            output (torch.Tensor): Input image to apply the noise.
            device (str): Device to use for computation.
            genparams (dict): Generation parameters.
                Default: {}. Should contain the key "noise_std" if the noise standard deviation is fixed.

        Returns:
            Image with the noise applied."""
        noise_std = None
        if np.random.rand() < self.prob or "noise_std" in genparams.keys():
            noise_std = (
                self.std_min + (self.std_max - self.std_min) * np.random.rand(1)
                if "noise_std" not in genparams.keys()
                else genparams["noise_std"]
            )

            noise_std = torch.tensor(
                noise_std,
                dtype=torch.float,
                device=device,
            )
            output = output + noise_std * torch.randn(
                output.shape, dtype=torch.float, device=device
            )
            output[output < 0] = 0
        noise_std = noise_std.item() if noise_std is not None else None
        return output, {"noise_std": noise_std}

__init__(prob, std_min, std_max)

The image scale is 0-255 so the noise is added in the same scale. Args: prob: Probability of applying the augmentation. std_min: Minimum standard deviation of the noise. std_max: Maximum standard deviation of the noise

Source code in fetalsyngen/generator/augmentation/synthseg.py
199
200
201
202
203
204
205
206
207
208
209
def __init__(self, prob: float, std_min: float, std_max: float):
    """
    The image scale is 0-255 so the noise is added in the same scale.
    Args:
        prob: Probability of applying the augmentation.
        std_min: Minimum standard deviation of the noise.
        std_max: Maximum standard deviation of the noise
    """
    self.prob = prob
    self.std_min = std_min
    self.std_max = std_max

__call__(output, device, genparams={})

Apply the noise to the input image.

Parameters:

Name Type Description Default
output Tensor

Input image to apply the noise.

required
device str

Device to use for computation.

required
genparams dict

Generation parameters. Default: {}. Should contain the key "noise_std" if the noise standard deviation is fixed.

{}

Returns:

Type Description
Tensor

Image with the noise applied.

Source code in fetalsyngen/generator/augmentation/synthseg.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
    """Apply the noise to the input image.

    Args:
        output (torch.Tensor): Input image to apply the noise.
        device (str): Device to use for computation.
        genparams (dict): Generation parameters.
            Default: {}. Should contain the key "noise_std" if the noise standard deviation is fixed.

    Returns:
        Image with the noise applied."""
    noise_std = None
    if np.random.rand() < self.prob or "noise_std" in genparams.keys():
        noise_std = (
            self.std_min + (self.std_max - self.std_min) * np.random.rand(1)
            if "noise_std" not in genparams.keys()
            else genparams["noise_std"]
        )

        noise_std = torch.tensor(
            noise_std,
            dtype=torch.float,
            device=device,
        )
        output = output + noise_std * torch.randn(
            output.shape, dtype=torch.float, device=device
        )
        output[output < 0] = 0
    noise_std = noise_std.item() if noise_std is not None else None
    return output, {"noise_std": noise_std}

RandGamma

Bases: RandTransform

Apply gamma correction to the input image with a probability of prob.

Source code in fetalsyngen/generator/augmentation/synthseg.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class RandGamma(RandTransform):
    """Apply gamma correction to the input image with a probability of `prob`."""

    def __init__(self, prob: float, gamma_std: float):
        """
        Args:
            prob: Probability of applying the augmentation.
            gamma_std: Standard deviation of the gamma correction.
        """
        self.prob = prob
        self.gamma_std = gamma_std

    def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
        """Apply the gamma correction to the input image.

        Args:
            output (torch.Tensor): Input image to apply the gamma correction.
            device (str): Device to use for computation.
            genparams (dict): Generation parameters.
                Default: {}. Should contain the key "gamma" if the gamma correction is fixed.

        Returns:
            Image with the gamma correction applied.
        """
        gamma = None
        if np.random.rand() < self.prob or "gamma" in genparams.keys():
            gamma = (
                np.exp(self.gamma_std * np.random.randn(1)[0])
                if "gamma" not in genparams.keys()
                else genparams["gamma"]
            )
            gamma_tensor = torch.tensor(
                gamma,
                dtype=float,
                device=device,
            )
            output = 300.0 * (output / 300.0) ** gamma_tensor
        return output, {"gamma": gamma}

__init__(prob, gamma_std)

Parameters:

Name Type Description Default
prob float

Probability of applying the augmentation.

required
gamma_std float

Standard deviation of the gamma correction.

required
Source code in fetalsyngen/generator/augmentation/synthseg.py
246
247
248
249
250
251
252
253
def __init__(self, prob: float, gamma_std: float):
    """
    Args:
        prob: Probability of applying the augmentation.
        gamma_std: Standard deviation of the gamma correction.
    """
    self.prob = prob
    self.gamma_std = gamma_std

__call__(output, device, genparams={})

Apply the gamma correction to the input image.

Parameters:

Name Type Description Default
output Tensor

Input image to apply the gamma correction.

required
device str

Device to use for computation.

required
genparams dict

Generation parameters. Default: {}. Should contain the key "gamma" if the gamma correction is fixed.

{}

Returns:

Type Description
Tensor

Image with the gamma correction applied.

Source code in fetalsyngen/generator/augmentation/synthseg.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def __call__(self, output, device, genparams: dict = {}) -> torch.Tensor:
    """Apply the gamma correction to the input image.

    Args:
        output (torch.Tensor): Input image to apply the gamma correction.
        device (str): Device to use for computation.
        genparams (dict): Generation parameters.
            Default: {}. Should contain the key "gamma" if the gamma correction is fixed.

    Returns:
        Image with the gamma correction applied.
    """
    gamma = None
    if np.random.rand() < self.prob or "gamma" in genparams.keys():
        gamma = (
            np.exp(self.gamma_std * np.random.randn(1)[0])
            if "gamma" not in genparams.keys()
            else genparams["gamma"]
        )
        gamma_tensor = torch.tensor(
            gamma,
            dtype=float,
            device=device,
        )
        output = 300.0 * (output / 300.0) ** gamma_tensor
    return output, {"gamma": gamma}

Fixed Image Generation