Skip to content

ldctbench.utils.metrics

Losses(dataloader, losses='loss')

Bases: object

Object to log losses

Examples:

In a training, losses can be logged as follows:

>>> from ldctbench.utils.metrics import Losses
>>> # Setup training
>>> dataloader = {"train": DataLoader(Dataset("train"), ...), "val": DataLoader(Dataset("val"), ...)}
>>> losses = Losses(dataloader)
>>> # Perform training and validation routine
>>> for epoch in n_epochs:
>>>     # Train model
>>>     for batch in dataloader["train"]:
>>>         x, y = batch
>>>         y_hat = model(x)
>>>         loss = criterion(y_hat, y)
>>>         # Perform weight optim
>>          ...
>>>         losses.push(loss, "train")
>>>     losses.summarize("train")
>>>     # Validate model
>>>     for batch in dataloader["val"]:
>>>         x, y = batch
>>>         y_hat = model(x)
>>>         loss = criterion(y_hat, y)
>>>         losses.push(loss, "val")
>>>     losses.summarize("val")
>>> # Log
>>> losses.log(savedir, epoch, 0)
>>> losses.plot(savedir)

Parameters:

  • dataloader (Dict[str, DataLoader]) –

    Dict containing the dataloaders

  • losses (str, default: 'loss' ) –

    Name of losses, by default "loss"

log(savepath, iteration, iterations_before_val)

Log losses to wandb and local file

Parameters:

  • savepath (str) –

    Where to store losses .csv file

  • iteration (int) –

    Current iteration

  • iterations_before_val (int) –

    Number of of training iterations before validation

plot(savepath, y_log=True)

Plot losses to a file

Parameters:

  • savepath (str) –

    WHere to store pdf file

  • y_log (bool, default: True ) –

    Plot y axis in logarithmic scale, by default True

push(loss, phase, name='loss')

Push loss to object

Parameters:

  • loss (Union[Dict[str, Tensor], Tensor]) –

    Single loss or dict of losses

  • phase (str) –

    To which phase the loss(es) to push belongs

  • name (str, default: 'loss' ) –

    Name of loss (only necessary if loss is not a dict), by default "loss"

reset()

Reset losses (add new epoch)

summarize(phase)

Summarize losses for this epoch

Parameters:

  • phase (str) –

    For which phase to summarize loss

Metrics(dataloader, metrics, denormalize_fn=None)

Bases: object

Object to log metrics

Examples:

In a training, metrics can be logged as follows:

>>> from ldctbench.utils.metrics import Metrics
>>> # Setup training
>>> dataloader = {"train": DataLoader(Dataset("train"), ...), "val": DataLoader(Dataset("val"), ...)}
>>> metrics = Metrics(dataloader, metrics=["SSIM", "PSNR", "RMSE"])
>>> # Perform training and validation routine
>>> for epoch in n_epochs:
>>>     # Train model
>>>     ...
>>>     # Validate model
>>>     for batch in dataloader["val"]:
>>>         x, y = batch
>>>         y_hat = model(x)
>>>         metrics.push(y_hat, y)
>>>     metrics.summarize()
>>> # Log
>>> metrics.log(savedir, epoch,0)
>>> metrics.plot(savedir)

Parameters:

  • dataloader (Dict[str, DataLoader]) –

    Dict containing the dataloaders

  • metrics (str) –

    Name of metrics to log. Must be RMSE | SSIM | PSNR

  • denormalize_fn (Optional[Callable], default: None ) –

    Function to use for denormalizing images before computing metrics, by default None

log(savepath, iteration, iterations_before_val)

Log metrics to wandb and local file

Parameters:

  • savepath (str) –

    Where to store losses .csv file

  • iteration (int) –

    Current iteration

  • iterations_before_val (int) –

    Number of of training iterations before validation

plot(savepath)

Plot metrics to a file

Parameters:

  • savepath (str) –

    WHere to store pdf file

push(targets, predictions)

Compute metrics for given targets and predictions

Parameters:

  • targets (Union[ndarray, Tensor]) –

    Ground truth reference

  • predictions (Union[ndarray, Tensor]) –

    Prediction by the network

Raises:

  • ValueError

    If shape of predictions and targets are not identical

  • ValueError

    If metric provided in init function is not in SSIM | PSNR | RMSE

reset()

Reset metrics (start new epoch)

summarize()

Summarize metric for this epoch