ldctbench.methods.wganvgg.Trainer
Trainer(args, device)
Bases: BaseTrainer
Trainer for WGAN-VGG1
-
Q. Yang et al., “Low-dose CT image denoising using a generative adversarial network with wasserstein distance and perceptual loss,” IEEE Transactions on Medical Imaging, vol. 37, no. 6, pp. 1348–1357, Jun. 2018. ↩
Parameters:
-
args
(Namespace
) –Arguments to configure the trainer.
-
device
(device
) –Torch device to use for training.
gradient_penalty(target, fake, lam=10.0)
Compute gradient penalty for given target and fake.
Parameters:
-
target
(Tensor
) –Ground truth tensor
-
fake
(Tensor
) –Fake = G(x) tensor
-
lam
(float
, default:10.0
) –lambda to weigh gradient penalty, by default 10.
Returns:
-
Tensor
–Computed penalty using provided target, fake and self.critic
train_step(batch)
Training step
Parameters:
-
batch
(Dict[str, Tensor]
) –Batch coming from training dataloader containing LD input and HD ground truth.
val_step(batch_idx, batch)
Validation step
Parameters:
-
batch_idx
(int
) –Batch idx necessary for logging of samples.
-
batch
(Dict[str, Tensor]
) –Batch coming from validation dataloader containing LD input and HD ground truth.