1 Introduction
Offline Reinforcement Learning (RL) is a learning paradigm where policies are learned entirely from previously collected data. Offline RL is gaining popularity since it enables RL algorithms to scale to several real-world applications, e.g., autonomous driving [
1] and healthcare [
2], where trial-and-error is too expensive. In the offline setting, Model-Based Reinforcement Learning (MBRL) is a popular framework that involves learning a predictive environment model for policy optimization [
3]. The effectiveness of this approach relies on the accuracy of the learned environment model.
However, current offline MBRL approaches usually have poor generalization because the environment models tend to capture spurious correlations that only exist in offline data, resulting in erroneous predictions. For instance, in autonomous driving, if offline data is acquired from a driver who always turns on the wiper and brake pedals on rainy days, such a preference will result in a spurious correlation between “the wiper is turned on” and “the speed is dropped” in the data, which will also be captured by the environment model. Once we employ this environment model for policy learning, the agent will likely urge the driver to turn on the wiper when the vehicle’s speed is too high because it believes that “the wiper is turned on” has an effect on “the speed is dropped”, which is not sensible and potentially hazardous. Intuitively, leveraging the causal structure of observed variables can avoid taking spurious correlations as causal influences and thus facilitate the learning of an environment model with enhanced generalizability. Recent empirical evidence also indicates that inducing the causal structure is important to improve the generalization ability of deep learning models [
4-
7]. Despite such evidence, it is still unknown whether and how the causal structure improves model generalization in offline RL.
For this purpose, we first provide theoretical support for the aforementioned intuition: we show that a causal environment model can outperform a plain environment model on generalization for offline RL. From the causal perspective, we categorize the variables in states and actions into two groups: causal variables and spurious variables. We then formalize the process of learning an environment model with both types of variables in states. On the basis of the formalization, we quantify the effect of spurious dependencies on the generalization error bound and thereby demonstrate that integrating causal structures can assist in minimizing this bound.
We also propose a practical offline causal MBRL algorithm, FOCUS, to illustrate the feasibility of learning causal structure in offline RL. Learning the causal structure from offline data, also known as causal discovery from observations [
8], is a crucial phase of FOCUS. However, causal discovery from observations requires a huge number of hypothesis testing, which is computation-consuming. To tackle this problem, we utilize the time-series property in RL data to reduce the number of hypothesis testing. Specifically, we incorporate the constraint that the future cannot cause the past in the PC algorithm [
8], which seeks to uncover causal relationships based on inferred conditional independence relations. Consequently, we can reduce the number of conditional independence tests and determine the causal direction. In addition, we employ kernel-based conditional independence tests [
9], which can be applied to continuous variables without assuming a specific functional form between the variables or a particular data distribution.
In conclusion, this paper makes the following key contributions.
● It theoretically demonstrates that a causal environment model outperforms a plain environment model in offline RL with respect to the generalization error bound.
● It proposes a practical algorithm, FOCUS, and illustrates the feasibility of learning and employing a causal environment model for offline MBRL.
● Our experimental results validate the theoretical claims, showing that FOCUS outperforms baseline models and other existing causal MBRL algorithms in the offline setting.
2 Related work
The RL algorithms with causal structure learning can be roughly divided by the type of their causal discovery methods. First, we will discuss relevant causal discovery methods, followed by related RL algorithms.
2.1 Causal discovery methods
Causal discovery methods can be categorized into two main types: intervention-based methods and observation-based methods. Intervention-based methods determine the causal relationship between variables by intervening on them and obtaining a counterfactual outcome different from what is currently observed. By contrast, observation-based methods rely solely on existing data to discover causal relationships without performing any intervention. Since interventions are not feasible in the offline reinforcement learning environment discussed in this paper, we will primarily focus on observation-based causal discovery methods.
Observation-based causal discovery methods can generally be divided into two categories: constraint-based methods and score-based methods. Constraint-based methods use statistic tests (conditional independent tests) to find the causal skeleton and determine the causal directions up to the Markov equivalence class. For example, [
10] is based on explicit estimation of the conditional densities or their variants, which exploit the difference between the characteristic functions of these conditional densities. The estimation of the conditional densities or related quantities is difficult, which deteriorates the testing performance especially when the conditioning set is not small enough. Score-based methods evaluate the quality of candidate causal models with some score functions and output one or multiple graphs having the optimal score [
11]. [
12] discretizes the conditioning set to a set of bins and transforms conditional independence (CI) to the unconditional one in each bin. Inevitably, due to the curse of dimensionality, as the conditioning set becomes larger, the required sample size increases dramatically.
2.2 Causal discovery in reinforcement learning
Causal discovery methods can be utilized in policy learning and model learning in RL.
In policy learning, [
7] proposes a causal imitation learning algorithm, which learns the causal structure between states and actions. However, it assumes that we can query experts for actions and use interventioned data to do causal discovery, which is impractical in offline RL.
In world model learning for RL, various methods have been proposed but can hardly be used in offline RL. For example, [
13] (LNCM) views data sampled from different policies as data with soft interventions and use score-based methods with the log-likelihood on “interventional” data as the score function. Its implicit assumption that data is sampled from multiple policies and data has been labeled by its sampling policy is not a general assumption in offline RL, which only holds true in online RL. [
14] samples causal structures from Multivariate Bernoulli distribution and scores those structures according to the log-likelihood on interventional data. Based on the scores, it calculates the gradients for the parameters of the Multivariate Bernoulli distribution and updates the parameters iteratively. [
6] utilizes the speed of adaptation to learn the causal direction, but it does not provide a complete algorithm for learning causal structure. Given the properties of offline RL that the sampling policy has unknown preferences and interactions with the environment are not available, the above methods are not practical for learning the causal structure in offline RL.
3 Preliminaries
3.1 Markov decision process (MDP)
We describe the RL environment as an MDP with five-tuple
[
15], where
is a finite set of states;
is a finite set of actions;
is the transition function with
denoting the next-state distribution after taking action
in state
;
is a reward function with
denoting the expected immediate reward gained by taking action
a in state
s; and
is a discount factor. An agent chooses actions
according to a policy
, which updates the system state
, yielding a reward
. The agent’s goal is to maximize the expected cumulative return by learning a good policy
. The state-action value
of a policy
is the expected discounted reward of executing action
from state
and subsequently following policy
:
.
3.2 Offline model-based reinforcement learning
In the offline RL setting, the algorithm only has access to a static dataset
collected by one or a mixture of behavior policies
, and further interactions with the environment is not available. When we use model-based approaches to solve offline RL problems, we will learn a virtual environment model
for transition prediction from offline data. With the learned environment model
, we can define a new MDP
. Similarly, we can also define the value function
with
. A standard model-based RL algorithm (in an online setting) learns a virtual model by fitting it using a maximum-likelihood estimate of the trajectory-based data collected by running the latest policy, which guarantees that the virtual model can always be accurate when the policy keeps exploring [
16,
17]. In the offline RL setting, where we only have access to the data collected by previous policies, the accuracy of the virtual model in exploring policy cannot be guaranteed. Therefore recent techniques all build on the idea of pessimism that regularizes the original problem based on how confident the agent is about the learned model [
3,
18]. Specifically, the policy only visits the states where the learned model is confident in predictions.
4 Theoretical advantages of causal world models
In this section, we present theoretical evidence demonstrating the superior performance of a causal environment model over a traditional environment model in the context of offline RL. This underscores the importance of leveraging an appropriate causal structure to minimize the generalization error bounds for offline MBRL algorithms. More specifically, we formalize a process that discerns the environment model burdened with spurious correlations, and subsequently, quantifies their influence. We then evaluate the effects of these spurious correlations in relation to the generalization ability. These evaluation metrics encompass the model prediction error bound and the policy evaluation error bound within the offline RL framework.
For the purpose of simplicity, we operate under the assumption that the causal structure is accurately known. As a result, our theoretical findings do not account for errors that could potentially arise from inaccurate structures. We further assume that the causal relationships are linear and that all causal variables are fully observed. For a comprehensive understanding of our theoretical proof, details are detached in Appendixes A1.
4.1 Problem formalization
Plain world models. In the context of this paper, we use the term “plain world models” to refer to the traditional forward dynamics model, , which ignores the underlying causal structure between the current state-action pair and the future state:
where denotes the dimension of the state .
Causal world models. By contrast, we use the term “causal world models” to refer to the forward dynamics model with the causal structure:
where denotes the dimensions of state that are causal parents of variable and denotes the variable set that contains the causal parent of variable .
Model learning. Model learning in MBRL aims to predict the future state given the current state and action, which can be viewed as a supervised learning problem. For simplicity, we formula it as supervised learning, i.e., use to represent and to represent one dimension of .
Let denotes the data distribution where we have samples . The goal of model learning is equivalent to learning a linear function to predict given . From the causal perspective, is generated from only its causal parent variables rather than all the variables in . Therefore we can category the variables in into two groups, .
● represents the causal parent variables of , that is, , where is the ground truth and is a noise with zero mean and independent to , i.e., .
● represents the spurious variables that are independent to , i.e., , but and have strong relatedness in biased offline data. In other words, can be predicted by with small error, i.e., , where is the regression error with zero mean and small variance .
For clearly representation, we use ( represents element-wise multiplication) to replace , where records the indices of in and . Correspondingly, we also use to replace .
Example: Suppose that and are causal parent, we have and , and then , .
According to the definition of , we have , where is the global optimal solution of the optimization problem:
Definition 1 (Optimization problem 1).
However, in the offline setting, we only have biased data sampled by given policy , where the optimization problem becomes:
Definition 2 (Optimization problem 2).
4.2 Model prediction error bound
In this subsection, we assume a causal structure of the RL environment and spurious relations in offline data. We highlight that these spurious relationships can contribute to the formation of an ill-posed problem within the model learning process, hence increasing the model prediction error bound. Specifically, we quantitatively measure the impact of spurious relations on model learning as shown in Theorem 1. This theorem provides the relations between the properties of spurious relations and the model prediction error bound.
4.2.1 The spurious theorem
In Lemma 1, we prove that the optimization problem 2 has multiple optimal solutions due to the strong linear relatedness of and in .
Lemma 1 Given that is the optimal solution of optimization problem 1, suppose that in , where and , we have that is also an optimal solution of optimization problem 2 for any :
The most popular method for solving such an ill-posed problem is to add a regularization term for parameters
[
19]:
Definition 3 (Optimization problem 3).
where is a coefficient.
The form of optimization problem 3 corresponds to the form of the ridge regression, which provides a closed-form solution of
by Hoerl-Kennard formula [
20]. Given the closed-form solution of
, we can derive the optimal solution of optimization problem 3, which corresponds a solution of
.
In the following, we will first introduce the solution of with given in Lemma 2, and then introduce the model prediction error bound with in Theorem 1. For ease of understanding, we provide a simple version where the dimensions of and are both one ( ).
Lemma 2 ( Lemma). Let denote the variance of the noise and denote the variance of , given in optimization problem 3 chosen by Hoerl-Kennard formula that , we have the solution of in optimization problem 3 that:
Based on Lemma 2, we can find that the smaller (it means that and have stronger relatedness in the training dataset ), the larger . And we prove that the coefficient is bounded:
Proposition 1 Given as in Lemma 2, we have its bound that
Theorem 1 (Spurious theorem). Let denote the data distribution, denote the solution in Lemma 1 with in Lemma 2, and denote the prediction. Suppose that the data value is bounded: and the error of optimal solution is also bounded: , we have the model prediction error bound:
Theorem 1 shows that
● the upper bound of the model prediction error increases by for each induced spurious variable in the model,
● when and have stronger relatedness (which means a bigger ), the increment of the prediction model error bound led by is bigger.
4.3 Policy evaluation error bound
Building on the model prediction error bound established earlier, we proceed to delineate the policy evaluation error bound, a metric that offers a more direct evaluation of the world model in MBRL. In this subsection, we apply the spurious theorem (Theorem 1) to offline MBRL. Our investigation explores how the policy evaluation error bound escalates in response to an increase in the number of spurious variables and the strengthening of spurious relationships.
Assuming that the state value is bounded as well as the reward that , and denoting the maximum of as and the maximum of as , we establish the policy evaluation error bound in Theorem 2.
Theorem 2. (RL spurious theorem) Given an MDP with the state dimension and the action dimension , a data-collecting policy , let denote the true transition model, denote the learned model that predicts the dimension with spurious variable sets and causal variables , i.e., . Let denote the policy value of the policy in model and correspondingly . For an arbitrary bounded divergence policy , i.e. , we have the policy evaluation error bound:
where
which represents the spurious variable density, that is, the ratio of spurious variables in all input variables.
Theorem 2 shows the relation between the policy evaluation error bound and the spurious variable density, which indicates that:
● When we use a non-causal model that all the spurious variables are input, reaches its maximum value . By contrast, in the optimal causal structure, reaches its minimum value of .
● The density of spurious variables and the correlation strength of spurious variables both influence the policy evaluation error bound. However, if we exclude all the spurious variables, i.e., , the correlation strength of spurious variables will have no effect.
5 FOCUS algorithm
Having established the necessity of a causal environment model in offline RL, we proceed in this section to propose a practical offline MBRL algorithm, FOCUS, to illustrate the feasibility of learning and using causal structure in offline RL. Our algorithm consists of two distinct steps: the discovery of the causal structure from offline data, followed by the effective integration of the identified structure with an offline MBRL algorithm. During the first step, the offline setting imposes limitations on the selection of causal discovery methods. Specifically, methods that involve interventions or randomized experiments are deemed unsuitable as interactions with the environment are prohibited. When considering approaches for observational data, score-based methods necessitate data from distinct sampling policies and a known form of causal mechanisms, neither of which are present in offline data collected from a single policy. Constraint-based methods, on the other hand, do not presuppose any specific form of causal mechanisms, but they fail to distinguish structures within one Markov equivalence class and may prove inefficient due to numerous independence tests. To address these challenges, FOCUS extends the PC algorithm (a derivative of constraint-based methods) and rectifies its shortcomings by incorporating sequential information. In the second step, FOCUS initializes the environment model with the learned causal structure and then proceeds to learn the environment model along with the policy.
5.1 Preliminary
5.1.1 Assumptions
First, we assume the causal Markov condition and faithfulness in the environment transition, with which we can use conditional independence tests to infer the causal graph [
8]. Second, we claim that the offline data and the data obtained through the learned policy share the same causal structure, through which the learned structure can be applied in unseen states. Specifically, policy preference affects the relations between variables by causing quantitative changes in causal relations and spurious relations in independent relations. Thanks to the ability to distinguish between spurious correlations and causal influences, the policy preference will not result in qualitative changes in the causal structure. Consequently, the above statement holds true in offline RL.
5.1.2 Conditional independence test
Independence and conditional independence (CI) plays a central role in causal discovery [
8,
21,
22]. Generally speaking, the CI relationship
allows us to drop
when constructing a probabilistic model for
with
. There are multiple CI testing methods for various conditions, which provide the correct conclusion only given the corresponding condition. The kernel-based Conditional Independence test (KCI test) [
9] is proposed for continuous variables without assuming a specific functional form between the variables as well as the data distributions.
5.1.3 Conditional variables
Besides the specific CI test method, the conclusion of conditional independence testing also depends on the conditional variable , that is, different conditional variables can lead to different conclusions. Taking the triple as an example, there are three typical structures, namely, Chain, Fork, and Collider as shown in Fig.1. Chain: There exists causation between and but conditioning on leads to independence. Fork: There does not exist causation between and but not conditioning on leads to non-independence. Collider: There does not exist causation between and but conditioning on leads to non-independence.
5.2 Causal structure learning
5.2.1 Applying the Independence test in RL
Based on the preliminaries, given the two target variables and the condition variable , the KCI test returns a probability value , which measures the probability that and are conditionally independent given the condition . To transform a probability into a binary conclusion of whether the causation exists, we design a threshold that
where represents independence and represents that causation exists. Details of choosing can be found in Appendixes A2.
In model learning of RL, variables are composed of states and actions of the current and next timesteps and the causal structure refers to whether a variable in timestep (e.g., the dimension, ) causes another variable in timestep (e.g., the dimension, ). With the KCI test, we get the causal relation through the function for each variable pair and then form the causal structure matrix :
where is the element in row and column of .
5.2.2 Choosing the conditional variable in RL
As stated in preliminaries, unsuitable conditional variables can reverse the conclusion of independence testing. The conditional variable set must include the intermediate variable of Chain and the common parent variable of Fork, but not the common son variable of Collider. Traditionally, the independence test traverses all possible combinations of the conditional variables and then reaches a conclusion, which is inefficient. However, in RL we can reduce the number of conditional independence tests by imposing the restriction that the future cannot cause the past. Actually, this constraint restricts the possible conditional variable sets to a tiny number. Consequently, we can have a classified discussion for every feasible collection of conditional variables. For simplicity, we eliminate two types of scenarios from the discussion:
Fig.1 The three basic structures for |
Full size|PPT slide
● Impossible situations. We exclude some impossible situations as Fig.2(i) (bottom left) by the temporal property of data in RL. Specifically, the direction of the causation cannot be from the variable of time step to that of time step because the effect cannot happen before the cause.
● Compound situations. We only discuss the basic situations and exclude the compound situations, e.g., Fig.2(j) (bottom right), which is a compound of Figs. 3(a) and 3(c). In such compound situations, the target variables and have direct causation (or it can not be a compound situation). When a causal relationship exists, it is impossible to arrive at a conclusion of “independence” in an independence test. Hence, we do not need to be concerned about coming to the conclusion that a causal relationship does not exist, thereby mistakenly excluding causal parent variables.
As seen in Fig.2, we list all conceivable circumstances involving target variables and condition variable in the environment model. With the preliminary knowledge of causal discovery, we investigate the following fundamental situations:
Top line: In Figs. 3(a) and 3(b), whether is included in the conditional variable set does not affect the conclusion of causation; In Fig. 3(c), although is an intermediate variable in a Chain and conditioning on leads to the conclusion of independence of and , the causal parent set of will include when testing the causal relationship between and , which can offset the influence of excluding . In Fig. 3(d), conditioning on is necessary for getting the correct conclusion of causation since is the common causal parent in a structure.
Bottom line: In Figs. 3(e) and 3(f), whether is included in the conditional variable set does not affect the conclusion of causation; In Fig. 3(g), not conditioning on is necessary to get the correct conclusion of causation since is the common son in a structure; In Fig. 3(h), although is an intermediate variable in a Chain and not conditioning on leads to the conclusion of non-independence of and , including in the causal parent set of will not induce any problem since does indirectly cause . Based on the classified discussion above, we can conclude our principle for choosing conditional variables in RL:
Fig.2 The basic, impossible and compound situations of the causation between target variables and condition variables. In the basic situations, Top line: (a)−(d) list the situations that the condition variable is in the time step. Bottom line: Similarly, (e)−(h) list the situations that the condition variable is in the time step. Bottom left: (i) lists impossible situations where the direction of causation is from the future to the past. Bottom right: (j) lists compound situations where the causal structure is a compound of basic situations |
Full size|PPT slide
Fig.3 The architecture of FOCUS. Given offline data, FOCUS learns a value matrix by KCI test and then gets the causal structure by choosing a threshold. After combining the learned causal structure with the neural network, FOCUS learns the policy through an offline MBRL algorithm |
Full size|PPT slide
● Condition on the other variables in time step.
● Do not condition on the other variables in time step.
5.3 Combining learned causal structure with an offline MBRL algorithm
We combine the learned causal structure with an offline MBRL algorithm, MOPO [
3], to create a causal offline MBRL algorithm as in Fig.3. The entire learning procedure can be found in Algorithms 1 and 2. Notice that our causal model learning method can be combined with any offline MBRL algorithm theoretically.
6 Experiments
In this study, we assess the performance of FOCUS on a variety of offline RL datasets to demonstrate its effectiveness. Specifically, we aim to show that (1) FOCUS enables the learning of a causal environment model in offline RL, and (2) a causal environment model surpasses a standard environment model and other related offline RL methods in policy performance.
To validate the first point, we examine the accuracy, efficiencyefficiency, and robustness of causal structure learning. In regard to the second point, we evaluate policy return and generalization ability within the context of offline MBRL.
6.1 Experimental settings
6.1.1 Baselines
We compare FOCUS with the sota offline MBRL algorithm, MOPO, MOReL, COMBO and other online RL algorithms that also learn causal structure.
●
MOPO [
3] is a popular and well-known offline MBRL algorithm that outperforms standard model-based RL algorithms and prior sota model-free offline RL algorithms on existing offline RL benchmarks. The central idea of MOPO is to artificially penalize rewards by the uncertainty of model predictions, hence avoiding erroneous predictions in unseen states. MOPO can be seen as the blank control with a plain environment model.
●
MOReL [
18] is an offline MBRL algorithm that learns a pessimistic MDP using the offline dataset and then a near-optimal policy in this P-MDP.
●
COMBO [
23] is also an offline algorithm, which conservatively estimates the value function by penalizing it in out-of-support states generated through model rollouts.
●
LNCM (Learning Neural Causal Models from Unknown Interventions) [
13] is an online algorithm, in which the causal structure learning method can be transformed to the offline setting with a simple adjustment. We take LNCM as an example to show that an online method cannot be directly converted into offline RL algorithms.
6.1.2 Environment
Toy Car Driving. Toy Car Driving is a typical RL environment where the agent can control its direction and velocity to finish various tasks including avoiding obstacles and navigating. In this paper, we use a 2D Toy Car Driving as the RL environment where the task of the car is to arrive at the destination (The visualization can be found in the Appendix). The state includes the direction , the velocity , the velocity on the -axis , the velocity on the -axis and the position . The action is the steering angle . We design the underlying causal structure to better demonstrate how spurious relations appear and highlight their influence in model learning (The structure can be found in the Appendixes).
Example: As shown in Fig.4, when the velocity maintains stationary due to an imperfect sample policy, and have strong relatedness that and one can represent the other. Since we design that , and also have strong relatedness, which leads to that becomes a spurious variable of given , despite that is not the causal parent of . By contrast, when the data is uniformly sampled with various velocities, this spuriousness will not exist.
Fig.4 The visualization of the example. The red dotted arrow presents that is a spurious variable for |
Full size|PPT slide
MuJoCo. The MuJoCo [
24] is the most popular benchmark for evaluating performance in continuous controlling, where the variables of the state represent the positions, angles, and velocity of the agent. Each dimension in MuJoCo of the state has a specific meaning and is highly abstract, which provides the convenience of causal structure learning.
6.1.3 Offline data.
We prepare three offline data sets,
Random,
Medium, and
Replay for the Car Driving and MuJoCo. The datasets of MuJoCo are from D4RL [
25]. To provide an intuitive illustration of the data sets, we use heat maps to visualize those offline data sets in the Car Driving in Fig.5:
Fig.5 The heat map of the three offline data sets in the Car Driving. The coordinates represent two dimensions of the state. The brightness at each data point represents its probability of being visited and high brightness represents high probability. (a) Random; (b) Medium; (c) Medium-Replay |
Full size|PPT slide
● In Random, data is collected by random sampling policies, and its distribution is similar to a multivariate Gaussian distribution.
● In Medium, data is collected by a fixed but not well-trained policy and its distribution is the least diverse.
● In Medium-Replay, data is a collection of trajectories sampled during training, which is the most diverse.
6.2 The validation for spurious theorems
Our theoretical analysis reveals that
● The model prediction error bound increases with increasing strength of spurious relations and increasing number of induced spurious variables.
● The policy evaluation error bound increases with increasing strength of spurious relations and increasing spurious variable density.
To validate our theoretical findings, we conduct experiments in multiple RL environments where latent causal structures are diverse, namely, the number of spurious variables, the spurious relatedness, and the spurious variable density are varying. The RL environment is built based on the Inverted Pendulum task of MuJoCo. In this environment, the spurious relatedness is changed by modifying the hyper-parameters, e.g., the fraction coefficient. The spurious variable density is changed by introducing external variables. The number of spurious variables is changed by introducing external variables while keeping the spurious variable density fixed. We estimate the model and policy prediction error bounds by taking the maximum of those errors from five repeated experiments.
The results in Fig.6 show that the model prediction error bound of the plain world model increases approximately linearly with the relatedness , the number and density of spurious variables. By contrast, the model prediction error bound of the causal world model does not noticeably change with the involved variables above, and it remains at a relatively low level, which supports our spurious theorem (Theorem 1). The results in Fig.7 show similar conclusions for the policy evaluation error bound, which validates our RL spurious theorem (Theorem 2).
Fig.6 Comparison of Causal World Models and Plain World Models to validate the spurious theorem. We evaluate the model prediction error on the relatedness, number and density of spurious variables in the offline dataset. (a) Comparison on relatedness λ; (b) comparison on spurious number; (c) comparison on spurious density |
Full size|PPT slide
Fig.7 Comparison of Causal World Models and Plain World Models to validate the RL spurious theorem. We evaluate the policy evaluation error on the relatedness, number and density of spurious variables in the offline dataset. (a) Comparison on relatedness λ; (b) comparison on spurious number; (c) comparison on spurious density |
Full size|PPT slide
6.3 Causal structure learning
We compare FOCUS with baselines on the causal structure learning with the indexes of the accuracy, efficiency, and robustness. The accuracy is evaluated by viewing the structure learning as a classification problem, where causation represents the positive example and independence represents the negative example. The efficiency is evaluated by measuring the samples for getting a stable structure. The robustness is evaluated by calculating the variance in multiple experiments. The results in Tab.1 show that FOCUS surpasses LNCM in accuracy, robustness, and efficiency in causal structure learning. Noticed that LNCM also has a low variance because it predicts the probability of existing causation between any variable pairs with around , which means that its robustness is meaningless.
Tab.1 The results on causal structure learning of our model and the baselines. Both the accuracy and the variance are calculated by five times experiments. FOCUS (-KCI) represents FOCUS with a linear independence test. FOCUS (-CONDITION) represents FOCUS with choosing all other variables as conditional variables |
INDEX | FOCUS | LNCM | FOCUS(-KCI) | FOCUS(-CONDITION) |
ACCURACY | 0.993 | 0.52 | 0.62 | 0.65 |
ROBUSTNESS | 0.001 | 0.025 | 0.173 | 0.212 |
EFFICIENCY(SAMPLES) | 1×106 | 1×107 | 1×106 | 1×106 |
6.4 Policy learning
6.4.1 Policy return
We evaluate the performance of FOCUS and baselines in the two benchmarks on three typical offline data sets. The results in Tab.2 and Tab.3 show that FOCUS outperforms baselines by a significant margin in most data sets.
Tab.2 The comparison on converged policy return in the two benchmarks |
ENV | CAR DRIVING | | MUJOCO(INVERTED PENDULUM) |
RANDOM | MEDIUM | REPLAY | | RANDOM | MEDIUM | REPLAY |
FOCUS | | | | | | | |
MOPO | | | | | | | |
LNCM | | | | | | | |
Tab.3 Results for D4RL datasets. For “ORIGIN” version, we take the results from their original papers. For “FOCUS” version, each number is the averaged score at the last iteration of training, averaged over 3 random seeds. We bold the highest score across all methods |
DATASET TYPE | ENVIRONMENT | MOPO | | MOReL | | COMBO |
ORIGIN | FOCUS | | ORIGIN | FOCUS | | ORIGIN | FOCUS |
| HALFCHEETAH | 35 | 37 | | 25 | 34 | | 38 | 44 |
RANDOM | HOPPER | 11 | 30 | | 53 | 63 | | 17 | 23 |
| WALKER2D | 13 | 28 | | 37 | 54 | | 7 | 11 |
| HALFCHEETAH | 42 | 49 | | 42 | 61 | | 54 | 60 |
MEDIUM | HOPPER | 28 | 30 | | 95 | 102 | | 103 | 107 |
| WALKER2D | 17 | 27 | | 77 | 85 | | 81 | 94 |
| HALFCHEETAH | 53 | 58 | | 40 | 47 | | 55 | 60 |
MEDIUM REPLAY | HOPPER | 67 | 71 | | 93 | 90 | | 89 | 85 |
| WALKER2D | 39 | 44 | | 49 | 54 | | 56 | 63 |
As shown in Tab.2, on Random datasets, FOCUS has the most significant performance gains to the baselines in both benchmarks because of the accuracy of causal structure learning in FOCUS. By contrast, on Medium-Replay, the performance gains of FOCUS are least since the high data diversity of Medium-Replay leads to weak relatedness of spurious variables (corresponds to small ), which verifies our theory. On Medium, the results in the two benchmarks are different. In Car Driving, the relatively high score of LNCM does not mean that LNCM is the best but all three fail. The failure indicates that extremely biased data makes even the causal model fail to generalize. In the Inverted Pendulum, the advantage of FOCUS indicates that causal environment models depend less on the data diversity since FOCUS can still reach high scores in such a biased dataset where the baselines fail.
In Tab.3, we showcase results from more complex experiments, including HalfCheetah, Hopper, and Walker2D, alongside additional baseline methods (MOReL and COMBO). To ensure fair comparisons, we integrate the core module of FOCUS, namely learning and applying causal environment models, into these MBRL algorithms, thereby creating various FOCUS-version algorithms. We then compare these enhanced versions (denoted as “FOCUS”) against their original counterparts (denoted as “ORIGIN”), which do not integrate causal structures. On Random and Medium datasets, FOCUS-version algorithms significantly outperform the control groups; on Medium-Replay, FOCUS still achieves the best performance in 7 out of 9 settings. The results also confirm that FOCUS can be combined with most MBRL algorithms and effectively enhance performance.
6.4.2 Generalization ability
The generalization ability of FOCUS refers to if it can learn a good policy from the data with limited data size and low data diversity. Therefore we designed datasets from to of the original data size and datasets with a mix of to other datasets, where we can compare FOCUS and baselines in datasets with different sizes and diversities. The results in Fig.8 show that the advantage of FOCUS over MOPO is much more significant in small data sizes. In the dataset of size, the advantage of FOCUS is relatively not significant because the size is too small. The results in Fig.9 show that FOCUS can perform well with a small ratio of Medium-Replay data while the baseline performs well only with a big ratio, which indicates that FOCUS is less dependent on the diversity of data. Related experiments on more environments can be found in Appendixes A2.
Fig.8 The comparison for data size. The in the x-axis represents that the data size is of the original size. The ratio in the y-axis represents the score ratio of FOCUS over the baseline MOPO |
Full size|PPT slide
Fig.9 The comparison for data diversity. The dataset is produced by mixing up Medium-Replay and Medium with different ratios. The in the x-axis represents that the data is mixed by of the Medium and of the Medium-Replay |
Full size|PPT slide
6.5 Ablation study
To evaluate the contribution of each component in our causal discovery method, we perform an ablation study for FOCUS. The results in Tab.1 show that the KCI test and our principle of choosing conditional variables contribute to the causal structure learning of both accuracy and robustness.
7 Conclusion
In this paper, we point out that the spurious correlations hinder the generalization ability of current offline MBRL algorithms, and that incorporating the causal structure into the model can improve generalization by removing spurious correlations. We provide theoretical support for the statement that utilizing a causal environment model reduces the generalization error bound in offline RL. We also propose a practical algorithm, FOCUS, to address the problem of learning causal structure in offline RL. The main idea of FOCUS is to leverage conditional independence tests for causal discovery, which does not need further assumptions on the causal mechanism. In FOCUS, we address the difficulties of extending the PC algorithm in offline RL, particularly to reduce the number of independence tests by leveraging sequential information. Extensive experiments on the typical benchmarks demonstrate that FOCUS performs accurate and robust causal structure learning, surpassing offline RL baselines by a significant margin.
We would like to note that: In our theoretical results (Theorem 1 and 2), we assume that the true causal structure is already known. However, in practice, we must learn it from data before applying it (Section 5), which will introduce additional theoretical errors. As it is recognized that quantifying the uncertainty in the learned causal structure from data is a difficult task, we will derive the generalization error bound with the learned causal structure as part of our future study.
{{custom_sec.title}}
{{custom_sec.title}}
{{custom_sec.content}}