learning representations for counterfactual inference...learning representations for counterfactual...
TRANSCRIPT
Learningrepresentationsforcounterfactualinference
FredrikD.Johansson*2,UriShalit*1,DavidSontag1*EqualContribution
2
1NIPS2016DeepLearningSymposiumDecember2016
Talktodayabouttwopapers
•FredrikD.Johansson,UriShalit,DavidSontag“LearningRepresentationsforCounterfactualInference”ICML2016
•UriShalit,FredrikD.Johansson,DavidSontag“Estimatingindividualtreatmenteffect:generalizationboundsandalgorithms”arXiv:1606.03976
Code:https://github.com/clinicalml/cfrnet
Causalinferencefromobservationaldata•Patient“Anna”comesinwithhypertension
• Asian,54,historyofdiabetes,bloodpressure150/95,…•Whichofthetreatments𝑡 willcauseAnnatohavelowerbloodpressure?• Calciumchannelblocker(𝑡 = 1)• ACEinhibitor(𝑡 = 0)
•Datasetofobservationaldatafrommanypatients:medications,bloodtests,pastdiagnoses,demographics…
Causalinferencefromobservationaldata
Howtobestuseobservationaldata for
individual-levelcausalinference?
Causalinferencefromobservationaldata:Jobtraining•1,000unemployedpersons• Jobtrainingprogramwithcapacityof100
• Training(𝑡 = 1)• Notraining(𝑡 = 0)
•Whoshouldgetjobtraining?• Forwhichpersonswilljobtraininghavethemostimpact?
•Observationaldataaboutthousandsofpeople:jobhistory,jobtraining,education,skills,demographics…
•Datasetoffeatures,actionsandoutcomes
•Wedonotcontroltheactions•Wedonotknowthemodelgeneratingtheactions
Observationaldata
Causalinferencefromobservationaldataandreinforcementlearning
•Robotonthesideline,learningbyobservingotherrobotsplayingrobotfootball
•Sideline-robotdoesnotknowtheplaying-robots’internalmodel
•Formofoff-policylearning,learningfromdemonstration
OutlineBackgroundModelExperimentsTheory
OutlineBackgroundModelExperimentsTheory
•Patient“Anna”comesinwithhypertension• Asian,54,historyofdiabetes,bloodpressure150/95,…
•Whichofthetreatments𝑡willlowerAnna’sbloodpressure?• Calciumchannelblocker(𝑡 = 1)• ACEinhibitor(𝑡 = 0)
•Datasetofobservationaldatamedications,bloodtests,pastdiagnoses,demographics…
Causalinferencefromobservationaldata:Medication
Buildaregressionmodelfrompatientfeaturesandtreatmentdecisionstobloodpressure
predictedBP
predictedBP
• Buildregressionmodelfrompatientfeaturesandtreatmentdecisiontobloodpressure(BP)usingourobservationaldata
• Input:
• Compare
Regressionmodeling
Anna’sfeatures𝑥 𝑡 = 1
Anna’sfeatures𝑥
Output:
−
=?
𝑡 = 0
𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑦)… …
𝑡treatment
𝑥features
Regressionmodeling
ℎ
𝒍𝒐𝒔𝒔(𝒉 𝒙, 𝒕 , 𝒚)… …
𝑡treatment
𝑥features
Regressionmodeling
ℎ
𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑦)… …
ℎ
𝒕treatment
𝑥features
Regressionmodeling
Notsupervisedlearning!
•Thisisnotaclassicsupervisedlearningproblem•Supervisedlearningisoptimizedtopredictoutcome,nottodifferentiatetheinfluenceof𝑡 = 1 vs.𝑡 = 0
•Whatifourhigh-dimensionalmodelthrewawaythefeatureoftreatment𝑡?
•Maybethere’sconfounding:youngerpatientstendtogetmedication𝑡 = 1olderpatientstendtogetmedication𝑡 = 0
Potentialoutcomes(Rubin&Neyman)Foreverysample𝑥 ∈ 𝒳,andtreatment𝑡 ∈ {0,1},thereisapotentialoutcome𝑌=|𝑥
Bloodpressurehadtheyreceivedtreatment1
Bloodpressurehadtheyreceivedtreatment0 𝑌?|𝑥
𝑌@|𝑥
Individualtreatmenteffect𝑰𝑻𝑬 𝒙 := 𝔼 𝒀𝟏 − 𝒀𝟎|𝒙Weobserveonlyonepotentialoutcome,andnotatrandom!
Example– patientbloodpressure(BP)
Factual(observed)set(age,gender,treatment)
BPaftermedication
(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165
Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}
Example– patientbloodpressure(BP)
Factual(observed)set(age,gender,treatment)
BPaftermedication
(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165
Counterfactual set(age,gender,treatment)
BPaftermedication
(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌@ =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?
Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}
Example– patientbloodpressure(BP)
Factual(observed)set(age,gender,treatment)
BPaftermedication
(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165
Counterfactual set(age,gender,treatment)
BPaftermedication
(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌@ =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?
Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}Predictionset
Factual(observed)set(age,gender,treatment)
BPaftermedication
(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165
Counterfactual set(age,gender,treatment)
BPaftermedication
(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌? =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?
Predictionset
• Closelyrelatedtounsuperviseddomainadaptation
• Nosamplesfromthetestset• Can’tperformcross-validation!
OutlineBackgroundModelExperimentsTheory
OutlineBackgroundModelExperimentsTheory
OurWork•Newneural-netbasedrepresentationlearningalgorithmwithexplicitregularizationforcounterfactualestimation
•State-of-the-artonpreviousbenchmarkandonreal-worldcausalinferencetask
•Firsterrorboundforestimatingindividualtreatmenteffect(ITE)
Features𝑥
Control,𝑡 = 0Treated,𝑡 = 1
Whenisthisproblemeasier?RandomizedControlledTrials
Randomizedtreatmentàcounterfactualandfactualhaveidenticaldistributions
Features𝑥
Control,𝑡 = 0Treated,𝑡 = 1
Whenisthisproblemharder?Observationalstudy
Treatmentassignmentnon-randomàcounterfactualandfactualhavedifferentdistributions
Learningmorebalancedrepresentations
Features𝑥
Control,𝑡 = 0Treated,𝑡 = 1
Learningmorebalancedrepresentations
Features𝑥
RepresentationΦ(𝑥)
Control,𝑡 = 0Treated,𝑡 = 1
Learningmorebalancedrepresentations
𝑝=VWX=WY(𝑥)
Features𝑥
RepresentationΦ(𝑥)
Control,𝑡 = 0Treated,𝑡 = 1
Learningmorebalancedrepresentations
𝑝Z[\=V[](𝑥)
Features𝑥
RepresentationΦ(𝑥)
Control,𝑡 = 0Treated,𝑡 = 1
Learningmorebalancedrepresentations
𝑝^=VWX=WY(𝑥)
Features𝑥
RepresentationΦ(𝑥)
Control,𝑡 = 0Treated,𝑡 = 1
Learningmorebalancedrepresentations
𝑝^Z[\=V[](𝑥)
Features𝑥
RepresentationΦ(𝑥)
Control,𝑡 = 0Treated,𝑡 = 1
𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑌=)… …
ℎ
𝑡treatment
𝑥features
NaïveNeuralNetworkforestimatingindividualtreatmenteffect(ITE)
𝑙𝑜𝑠𝑠(ℎ Φ, 𝑡 , 𝑌=)… …Φ
𝑡
RepresentationΦ
Predictionℎ
𝑡treatment
𝑥features
VanillaNeuralNetworkforCounterfactualRegression(CFR)
BalancingNeuralNetworkforCounterfactualRegression(CFR)
𝑙𝑜𝑠𝑠(ℎ Φ, 𝑡 , 𝑌=)… …Φ
𝑡
𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
RepresentationΦ
𝑡treatment
𝑥features
Predictionℎ
…
𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
RepresentationΦ
Predictionℎ
𝑑𝑖𝑠𝑡 𝑝^=VWX=WY, 𝑝^Z[\=V[] :
MMDdistance(Gretton etal.2012)Wassersteindistance(Villani2008,Cuturi 2013)
…
𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
RepresentationΦ
Predictionℎ
InspiredbyDomainAdversarialNetworks(Ganin etal.,2016):
(source domain, target domain) à(treated population, control population)
𝑑𝑖𝑠𝑡 𝑝^=VWX=WY, 𝑝^Z[\=V[] :
MMDdistance(Gretton etal.2012)Wassersteindistance(Villani2008,Cuturi 2013)
OutlineBackgroundModelExperimentsTheory
OutlineBackgroundModelExperimentsTheory
EvaluatingcounterfactualinferenceTrain-testparadigmbreaksNoobservationsfromthecounterfactual“test”setCan’tdocross-validationforhyper-parameterselection
1)Simulateddata:IHDP(Hill,2011)2)Realdata:NationalSupportedWorkstudy(LaLonde,1986,Todd&Smith2005)Theeffectofjobtrainingonemployment andincomeObservationalstudywitharandomizedcontrolledtrialsubset
EvaluatingcounterfactualinferenceTrain-testparadigmbreaksNoobservationsfromthecounterfactual“test”setCan’tdocross-validationforhyper-parameterselection
1)Simulateddata:IHDP(Hill,2011)2)Realdata:NationalSupportedWorkstudy(LaLonde,1986,Todd&Smith2005)Theeffectofjobtrainingonemployment andincomeObservationalstudywitharandomizedcontrolledtrialsubset3212samples,8features incl.educationandpreviousincome
Evaluatingmodelswithrandomizedcontrolledtrialsdata• Wecan’tdirectlyevaluateindividualtreatmenteffect(ITE)errorbecauseweneverseethecounterfactual
• EveryITEestimatorimpliesapolicy𝐼𝑇𝐸c 𝑥 = 𝑓(𝑥)
Policy𝜋f,g:𝒳 → {0,1}Treatallpersons𝑥with𝑓 𝑥 > 𝜆, forthreshold𝜆
• Everypolicy𝜋hasapolicy-value:
𝔼 𝑌@ 𝜋 𝑥 = 1 𝑝 𝜋 = 1 + 𝔼 𝑌? 𝜋 𝑥 = 0 𝑝 𝜋 = 0
RandomizedControlledTrial Policy𝜋
Agreement
Evaluatingmodelperformanceusingrandomizeddata(off-policyevaluation)
Control,𝑡 = 0Treated,𝑡 = 1
Policyvalue:𝔼 𝑌@ 𝜋 𝑥 = 1 𝑝 𝜋 = 1 + 𝔼 𝑌? 𝜋 𝑥 = 0 𝑝 𝜋 = 0
• NationalSupportedWork:randomizedtrialembeddedinanobservationalstudy• Policyriskestimatedonrandomizedsubsample• CFR-2-2:ourmodel,with2 layersbeforeΦand2layersafterΦ
Method Policyrisk(std)
ℓ@-reg. logisticregression 0.23±0.00BART(Chipman,George &McCulloch,2010) 0.24±0.01Causalforests(Wager& Athey,2015) 0.17±0.006CFR-2-2 Vanilla 0.16±0.02CFR-2-2Wasserstein 0.15±0.02CFR-2-2MMD 0.13±0.02
Experimentalresults– NationalSupportedWorkStudy
Loweris
better
Experimentalresults– NationalSupportedWorkStudy
Loweris
better
CausalforestCFR2-2VanillaCFR2-2MMDRandomPolicy
OutlineBackgroundModelExperimentsTheory
OutlineBackgroundModelExperimentsTheory
Theoryofcausaleffectinference
•Standardresultsinstatistics:asymptoticrateofconvergencetotrueaverageeffect•Assumptions:weknowtruemodel(consistency)
•Ourresult:generalizationerrorboundforindividual-levelinference•Assumptions:truemodellieswithinlargemodelfamily,e.g.boundedLipschitzfunctions
Theorem (informal) • Let 𝑌m=
^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@
^,n(𝑥) − 𝑌m?^,n(𝑥)
• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:
𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t
^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1
• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?
^,n(𝑥)
• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:
Theorem (informal)
ExpectederrorinestimatingITE
𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t
^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
Theorem (informal)
“supervisedlearninggeneralizationerror”
𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t
^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1
• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?
^,n(𝑥)
• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:
Theorem (informal)
Distancebetween𝛷-induceddistributions
𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t
^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1
• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?
^,n(𝑥)
• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:
• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1
• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?
^,n(𝑥)
Theorem (informal)
Weminimizeupperboundwithrespectto𝛷andℎ
𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t
^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])
Summary•EstimatingIndividualTreatmentEffectisdifferentfromsupervisedlearning• Bearsstrongconnectionstodomainadaptation
•WegivenewrepresentationlearningalgorithmsforestimatingIndividualTreatmentEffect• UsetheMMDandWassersteindistributionaldistances
• Experimentsshowourmethodiscompetitiveorbetterthanstate-of-the-art
•Wegiveanewerrorbound forestimatingIndividualTreatmentEffect
• FredrikD.Johansson,UriShalit,DavidSontag“LearningRepresentationsforCounterfactualInference”ICML2016
• UriShalit,FredrikD.Johansson,DavidSontag“Estimatingindividualtreatmenteffect:generalizationboundsandalgorithms”arXiv:1606.03976
Acknowledgments:JustinChiu(FAIR),MarcoCuturi (ENSAE/CREST),JenniferHill(NYU),Aahlad Manas (NYU),Sanjong Misra (U.Chicago),EstebanTabak (NYU)andStefanWager(Columbia)
Thankyou!