ldctbench.methods.wganvgg.Trainer
Trainer(args, device)
Bases: BaseTrainer
Trainer for WGAN-VGG1
-
Q. Yang et al., “Low-dose CT image denoising using a generative adversarial network with wasserstein distance and perceptual loss,” IEEE Transactions on Medical Imaging, vol. 37, no. 6, pp. 1348–1357, Jun. 2018. ↩
Parameters:
-
args(Namespace) –Arguments to configure the trainer.
-
device(device) –Torch device to use for training.
gradient_penalty(target, fake, lam=10.0)
Compute gradient penalty for given target and fake.
Parameters:
-
target(Tensor) –Ground truth tensor
-
fake(Tensor) –Fake = G(x) tensor
-
lam(float, default:10.0) –lambda to weigh gradient penalty, by default 10.
Returns:
-
Tensor–Computed penalty using provided target, fake and self.critic
train_step(batch)
Training step
Parameters:
-
batch(Dict[str, Tensor]) –Batch coming from training dataloader containing LD input and HD ground truth.
val_step(batch_idx, batch)
Validation step
Parameters:
-
batch_idx(int) –Batch idx necessary for logging of samples.
-
batch(Dict[str, Tensor]) –Batch coming from validation dataloader containing LD input and HD ground truth.