12 Robustness
This book is work in progress. We are happy to receive your feedback via science-book@christophmolnar.com
Machine learning systems should not only work under laboratory conditions – they should work in the wild! And yes, we mean that in the true sense of the word.
Imagine you are an animal ecologist studying the diversity and conservation of species in the Serengeti. You know that machine learning systems allow you to identify, count, and even describe animals from images alone, as illustrated in Figure 12.1. There are tons of motion sensor cameras throughout the Serengeti. Together with the predictions by the machine learning model, you’ll soon have an amazing dataset to tackle your research questions. Norouzzadeh et al. (2018) have indeed done an impressive job with their machine learning model: Their ensemble of trained convolutional neural networks (CNN) achieves 94.9% accuracy on the Snapshot Serengeti dataset (Swanson et al. 2015) – that is a comparable performance to human labelers. Sounds like a reliable tool to build your research on, right?
To human eyes, Figure 12.2 is the same image as Figure 12.1, but for the model, it is not. The model gets it all wrong now: the species, the count, and even the description. What has happened? The image is a so-called adversarial example. All current machine learning models, especially image classifiers, are susceptible to such well-engineered variations in the input that are imperceptible to the human eye but lead the prediction model astray.
But Norouzzadeh et al. (2018) didn’t do anything wrong. Quite the contrary! To us, the paper gives a role model for how we would like to see machine learning used in science:
- They explain in detail how they deal with class imbalances in the data and target leaks (see Chapter 8) , which can occur because the camera sensors take three photos in succession.
- They take into account label noise and conduct confidence thresholding (see Chapter 13).
- They provide detailed reports on both overall performance and class performance (see Chapter 15) and usable open-source code (see Chapter 14).
But what about the fact that machine learning models are not robust to adversarial examples? Adversarial examples have kicked off the robustness debate in machine learning (Szegedy et al. 2013), and we will explore their occurrence in some detail at the end of this chapter Section 12.7. But it is not that any kind of noise would fool the animal classifier. Especially with real images taken in the wild, the model performs very well. And in scientific applications, there is usually no adversary that deceives your model on purpose – except perhaps Reviewer 2. We believe that scientists should not get sleepless nights from deceiving artificial input data that never occurs in practice. As you will see in this chapter, there are much more profound robustness issues around machine learning that should worry you!
The raven Goodman and his family live in the small town of Hatebird. It is a nice little city, with a gorgeous gothic church, the famous Hatebird Park, and a rose garden – but it’s also a crazy dangerous place. Parental supervision and tough self-experience taught Goodman to fear all humans; be they big ones with shooting metal sticks or small ones with stones. Goodman has lost too many friends and family members to these butcherly monsters. He therefore learnt a simple prediction rule: every human is a life threat. A correct prediction for the entire Hatebird population – 100% accuracy. After a group of Goodman’s archenemies burned down his home (luckily he and his loved ones were on a trip), Goodman and his family decide enough is enough. They are going to leave town. After traveling for days, they arrive in Raven’s Heaven, a little town in the south. It also has many beautiful parks and fountains, and the trees are beyond comprehension. However, there are humans everywhere; they have built fancy little houses for ravens (traps, no doubt) and even tried the oldest trick in the book and fed the birds (poisoned, no doubt). How can the other ravens be so stupid as to fall for these cheap tricks? Goodman and his family decide to hide in a dead tree outside of the city, finally a place without humans. They suffer a hard time, bad food, strong winds, and little water, but at least it is safe.
12.1 What does robustness mean?
In everyday language, robustness is a property of a single entity. A washing machine can be robust if it works for many years without trouble. A person can be robust if she can handle many situations competently. However, this leaves very much open to what robustness is about except for functioning well in general.
To detect robustness problems and solve them systematically, we must operationalize what we mean by robustness. We always need to specify the following ingredients (Freiesleben and Grote 2023):
- Robustness target is the thing that should be robust. For example, you might be interested in how robust the performance of the animal classifier is.
- Robustness modifier is the thing with respect to which the robustness target should be robust. For example, this could be the images on which you apply your animal classifier.
- Modifier domain specifies the relevant changes in the modifier to which the target should be robust. For example, this could be changes in the background lighting of the images.
- Target tolerance specifies how much the target is allowed to change if the modifier changes within the modifier domain. For example, you might be fine if the model performance decreases a bit for images with a darker background as long as the performance does not drop drastically.
This allows us to clearly define what we mean by robustness.
Definition: We call a robustness target robust with respect to a modifier if relevant interventions to the modifier within the domain do not lead to greater changes than specified by the target tolerance.
We can now talk in a more nuanced way about robustness. For example, the wildlife image classifier performance (robustness target) is relatively robust to changes in the lighting conditions when taking the images (relevant interventions) but less robust to targeted modifications (irrelevant interventions) on the input pixels. In this example, the modifier is the data distribution. Ultimately, to judge whether your model is suitable for a specific application, you have to check if it is robust to relevant interventions on modifiers. This forces you to think about the changes you expect to occur in deployment!
12.2 The two functions of robustness: auditing and robustifying strategies
Broadly speaking, there are usually two functions that robustness research deals with:
- Auditing: Is your model performance robust to relevant modifier changes? For example, you can check how the animal classifier performs when you darken the background of the test data.
- Robustifying strategies: What can you change in data collection and model selection to make your model more robust to relevant modifier changes? For example, if you train your model with more images taken at night, your model will generally perform better in this environment.
Both functions are interacting. You may audit the model and detect potential robustness weaknesses. To mitigate these weaknesses, you apply robustifying strategies. But what should you audit for?
12.3 Understand sources and types of data distribution shifts
In practice, the most crucial modifier is the data. Your training data may stem from one source, but if you now deploy your model, the data may look entirely different. The research literature distinguishes between different sources of data distribution shifts:
- Natural shifts occur because the natural conditions change. If you think of the animal classifier, you have to deal with:
- varying weather or lighting conditions,
- changes in the flora and fauna over time, or
- new camera sensors placed at novel locations.
- Performative shifts are induced by the model itself and its effects on the data. Imagine a model that predicts which diseases a person will develop based on their behaviour. Provided with these predictions, the person may change her behaviour and thus invalidate the prediction.
- Adversarial shifts occur due to attackers who modify the data. An example is given in Figure 12.2.
There is separate research literature for all these different sources (Hendrycks et al. 2021; Freiesleben and Grote 2023). For natural scientists, the most important distribution shifts are natural distribution shifts, whereas social scientists must care just as much about performative shifts. In our opinion, adversarial shifts are more of a problem in business or industry applications but not so much in science.
More technically, we can also distinguish different types of data distribution shifts. Remember from the earlier chapters (Chapter 3, Chapter 8, and Chapter 9) that data in supervised machine learning is described as pairs \((x^{(i)}, y^{(i)})_{i=1, \ldots, n}\) sampled from an underlying distribution \(\mathbb{P}(X,Y)\), where \(X:=(X_1,\dots,X_p)\) describes the input features and \(Y\) the target variable. We can express the different types of data distribution shifts directly in terms of the way \(X\) and \(Y\) are distributed and their relationship to each other:
- Covariate shift describes a case where the distribution \(\mathbb{P}(X)\) has changed. In the example above, a covariate shift can occur if a camera was previously in the open sun but is now in the shade of a plant.
- Label shift describes a case where the distribution \(\mathbb{P}(Y)\) has changed. The installation of a new camera in the previously ignored riverine forests in the Serengeti, for example, will lead to many more hippos being observed.
- Concept shift describes a case where the conditional distribution \(\mathbb{P}(Y\mid X)\) has changed. In the example of wildlife, this can happen when there is a new categorisation of species; previously there is only the gazelle category, but afterward it is subdivided into Grant’s gazelle, Thomson’s gazelle and so on.
Again, understanding what type of distribution shift you face can be vital for robustifying your model. But let’s first look into robustness audits!
12.4 Strategies to audit for robustness
Has the distribution shift already occurred or are you only anticipating it?
- Post-hoc audit: The distribution shift has occurred already. You have collected new real data after the shift and can evaluate the quality of your model on this data. In the wildlife example, you may have only trained your model with data from the dry season, and now, in the middle of the wet season, you want to evaluate the model’s performance.
- Anticipatory audit: The distribution shift has not yet occurred. So you have no real data about the expected changes. For example, if you are in the dry season and want to predict how your model will work in the wet season.
12.4.1 Post-hoc audit
A post-hoc audit is comparably simple – just analyze your model performance on your data. In the wildlife example, you might want to compare the model performance between dry (pre-shift) and wet season data (post-shift).1
Understand whether and how the data differs
But how can distribution shifts even be recognised? You need to constantly monitor your data and check the data properties. For tabular data, there are a variety of summary statistics like the mean values, standard deviations, ranges, or correlation coefficients. The situation is similar for text, where you have word frequencies, word lengths, similarity of word embeddings, cosine similarity, or text sentiment. When these statistics start to vary, your distribution shift alarm bells should ring. There are less effective summary statistics for images. Instead, eyeballing data samples through time taking into account domain knowledge can be more effective. For example, you can observe differences in brightness or colour between data from the wet and dry seasons.
There are also modality-independent automated strategies for detecting data distribution shifts – called out-of-distribution (OOD) detection (Yang et al. 2024). We discuss these strategies in more detail below in Section 12.5.2, when we talk about robustifying strategies.
Analyze the different errors
Just as important as understanding the distribution shift itself is understanding how it affects model performance. Comparing performance before and after the shift is only the tip of the iceberg of a proper audit, false positive or false negative rates may also differ. There are often contextual preferences for either bounding false positives or false negatives, e.g. false negatives can have very severe consequences in medical diagnosis. Similarly, grouping and comparing errors across labelling classes can be informative, e.g. the original model was able to distinguish different gazelle subspecies in the dry season based on their preferred habitat, but the strategy fails in the wet season.
Interpretation methods can point to the sources of errors
Interpretation methods (see Chapter 10) give you insight into the features on which the model relies upon and how these features impact model performance. Feature attribution techniques, for example, allow you to analyze the attention of image classifiers for specific predictions. You should be alerted if the model classifies Grant Gazelles by looking at the background rather than the animals (Ribeiro, Singh, and Guestrin 2016). For tabular data, you may compare global feature importances based on data before against after the distribution shift.
12.4.2 Anticipatory audit
An anticipatory robustness audit is more demanding than a post-hoc audit. The shift has not occurred yet – so you have no data after the shift. You need to specify the shift qualitatively based on the sources and types of shifts above. To also quantify the effect of an anticipated shift on model performance, you additionally have to generate synthetic data that reflects the shift.
Specify the shift qualitatively
What shift do you expect, a natural or a performative shift? Is it a shift in covariates, in labels, or even in concepts? Misspecifying the shift renders your audit a farce. For example, mistaking a concept shift for a covariate shift leads to entirely different auditing results and requires different robustifying strategies. Similarly, auditing your model for adversarial robustness in a non-adversarial environment setting is simply pointless. Consult your domain knowledge! What aspects of your domain do you expect to vary and to which of the above types of shifts do they belong? Think of the seasonal shift in the wildlife example: a wildlife researcher knows seasonal effects on the Serengeti environment and its wild inhabitants.
Generate (semi-)synthetic data and systematically test your model on it
For a quantitative audit, you need data that reflects the distribution shift. As you don’t observe but anticipate the shift, you have to generate data synthetically. You can translate your qualitative knowledge about the shift into a way to generate data. There are various strategies to do so, all denoted as data augmentation, we discuss them in depth below (see Section 12.5.3). The simplest approach is to apply transformations to your training data (Hendrycks et al. 2021), e.g. image filters that turn day data to night data (see Figure 12.4).
Suppose you have created a synthetic dataset, does this mean you finally reached the point you start at in a post-hoc audit? Yes and no, there remain two big differences: 1. You already have a profound understanding of the data because you generated it. 2. The insights your synthetic data offer depend on your assumptions about the shift.
Interpretation methods allow you to anticipate shifts and detect model weaknesses
Interpretation methods allow for an audit without specifying the expected distribution shift or synthetic data. Instead, you simply analyze your pre-distribution shift data in a more explorative way, e.g. if feature attribution methods indicate that the model relies on background features like trees to classify an animal, this could mean that the model performs worse under varying background conditions. This alerts you to both the potentially dangerous distribution shift (changes in background) and the reason why your model is failing (reliance on background). Similarly, if a spurious feature has a large feature importance, your model may fail in settings where this feature varies; A COVID risk predictor trained in Munich that relies on street names is likely to fail in Paris. Finally, feature effect methods such as individual conditional effect curves (ICE) and counterfactual explanations describe model behaviour in relevant counterfactual scenarios.
What you need to audit your model for, depends heavily on the application context. The wildlife model, which acts as a labeling tool, requires less extensive auditing than machine learning-based medical diagnosis. What audits are needed depends on the risks of errors, the instability of the environment, and other domain characteristics. For high-risk applications like in medicine, there are often additional legal requirements, such as those set out in European AI legislation.
12.5 Strategies to robustify your model
Say your audit has shown that your model is not robust to relevant distribution shifts. How can you robustify it? Below is a list of common robustifying strategies, ordered by their position in the machine learning pipeline. Each of them is discussed in more detail below:
- Control the source: Some shifts are under the control of the model authority. For example, the purchase of a new brand of camera sensor can lead to greater distribution shifts than sticking with the old brand.
- Filter out-of-distribution data: You could train a second model to filter out data that significantly differs from the training data. For example, the machine learning model should not provide predictions for animals that are not included in the training set, leaving these cases for human labelers.
- Gather data representative of the shift: You may gather additional (real or synthetic) data that accounts for the shift. For instance, one could use image filters to augment wet season data using dry season data.
- Carefully select and construct features: The lack of robustness in your model may be due to its reliance on incorrect features or improper feature encoding. For instance, by removing the image background, shifts in the background will no longer affect the model’s performance.
- Choose inductive biases wisely: Some shifts can be accounted for by adjusting your modeling assumptions (see Chapter 9). For counting animals, you may choose an architecture that generalizes to high animal counts like sequential subtizing (Chattopadhyay et al. 2017).
- Transfer learning to obtain robust representations: Often, your model is sensitive to distribution shifts because its learned representations overfit the training data. In such cases, transfer learning – reusing representations learned by other models trained on the same data modality – can help. Norouzzadeh et al. (2018) demonstrate that transferring representations from common image classifiers increases the data efficiency of the wildlife classifier.
The strategies you should choose to robustify your model should be informed by your audit. If you caused the distribution shift yourself, you might be able to control the shift. When your model performs poorly only on rare anomalies, filtering out these cases might be the best approach. If the distribution shift is unavoidable and causes a significant performance drop, the most common strategy is to augment your data and retrain the model. If the distribution shift is unavoidable but data augmentation is difficult, you may need to engineer features, adjust the modeling assumptions, or use transfer learning.
12.5.1 Control the source
It is usually not like distribution shifts are just happening and there is nothing you can do about them. You often have an active role – you caused the distribution shift by an action. You installed a camera sensor from a different brand and suddenly the performance drops? You put food next to the camera to spot more animals and suddenly your animal count bound reaches its limits? In some cases, rather than adapting the data or the model, it can be smarter to tackle the very source of the shift. You must enact control to guarantee a stable environment, e.g. by using only cameras from the same brand.
But not all distribution shifts are within your control. You cannot keep the Serengeti in the wet season all year. Other sources can be controlled but not without unwanted side effects: like installing constant lighting conditions around your camera, which may attract certain animals and repel others.
12.5.2 Filter out-of-distribution data
There will always be data that presents a challenge for your model. Especially data that is very different from the training data. Out-of-distribution (OOD) detectors enable you to filter out data on which your model would perform substantially worse. These data can then be handled separately, e.g. by a human labeler. By handling difficult data separately, OOD detection robustifies overall performance in deployment. However, OOD detectors only act as filters, they do not solve the initial prediction problem. For substantial data distribution shifts, OOD detectors therefore won’t help.
How to find out if a data point is OOD? With OOD detectors! Sometimes people further differentiate between methods designed to detect anomalies (rare data from different distributions), novelties (data from a shifting distribution), or outliers (rare data within training distribution) (Ruff et al. 2021; Yang et al. 2024). Here, we focus on OOD detection more generally. There are four different approaches (Yang et al. 2024):
- Classification-based methods phrase OOD detection as a classification problem. Some classification-based methods require data labeled as within and outside of the distribution. Others leverage uncertainty-aware classifiers and label data with high classification uncertainty as OOD.
- Density-based methods model explicitly the probability density of the training data. Data whose density lies below a certain threshold is labeled as OOD.
- Distance-based methods calculate the difference of a given datapoint to the centroid or prototypes in a dataset. Data whose distance surpasses a certain threshold is labeled OOD.
- Reconstruction-based methods use autoencoder methods to detect OOD data. Data with higher reconstruction error are labeled as OOD.
In our opinion, reconstruction-based techniques have advantages over the other approaches: Unlike classification-based methods, you do not need to label data as (non) OOD; Unlike density-based methods, you do not need to specify a model of the probability density; And unlike distance-based methods, you do not need to construct complex reference points such as centroids or prototypes. For these reasons, and also to provide you with an intuition on OOD detection in general, let us take a deeper look into reconstruction-based OOD detectors. To get an overview of all the different types of OOD detectors check out the review papers by Ruff et al. (2021), Yang et al. (2024), and Chalapathy and Chawla (2019).
Autoencoders describe a neural-network-based method to project high-dimensional inputs into a low-dimensional feature space (often called latent space) with a minimal loss of information. This is achieved through a two-stage architecture:
- Encoder: Projects the high-dimensional input into a prespecified low-dimensional latent space.
- Decoder: Maps the projected inputs from the latent space back to the original space.
The encoder and decoder mappings are optimized to minimize the reconstruction error of the training data. The reconstruction error of a given input describes the difference (according to some metric) between this input and the output we obtain after sequentially encoding and decoding the input. In our example, this would mean calculating the mean squared error between the initial Impala image and its reconstructed version that runs through the encoder and decoder network.
Out-of-distribution detection with autoencoders
Say \(x\) is the data point you want to classify as within or outside of the training distribution. Then the encoder network can be described as a mapping \(E:\mathbb{R}^h\rightarrow \mathbb{R}^l\) from a high-dimensional feature space \(\mathbb{R}^h\) to a low dimensional latent space \(\mathbb{R}^l\). Similarly, the decoder network is a mapping \(D:\mathbb{R}^l\rightarrow \mathbb{R}^h\) and \(L\) is a distance function on the input space \(\mathbb{R}^h\) (e.g. mean squared error). Then, the reconstruction error of \(x\) is defined as: \[\text{RE}(x):=L(x,E(D(x))).\] This allows us to define a simple OOD detector: \[OOD(x):=\begin{cases} \text{OOD}\quad\quad \text{if RE}(x)>\tau \\ \text{not OOD}\quad \text{else} \end{cases}\] Inuitively, any data with a reconstruction error above a certain threshold \(\tau\) counts as OOD. This primitive approach faces clear limitations:
- Also OOD data can have low reconstruction error. By extending the reconstruction error with the Mahalanobis Distance on the latent space, this problem can be tackled (Denouden et al. 2018).
- Training autoencoders is hard, you have to make architectural choices based on domain knowledge (e.g. CNN for images), define an appropriate latent space (i.e. just the right size to capture the training distribution without information loss), and choose an appropriate loss function (Ruff et al. 2021).
- Training autoencoders requires representative training data (Chalapathy and Chawla 2019), but in practice the data is often unbalanced.
12.5.3 Gather data representative of the shift
Why are distribution shifts a problem for machine learning models? In the case of a covariate shift (i.e. shift of \(\mathbb{P}(X)\)) or a label shift (i.e. shift of \(\mathbb{P}(Y)\)), your model fails because it has never seen this kind of data and is unable to extrapolate. In the case of a concept shift (i.e. shift in \(\mathbb{P}(Y\mid X)\)), your model may have faced similar inputs but the learned dependencies became unreliable.
In any case, the most prominent solution in the research literature for how to robustify your model to such shifts is the same – gather more data that reflects the distribution shift! There are two strategies to obtain such data:
- Gather real labeled data with active learning: You may want real labeled data. Active learning concerns the systematic search for data worth labeling.
- Augment your data: Gathering real data comes with high costs and you have limited control over which kind of data you get. Data augmentation is concerned with generating synthetic instances using relevant domain knowledge.
Active learning
Which data should you label to best robustify your model? Labeling all data is often too expensive and time-consuming. The literature differentiates three active learning setups, based on how the labeler receives the data (Settles 2009):
- In membership query synthesis any input in the input space is a potential candidate for labeling, even completely unrealistic inputs.
- In stream-based sampling you receive data sequentially and have to decide whether to label this data.
- In pool-based sampling you receive a big data sample and have to choose a subsample that you want to label.
Active learning is largely concerned with the automation of the selection process for labeling (Settles 2009; Ren et al. 2021). The data selection process can be based on high prediction uncertainty, proximity to the decision boundary, randomness, expected model error, expected training effect, or the data representativeness. In the example, Norouzzadeh et al. (2018) suggest to label those wildlife images with the highest prediction uncertainty. Note that all these different active learning strategies are readily available in recent Python packages like modAL (Danka and Horvath 2018) and ALiPy (Tang, Li, and Huang 2019).
Humans also often have intuitions about what data to track and label. People also often have intuitions about what data to track and label. When people incorporate their domain knowledge (Ciravegna et al. 2023) and dynamically interact with the model using interpretability techniques (Ghai et al. 2021), this can significantly improve active learning strategies.
Data augmentation
Data augmentation is about creating synthetic data. Wait, don’t you also need active learning to label these data? In some cases yes, namely if you want to find the label for an arbitrary input. However, when we talk about data augmentation, we usually create data for which we know the label. Data augmentation has been particularly the focus in computer vision (Mumuni and Mumuni 2022; Shorten and Khoshgoftaar 2019), but recently there is growing literature on data augmentation in the context of natural language processing (Feng et al. 2021).
We can generally differentiate the following two types of data augmentation (Mumuni and Mumuni 2022):
- Data transformation: Transformations applied to labeled data that are known not to change the label. Selecting the right transformations is an excellent way to incorporate domain knowledge. Focus on transformations you expect to occur in practice.
- Data synthesis: Creation of entirely new data with known labels. Data synthesis may be based on generative models or Computer Aided Design (CAD) models. The aim is to generate synthetic data that shares important properties with your training data but varies from it in relevant aspects.
In computer vision, there is a variety of standard transformations that have been considered: Geometric transformations change classical geometric properties, like the angle, position, direction, or size of an image; Photometric transformations on the other side focus on attributes like coloring, saturation, contrast, or camera artifacts. Transformations can concern the overall image or only a region. They may delete, replace, swap, or recombine regions of the image.
In natural language tasks, standard transformations are the random insertion/deletion/swapping of words in a sentence or substituting words in sentences with their synonyms. With a trained language model also more sophisticated transformations can be performed, such as: back translation, where sentences are translated back and forth between two languages; text paraphrasing, where the same content is rephrased without changing the meaning; and style transformations, where the same content is described in different styles (e.g. more formal vs less formal language).
Data synthesis is more challenging. Computer-aided design (CAD) allows to model physical objects by hand and perform various geometric and photometric transformations in a physically adequate way. Neural rendering learns 3D scene representations from 2D images and thereby avoids hand-crafted CAD modeling. The most common approach to data synthesis in computer vision is using generative models like generative adversarial networks (GANs) or variational autoencoders (VAEs). The models generate realistic images that can be sampled conditionally on a desired target class. In natural language processing, the best generator models are large language models like ChatGPT or LLaMA, which can be prompted to generate data with a certain content.
A GAN is a generative model that allows to generate highly realistic data (I. Goodfellow et al. 2014). It is trained using two sub-models:
- Generator: This neural network model is designed to generate data. It takes random noise from a (Gaussian) distribution as inputs and transforms the noise into ideally realistic data.
- Discriminator: This network is designed to distinguish real from artificial data. It takes inputs and decides whether they are real data or synthetic data.
The two sub-models are trained iteratively in a zero-sum game against each other. The generator aims to generate data that the discriminator mistakes for real data. The discriminator aims at perfectly separating real from synthetic data. In the course of the training, both models get better and better at their tasks, which ultimately leads to a well-performing generator model.
The divide between data transformations and synthesis is less straightforward than suggested above. Relevant transformations often go beyond simple geometric or photometric transformations. Think of changing the entire background or adding other animals to the Serengeti images. Complex transformations often demand data synthesis methods. For example, data augmentation GANs (Antoniou, Storkey, and Edwards 2017) and conditional GANs (Isola et al. 2017) allow to generation of new instances conditioned on a given data instance.
Common packages to perform data augmentation in Python are the Augmentor package for computer vision (Bloice, Stocker, and Holzinger 2017), ImageDataGenerator in Keras, and the natural language toolkit (NLTK) for natural language processing (Bird 2006).
What happens after you obtain the data?
Let’s say you have gathered the required data. What should you do? You could retrain the entire model: merge the newly collected data with your training data and run again your machine learning algorithm. Another strategy is to fine-tune your existing machine learning model by training it on the new data you collected. Indeed, fine-tuning sounds like less work than retraining but dependent on the fine-tuning specifics it often introduces a bias either towards the training data or the new data. Instead of retraining and fine-tuning, you may decide to train a new separate model exclusively on the newly collected data.
If you face a covariate or a label shift, retraining and fine-tuning are both reasonable strategies. Old and new data can be merged in one model as the predictive pattern between the two stays intact. Fine-tuning is particularly advisable if you gain a few high-quality data with active learning and you want to emphasize this data in training. In a concept shift, on the other hand, the predictive pattern between old and new data differs – training a separate model is the only option.
Does data augmentation really improve robustness?
This question is difficult to assess in general as it depends on the domain and the data augmentation approach. It is a mantra in machine learning that more data is always better. But what if the data is synthetic? In many settings, data augmentation indeed effectively improves robustness (Hendrycks et al. 2021; Mumuni and Mumuni 2022). However, in some cases, (adversarial) robustness conflicts with predictive performance on the original dataset (Tsipras et al. 2018). This is unsurprising: robustness means insensitivity to certain features, and insensitivity to predictive features leads to a performance drop. Whether we have to trade-off between (adversarial) robustness and performance depends on the data augmentation approach (Rebuffi et al. 2021). Interestingly, training your model exclusively on synthetic data may turn your model mad (Alemohammad et al. 2023)…
12.5.4 Carefully select and construct features
As we discussed in Chapter 3, the input features and the target feature are key modeling choices that every modeler faces. Achieving a robust model with unreliable features can be impossible. Imagine having to predict an animal’s species only on the background against which it was sighted. The slightest distribution shift will diminish model performance. Similarly, if you are forced to predict a specific species but 75% of the data contains no animals (Norouzzadeh et al. 2018), your model is doomed to fail from the start!
Choosing reasonable input and target features is challenging and requires (causal) domain knowledge (see Chapter 9 and Chapter 11). There exist various approaches to obtain better input features:
- Feature selection describes approaches to select an optimal subset of features from a feature set (Chandrashekar and Sahin 2014). The subset is chosen based on criteria like the (conditional) mutual information between input and target, or the performance of a trained classifier on the subset of input features (see e.g. conditional feature importance in Chapter 10). Feature selection reduces the dimensionality of the data and filters unreliable or noisy features, thereby improving robustness. Feature selection algorithms are often tailored to tabular data but are hard to apply to image, text, or speech data.
- Feature engineering describes approaches that transform the input features. One option is to apply hand-crafted transformations, which:
- describe statistical properties, e.g. interaction terms,
- encode domain knowledge, e.g. graph structure in graph neural networks,
- have physical meaning, e.g. edge detectors,
- reduce dimensionality, e.g. principal components analysis (PCA),
- provide useful encodings, e.g. bucketing or one-hot-encoding, or
- emphasize particularly important features, e.g. time.
Also, the target encoding can be improved. Norouzzadeh et al. (2018), for example, dissect their prediction target into two parts: In Task 1, a model discriminates between inputs with and without animals; In Task 2, a second model classifies the images that contain animals into different species. This dissection in two tasks substantially improves the classifier’s robustness. Similarly, target features often come in a hierarchical form (Vens et al. 2008): the Thompson Gazelle and the Grant Gazelle are different subspecies of Gazelles. Such additional structure can be encoded into hierarchical multi-label encodings like decision trees and thereby improve the robustness of the classifier.
12.5.5 Choose inductive biases wisely
The relationship between modeling choices and predictive performance has already been discussed extensively in Chapter 8 and Chapter 9. The key insight was – the better suited the inductive bias (e.g. model class, architecture, hyperparameters, loss, etc.) is to your learning environment, the faster you will learn a high-performing model. Data distribution shifts are nothing but learning environments. Modeling choices, therefore, provide another robustifying strategy. For example:
- CNNs are translational invariant. The position of the animal in the image does not affect model performance.
- Dropout improves the robustness to adversarial attacks by switching off certain neurons in training (Wang et al. 2018). The reason is that dropout enforces smoother representations in the latent space.
Improving robustness to distribution shifts with data augmentation and with inductive modeling biases are two sides of the same coin. The former turns knowledge about (anticipated) distribution shifts into data instances; The latter turns knowledge about (anticipated) distribution shifts into modeling assumptions. CNNs are only one example where this parallelism is evident, there are many similar architectural solutions that encode rotation and scaling invariance (Mumuni and Mumuni 2021). Similarly, graph neural networks allow encoding all forms of domain knowledge into the graph structure and node/edge properties (Corso et al. 2024; Wu et al. 2020; Battaglia et al. 2018).
Is it better to apply data transformations or to encode invariances into the inductive modeling assumptions? On the one hand side, some data transformations are difficult to encode as inductive biases into the model. Think about dry versus wet season data, there is no easy way to encode an invariance to seasonal changes. On the other hand, if it is possible to encode invariances as inductive biases, you should. Your model is guaranteed to obey them, whereas data augmentation only makes it more likely that the model learns invariances. Furthermore, there is a computational trade-off: more data requires more compute, whereas better inductive biases usually improve computational efficiency.
12.5.6 Transfer learning to obtain robust representations
Transfer learning is about transferring knowledge from one task to another. So learning a task does not have to be done from scratch but can build upon existing knowledge. While tasks often differ substantially, they share certain aspects. For example, classifying pets and wildlife animals both require learning higher-order representations that allow us to differentiate animals, even though the animals, their actions, and image backgrounds differ. By inducing more domain-general knowledge, transfer learning makes the model more robust to common distribution shifts. We distinguish different kinds of transfer learning according to the knowledge they transfer:
- Feature extraction builds on the the representations learned in one task to reuse them in another task. For example, say you have a general image classifier like ResNet or Inception V3 that has been trained on a huge dataset like ImageNET. Then, there are powerful representations stored in the neural network’s weights, activations, and nodes, which can be reused to make the wildlife classifier robust against common image permutations. Commonly feature extraction methods focus on the penultimate layer before the final prediction. Feature extraction is the most popular transfer learning technique and has also been used by Norouzzadeh et al. (2018) leading to better data efficiency.
- Fine-tuning uses the learned specifics of a model trained for one task as the starting point for another task. One may, for example, take the trained ResNet classifier, substitute the output layer, and train the model on the wildlife images. Unlike in feature extraction, the model can adapt the representations learned by ResNet on ImageNet and tailor them for the wildlife case. This was again performed by Norouzzadeh et al. (2018), however, without improving overall performance significantly.
- Multi-task learning concerns training a single core model that is used to perform multiple related tasks simultaneously. The idea is to enforce the core model to learn representations that are useful across different tasks, leading to improved representations and better robustness to common changes. One could, for example, use the same core model to classify trees, and wildlife animals, and predict daytime. By optimizing for such a diverse set of tasks, the core model has to learn representations that work on all of them.
- Self-supervised learning can be seen as one specific approach to learning a core model that is useful across a variety of tasks. It masks certain parts of the data and tries to infer them from the rest of the data. Thereby, self-supervision learns generalizing pattern. One could, for example, mask the animal’s heads to learn interdependencies between animal heads and their bodies.
Transfer learning can be key if data is scarce for the specific task at hand but widely available for related tasks. More and more there are core models for all data modalities like ChatGPT for text data or ResNet for image data. Based on vast amounts of data, these core models have learned such powerful representations that they make the models robust to a wide range of standard distribution shifts. We believe that in future scientific applications, such transfer learning from core models will play an essential role. The amount of data and knowledge that entered these models should not go unused.
12.6 Generalization and causality are closely linked to robustness
Remember the three different kinds of generalization that we discussed in Chapter 8? Generalization to predict in theory, generalization to predict in practice, and generalization to the phenomenon. Generalization to predict in theory concerns prediction for a static data distribution, indeed a natural requirement on machine learning models. Robustness usually goes one step further: the machine learning model should perform well in all practically relevant scenarios. That means the performance of the model should be robust under expected data distribution shifts. Generalization to predict in practice and robustness therefore often mean the same thing – the model should work under natural conditions.
Causality is about generalization to the phenomenon at hand. We want an accurate representations of the data-generating process. This constitutes a link between robustness to causality. Why? If we have an accurate representation of the phenomenon, then we can simulate all kinds of alternative scenarios and provide predictions under all kinds of distribution shifts:
- You know what led to the distribution shift? Then you can simulate the shift using your causal model and still provide optimal predictions with your causal model (Arjovsky et al. 2019; Kamath et al. 2021). Even your initial machine learning prediction model may perform robustly under certain shifts (König, Freiesleben, and Grosse-Wentrup 2023).
- You do not know what led to the distribution shift? Say you only receive data indicative of the shift. Then, causal models allow to generate data for various possible shifts and compare them to the observed data (Cranmer, Brehmer, and Louppe 2020).
Even beyond these cases, causal models have an essential property that makes them appealing for machine learning robustness research – causal models are modular. Say you train a machine learning model to learn the joint distribution of two variables \(A\) and \(B\), namely \(\mathbb{P}(A,B)\). Then, as soon as your distribution of \(A\) or of \(B\) changes, you have to learn an entirely new model.
Instead, assume you know that \(A\) causes \(B\). Then you can split your learning task into two components, namely \(\mathbb{P}(B\mid A)\) and \(\mathbb{P}(A)\). This provides you again with the joint distribution because \[\mathbb{P}(A,B)=\mathbb{P}(A)\mathbb{P}(B\mid A).\] But now if \(\mathbb{P}(A)\) shifts, all you need to update is your model of \(\mathbb{P}(A)\), whereas \(\mathbb{P}(B\mid A)\) must remain stable because it is a causal mechanistic relationship (see Chapter 11). This modularity makes causal relations worth learning!
12.7 The riddle of adversarial examples
Why does a model that is as good as Norouzzadeh et al. (2018) make mistakes like in Figure 12.2? Behind this question lies the riddle of adversarial examples. Adversarial examples are inputs that are modified in a way that makes them humanly indistinguishable from the original input but entirely shifts the model’s prediction, leading to misclassification. The reasons for this behavior are still only partially understood:
The first hypothesis was that adversarial examples describe unlikely instances that are not well represented in the original training data (Szegedy et al. 2013). However, if this were true, adding adversarials to the training data should get rid of the problem – the thing is, it didn’t (Zhang et al. 2019). Even if you train your model on a wide variety of adversarial examples there will still be new adversarial examples in the direct vicinity. Famously, I. J. Goodfellow, Shlens, and Szegedy (2014) came up with a new proposal. They suggest that adversarials arise in machine learning models because the models are too linear. Most machine learning models are still based on relatively linear activation functions like Rectified Linear Units (ReLUs). Thus, changing the input up to a certain norm leads (due to linearity) to a significant change in the prediction. This insight led to a new efficient algorithm to compute adversarial examples, namely the Fast Gradient Sign Method (I. J. Goodfellow, Shlens, and Szegedy 2014). The thing is, having linear activation functions turned out to be neither necessary nor sufficient for adversarials (Tanay and Griffin 2016). New adversarial examples showed up not only for ReLU-based neural network models but for all kinds of Machine learning models (Han et al. 2023). Moreover, these examples transfere between different models trained on the same dataset – you could generate an adversarial example for a ResNet model and apply it successfully to another model with an entirely different architecture (Papernot, McDaniel, and Goodfellow 2016). Linearity couldn’t explain this strange behavior.
Adversarial examples are not bugs, they are features
Ilyas et al. (2019) gave a novel explanation of adversarial examples – adversarial examples arise because human perception and machine learning perception operate differently. Machine learning models learn patterns that are stable across the dataset including patterns on which humans don’t rely on. Like classifying an elephant based on the texture of its skin rather than its shape or its trunk. Slight changes in the elephant’s skin will not change the human classification but machine learning prediction.
Ilyas et al. (2019) reach this explanation based on the following experiment: Say you have a trained model and then generate a set of adversarial examples for this model. Now, train a new model from scratch exclusively on these adversarial data (with the wrong labels). Surprisingly, this model performs strongly on the original data with correct labels. The only explanation for this behavior is that the adversarial examples are strange in a meaningful way, adversarials differ in features that allow to classify ordinary data points. This also explains why adversarial examples can transfer between different models trained on the same data, because they differ from real data in ways that are usually indicative of a different class. Unfortunately, this explanation of adversarials (which we deem plausible) does not come with easy fixes. It implies that adversarials can only be evaded by relying on the very same features as humans. However, doing so would limit the space of possible solutions and thereby the predictive performance. In science, new predictive patterns undiscovered by humans might be what you are ultimately after.
Another factor leading to adversarial examples is that machine learning models cannot distinguish causes from effects or spurious correlations. Thus, changing a feature that is associated with a different label class but does not cause the target to change will lead to a misclassification (Freiesleben 2022). Like changing the elephant’s skin texture similar to a hippo without touching causal features like their shape or their trunk. Adding causal assumptions may therefore protect against certain adversarial examples (Schölkopf 2022). But again, this is not an easy fix to use in practice. Adversarial examples remain a phenomenon that we have to live with and it is unclear if getting rid of the phenomenon is always desirable. In science, we might even strive for the weird predictive features invisible to humans.
12.8 Robustness is a constant challenge
Robustness in machine learning is not something you achieve. Robustness is like the boulder that Sisyphus2 keeps pushing to the top of the hill just to see it rolling down again. But if you want a reliable machine learning model, the only way is to keep on pushing! Your environment will change making your model performance drop eventually. What is important is to stay on your watch and constantly adapt.
The chapter should have made clear that robustifying models is multifaceted. Each step in your pipeline can contribute and each step should be kept dynamic to be able to react quickly to changes. Each chapter in this book is an ally on your neverending machine learning robustness journey, you need models that generalize (Chapter 8), you have to incorporate domain knowledge (Chapter 9), you must audit your model with interpretability techniques (Chapter 10), you should be aware of prediction uncertainties (Chapter 13), and ultimately you should incorporate causality (Chapter 11).
Many problems remain challenging in robustness research:
- Translating a lack of robustness into a robustifying strategy is difficult. Data augmentation can provide a first heuristic solution but there are no guarantees. Also, different shifts require different robustifying strategies.
- Adversarial examples remain relatively mysterious. Can they be avoided and if yes, should they be?
- Synthetic data is often the least costly approach to robustify your model. But, if the data does not resemble the characteristics of real data, training your model on it will be pointless.
Being a machine learning robustness engineer is like being a scientist – it is a constant struggle with nature. Nevertheless, paraphrasing Camus3, we must imagine machine learning robustness engineers as happy people.
To evaluate your model performance on the new data it must be labeled. We discuss in Section 12.5.3 how to label data systematically using active learning.↩︎
Sisyphus was a figure in Greek mythology, known for his cunning and deceitful nature. He was the King of Corinth and was infamous for his trickery and deceitfulness, which often got him into trouble with the gods. One of the most famous stories involving Sisyphus is his punishment in the afterlife. According to Greek mythology, after his death, Sisyphus was condemned by the gods to roll a boulder uphill for eternity, only for it to roll back down every time he neared the top. This endless and futile task became known as Sisyphean and is often used as a metaphor for a task that is never-ending and ultimately feels pointless.↩︎
Albert Camus was a French philosopher, author, and journalist who is best known for his existentialist works. In his philosophical essay The Myth of Sisyphus (1942), Camus explores the philosophical concept of the absurd – the inherent conflict between the human desire to find meaning in life and the universe’s indifference to human concerns. Camus argues that despite the apparent pointlessness of Sisyphus’s task, he can still find happiness and meaning in his existence by embracing the absurdity of life and finding fulfillment in the act of rebellion against it. He therefore famously stated that “we must imagine Sisyphus happy”. Indeed, connecting these deep thoughts to machine learning robustness is a bit far-fetched…↩︎