Data Augmentations for Improved (Large) Language Model Generalization


Journal article


Amir Feder, Yoav Wald, Claudia Shi, S. Saria, David Blei
NeurIPS, 2023

Semantic Scholar ArXiv
Cite

Cite

APA   Click to copy
Feder, A., Wald, Y., Shi, C., Saria, S., & Blei, D. (2023). Data Augmentations for Improved (Large) Language Model Generalization. NeurIPS.


Chicago/Turabian   Click to copy
Feder, Amir, Yoav Wald, Claudia Shi, S. Saria, and David Blei. “Data Augmentations for Improved (Large) Language Model Generalization.” NeurIPS (2023).


MLA   Click to copy
Feder, Amir, et al. “Data Augmentations for Improved (Large) Language Model Generalization.” NeurIPS, 2023.


BibTeX   Click to copy

@article{amir2023a,
  title = {Data Augmentations for Improved (Large) Language Model Generalization},
  year = {2023},
  journal = {NeurIPS},
  author = {Feder, Amir and Wald, Yoav and Shi, Claudia and Saria, S. and Blei, David}
}

Abstract

The reliance of text classifiers on spurious correlations can lead to poor generalization at deployment, raising concerns about their use in safety-critical domains such as healthcare. In this work, we propose to use counterfactual data augmentation, guided by knowledge of the causal structure of the data, to simulate interventions on spurious features and to learn more robust text classifiers. We show that this strategy is appropriate in prediction problems where the label is spuriously correlated with an attribute. Under the assumptions of such problems, we discuss the favorable sample complexity of counterfactual data augmentation, compared to importance re-weighting. Pragmatically, we match examples using auxiliary data, based on diff-in-diff methodology, and use a large language model (LLM) to represent a conditional probability of text. Through extensive experimentation on learning caregiver-invariant predictors of clinical diagnoses from medical narratives and on semi-synthetic data, we demonstrate that our method for simulating interventions improves out-of-distribution (OOD) accuracy compared to baseline invariant learning algorithms.