Soft Actor Critic (SAC)
In the previous post, we discussed Proximal Policy Optimization (PPO) and its strengths, which have made it a popular choice in recent years. However, as an on-policy method, PPO suffers from a key drawback—limited sample efficiency. This limitation is one of the main reasons I’d like to shift focus in this post to Soft Actor-Critic (SAC)[arxiv], a highly effective off-policy alternative.
Let’s dive in.
SAC Introduction
What is SAC, and why is it considered a cornerstone like PPO?
Soft Actor-Critic (SAC) is a model-free, off-policy reinforcement learning algorithm that stands out for its ability to balance exploitation and exploration effectively. Like PPO, SAC is widely used as a baseline in research due to its robustness and strong empirical performance across a variety of continuous control tasks.
What makes SAC special is its use of entropy regularization in the objective function. Instead of only aiming to maximize expected rewards, SAC also explicitly encourages the agent to explore diverse actions by maximizing the entropy of the policy. This added randomness prevents the agent from prematurely converging to suboptimal policies and helps it better explore the environment. As a result, SAC often achieves superior sample efficiency and stability compared to traditional methods, especially in continuous action spaces.
The main components of SAC are:
- Actor: learns a stochastic policy that outputs a probability distribution over actions
- Critic: two Q-function networks are used to estimate the expected return for state-action pairs
As an off-policy method—meaning it uses a different policy for data collection (exploration) than for learning (exploitation) (see On-Off Policy for more detail) —SAC is built on the principle of maximizing not only expected rewards but also the entropy of the policy. This encourages more diverse action selection, leading to better exploration. The core idea is to optimize a modified objective function that balances reward maximization with entropy:
$$J(π) = 𝔼[\sum_{t} γ^{t}(r(s_t,a_t) + ⍺H(π(·|s_t)))]$$
where:
- $r(s_t,a_t)$: reward for taking action $a_t$ in state $s_t$
- $H(π(·|s_t) = -𝔼_{a_t \sim π}[\log{π(a_t|s_t)}]$: entropy of the policy
- ⍺: temperature parameter that controls the trade-off between reward and entropy
- γ: discount factor for future reward
In brief, the actor in SAC learns a policy that selects actions to maximize both expected rewards and policy entropy, encouraging exploration. Meanwhile, the critic estimates the Q-values of state-action pairs, which are used to guide the actor's updates. Owing to its structure, SAC can be viewed as a continuous action space counterpart to Double DQN, incorporating similar principles such as target networks and value smoothing to stabilize learning.
In detail, the updates of the components are as follows:
- Actor: minimize the KL-divergence between policy and Q-function
$$J_{π}(θ) = 𝔼_{s_t \sim D, a_t \sim π_{θ}}[\min_{i=1,2}Q_{ɸ_i}(s_t,a_t) - ⍺\log_{π_θ(a_t|s_t)}]$$
- Critic: Q-functions are trained to minimize the Bellman residual
$$J_Q(ɸ_i) = 𝔼_{s_t,a_t,r_t,s_{t+1} \sim D}[(Q_{ɸ_i}(s_t,a_t) - (r_t + γ(V_{\hat{Ψ}}(s_{t+1}))))^2]$$ where 𝒟 is the replay buffer, and $V_{\hat{Ψ}}$ comes from target Q-networks, approximately as $V(s_{t+1}) \approx 𝔼_{a_{t+1} \sim π}[\min_{i=1,2}Q_{\hat{ɸ_i}}(s_{t+1},a_{t+1}) - ⍺\log_{π_θ(a_{t+1}|s_{t+1})}]$
- Temperature: learned dynamically to target a desired entropy level
$$J(⍺) = 𝔼_{a_t \sim π}[-⍺\log{π(a_t|s_t)} - ⍺\hat{H}]$$ where $\hat{H}$ is the target entropy
Just as there is no single cure for all diseases, SAC also has its own strengths and limitations.
Pros and Cons
Pros
- Sample Efficiency: off-policy nature allows reuse of past experiences, reducing the number of interactions needed
- Robust Exploration: entropy regularization promotes diverse selection, preventing premature convergence.
- Continuous Action Space: well-suited for continuous actions
Cons
- Complexity: multiple neural networks means more cost
- Limited to Model-Free: not work well if environment model is learnable
SAC Implemenation
To compare SAC with PPO, we will use the same env used at Proximal Policy Optimization (PPO) to show their difference. You can find the code here. We will see its core code separately.
Actor Network
The actor (policy) network samples from a distribution in a way that allows gradients to be backpropagated through the sampling process.
class Actor(nn.Module):
def __init__(self, observation_space, action_space, hidden_dim=256):
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_output_size(observation_space.shape)
self.shared_layers = nn.Sequential(
nn.Linear(conv_out_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.mean_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, action_space.shape[0])
)
self.log_std_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, action_space.shape[0])
)
self.action_scale = torch.FloatTensor((action_space.high - action_space.low) / 2.0)
self.action_bias = torch.FloatTensor((action_space.high + action_space.low) / 2.0)
self.LOG_STD_MIN = -20
self.LOG_STD_MAX = 2
def _get_conv_output_size(self, input_shape):
"""Calculate the output size of convolutional layers
Size: 1 * 8 * 8 * 64 = 4096
"""
dummy_input = torch.zeros(1, input_shape[2], input_shape[0], input_shape[1])
dummy_output = self.conv_layers(dummy_input)
return int(np.prod(dummy_output.size()))
def forward(self, state):
conv_out = self.conv_layers(state)
conv_out = conv_out.view(conv_out.size(0), -1)
features = self.shared_layers(conv_out)
mean = self.mean_net(features)
log_std = self.log_std_net(features)
log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
return mean, log_std
def sample_action(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample()
y_t = torch.tanh(x_t)
action = y_t * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
log_prob = log_prob.sum(axis=-1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean
Actor Network - CSY
Critic Network
The Critic class in SAC adopts a structure similar to Double DQN, implementing two parallel Q-networks to estimate the expected return of a given state-action pair. By maintaining two separate Q-functions, the architecture mitigates overestimation bias by using the minimum of the two Q-values when computing the target value.
class Critic(nn.Module):
def __init__(self, observation_space, action_space, hidden_dim=256):
super(Critic, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_output_size(observation_space.shape)
self.q1_net = nn.Sequential(
nn.Linear(conv_out_size + action_space.shape[0], hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.q2_net = nn.Sequential(
nn.Linear(conv_out_size + action_space.shape[0], hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def _get_conv_output_size(self, input_shape):
dummy_input = torch.zeros(1, input_shape[2], input_shape[0], input_shape[1])
dummy_output = self.conv_layers(dummy_input)
return int(np.prod(dummy_output.size()))
def forward(self, state, action):
conv_out = self.conv_layers(state)
conv_out = conv_out.view(conv_out.size(0), -1)
x = torch.cat([conv_out, action], dim=1)
q1 = self.q1_net(x)
q2 = self.q2_net(x)
return q1, q2Critic Network - CSY
Replay Buffer
As an off-policy method, we typically use a replay buffer to improve training stability.
class ReplayBuffer:
def __init__(self, capacity, device):
self.device = device
self.capacity = capacity
self.states = []
self.actions = []
self.rewards = []
self.next_states = []
self.dones = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.states) < self.capacity:
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.next_states.append(next_state)
self.dones.append(done)
if state.dim() == 4 and state.size(0) == 1:
state = state.squeeze(0)
if next_state.dim() == 4 and next_state.size(0) == 1:
next_state = next_state.squeeze(0)
if action.dim() == 2 and action.size(0) == 1:
action = action.squeeze(0)
self.states[self.position] = state.cpu()
self.actions[self.position] = action.cpu()
self.rewards[self.position] = reward
self.next_states[self.position] = next_state.cpu()
self.dones[self.position] = done
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch_indices = random.sample(range(len(self.states)), batch_size)
states = torch.stack([self.states[i] for i in batch_indices]).to(self.device)
actions = torch.stack([self.actions[i] for i in batch_indices]).to(self.device)
rewards = torch.FloatTensor([self.rewards[i] for i in batch_indices]).to(self.device)
next_states = torch.stack([self.next_states[i] for i in batch_indices]).to(self.device)
dones = torch.FloatTensor([self.dones[i] for i in batch_indices]).to(self.device)
return states, actions, rewards, next_states, dones
def __len__(self):
return len(self.states)
SAC Agent
Finally, we integrate all the components described above to construct the SAC agent.
class SACAgent:
def __init__(self, env, lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2,
batch_size=256, buffer_size=1000000, auto_entropy_tuning=True):
self.env = env
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lr = lr
self.gamma = gamma
self.tau = tau
self.batch_size = batch_size
self.auto_entropy_tuning = auto_entropy_tuning
self.actor = Actor(env.observation_space, env.action_space).to(self.device)
self.critic = Critic(env.observation_space, env.action_space).to(self.device)
self.critic_target = Critic(env.observation_space, env.action_space).to(self.device)
self.replay_buffer = ReplayBuffer(buffer_size, self.device)
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(param.data)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
if auto_entropy_tuning:
self.target_entropy = -torch.prod(torch.Tensor(env.action_space.shape)).item()
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
else:
self.alpha = alpha
self.episode_rewards = deque(maxlen=100)
self.episode_lengths = deque(maxlen=100)
self.training_metrics = {
'q1_values': deque(maxlen=1000),
'q2_values': deque(maxlen=1000),
'entropy': deque(maxlen=1000),
'policy_loss': deque(maxlen=1000),
'value_loss': deque(maxlen=1000),
'alpha_values': deque(maxlen=1000),
'episode_count': 0
}
@property
def alpha(self):
if self.auto_entropy_tuning:
return self.log_alpha.exp()
else:
return self._alpha
@alpha.setter
def alpha(self, value):
self._alpha = value
def preprocess_state(self, state):
if isinstance(state, tuple):
state = state[0]
state = np.transpose(state, (2, 0, 1))
state = state / 255.0
return torch.FloatTensor(state).to(self.device)
def select_action(self, state, evaluate=False):
if state.dim() == 3:
state = state.unsqueeze(0)
if evaluate:
with torch.no_grad():
_, _, action = self.actor.sample_action(state)
else:
with torch.no_grad():
action, _, _ = self.actor.sample_action(state)
return action
def collect_experience(self, num_steps):
state, _ = self.env.reset()
state = self.preprocess_state(state)
episode_reward = 0
episode_length = 0
for _ in range(num_steps):
action = self.select_action(state)
if action.dim() == 2 and action.size(0) == 1:
action_for_env = action.squeeze(0)
else:
action_for_env = action
if action.dim() == 1:
action_for_env = action.cpu().numpy()
else:
action_for_env = action.squeeze(0).cpu().numpy()
next_state, reward, done, truncated, _ = self.env.step(action_for_env)
next_state = self.preprocess_state(next_state)
self.replay_buffer.push(state, action, reward, next_state, done or truncated)
episode_reward += reward
episode_length += 1
state = next_state
if done or truncated:
self.episode_rewards.append(episode_reward)
self.episode_lengths.append(episode_length)
state, _ = self.env.reset()
state = self.preprocess_state(state)
episode_reward = 0
episode_length = 0
def update_policy(self):
if len(self.replay_buffer) < self.batch_size:
return {}
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
with torch.no_grad():
next_actions, next_log_probs, _ = self.actor.sample_action(next_states)
next_q1, next_q2 = self.critic_target(next_states, next_actions)
next_q = torch.min(next_q1, next_q2) - self.alpha * next_log_probs
target_q = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * self.gamma * next_q
current_q1, current_q2 = self.critic(states, actions)
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.training_metrics['q1_values'].extend(current_q1.detach().cpu().numpy().flatten())
self.training_metrics['q2_values'].extend(current_q2.detach().cpu().numpy().flatten())
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
new_actions, log_probs, _ = self.actor.sample_action(states)
q1_new, q2_new = self.critic(states, new_actions)
q_new = torch.min(q1_new, q2_new)
entropy = -log_probs.mean()
self.training_metrics['entropy'].append(entropy.item())
actor_loss = (self.alpha * log_probs - q_new).mean()
self.training_metrics['policy_loss'].append(actor_loss.item())
self.training_metrics['value_loss'].append(critic_loss.item())
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
alpha_loss = None
if self.auto_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
current_alpha = self.alpha.item() if hasattr(self.alpha, 'item') else self.alpha
self.training_metrics['alpha_values'].append(current_alpha)
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)
return {
'critic_loss': critic_loss.item(),
'actor_loss': actor_loss.item(),
'alpha_loss': alpha_loss.item() if alpha_loss is not None else 0,
'alpha': current_alpha,
'entropy': entropy.item(),
'mean_q1': current_q1.mean().item(),
'mean_q2': current_q2.mean().item(),
'std_q1': current_q1.std().item(),
'std_q2': current_q2.std().item()
}SAC Agent - CSY
Conclusion
SAC remains a cornerstone in RL, much like PPO. SAC is widely adopted as a baseline in research and experimentation due to its robustness and sample efficiency. However, recent advancements in model-based and offline RL methods have introduced strong alternatives that challenge its dominance. Nonetheless, SAC and PPO continue to represent foundational milestones in the RL landscape—SAC from the off-policy perspective and PPO from the on-policy perspective—both within the model-free setting.