当前位置: 首页 > news >正文

SAC In JAX【个人记录向】

众所周知,SAC 是 RL 中的一种高效的 Off Policy 算法,在《动手学强化学习》中已经给出了比较完善的实现。而 JAX 是一种新兴的神经网络范式,以函数式编程为基础,这里将以《动手学强化学习》中的实现为范本,实现一个 SAC In JAX,同时配套 tensorboard 与 model save 以及 model load。
需要提前安装 stable_baselines3==2.1.0,jax[cuda12_pip]==0.4.33,flax==0.9.0,tensorboard==2.14.0,tensorflow-probability==0.21.0,protobuf==3.20.3,mujoco==2.3.7 其他的根据提示配置一下应该问题不大了。
代码:

import os
import jax
# import gym
import flax
import optax
import distrax
import random
import collections
import numpy as np
import flax.serialization
import jax.numpy as jnp
from tqdm import tqdm
import gymnasium as gym
from flax import linen as nn
from functools import partial
from datetime import datetime
from flax.training import train_state
from flax.training.train_state import TrainState
from stable_baselines3.common.logger import configureclass RLTrainState(TrainState):  # type: ignore[misc]target_params: flax.core.FrozenDict  # type: ignore[misc]class ReplayBuffer:def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), donedef size(self):return len(self.buffer)def save_model_state(train_state, path, name, n_steps):"""使用flax.serialization保存单个TrainState。"""serialized_state = flax.serialization.to_bytes(train_state)os.makedirs(path, exist_ok=True)extended_path = os.path.join(path, f'{name}_{n_steps}.msgpack')with open(extended_path, 'wb') as f:f.write(serialized_state)print(f"  - 已保存: {extended_path}")def load_state(path, name, n_steps, train_state):"""使用flax.serialization从文件加载单个TrainState。"""extended_path = os.path.join(path, f'{name}_{n_steps}.msgpack')with open(extended_path, 'rb') as f:train_state_loaded = f.read()return flax.serialization.from_bytes(train_state, train_state_loaded)class EntropyCoef(nn.Module):ent_coef_init: float = 1.0@nn.compactdef __call__(self, step) -> jnp.ndarray:log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)))return log_ent_coefclass Critic(nn.Module):obs_dim: intaction_dim: inthidden_dim: int@nn.compactdef __call__(self, obs, action):cat = jnp.concatenate([obs, action], axis=1)x = nn.Dense(self.hidden_dim)(cat)x = nn.relu(x)x = nn.Dense(self.hidden_dim)(x)x = nn.relu(x)x = nn.Dense(1)(x)# x = nn.relu(x)return xclass VectorCritic(nn.Module):obs_dim: intaction_dim: inthidden_dim: intn_critics: int@nn.compactdef __call__(self, obs, action):vmap_critic = nn.vmap(Critic,variable_axes={"params": 0},split_rngs={"params": True},in_axes=None,out_axes=0,axis_size=self.n_critics,)q_values = vmap_critic(obs_dim=self.obs_dim,action_dim=self.action_dim,hidden_dim=self.hidden_dim,)(obs, action)return q_valuesclass Actor(nn.Module):obs_dim: intaction_dim: inthidden_dim: intaction_scale: float@nn.compactdef __call__(self, obs):x = nn.Dense(self.hidden_dim)(obs)x = nn.relu(x)mu = nn.Dense(self.action_dim)(x)std = nn.Dense(self.action_dim)(x)return mu, nn.softplus(std)@staticmethod@partial(jax.jit, static_argnames=["action_scale"])def sample_action(params, key, obs, actor_state, action_scale):mu, std = actor_state.apply_fn({"params": params}, obs)dist = distrax.Normal(loc=mu, scale=std)# tanh_dist = distrax.Transformed(dist, distrax.Block(distrax.Tanh(), ndims=1))# action = tanh_dist.sample(seed=key)# log_prob = tanh_dist.log_prob(action).sum(axis=-1)action = dist.sample(seed=key)log_prob = dist.log_prob(action)action = jnp.tanh(action)# log_prob = log_prob - jnp.log(1 - jnp.square(jnp.tanh(action)) + 1e-7)log_prob = log_prob - jnp.log(1 - jnp.square(action) + 1e-7)return action * action_scale, log_probclass SAC:def __init__(self, obs_dim, action_dim, hidden_dim, batch_size,actor_lr, critic_lr, alpha_lr,tau=0.005, gamma=0.99, action_scale=1, target_entropy=0.01, train_alpha=False,save_path=" ", base_name=" "):self.obs_dim, self.action_dim, self.hidden_dim = obs_dim, action_dim, hidden_dimself.batch_size = batch_sizeself.actor_lr, self.critic_lr, self.alpha_lr = actor_lr, critic_lr, alpha_lrself.tau, self.gamma, self.action_scale = tau, gamma, action_scaleself.train_alpha = train_alphaself.save_path = save_pathself.base_name = base_nameself.actor = Actor(self.obs_dim, self.action_dim, self.hidden_dim, self.action_scale)self.critic = VectorCritic(self.obs_dim, self.action_dim, self.hidden_dim, n_critics=2)self.log_alpha = EntropyCoef(0.01)self.target_entropy = target_entropyself.key = jax.random.PRNGKey(0)self.key, actor_key, critic_key, alpha_key = jax.random.split(self.key, 4)actor_params = self.actor.init(actor_key, jnp.ones((self.batch_size, obs_dim)))['params']critic_params = self.critic.init(critic_key, jnp.ones((self.batch_size, obs_dim)), jnp.ones((self.batch_size, action_dim)))['params']critic_target_params = self.critic.init(critic_key, jnp.ones((self.batch_size, obs_dim)), jnp.ones((self.batch_size, action_dim)))['params']alpha_params = self.log_alpha.init(alpha_key, 0.0)['params']actor_optx = optax.adam(actor_lr)critic_optx = optax.adam(critic_lr)alpha_optx = optax.adam(alpha_lr)self.actor_model_state = train_state.TrainState.create(apply_fn=self.actor.apply, params=actor_params, tx=actor_optx)self.critic_model_state = RLTrainState.create(apply_fn=self.critic.apply, params=critic_params, target_params=critic_target_params, tx=critic_optx)self.alpha_model_state = train_state.TrainState.create(apply_fn=self.log_alpha.apply, params=alpha_params, tx=alpha_optx)def take_action(self, state):self.key, actor_key = jax.random.split(self.key, 2)obs = jnp.array([state])action, _ = Actor.sample_action(self.actor_model_state.params, actor_key, obs, self.actor_model_state, self.action_scale)return action[0]def update(self, transition_dict):(self.actor_model_state, self.critic_model_state, self.alpha_model_state, self.key), metrics = self._train_step(self.actor_model_state, self.critic_model_state,self.alpha_model_state, self.key, transition_dict, self.action_scale, self.gamma, self.tau, self.target_entropy, self.train_alpha)return metrics@staticmethod@partial(jax.jit, static_argnames=["action_scale", "gamma", "tau", "target_entropy", "train_alpha"])def _train_step(actor_model_state, critic_model_state, alpha_model_state, key, transition, action_scale, gamma, tau, target_entropy, train_alpha):states = jnp.array(transition['states'])actions = jnp.array(transition['actions'])rewards = jnp.array(transition['rewards']).reshape(-1, 1)next_states = jnp.array(transition['next_states'])dones = jnp.array(transition['dones']).reshape(-1, 1)# rewards = (rewards + 8.0) / 8.0critic_loss, q1_loss, q2_loss, critic_model_state, key = SAC.update_critic(states, actions, rewards, next_states, dones, actor_model_state, critic_model_state, alpha_model_state, action_scale, gamma, key)actor_loss, actor_model_state, key = SAC.update_actor(states, actor_model_state, critic_model_state, alpha_model_state, action_scale, key)if train_alpha:alpha_loss, alpha_model_state, key = SAC.update_alpha(states, actor_model_state, alpha_model_state, action_scale, target_entropy, key)critic_model_state = SAC.soft_update(tau, critic_model_state)metrics = {"critic_loss": critic_loss,"actor_loss": actor_loss,"alpha_loss": alpha_loss if train_alpha else 0}return (actor_model_state, critic_model_state, alpha_model_state, key), metrics@staticmethod@partial(jax.jit, static_argnames=["action_scale", "gamma"])def update_critic(states, actions, rewards, next_states, dones, actor_model_state, critic_model_state, alpha_model_state, action_scale, gamma, key):def loss_fn(params):def calc_target(rewards, next_states, dones, key):  # 计算目标Q值now_key, actor_key, critic_key = jax.random.split(key, 3)next_actions, log_prob = Actor.sample_action(actor_model_state.params, actor_key, next_states, actor_model_state, action_scale)entropy = -log_probq_value = critic_model_state.apply_fn({"params": critic_model_state.target_params}, next_states, next_actions)log_alpha = alpha_model_state.apply_fn({"params": alpha_model_state.params}, 0)log_alpha = jax.lax.stop_gradient(log_alpha)# log_alpha = jnp.log(0.01)q1_value, q2_value = q_value[0], q_value[1]next_value = jax.lax.stop_gradient(jnp.min(jnp.stack([q1_value, q2_value], axis=0), axis=0) + jnp.exp(log_alpha) * entropy)td_target = rewards + gamma * next_value * (1 - dones)return td_target, now_keytd_target, now_key = calc_target(rewards, next_states, dones, key)current_q = critic_model_state.apply_fn({"params": params}, states, actions)current_q1, current_q2 = current_q[0], current_q[1]q1_loss = jnp.mean(jnp.square(td_target - current_q1))q2_loss = jnp.mean(jnp.square(td_target - current_q2))critic_loss = q1_loss + q2_lossreturn critic_loss, (q1_loss, q2_loss, now_key)(critic_loss, (q1_loss, q2_loss, now_key)), grads = jax.value_and_grad(loss_fn, has_aux=True)(critic_model_state.params)critic_model_state = critic_model_state.apply_gradients(grads=grads)return critic_loss, q1_loss, q2_loss, critic_model_state, now_key@staticmethod@partial(jax.jit, static_argnames=["action_scale"])def update_actor(states, actor_model_state, critic_model_state, alpha_model_state, action_scale, key):def loss_fn(params):now_key, actor_key = jax.random.split(key, 2)next_actions, log_prob = Actor.sample_action(params, actor_key, states, actor_model_state, action_scale)entropy = -log_probq_value = critic_model_state.apply_fn({"params": critic_model_state.params}, states, next_actions)log_alpha = alpha_model_state.apply_fn({"params": alpha_model_state.params}, 0)log_alpha = jax.lax.stop_gradient(log_alpha)# log_alpha = jnp.log(0.01)q1_value, q2_value = q_value[0], q_value[1]actor_loss = jnp.mean(-jnp.exp(log_alpha) * entropy - jnp.min(jnp.stack([q1_value, q2_value], axis=0), axis=0))return actor_loss, now_key(actor_loss, now_key), grads = jax.value_and_grad(loss_fn, has_aux=True)(actor_model_state.params)actor_model_state = actor_model_state.apply_gradients(grads=grads)return actor_loss, actor_model_state, now_key@staticmethod@partial(jax.jit, static_argnames=["action_scale", "target_entropy"])def update_alpha(states, actor_model_state, alpha_model_state, action_scale, target_entropy, key):def loss_fn(params):now_key, actor_key = jax.random.split(key, 2)next_actions, log_prob = Actor.sample_action(actor_model_state.params, actor_key, states, actor_model_state, action_scale)entropy = -log_problog_alpha = alpha_model_state.apply_fn({"params": params}, 0)alpha_loss = jnp.mean(jax.lax.stop_gradient((entropy - target_entropy)) * jnp.exp(log_alpha))return alpha_loss, now_key(alpha_loss, now_key), grads = jax.value_and_grad(loss_fn, has_aux=True)(alpha_model_state.params)alpha_model_state = alpha_model_state.apply_gradients(grads=grads)return alpha_loss, alpha_model_state, now_key@staticmethod@partial(jax.jit, static_argnames=["tau"])def soft_update(tau, model_state):model_state = model_state.replace(target_params=optax.incremental_update(model_state.params, model_state.target_params, tau))return model_statedef save(self, n_steps):print(f"正在保存模型至 {self.save_path} ...")save_model_state(self.actor_model_state, self.save_path, f"{self.base_name}_actor", n_steps)save_model_state(self.critic_model_state, self.save_path, f"{self.base_name}_critic", n_steps)save_model_state(self.alpha_model_state, self.save_path, f"{self.base_name}_alpha", n_steps)def load(self, n_steps):print(f"正在从 {self.save_path} 加载模型...")self.actor_model_state = load_state(self.save_path, f"{self.base_name}_actor", n_steps, self.actor_model_state)self.critic_model_state = load_state(self.save_path, f"{self.base_name}_critic", n_steps, self.critic_model_state)self.alpha_model_state = load_state(self.save_path, f"{self.base_name}_alpha", n_steps, self.alpha_model_state)print("模型加载完毕。")def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, logger):return_list = []total_steps = 0for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state, _ = env.reset()done = Falsewhile not done:# print(state)action = agent.take_action(state)next_state, reward, done, _, info = env.step(action)done = done or _replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += rewardtotal_steps += 1if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r,'dones': b_d}metrics = agent.update(transition_dict)return_list.append(episode_return)if replay_buffer.size() > minimal_size:if (i_episode + 1) % 5 == 0:metrics_to_log = {"return": episode_return,**{f"loss/{k}": v for k, v in metrics.items()}  # Add a prefix to loss names}# TODO: 将metrics_to_log 的内容写到logger中for key, value in metrics_to_log.items():logger.record(key, value)logger.dump(step=total_steps)# if (i_episode + 1) % 10 == 0:#     agent.save(total_steps)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_list# env_name = 'Pendulum-v1'
env_name = "Walker2d-v4"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_scale = env.action_space.high[0]
random.seed(0)
# exit(0)actor_lr = 3e-4
critic_lr = 7e-4
alpha_lr = 3e-4
num_episodes = 20000
hidden_dim = 256
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 1000000
minimal_size = 10000
batch_size = 256
train_alpha = True
target_entropy = -env.action_space.shape[0]
start_time = datetime.now().strftime('%Y%m%d_%H%M%S')
# start_time = "20250910_132225"
# steps = 10263log_path = f"logs/sac_{env_name}_{start_time}/"
logger = configure(log_path, ["stdout", "tensorboard"])replay_buffer = ReplayBuffer(buffer_size)
model_save_path = "logs/models"
model_base_name = f"sac_{env_name}_{start_time}"
agent = SAC(obs_dim=state_dim, action_dim=action_dim, hidden_dim=hidden_dim, batch_size=batch_size, actor_lr=actor_lr, critic_lr=critic_lr, alpha_lr=alpha_lr, tau=tau, gamma=gamma, action_scale=action_scale, target_entropy=target_entropy, train_alpha=train_alpha, save_path=model_save_path, base_name=model_base_name)
# agent.load(steps)return_list = train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, logger)
http://www.wxhsa.cn/company.asp?id=268

相关文章:

  • 1.2 亿篇论文数据集,多学科学术语料库,涵盖医学、化学、生物学、人文、物理、工程、数学、生态、经济与计算机科学,用于 NLP、知识图谱与大模型训练
  • Putty 工具集 plink和pscp使用
  • MyEMS:开源驱动下的企业能源管理革新者 —— 从技术架构到 “双碳” 落地的实践之路
  • JWT攻击详解与CTF实战
  • MyEMS:开源能源管理的破局者
  • github拉项目报Failed to connect to github.com port 443失败解决方法
  • 多进程、多线程、分布式锁
  • ECT-OS-JiuHuaShan 的终极使命是构建一个从数学到伦理皆可被绝对推理的确定性宇宙模型
  • 服务治理
  • ? #2
  • 第9章 STM32 TCP配置和测试
  • 软件开发方法与模型完全指南(从厨房到盛宴的完全指南)
  • 介绍Activiti BPMN visualizer插件的图形界面
  • NvM代码级别的调用
  • ECT-OS-JiuHuaShan 与经典/量子计算模型存在根本性范式断裂
  • 人像 风光 纪实 旅游、生活 摄影精选集
  • 必看!Apache DolphinScheduler 任务组因 MySQL 时区报错全解析与避坑指南
  • Android开发中 Button 背景控制选择器
  • redis非阻塞锁
  • MyEMS:技术架构深度剖析与用户实践支持体系
  • ECT-OS-JiuHuaShan 的本质是超验数学结构,史上首个实现完全移植保真性的认知框架
  • Appium元素等待
  • DropWizard-REST-Web-服务指南-全-
  • Spring Boot如何启动嵌入式Tomcat?
  • sql随机查看数据
  • 自我介绍
  • 83、SpringMVC全局异常处理和数据校验
  • nginx反向代理
  • 微算法科技(NASDAQ: MLGO)基于阿基米德优化算法(AOA)的区块链存储优化方案
  • mysql常用命令