Skip to content

ldctbench.methods.wganvgg.Trainer

Trainer(args, device)

Bases: BaseTrainer

Trainer for WGAN-VGG1


  1. Q. Yang et al., “Low-dose CT image denoising using a generative adversarial network with wasserstein distance and perceptual loss,” IEEE Transactions on Medical Imaging, vol. 37, no. 6, pp. 1348–1357, Jun. 2018. 

Parameters:

  • args (Namespace) –

    Arguments to configure the trainer.

  • device (device) –

    Torch device to use for training.

gradient_penalty(target, fake, lam=10.0)

Compute gradient penalty for given target and fake.

Parameters:

  • target (Tensor) –

    Ground truth tensor

  • fake (Tensor) –

    Fake = G(x) tensor

  • lam (float, default: 10.0 ) –

    lambda to weigh gradient penalty, by default 10.

Returns:

  • Tensor

    Computed penalty using provided target, fake and self.critic

train_step(batch)

Training step

Parameters:

  • batch (Dict[str, Tensor]) –

    Batch coming from training dataloader containing LD input and HD ground truth.

val_step(batch_idx, batch)

Validation step

Parameters:

  • batch_idx (int) –

    Batch idx necessary for logging of samples.

  • batch (Dict[str, Tensor]) –

    Batch coming from validation dataloader containing LD input and HD ground truth.