NectarGAN API - LossManager
The loss manager is one of the core features of the NectarGAN API. It is a drop-in solution for managing, tracking, and logging everything related to loss in your model.
Source: nectargan.losses.loss_manager.LossManager
Key Features
- Builds an easy to use wrapper around around any loss function, allowing you to evaluate loss functions in your training script in a way which is as easy as calling loss functions traditionally, but which dramatically expands the backend functionality of any loss function registered with the
LossManager. - Caches loss function results in multiple formats with easy to use mechanisms for recalling the values during training.
- An intelligent cache management system allows mean loss values to be cached to memory, and dumped to a JSON log at your discretion, or automatically if a configurable cache limit is reached.
- Quickly initialize a configurable objective function with a pre-built loss spec, or register your own loss functions with the
LossManagerto build your own model objective from scratch, while still being able to use all the QOL features that the LossManager offers. You can even define your own reusable loss specs to feed to theLossManager, and it will take care of the rest.
LossManager Dataclasses
To understand how the LossManager functions, and how it manages the data for the losses that are registered with it, we first have to take a quick look at two dataclasses which are at the core of it's functionality. These are:
1. nectargan.losses.lm_data.LMHistory
Starting with the simpler of the two dataclasses, LMHistory only has one job: store previous loss value history.
Every loss function registered with a LossManager instance has an LMHistory instance assigned to it in a way which will be explained momentarily. An LMHistory instance contains just two lists, they are dual-purpose, however. If loss logging is enabled, i.e. LossManager(enable_logging=True), these two lists will be used to store the mean value of the loss result tensor and the current weight value of the loss, both as 32bit floating point values, every time the parent loss function is called via LossManager.compute_loss_xy().
If logging is disabled, however, each time LossManager.compute_loss_xy() is called for a given loss, both lists in that loss's LMHistory are cleared, after which time the new values are appending to each list. In practice, this means that if enable_logging=False each list will only store a single value, the most recent loss mean and weight respectively, at any given time.
2. nectargan.losses.lm_data.LMLoss
This dataclass is responsible for storing all information about a registered loss. For every loss function that is registered with a LossManager instance via LossManager.register_loss_fn(), an LMLoss instance is created which describes the loss function. A full description of the values contained with an LMLoss instance can be seen by clicking on the above link, but here is a rough outline:
name: a string, unique to this registered loss, which is used for lookup by variousLossManagerfunctions.function: a reference to thetorch.nn.Modulefor the loss function. This can be almost anyModuleas long as it has a forward function that returns a tensor. One caveat is that, currently, loss functions registered with theLossManagercan only accept two input tensors for loss computation (y,y_fake), although I do plan to expand that at some point in the future.loss_weight: The weight value (lambda) to apply to the resulting loss tensor when it is called, before the tensor is returned byLossManager.compute_loss_xy().schedule: AScheduleobject defining a weight schedule for the given loss. If noScheduleis provided when the LMLoss is initialized, the providedloss_weightwill be used for the duration of training.last_loss_map: This is not set when initializing anLMLossobject. It is instead initialized as a dummy tensor, and then used by theLossManagereach time the parent loss function is run to store a detached version of the resulting loss tensor so they can be recalled for visualization. store history.history: This is also not set at init-time. A uniqueLMHistoryobject is automatically created and assigned to every loss registered with the loss manager.tags: An optional list of strings containing identifier tags which can be used to search for and filter registered losses in variousLossManagerfunctions.
Using the LossManager
Initializing a new LossManager instance:
from nectargan.config.config_manager import ConfigManager
from nectargan.losses.loss_manager import LossManager
config_manager = ConfigManager('path/to/config.json')
loss_manager = LossManager(
config=config_manager.data,
experiment_dir='/path/to/experiment/output/directory')
Register a new loss function with a LossManager instance:
import torch.nn as nn
L1 = nn.L1Loss().to(config_manager.data.common.device)
loss_manager.register_loss_fn(
loss_name='mylossfunction',
loss_fn=L1,
loss_weight=100.0,
tags=['descriptive_lookup_tag'])
[!WARNING] The
loss_nameyou assign to your loss function when you register it must be unique amongst all other loss functions registered with thatLossManagerinstance. If you attempt to register a loss function with a name that is already registered, theLossManagerwill raise an exception.
Running your loss function via the loss manager will return the result of the given loss function's forward() function as a torch.Tensor. The Tensor that is return has had weights pre-applied by LossManager.compute_loss_xy() -> LossManager._weight_loss(), based on whatever the current weight value of the registered loss is:
import torch
y = torch.Tensor() # Ground truth
y_fake = torch.Tensor() # Generator output
result: torch.Tensor = loss_manager.compute_loss_xy(loss_name='mylossfunction', x=y_fake, y=y)
Querying Registered Loss Data
Data relating to losses registered with a given LossManager instance can be retrived in a variety of ways, dependent upon exactly what data you are trying to query, and what format you would like it returned in.
Querying LMLoss Objects Directly
The most flexible method is to just query the raw LMLoss objects directly. This can be done as follows:
losses: dict[str, LMLoss] = loss_manager.get_registered_losses(query=None)
This will return all registered loss functions as a dict. The key for each loss will be the name it was registered with. So for our above example, we could then query any info related to our mylossfunction function like this:
mylossfn: LMLoss = losses['mylossfunction'] # Query the LMLoss object
lossfn = mylossfn.function # Get the loss function module
loss_map = mylossfn.last_lost_map # Get the most recent loss result as a torch.Tensor
[!NOTE] There is one thing to be aware of when querying the
LMLossobjects directly like this. Were you to do this, expecting to get the loss value and weights history lists:python values : dict[str, float] = mylossfn.history.losses weights: dict[str, float] = mylossfn.history.weightsYou would find thatweightsandvaluesare empty lists. This is intentional.LossManager.get_registered_losses()has an optional flag calledstripwhich defaults toTrue. If this flag is not overridden with a valueDalse, the losseshistory.valuesandhistory.weightslists are cleared.The reasoning behind this is that dependent upon what the
history_buffer_sizeof theLossManageris set at, these lists can get fairly long. And if you have a significant amount of registered losses, passing them around can become a fairly heavy task, so they are stripped by default to reduce the memory overhead. If you need these values for whatever reason, though, just callLossManager.get_registered_losses()withstrip=False. All that said, however, as long as theLossManager'shistory_buffer_sizeis kept to below a reasonable value (i.e. ~100,000), the cost realistically isn't all that concerning.
Querying Loss Values (as a dict)
Loss values can be retrieved in dictionary form as follows:
values: dict[str, float] = loss_manager.get_loss_values(precision=2)
This will return a dictionary of floating point values with lookup keys matching the corresponding loss's name. The optional precision argument tells the function how many digits after the decimal point to round the returned loss values to. The default is 2 to keep things tidy if you want to print or otherwise log the values, but you can increase it if you need more precise return values.
[!IMPORTANT]
LossManager.get_loss_values()uses the last value stored inLossManager.historyfor each loss. As such, this function should generally be called AFTER all of the registered loss functions have been run for the batch. Calling it before running some or all of the loss funtions could lead to unexpected results.
Querying Loss Values (as a tensor)
The most recent loss tensor, which is detached and stored in the loss function's LMLoss every time the loss is called with LossManager.compute_loss_xy(), can be queried with:
tensors: torch.Tensor = loss_manager.get_loss_tensors()
[!IMPORTANT]
LossManager.get_loss_tensors()uses the last value stored in theLMLossobjectslast_loss_map. As such, this function should generally be called AFTER all of the registered loss functions have been run for the batch. Calling it before running some or all of the loss funtions could lead to unexpected results.
Querying Loss Weights
The current weight value of all registered losses can be retrieved as follows:
weights: dict[str, float] = loss_manager.get_loss_weights(precision=2)
This will return a dictionary of floating point values with lookup keys matching the corresponding loss's name. The values will represent the current weight of the loss at the time that the function was run. The optional precision argument tells the function how many digits after the decimal point to round the returned weight values to. The default is 2 to keep things tidy if you want to print or otherwise log the values, but you can increase it if you need more precise return values.
[!IMPORTANT]
LossManager.get_loss_weights()uses the last weight value stored inLossManager.schedule.current_value. Since this value is updated immediately after the associated loss has been run, the weight values that this function returns are the ones which will be applied the next time the loss is run.
Using Tags to Query Registered Losses
All of the above loss querying functions accept an optional argument called query, which is a list of strings. This can be used to query loss values by tag; only loss values which have a tag matching one of the strings in the input query argument will be returned.
For example, when we registered our loss function above, we assigned it a tag of descriptive_lookup_tag. Were we to then want to query losses with that tag, it would look something like this:
lossfns: dict[str, LMLoss] = loss_manager.get_registered_losses(query=['descriptive_lookup_tag'])
values : dict[str, float] = loss_manager.get_loss_values(query=['descriptive_lookup_tag'])
tensors: dict[str, Tensor] = loss_manager.get_loss_tensors(query=['descriptive_lookup_tag'])
weights: dict[str, float] = loss_manager.get_loss_weights(query=['descriptive_lookup_tag'])
Dumping Cached Values to Loss Log
[!NOTE] This section is only applicable if the given
LossManagerinstance was intialized withenable_logging=True.
When a LossManager instance is initialized, an optional argument can be passed called history_buffer_size. This value defines how many values (stored as 32bit floating point values) can be stored in each list (i.e. losses, weights) of each LMLoss's LMHistory container. By default, this is 50000. So, with the default value, each registed loss is allowed to store 50,000 unique previous loss values and 50,000 unique previous loss weight values. 100,000 32bit floats per registered loss. This is a totally acceptable fallback value on any modern system (you could mutiply this value by 100 and still be totally safe) and this value felt fine on every dataset I tested it on, but you are also welcome to set the value higher when you initialize the LossManager. The Toolbox Pix2pixTrainerWorker actually bypasses this cap altogether by always setting the buffer size to a value that is slightly higher than what would be dictated by the options selected in the UI.
Whenever any registered loss is run from any LossManager instance, the LossManager first does a quick check to see if the buffer for that loss is full. If it is, it will dump the loss's buffer to the log, clear the buffer, then run the loss function it was originally going to run and appends the result to the now freed up buffer. This is a fine way to handle loss logging with the LossManager. It is set and forget, just give it a value for the history_buffer_size and the LossManager will take care of the loss logging and memory management from there.
However, if you would like more control over when exactly the logs are dumped (at the end of each epoch, for example, as in the core Pix2pixTrainer's on_epoch_end() function. Or maybe each x number of epochs, as in the Toolbox Pix2pixTrainerWorker's run() function), you can instead force the LossManager to dump its buffers to the loss log with:
loss_manager.update_loss_log(silent=True, capture=False)
[!TIP]
LossManager.update_loss_log()has two optional boolean arguments,silent(default=True) andcapture(default=False). IfsilentisFalse, theLossManagerwill print a string to the console after it has dumped it values to the log with a timestamp showing how long it took to perform the operation. If silent isFalseand the other optional argument,capture, isTruehowever, the string that would have been printed is instead returned by the function.
Loss Specs?
Loss specifications (specs) are a novel way to define reusable objective functions for your models. At their core, loss specs are just standard Python functions but with one specific requirement, the must return a dictionary of string-mapped LMLoss objects.
See here for more information on loss specifications
Convenience Functions
LossManager.print_weights()
Prints (or optionally returns) loss weight information.
Note: This function uses the last weight value stored in:
- `LossManager.schedule.current_value`
Since this value is updated immediately after the associated loss has
been run, the weight values that this function prints (or returns) are
the ones which will be applied the next time the loss is run.
By default, this function will print a string of all registered losses
and their most recent weights formatted as:
"Loss weights: {L_1_N}: {L_1_W} {L_2_N}: {L_2_W} ..."
Key:
L_X_N : Loss X name
L_X_W : Loss X weight