Offline model-based reinforcement learning with causal structured world models

Zhengmao ZHU, Honglong TIAN, Xionghui CHEN, Kun ZHANG, Yang YU

Front. Comput. Sci. ›› 2025, Vol. 19 ›› Issue (4) : 194347.

PDF(2392 KB)
Front. Comput. Sci. All Journals
PDF(2392 KB)
Front. Comput. Sci. ›› 2025, Vol. 19 ›› Issue (4) : 194347. DOI: 10.1007/s11704-024-3946-y
Artificial Intelligence
RESEARCH ARTICLE

Offline model-based reinforcement learning with causal structured world models

Author information +
History +

Abstract

Model-based methods have recently been shown promising for offline reinforcement learning (RL), which aims at learning good policies from historical data without interacting with the environment. Previous model-based offline RL methods employ a straightforward prediction method that maps the states and actions directly to the next-step states. However, such a prediction method tends to capture spurious relations caused by the sampling policy preference behind the offline data. It is sensible that the environment model should focus on causal influences, which can facilitate learning an effective policy that can generalize well to unseen states. In this paper, we first provide theoretical results that causal environment models can outperform plain environment models in offline RL by incorporating the causal structure into the generalization error bound. We also propose a practical algorithm, oFfline mOdel-based reinforcement learning with CaUsal Structured World Models (FOCUS), to illustrate the feasibility of learning and leveraging causal structure in offline RL. Experimental results on two benchmarks show that FOCUS reconstructs the underlying causal structure accurately and robustly, and, as a result, outperforms both model-based offline RL algorithms and causal model-based offline RL algorithms.

Graphical abstract

Keywords

reinforcement learning / offline reinforcement learning / model-based reinforcement learning / causal discovery

Cite this article

Download citation ▾
Zhengmao ZHU, Honglong TIAN, Xionghui CHEN, Kun ZHANG, Yang YU. Offline model-based reinforcement learning with causal structured world models. Front. Comput. Sci., 2025, 19(4): 194347 https://doi.org/10.1007/s11704-024-3946-y

1 Introduction

Offline Reinforcement Learning (RL) is a learning paradigm where policies are learned entirely from previously collected data. Offline RL is gaining popularity since it enables RL algorithms to scale to several real-world applications, e.g., autonomous driving [1] and healthcare [2], where trial-and-error is too expensive. In the offline setting, Model-Based Reinforcement Learning (MBRL) is a popular framework that involves learning a predictive environment model for policy optimization [3]. The effectiveness of this approach relies on the accuracy of the learned environment model.
However, current offline MBRL approaches usually have poor generalization because the environment models tend to capture spurious correlations that only exist in offline data, resulting in erroneous predictions. For instance, in autonomous driving, if offline data is acquired from a driver who always turns on the wiper and brake pedals on rainy days, such a preference will result in a spurious correlation between “the wiper is turned on” and “the speed is dropped” in the data, which will also be captured by the environment model. Once we employ this environment model for policy learning, the agent will likely urge the driver to turn on the wiper when the vehicle’s speed is too high because it believes that “the wiper is turned on” has an effect on “the speed is dropped”, which is not sensible and potentially hazardous. Intuitively, leveraging the causal structure of observed variables can avoid taking spurious correlations as causal influences and thus facilitate the learning of an environment model with enhanced generalizability. Recent empirical evidence also indicates that inducing the causal structure is important to improve the generalization ability of deep learning models [4-7]. Despite such evidence, it is still unknown whether and how the causal structure improves model generalization in offline RL.
For this purpose, we first provide theoretical support for the aforementioned intuition: we show that a causal environment model can outperform a plain environment model on generalization for offline RL. From the causal perspective, we categorize the variables in states and actions into two groups: causal variables and spurious variables. We then formalize the process of learning an environment model with both types of variables in states. On the basis of the formalization, we quantify the effect of spurious dependencies on the generalization error bound and thereby demonstrate that integrating causal structures can assist in minimizing this bound.
We also propose a practical offline causal MBRL algorithm, FOCUS, to illustrate the feasibility of learning causal structure in offline RL. Learning the causal structure from offline data, also known as causal discovery from observations [8], is a crucial phase of FOCUS. However, causal discovery from observations requires a huge number of hypothesis testing, which is computation-consuming. To tackle this problem, we utilize the time-series property in RL data to reduce the number of hypothesis testing. Specifically, we incorporate the constraint that the future cannot cause the past in the PC algorithm [8], which seeks to uncover causal relationships based on inferred conditional independence relations. Consequently, we can reduce the number of conditional independence tests and determine the causal direction. In addition, we employ kernel-based conditional independence tests [9], which can be applied to continuous variables without assuming a specific functional form between the variables or a particular data distribution.
In conclusion, this paper makes the following key contributions.
● It theoretically demonstrates that a causal environment model outperforms a plain environment model in offline RL with respect to the generalization error bound.
● It proposes a practical algorithm, FOCUS, and illustrates the feasibility of learning and employing a causal environment model for offline MBRL.
● Our experimental results validate the theoretical claims, showing that FOCUS outperforms baseline models and other existing causal MBRL algorithms in the offline setting.

2 Related work

The RL algorithms with causal structure learning can be roughly divided by the type of their causal discovery methods. First, we will discuss relevant causal discovery methods, followed by related RL algorithms.

2.1 Causal discovery methods

Causal discovery methods can be categorized into two main types: intervention-based methods and observation-based methods. Intervention-based methods determine the causal relationship between variables by intervening on them and obtaining a counterfactual outcome different from what is currently observed. By contrast, observation-based methods rely solely on existing data to discover causal relationships without performing any intervention. Since interventions are not feasible in the offline reinforcement learning environment discussed in this paper, we will primarily focus on observation-based causal discovery methods.
Observation-based causal discovery methods can generally be divided into two categories: constraint-based methods and score-based methods. Constraint-based methods use statistic tests (conditional independent tests) to find the causal skeleton and determine the causal directions up to the Markov equivalence class. For example, [10] is based on explicit estimation of the conditional densities or their variants, which exploit the difference between the characteristic functions of these conditional densities. The estimation of the conditional densities or related quantities is difficult, which deteriorates the testing performance especially when the conditioning set is not small enough. Score-based methods evaluate the quality of candidate causal models with some score functions and output one or multiple graphs having the optimal score [11]. [12] discretizes the conditioning set to a set of bins and transforms conditional independence (CI) to the unconditional one in each bin. Inevitably, due to the curse of dimensionality, as the conditioning set becomes larger, the required sample size increases dramatically.

2.2 Causal discovery in reinforcement learning

Causal discovery methods can be utilized in policy learning and model learning in RL.
In policy learning, [7] proposes a causal imitation learning algorithm, which learns the causal structure between states and actions. However, it assumes that we can query experts for actions and use interventioned data to do causal discovery, which is impractical in offline RL.
In world model learning for RL, various methods have been proposed but can hardly be used in offline RL. For example, [13] (LNCM) views data sampled from different policies as data with soft interventions and use score-based methods with the log-likelihood on “interventional” data as the score function. Its implicit assumption that data is sampled from multiple policies and data has been labeled by its sampling policy is not a general assumption in offline RL, which only holds true in online RL. [14] samples causal structures from Multivariate Bernoulli distribution and scores those structures according to the log-likelihood on interventional data. Based on the scores, it calculates the gradients for the parameters of the Multivariate Bernoulli distribution and updates the parameters iteratively. [6] utilizes the speed of adaptation to learn the causal direction, but it does not provide a complete algorithm for learning causal structure. Given the properties of offline RL that the sampling policy has unknown preferences and interactions with the environment are not available, the above methods are not practical for learning the causal structure in offline RL.

3 Preliminaries

3.1 Markov decision process (MDP)

We describe the RL environment as an MDP with five-tuple S,A,P,R,γ [15], where S is a finite set of states; A is a finite set of actions; P is the transition function with P(s|s,a) denoting the next-state distribution after taking action a in state s; R is a reward function with R(s,a) denoting the expected immediate reward gained by taking action a in state s; and γ[0,1] is a discount factor. An agent chooses actions a according to a policy aπ(s), which updates the system state sP(s,a), yielding a reward rR(s,a). The agent’s goal is to maximize the expected cumulative return by learning a good policy maxπ,PE[t=1TγtR(st,at)]. The state-action value Qπ of a policy π is the expected discounted reward of executing action a from state s and subsequently following policy π: Qπ(s,a):=R(s,a)+γEsP,aπ[Qπ(s,a)].

3.2 Offline model-based reinforcement learning

In the offline RL setting, the algorithm only has access to a static dataset D={(s,a,r,s)} collected by one or a mixture of behavior policies πB, and further interactions with the environment is not available. When we use model-based approaches to solve offline RL problems, we will learn a virtual environment model P^ for transition prediction from offline data. With the learned environment model P^, we can define a new MDP S,A,P^,R,γ. Similarly, we can also define the value function Q^π with P^. A standard model-based RL algorithm (in an online setting) learns a virtual model by fitting it using a maximum-likelihood estimate of the trajectory-based data collected by running the latest policy, which guarantees that the virtual model can always be accurate when the policy keeps exploring [16,17]. In the offline RL setting, where we only have access to the data collected by previous policies, the accuracy of the virtual model in exploring policy cannot be guaranteed. Therefore recent techniques all build on the idea of pessimism that regularizes the original problem based on how confident the agent is about the learned model [3,18]. Specifically, the policy only visits the states where the learned model is confident in predictions.

4 Theoretical advantages of causal world models

In this section, we present theoretical evidence demonstrating the superior performance of a causal environment model over a traditional environment model in the context of offline RL. This underscores the importance of leveraging an appropriate causal structure to minimize the generalization error bounds for offline MBRL algorithms. More specifically, we formalize a process that discerns the environment model burdened with spurious correlations, and subsequently, quantifies their influence. We then evaluate the effects of these spurious correlations in relation to the generalization ability. These evaluation metrics encompass the model prediction error bound and the policy evaluation error bound within the offline RL framework.
For the purpose of simplicity, we operate under the assumption that the causal structure is accurately known. As a result, our theoretical findings do not account for errors that could potentially arise from inaccurate structures. We further assume that the causal relationships are linear and that all causal variables are fully observed. For a comprehensive understanding of our theoretical proof, details are detached in Appendixes A1.

4.1 Problem formalization

Plain world models. In the context of this paper, we use the term “plain world models” to refer to the traditional forward dynamics model, Pθ(st|st1,at1), which ignores the underlying causal structure between the current state-action pair and the future state:
Pθ(st|st1,at1)=Πi=1nP(st,i|st1,at1),
where st,i denotes the ith dimension of the state st.
Causal world models. By contrast, we use the term “causal world models” to refer to the forward dynamics model with the causal structure:
Pcausal(st|st1,at1)=Πi=1nP(st,i|jPaist1,j,at1),
where Pai denotes the dimensions of state st1 that are causal parents of variable st,i and jPaist1,j denotes the variable set that contains the causal parent of variable st,i.
Model learning. Model learning in MBRL aims to predict the future state given the current state and action, which can be viewed as a supervised learning problem. For simplicity, we formula it as supervised learning, i.e., use X to represent (st1,at1) and Y to represent one dimension of st.
Let D denotes the data distribution where we have samples (X,Y)D,XRn. The goal of model learning is equivalent to learning a linear function f to predict Y given X. From the causal perspective, Y is generated from only its causal parent variables rather than all the variables in X. Therefore we can category the variables in X into two groups, X= (Xcausal,Xspurious).
Xcausal represents the causal parent variables of Y, that is, Y=Xcausalβ+ϵcausal, where β is the ground truth and ϵcausal is a noise with zero mean and independent to Xcausal, i.e., Xcausal⊥⊥ϵcausal.
Xspurious represents the spurious variables that are independent to Y, i.e., Xspurious⊥⊥Y, but Xspurious and Xcausal have strong relatedness in biased offline data. In other words, Xspurious can be predicted by Xcausal with small error, i.e., Xspurious=Xcausalγspurious+ϵspurious, where ϵspurious is the regression error with zero mean and small variance σspu.
For clearly representation, we use XcauXωcau ( represents element-wise multiplication) to replace Xcausal, where cau records the indices of Xcausal in X and (ωcau)i=I(icau). Correspondingly, we also use XspuXωspu to replace Xspurious.
Example: Suppose that X=[x1,x2,x3] and (x1,x3) are causal parent, we have ωcau=[1,0,1] and ωspu=[0,1,0], and then Xcau=[x1,0,x3], Xspu=[0,x2,0].
According to the definition of Xcau, we have Y=(Xωcau)β+ϵcau, where ωcauβ is the global optimal solution of the optimization problem:
Definition 1 (Optimization problem 1).
minβE(X,Y)D[XβY]2.
However, in the offline setting, we only have biased data Dtrain sampled by given policy πtrain, where the optimization problem becomes:
Definition 2 (Optimization problem 2).
minβE(X,Y)Dtrain[XβY]2.

4.2 Model prediction error bound

In this subsection, we assume a causal structure of the RL environment and spurious relations in offline data. We highlight that these spurious relationships can contribute to the formation of an ill-posed problem within the model learning process, hence increasing the model prediction error bound. Specifically, we quantitatively measure the impact of spurious relations on model learning as shown in Theorem 1. This theorem provides the relations between the properties of spurious relations and the model prediction error bound.

4.2.1 The spurious theorem

In Lemma 1, we prove that the optimization problem 2 has multiple optimal solutions due to the strong linear relatedness of Xspu and Xcau in Dtrain.
Lemma 1 Given that ωcauβ is the optimal solution of optimization problem 1, suppose that in Dtrain, Xspu= (Xωcau)γspu+ϵspu where EDtrain[ϵspu]=0 and γspu0, we have that β^spuωcau(βλγspu)+λωspu is also an optimal solution of optimization problem 2 for any λ:
E(X,Y)Dtrain[(|X(ωcauβ)Y|2)X]=E(X,Y)Dtrain[(|Xβ^spuY|2)X].
The most popular method for solving such an ill-posed problem is to add a regularization term for parameters β [19]:
Definition 3 (Optimization problem 3).
minβE(X,Y)Dtrain[XβY]2+kβ2,
where k is a coefficient.
The form of optimization problem 3 corresponds to the form of the ridge regression, which provides a closed-form solution of k by Hoerl-Kennard formula [20]. Given the closed-form solution of k, we can derive the optimal solution of optimization problem 3, which corresponds a solution of λ.
In the following, we will first introduce the solution of λ with given k in Lemma 2, and then introduce the model prediction error bound with λ in Theorem 1. For ease of understanding, we provide a simple version where the dimensions of Xcau and Xspu are both one (|Xcau|= |Xspu|=1).
Lemma 2 (λ Lemma). Let σcau2 denote the variance of the noise ϵcau and σspu2 denote the variance of ϵspu, given k in optimization problem 3 chosen by Hoerl-Kennard formula that k=σspu2(β)2, we have the solution of λ in optimization problem 3 that:
λ=σcau2βγspukσcau2σspu2+σcau2γspu2k+σcau2k+σspu2k+k2=βγspuβ2+γspu2+1+σspu2σcau2(1+1(β)2).
Based on Lemma 2, we can find that the smaller σspu2 (it means that Xspu and Xcau have stronger relatedness in the training dataset Dtrain), the larger λ. And we prove that the coefficient λ is bounded:
Proposition 1 Given λ as in Lemma 2, we have its bound that 12λ12.
Theorem 1 (Spurious theorem). Let D={(X,Y)} denote the data distribution, β^spu denote the solution in Lemma 1 with λ in Lemma 2, and Y^spu=Xβ^spu denote the prediction. Suppose that the data value is bounded: |Xi|1Xmax,i=1,...,n and the error of optimal solution ϵcau is also bounded: |ϵcau|1ϵc, we have the model prediction error bound:
E(X,Y)D[(|Y^spuY|1)X]Xmax|λ|1(|γspu|1+1)+ϵc.
Theorem 1 shows that
● the upper bound of the model prediction error |Y^spuY|1 increases by Xmax|λ|1(|γspu|1+1) for each induced spurious variable Xspu in the model,
● when Xspu and Xcau have stronger relatedness (which means a bigger λ), the increment of the prediction model error bound led by Xspu is bigger.

4.3 Policy evaluation error bound

Building on the model prediction error bound established earlier, we proceed to delineate the policy evaluation error bound, a metric that offers a more direct evaluation of the world model in MBRL. In this subsection, we apply the spurious theorem (Theorem 1) to offline MBRL. Our investigation explores how the policy evaluation error bound escalates in response to an increase in the number of spurious variables and the strengthening of spurious relationships.
Assuming that the state value is bounded as well as the reward that |St,i|1Smax,RtRmax, and denoting the maximum of λ as λmax and the maximum of |γspu|1 as γmax, we establish the policy evaluation error bound in Theorem 2.
Theorem 2. (RL spurious theorem) Given an MDP with the state dimension ns and the action dimension na, a data-collecting policy πD, let M denote the true transition model, Mθ denote the learned model that Mθi predicts the ith dimension with spurious variable sets spui and causal variables caui, i.e., S^t+1,i=Mθi((St,At)ωcauispui). Let VπMθ denote the policy value of the policy π in model Mθ and correspondingly VπM. For an arbitrary bounded divergence policy π, i.e. maxSDKL(π(|S),πD(|S))ϵπ, we have the policy evaluation error bound:
|VπMθVπM|22Rmax(1γ)2ϵπ+Rmaxγ2(1γ)2Smax[nsϵc+(1+γmax)λmaxns(ns+na)Rspu],
where
Rspu=i=1ns|spui|ns(ns+na),
which represents the spurious variable density, that is, the ratio of spurious variables in all input variables.
Theorem 2 shows the relation between the policy evaluation error bound and the spurious variable density, which indicates that:
● When we use a non-causal model that all the spurious variables are input, Rspu reaches its maximum value R¯spu<1. By contrast, in the optimal causal structure, Rspu reaches its minimum value of 0.
● The density of spurious variables Rspu and the correlation strength of spurious variables λmax both influence the policy evaluation error bound. However, if we exclude all the spurious variables, i.e., Rspu=0, the correlation strength of spurious variables will have no effect.

5 FOCUS algorithm

Having established the necessity of a causal environment model in offline RL, we proceed in this section to propose a practical offline MBRL algorithm, FOCUS, to illustrate the feasibility of learning and using causal structure in offline RL. Our algorithm consists of two distinct steps: the discovery of the causal structure from offline data, followed by the effective integration of the identified structure with an offline MBRL algorithm. During the first step, the offline setting imposes limitations on the selection of causal discovery methods. Specifically, methods that involve interventions or randomized experiments are deemed unsuitable as interactions with the environment are prohibited. When considering approaches for observational data, score-based methods necessitate data from distinct sampling policies and a known form of causal mechanisms, neither of which are present in offline data collected from a single policy. Constraint-based methods, on the other hand, do not presuppose any specific form of causal mechanisms, but they fail to distinguish structures within one Markov equivalence class and may prove inefficient due to numerous independence tests. To address these challenges, FOCUS extends the PC algorithm (a derivative of constraint-based methods) and rectifies its shortcomings by incorporating sequential information. In the second step, FOCUS initializes the environment model with the learned causal structure and then proceeds to learn the environment model along with the policy.

5.1 Preliminary

5.1.1 Assumptions

First, we assume the causal Markov condition and faithfulness in the environment transition, with which we can use conditional independence tests to infer the causal graph [8]. Second, we claim that the offline data and the data obtained through the learned policy share the same causal structure, through which the learned structure can be applied in unseen states. Specifically, policy preference affects the relations between variables by causing quantitative changes in causal relations and spurious relations in independent relations. Thanks to the ability to distinguish between spurious correlations and causal influences, the policy preference will not result in qualitative changes in the causal structure. Consequently, the above statement holds true in offline RL.

5.1.2 Conditional independence test

Independence and conditional independence (CI) plays a central role in causal discovery [8,21,22]. Generally speaking, the CI relationship X⊥⊥YZ allows us to drop Y when constructing a probabilistic model for X with (Y,Z). There are multiple CI testing methods for various conditions, which provide the correct conclusion only given the corresponding condition. The kernel-based Conditional Independence test (KCI test) [9] is proposed for continuous variables without assuming a specific functional form between the variables as well as the data distributions.

5.1.3 Conditional variables

Besides the specific CI test method, the conclusion of conditional independence testing also depends on the conditional variable Z, that is, different conditional variables can lead to different conclusions. Taking the triple (X,Y,Z) as an example, there are three typical structures, namely, Chain, Fork, and Collider as shown in Fig.1. Chain: There exists causation between X and Y but conditioning on Z leads to independence. Fork: There does not exist causation between X and Y but not conditioning on Z leads to non-independence. Collider: There does not exist causation between X and Y but conditioning on Z leads to non-independence.

5.2 Causal structure learning

5.2.1 Applying the Independence test in RL

Based on the preliminaries, given the two target variables X,Y and the condition variable Z, the KCI test returns a probability value p=fKCI(X,Y,Z)[0,1], which measures the probability that X and Y are conditionally independent given the condition Z. To transform a probability into a binary conclusion of whether the causation exists, we design a threshold p that
Causation(X,Y)=I(fKCI(X,Y,Z)p){0,1},
where Causation(X,Y)=1 represents independence and 0 represents that causation exists. Details of choosing p can be found in Appendixes A2.
In model learning of RL, variables are composed of states and actions of the current and next timesteps and the causal structure refers to whether a variable in t timestep (e.g., the ith dimension, Xti) causes another variable in t+1 timestep (e.g., the jth dimension, Xt+1j). With the KCI test, we get the causal relation through the function Causation(,) for each variable pair (Xti,Xt+1j) and then form the causal structure matrix G:
Gi,j=Causation(Xti,Xt+1j),
where Gi,j is the element in row i and column j of G.

5.2.2 Choosing the conditional variable in RL

As stated in preliminaries, unsuitable conditional variables can reverse the conclusion of independence testing. The conditional variable set must include the intermediate variable of Chain and the common parent variable of Fork, but not the common son variable of Collider. Traditionally, the independence test traverses all possible combinations of the conditional variables and then reaches a conclusion, which is inefficient. However, in RL we can reduce the number of conditional independence tests by imposing the restriction that the future cannot cause the past. Actually, this constraint restricts the possible conditional variable sets to a tiny number. Consequently, we can have a classified discussion for every feasible collection of conditional variables. For simplicity, we eliminate two types of scenarios from the discussion:
Fig.1 The three basic structures for (X,Y,Z)

Full size|PPT slide

Impossible situations. We exclude some impossible situations as Fig.2(i) (bottom left) by the temporal property of data in RL. Specifically, the direction of the causation cannot be from the variable of t+1 time step to that of t time step because the effect cannot happen before the cause.
Compound situations. We only discuss the basic situations and exclude the compound situations, e.g., Fig.2(j) (bottom right), which is a compound of Figs. 3(a) and 3(c). In such compound situations, the target variables Xti and Xt+1j have direct causation (or it can not be a compound situation). When a causal relationship exists, it is impossible to arrive at a conclusion of “independence” in an independence test. Hence, we do not need to be concerned about coming to the conclusion that a causal relationship does not exist, thereby mistakenly excluding causal parent variables.
As seen in Fig.2, we list all conceivable circumstances involving target variables Xti,Xt+1j and condition variable Xt/t+1k in the environment model. With the preliminary knowledge of causal discovery, we investigate the following fundamental situations:
Top line: In Figs. 3(a) and 3(b), whether Xtk is included in the conditional variable set does not affect the conclusion of causation; In Fig. 3(c), although Xtk is an intermediate variable in a Chain and conditioning on Xtk leads to the conclusion of independence of Xti and Xt+1j, the causal parent set of Xt+1j will include Xtk when testing the causal relationship between Xtk and Xt+1j, which can offset the influence of excluding Xti. In Fig. 3(d), conditioning on Z is necessary for getting the correct conclusion of causation since Xtk is the common causal parent in a Fork structure.
Bottom line: In Figs. 3(e) and 3(f), whether Xt+1k is included in the conditional variable set does not affect the conclusion of causation; In Fig. 3(g), not conditioning on Xt+1k is necessary to get the correct conclusion of causation since Xt+1k is the common son in a Collider structure; In Fig. 3(h), although Xt+1k is an intermediate variable in a Chain and not conditioning on Xt+1k leads to the conclusion of non-independence of Xti and Xt+1j, including Xti in the causal parent set of Xt+1j will not induce any problem since Xti does indirectly cause Xt+1j. Based on the classified discussion above, we can conclude our principle for choosing conditional variables in RL:
Fig.2 The basic, impossible and compound situations of the causation between target variables and condition variables. In the basic situations, Top line: (a)−(d) list the situations that the condition variable Xk is in the t time step. Bottom line: Similarly, (e)−(h) list the situations that the condition variable Xk is in the t+1 time step. Bottom left: (i) lists impossible situations where the direction of causation is from the future to the past. Bottom right: (j) lists compound situations where the causal structure is a compound of basic situations

Full size|PPT slide

Fig.3 The architecture of FOCUS. Given offline data, FOCUS learns a p value matrix by KCI test and then gets the causal structure by choosing a p threshold. After combining the learned causal structure with the neural network, FOCUS learns the policy through an offline MBRL algorithm

Full size|PPT slide

● Condition on the other variables in t time step.
● Do not condition on the other variables in t+1 time step.

5.3 Combining learned causal structure with an offline MBRL algorithm

We combine the learned causal structure with an offline MBRL algorithm, MOPO [3], to create a causal offline MBRL algorithm as in Fig.3. The entire learning procedure can be found in Algorithms 1 and 2. Notice that our causal model learning method can be combined with any offline MBRL algorithm theoretically.

Full size|PPT slide

Full size|PPT slide

6 Experiments

In this study, we assess the performance of FOCUS on a variety of offline RL datasets to demonstrate its effectiveness. Specifically, we aim to show that (1) FOCUS enables the learning of a causal environment model in offline RL, and (2) a causal environment model surpasses a standard environment model and other related offline RL methods in policy performance.
To validate the first point, we examine the accuracy, efficiencyefficiency, and robustness of causal structure learning. In regard to the second point, we evaluate policy return and generalization ability within the context of offline MBRL.

6.1 Experimental settings

6.1.1 Baselines

We compare FOCUS with the sota offline MBRL algorithm, MOPO, MOReL, COMBO and other online RL algorithms that also learn causal structure.
MOPO [3] is a popular and well-known offline MBRL algorithm that outperforms standard model-based RL algorithms and prior sota model-free offline RL algorithms on existing offline RL benchmarks. The central idea of MOPO is to artificially penalize rewards by the uncertainty of model predictions, hence avoiding erroneous predictions in unseen states. MOPO can be seen as the blank control with a plain environment model.
MOReL [18] is an offline MBRL algorithm that learns a pessimistic MDP using the offline dataset and then a near-optimal policy in this P-MDP.
COMBO [23] is also an offline algorithm, which conservatively estimates the value function by penalizing it in out-of-support states generated through model rollouts.
LNCM (Learning Neural Causal Models from Unknown Interventions) [13] is an online algorithm, in which the causal structure learning method can be transformed to the offline setting with a simple adjustment. We take LNCM as an example to show that an online method cannot be directly converted into offline RL algorithms.

6.1.2 Environment

Toy Car Driving. Toy Car Driving is a typical RL environment where the agent can control its direction and velocity to finish various tasks including avoiding obstacles and navigating. In this paper, we use a 2D Toy Car Driving as the RL environment where the task of the car is to arrive at the destination (The visualization can be found in the Appendix). The state includes the direction d, the velocity v, the velocity on the x-axis vx, the velocity on the y-axis vy and the position (px,py). The action is the steering angle a. We design the underlying causal structure to better demonstrate how spurious relations appear and highlight their influence in model learning (The structure can be found in the Appendixes).
Example: As shown in Fig.4, when the velocity vt1 maintains stationary due to an imperfect sample policy, (vx)t and (vy)t have strong relatedness that (vx)t2+(vy)t2=vt12 and one can represent the other. Since we design that (py)t+1(py)t=(vy)t, (vx)t and (py)t+1(py)t also have strong relatedness, which leads to that (vx)t becomes a spurious variable of (py)t+1 given (py)t, despite that (vx)t is not the causal parent of yt+1. By contrast, when the data is uniformly sampled with various velocities, this spuriousness will not exist.
Fig.4 The visualization of the example. The red dotted arrow presents that (vx)t is a spurious variable for (py)t+1

Full size|PPT slide

MuJoCo. The MuJoCo [24] is the most popular benchmark for evaluating performance in continuous controlling, where the variables of the state represent the positions, angles, and velocity of the agent. Each dimension in MuJoCo of the state has a specific meaning and is highly abstract, which provides the convenience of causal structure learning.

6.1.3 Offline data.

We prepare three offline data sets, Random, Medium, and Replay for the Car Driving and MuJoCo. The datasets of MuJoCo are from D4RL [25]. To provide an intuitive illustration of the data sets, we use heat maps to visualize those offline data sets in the Car Driving in Fig.5:
Fig.5 The heat map of the three offline data sets in the Car Driving. The coordinates represent two dimensions of the state. The brightness at each data point represents its probability of being visited and high brightness represents high probability. (a) Random; (b) Medium; (c) Medium-Replay

Full size|PPT slide

● In Random, data is collected by random sampling policies, and its distribution is similar to a multivariate Gaussian distribution.
● In Medium, data is collected by a fixed but not well-trained policy and its distribution is the least diverse.
● In Medium-Replay, data is a collection of trajectories sampled during training, which is the most diverse.

6.2 The validation for spurious theorems

Our theoretical analysis reveals that
● The model prediction error bound increases with increasing strength of spurious relations and increasing number of induced spurious variables.
● The policy evaluation error bound increases with increasing strength of spurious relations and increasing spurious variable density.
To validate our theoretical findings, we conduct experiments in multiple RL environments where latent causal structures are diverse, namely, the number of spurious variables, the spurious relatedness, and the spurious variable density are varying. The RL environment is built based on the Inverted Pendulum task of MuJoCo. In this environment, the spurious relatedness is changed by modifying the hyper-parameters, e.g., the fraction coefficient. The spurious variable density is changed by introducing external variables. The number of spurious variables is changed by introducing external variables while keeping the spurious variable density fixed. We estimate the model and policy prediction error bounds by taking the maximum of those errors from five repeated experiments.
The results in Fig.6 show that the model prediction error bound of the plain world model increases approximately linearly with the relatedness λ, the number and density of spurious variables. By contrast, the model prediction error bound of the causal world model does not noticeably change with the involved variables above, and it remains at a relatively low level, which supports our spurious theorem (Theorem 1). The results in Fig.7 show similar conclusions for the policy evaluation error bound, which validates our RL spurious theorem (Theorem 2).
Fig.6 Comparison of Causal World Models and Plain World Models to validate the spurious theorem. We evaluate the model prediction error on the relatedness, number and density of spurious variables in the offline dataset. (a) Comparison on relatedness λ; (b) comparison on spurious number; (c) comparison on spurious density

Full size|PPT slide

Fig.7 Comparison of Causal World Models and Plain World Models to validate the RL spurious theorem. We evaluate the policy evaluation error on the relatedness, number and density of spurious variables in the offline dataset. (a) Comparison on relatedness λ; (b) comparison on spurious number; (c) comparison on spurious density

Full size|PPT slide

6.3 Causal structure learning

We compare FOCUS with baselines on the causal structure learning with the indexes of the accuracy, efficiency, and robustness. The accuracy is evaluated by viewing the structure learning as a classification problem, where causation represents the positive example and independence represents the negative example. The efficiency is evaluated by measuring the samples for getting a stable structure. The robustness is evaluated by calculating the variance in multiple experiments. The results in Tab.1 show that FOCUS surpasses LNCM in accuracy, robustness, and efficiency in causal structure learning. Noticed that LNCM also has a low variance because it predicts the probability of existing causation between any variable pairs with around 50%, which means that its robustness is meaningless.
Tab.1 The results on causal structure learning of our model and the baselines. Both the accuracy and the variance are calculated by five times experiments. FOCUS (-KCI) represents FOCUS with a linear independence test. FOCUS (-CONDITION) represents FOCUS with choosing all other variables as conditional variables
INDEX FOCUS LNCM FOCUS(-KCI) FOCUS(-CONDITION)
ACCURACY 0.993 0.52 0.62 0.65
ROBUSTNESS 0.001 0.025 0.173 0.212
EFFICIENCY(SAMPLES) 1×106 1×107 1×106 1×106

6.4 Policy learning

6.4.1 Policy return

We evaluate the performance of FOCUS and baselines in the two benchmarks on three typical offline data sets. The results in Tab.2 and Tab.3 show that FOCUS outperforms baselines by a significant margin in most data sets.
Tab.2 The comparison on converged policy return in the two benchmarks
ENV CAR DRIVING MUJOCO(INVERTED PENDULUM)
RANDOM MEDIUM REPLAY RANDOM MEDIUM REPLAY
FOCUS 68.1±20.9 58.9±41.3 86.2±18.2 23.5±17.9 24.9±14.1 49.2±19.0
MOPO 30.3±49.9 50.1±34.2 46.2±28.1 8.5±6.2 2.5±0.08 43.4±7.7
LNCM 9.9±42.5 5.4±32.5 11.4±24.0 13.3±0.9 3.1±0.7 16.3±6.4
Tab.3 Results for D4RL datasets. For “ORIGIN” version, we take the results from their original papers. For “FOCUS” version, each number is the averaged score at the last iteration of training, averaged over 3 random seeds. We bold the highest score across all methods
DATASET TYPE ENVIRONMENT MOPO MOReL COMBO
ORIGIN FOCUS ORIGIN FOCUS ORIGIN FOCUS
HALFCHEETAH353725343844
RANDOMHOPPER113053631723
WALKER2D13283754711
HALFCHEETAH424942615460
MEDIUMHOPPER283095102103107
WALKER2D172777858194
HALFCHEETAH535840475560
MEDIUM REPLAYHOPPER677193908985
WALKER2D394449545663
As shown in Tab.2, on Random datasets, FOCUS has the most significant performance gains to the baselines in both benchmarks because of the accuracy of causal structure learning in FOCUS. By contrast, on Medium-Replay, the performance gains of FOCUS are least since the high data diversity of Medium-Replay leads to weak relatedness of spurious variables (corresponds to small λ), which verifies our theory. On Medium, the results in the two benchmarks are different. In Car Driving, the relatively high score of LNCM does not mean that LNCM is the best but all three fail. The failure indicates that extremely biased data makes even the causal model fail to generalize. In the Inverted Pendulum, the advantage of FOCUS indicates that causal environment models depend less on the data diversity since FOCUS can still reach high scores in such a biased dataset where the baselines fail.
In Tab.3, we showcase results from more complex experiments, including HalfCheetah, Hopper, and Walker2D, alongside additional baseline methods (MOReL and COMBO). To ensure fair comparisons, we integrate the core module of FOCUS, namely learning and applying causal environment models, into these MBRL algorithms, thereby creating various FOCUS-version algorithms. We then compare these enhanced versions (denoted as “FOCUS”) against their original counterparts (denoted as “ORIGIN”), which do not integrate causal structures. On Random and Medium datasets, FOCUS-version algorithms significantly outperform the control groups; on Medium-Replay, FOCUS still achieves the best performance in 7 out of 9 settings. The results also confirm that FOCUS can be combined with most MBRL algorithms and effectively enhance performance.

6.4.2 Generalization ability

The generalization ability of FOCUS refers to if it can learn a good policy from the data with limited data size and low data diversity. Therefore we designed datasets from 1% to 100% of the original data size and datasets with a mix of 20% to 80% other datasets, where we can compare FOCUS and baselines in datasets with different sizes and diversities. The results in Fig.8 show that the advantage of FOCUS over MOPO is much more significant in small data sizes. In the dataset of 1% size, the advantage of FOCUS is relatively not significant because the size is too small. The results in Fig.9 show that FOCUS can perform well with a small ratio of Medium-Replay data while the baseline performs well only with a big ratio, which indicates that FOCUS is less dependent on the diversity of data. Related experiments on more environments can be found in Appendixes A2.
Fig.8 The comparison for data size. The X% in the x-axis represents that the data size is X% of the original size. The ratio Y% in the y-axis represents the score ratio of FOCUS over the baseline MOPO

Full size|PPT slide

Fig.9 The comparison for data diversity. The dataset is produced by mixing up Medium-Replay and Medium with different ratios. The X% in the x-axis represents that the data is mixed by (100X)% of the Medium and X% of the Medium-Replay

Full size|PPT slide

6.5 Ablation study

To evaluate the contribution of each component in our causal discovery method, we perform an ablation study for FOCUS. The results in Tab.1 show that the KCI test and our principle of choosing conditional variables contribute to the causal structure learning of both accuracy and robustness.

7 Conclusion

In this paper, we point out that the spurious correlations hinder the generalization ability of current offline MBRL algorithms, and that incorporating the causal structure into the model can improve generalization by removing spurious correlations. We provide theoretical support for the statement that utilizing a causal environment model reduces the generalization error bound in offline RL. We also propose a practical algorithm, FOCUS, to address the problem of learning causal structure in offline RL. The main idea of FOCUS is to leverage conditional independence tests for causal discovery, which does not need further assumptions on the causal mechanism. In FOCUS, we address the difficulties of extending the PC algorithm in offline RL, particularly to reduce the number of independence tests by leveraging sequential information. Extensive experiments on the typical benchmarks demonstrate that FOCUS performs accurate and robust causal structure learning, surpassing offline RL baselines by a significant margin.
We would like to note that: In our theoretical results (Theorem 1 and 2), we assume that the true causal structure is already known. However, in practice, we must learn it from data before applying it (Section 5), which will introduce additional theoretical errors. As it is recognized that quantifying the uncertainty in the learned causal structure from data is a difficult task, we will derive the generalization error bound with the learned causal structure as part of our future study.

Zhengmao Zhu received the BSc degree in Department of Mathematical Science in June 2018 from Zhejiang University, China. He is currently pursuing the PhD degree with the School of Artificial Intelligence, Nanjing University, China. His current research interests mainly include reinforcement learning and causal learning. His works have been accepted in the top conferences of artificial intelligence, including NeurIPS, AAAI, etc. He have served as a reviewer of NeurIPS, ICML, AAAI, etc

Honglong Tian received the BSc degree in Software Institute of Nanjing University, China in June 2022. He is currently a graduate student of Software Institute of Nanjing University Nanjing University, China. His current research interests mainly include reinforcement learning. His works have been accepted in the top conferences of artificial intelligence, including NeurIPS, AAAI, etc

Xionghui Chen received the BSc degree from Southeast University, China in 2018. Currently, he is working towards the PhD degree in the National Key Lab for Novel Software Technology, the School of Artificial Intelligence, Nanjing University, China. His research focuses on handling the challenges of reinforcement learning in real-world applications. His works have been accepted in the top conferences of artificial intelligence, including NeurIPS, AAMAS, DAI, KDD, etc. He have served as a reviewer of NeurIPS, IJCAI, KDD, DAI, etc

Kun Zhang received the BS degree in automation from the University of Science and Technology of China, China in 2001, and the PhD degree in computer science from The Chinese University of Hong Kong, China in 2005. He is currently an associate professor with the Philosophy Department and an Affiliate Faculty Member with the Machine Learning Department, Carnegie Mellon University, Pittsburgh, USA. His research interests lie in causality, machine learning, and artificial intelligence, especially in causal discovery, hidden causal representation learning, transfer learning, and general-purpose artificial intelligence

Yang Yu received the PhD degree in computer science from Nanjing University, China in 2011, and is currently a professor at the School of Artificial Intelligence, Nanjing University, China. His research interests include machine learning, mainly reinforcement learning and derivative-free optimization for learning. Prof. Yu was granted the CCF-IEEE CS Young Scientist Award in 2020, recognized as one of the AIs 10 to Watch by IEEE Intelligent Systems, and received the PAKDD Early Career Award in 2018. His team won the Champion of the 2018 OpenAI Retro Contest on transfer reinforcement learning and the 2021 ICAPS Learning to Run a Power Network Challenge with Trust. He served as AC for NeurIPS, ICML, IJCAI, AAAI, etc

References

[1]
Yu F, Xian W, Chen Y, Liu F, Liao M, Madhavan V, Darrell T. BDD100K: a diverse driving video database with scalable annotation tooling. 2018, arXiv preprint arXiv: 1805.04687
[2]
Gottesman O, Johansson F, Komorowski M, Faisal A, Sontag D, Doshi-Velez F, Celi L A . Guidelines for reinforcement learning in healthcare. Nature Medicine, 2019, 25( 1): 16–18
[3]
Yu T, Thomas G, Yu L, Ermon S, Zou J, Levine S, Finn C, Ma T. MOPO: model-based offline policy optimization. In: Proceedings of the 34th International Conference on Neural Information Processing Systems. 2020, 1185
[4]
Bengio Y, Deleu T, Rahaman N, Ke N R, Lachapelle S, Bilaniuk O, Goyal A, Pal C J. A meta-transfer objective for learning to disentangle causal mechanisms. In: Proceedings of the 8th International Conference on Learning Representations. 2020
[5]
de Haan P, Jayaraman D, Levine S. Causal confusion in imitation learning. In: Proceedings of the 33rd International Conference on Neural Information Processing Systems. 2019, 1049
[6]
Tenenbaum J. Building machines that learn and think like people. In: Proceedings of the 17th International Conference on Autonomous Agents and MultiAgent Systems. 2018, 5
[7]
Edmonds M, Kubricht J, Summers C, Zhu Y, Rothrock B, Zhu S C, Lu H. Human causal transfer: challenges for deep reinforcement learning. In: Proceedings of the 40th Annual Meeting of the Cognitive Science Society. 2018
[8]
Spirtes P, Glymour C N, Scheines R. Causation, Prediction, and Search. 2nd ed. Cambridge: MIT Press, 2000
[9]
Zhang K, Peters J, Janzing D, Schölkopf B. Kernel-based conditional independence test and application in causal discovery. In: Proceedings of the 27th Conference on Uncertainty in Artificial Intelligence. 2011, 804−813
[10]
Sun X, Janzing D, Schölkopf B, Fukumizu K. A kernel-based causal learning algorithm. In: Proceedings of the 24th International Conference on Machine Learning. 2007, 855−862
[11]
Heckerman D, Meek C, Cooper G. A Bayesian approach to causal discovery. In: Holmes D E, Jain L C, eds. Innovations in Machine Learning: Theory and Applications. Berlin, Heidelberg: Springer, 2006, 1−28
[12]
Margaritis D. Distribution-free learning of Bayesian network structure in continuous domains. In: Proceedings of the 20th National Conference on Artificial Intelligence. 2005, 825−830
[13]
Ke N R, Bilaniuk O, Goyal A, Bauer S, Larochelle H, Schölkopf B, Mozer M C, Pal C, Bengio Y. Learning neural causal models from unknown interventions. 2019, arXiv preprint arXiv: 1910.01075
[14]
Wang Z, Xiao X, Xu Z, Zhu Y, Stone P. Causal dynamics learning for task-independent state abstraction. In: Proceedings of the 39th International Conference on Machine Learning. 2022, 23151−23180
[15]
Bellman R . A Markovian decision process. Journal of Mathematics and Mechanics, 1957, 6( 5): 679–684
[16]
Kurutach T, Clavera I, Duan Y, Tamar A, Abbeel P. Model-ensemble trust-region policy optimization. In: Proceedings of the 6th International Conference on Learning Representations. 2018
[17]
Williams G, Wagener N, Goldfain B, Drews P, Rehg J M, Boots B, Theodorou E A. Information theoretic MPC for model-based reinforcement learning. In: Proceedings of the IEEE International Conference on Robotics and Automation (ICRA). 2017, 1714−1721
[18]
Kidambi R, Rajeswaran A, Netrapalli P, Joachims T. MOReL: model-based offline reinforcement learning. In: Proceedings of the 34th International Conference on Neural Information Processing Systems. 2020, 1830
[19]
Akkaya I, Andrychowicz M, Chociej M, Litwin M, McGrew B, Petron A, Paino A, Plappert M, Powell G, Ribas R, Schneider J, Tezak N, Tworek J, Welinder P, Weng L, Yuan Q, Zaremba W, Zhang L. Solving Rubik’s cube with a robot hand. 2019, arXiv preprint arXiv: 1910.07113
[20]
Hoerl A E, Kennard R W . Ridge regression: biased estimation for nonorthogonal problems. Technometrics, 2000, 42( 1): 80–86
[21]
Pearl J. Causality: Models, Reasoning, and Inference. Cambridge: Cambridge University Press, 2000
[22]
Koller D, Friedman N. Probabilistic Graphical Models: Principles and Techniques. Cambridge: MIT Press, 2009
[23]
Yu T, Kumar A, Rafailov R, Rajeswaran A, Levine S, Finn C. COMBO: conservative offline model-based policy optimization. In: Proceedings of the 35th International Conference on Neural Information Processing Systems. 2021, 2218
[24]
Todorov E, Erez T, Tassa Y. MuJoCo: a physics engine for model-based control. In: Proceedings of the IEEE/RSJ International Conference on Intelligent Robots and Systems. 2012, 5026−5033
[25]
Fu J, Kumar A, Nachum O, Tucker G, Levine S. D4RL: datasets for deep data-driven reinforcement learning. 2020, arXiv preprint arXiv: 2004.07219
[26]
Xu T, Li Z, Yu Y. Error bounds of imitating policies and environments. In: Proceedings of the 34th International Conference on Neural Information Processing Systems. 2020, 1320

Competing interests

The authors declare that they have no competing interests or financial conflicts to disclose.

RIGHTS & PERMISSIONS

2025 Higher Education Press
AI Summary AI Mindmap
PDF(2392 KB)

Supplementary files

Highlights (591 KB)

362

Accesses

0

Citations

Detail

Sections
Recommended

/