算法2026-05-04·12 分钟
算法知识库:强化学习算法实现
JavaScript/TypeScript 实现强化学习基础算法,如Q-Learning、SARSA、策略梯度等。
强化学习算法实现
1. Q-Learning 算法
ts
class QLearning {
private qTable: Map<string, Map<string, number>> = new Map();
private learningRate: number;
private discountFactor: number;
private explorationRate: number;
private minExplorationRate: number;
private explorationDecay: number;
constructor(learningRate: number = 0.1, discountFactor: number = 0.9, explorationRate: number = 1.0, minExplorationRate: number = 0.01, explorationDecay: number = 0.995) {
this.learningRate = learningRate;
this.discountFactor = discountFactor;
this.explorationRate = explorationRate;
this.minExplorationRate = minExplorationRate;
this.explorationDecay = explorationDecay;
}
getQValue(state: string, action: string): number {
if (!this.qTable.has(state)) {
this.qTable.set(state, new Map());
}
return this.qTable.get(state)!.get(action) || 0;
}
setQValue(state: string, action: string, value: number): void {
if (!this.qTable.has(state)) {
this.qTable.set(state, new Map());
}
this.qTable.get(state)!.set(action, value);
}
chooseAction(state: string, actions: string[]): string {
if (Math.random() < this.explorationRate) {
// 探索:随机选择动作
return actions[Math.floor(Math.random() * actions.length)];
} else {
// 利用:选择 Q 值最大的动作
let bestAction = actions[0];
let bestQValue = this.getQValue(state, bestAction);
for (const action of actions) {
const qValue = this.getQValue(state, action);
if (qValue > bestQValue) {
bestQValue = qValue;
bestAction = action;
}
}
return bestAction;
}
}
update(state: string, action: string, reward: number, nextState: string, nextActions: string[]): void {
const currentQ = this.getQValue(state, action);
// 计算最大未来 Q 值
let maxNextQ = -Infinity;
for (const nextAction of nextActions) {
const nextQ = this.getQValue(nextState, nextAction);
if (nextQ > maxNextQ) {
maxNextQ = nextQ;
}
}
if (maxNextQ === -Infinity) maxNextQ = 0;
// Q-Learning 更新公式
const newQ = currentQ + this.learningRate * (reward + this.discountFactor * maxNextQ - currentQ);
this.setQValue(state, action, newQ);
// 衰减探索率
this.explorationRate = Math.max(this.minExplorationRate, this.explorationRate * this.explorationDecay);
}
getPolicy(): Map<string, string> {
const policy = new Map<string, string>();
for (const [state, actions] of this.qTable) {
let bestAction = '';
let bestQValue = -Infinity;
for (const [action, qValue] of actions) {
if (qValue > bestQValue) {
bestQValue = qValue;
bestAction = action;
}
}
if (bestAction) {
policy.set(state, bestAction);
}
}
return policy;
}
}2. SARSA 算法
ts
class SARSA extends QLearning {
update(state: string, action: string, reward: number, nextState: string, nextActions: string[]): void {
const currentQ = this.getQValue(state, action);
// 选择下一个动作 (使用 ε-贪婪策略)
const nextAction = this.chooseAction(nextState, nextActions);
const nextQ = this.getQValue(nextState, nextAction);
// SARSA 更新公式
const newQ = currentQ + this.learningRate * (reward + this.discountFactor * nextQ - currentQ);
this.setQValue(state, action, newQ);
// 衰减探索率
this.explorationRate = Math.max(this.minExplorationRate, this.explorationRate * this.explorationDecay);
}
}3. 策略梯度算法 (REINFORCE)
ts
class PolicyGradient {
private policy: Map<string, Map<string, number>> = new Map(); // 状态 -> 动作 -> 概率
private learningRate: number;
private baseline: number = 0;
constructor(learningRate: number = 0.01) {
this.learningRate = learningRate;
}
getActionProbabilities(state: string, actions: string[]): Map<string, number> {
if (!this.policy.has(state)) {
// 初始化均匀分布
const probabilities = new Map<string, number>();
const prob = 1 / actions.length;
for (const action of actions) {
probabilities.set(action, prob);
}
this.policy.set(state, probabilities);
}
return new Map(this.policy.get(state)!);
}
chooseAction(state: string, actions: string[]): string {
const probabilities = this.getActionProbabilities(state, actions);
const random = Math.random();
let cumulative = 0;
for (const [action, prob] of probabilities) {
cumulative += prob;
if (random <= cumulative) {
return action;
}
}
return actions[0];
}
update(trajectory: Array<{ state: string; action: string; reward: number }>): void {
// 计算每个时间步的回报
const returns: number[] = [];
let G = 0;
for (let t = trajectory.length - 1; t >= 0; t -= 1) {
G = trajectory[t].reward + 0.99 * G; // 折扣因子 0.99
returns[t] = G;
}
// 更新基线
const avgReturn = returns.reduce((sum, r) => sum + r, 0) / returns.length;
this.baseline = 0.9 * this.baseline + 0.1 * avgReturn;
// 更新策略
for (let t = 0; t < trajectory.length; t += 1) {
const { state, action } = trajectory[t];
const advantage = returns[t] - this.baseline;
const probabilities = this.policy.get(state)!;
const currentProb = probabilities.get(action)!;
// 增加优势动作的概率,减少劣势动作的概率
const gradient = (advantage * (1 - currentProb)) / currentProb;
for (const [act, prob] of probabilities) {
if (act === action) {
probabilities.set(act, prob + this.learningRate * gradient);
} else {
probabilities.set(act, prob - (this.learningRate * gradient) / (probabilities.size - 1));
}
}
// 归一化概率
this.normalizeProbabilities(probabilities);
}
}
private normalizeProbabilities(probabilities: Map<string, number>): void {
const total = Array.from(probabilities.values()).reduce((sum, prob) => sum + prob, 0);
for (const [action, prob] of probabilities) {
probabilities.set(action, prob / total);
}
}
}4. Q-Learning with 经验回放
ts
interface Experience {
state: string;
action: string;
reward: number;
nextState: string;
nextActions: string[];
}
class QLearningWithReplay extends QLearning {
private replayBuffer: Experience[] = [];
private bufferSize: number;
private batchSize: number;
constructor(
learningRate: number = 0.1,
discountFactor: number = 0.9,
explorationRate: number = 1.0,
minExplorationRate: number = 0.01,
explorationDecay: number = 0.995,
bufferSize: number = 10000,
batchSize: number = 32,
) {
super(learningRate, discountFactor, explorationRate, minExplorationRate, explorationDecay);
this.bufferSize = bufferSize;
this.batchSize = batchSize;
}
addExperience(experience: Experience): void {
this.replayBuffer.push(experience);
if (this.replayBuffer.length > this.bufferSize) {
this.replayBuffer.shift();
}
}
trainOnBatch(): void {
if (this.replayBuffer.length < this.batchSize) return;
// 随机采样批次
const batch = this.sampleBatch();
for (const experience of batch) {
const { state, action, reward, nextState, nextActions } = experience;
this.update(state, action, reward, nextState, nextActions);
}
}
private sampleBatch(): Experience[] {
const batch: Experience[] = [];
const indices = new Set<number>();
while (batch.length < this.batchSize) {
const index = Math.floor(Math.random() * this.replayBuffer.length);
if (!indices.has(index)) {
indices.add(index);
batch.push(this.replayBuffer[index]);
}
}
return batch;
}
}5. 多臂赌博机 (ε-贪婪算法)
ts
class EpsilonGreedyBandit {
private actionValues: number[];
private actionCounts: number[];
private epsilon: number;
constructor(numActions: number, epsilon: number = 0.1) {
this.actionValues = new Array(numActions).fill(0);
this.actionCounts = new Array(numActions).fill(0);
this.epsilon = epsilon;
}
chooseAction(): number {
if (Math.random() < this.epsilon) {
// 探索:随机选择
return Math.floor(Math.random() * this.actionValues.length);
} else {
// 利用:选择价值最高的动作
let bestAction = 0;
let bestValue = this.actionValues[0];
for (let i = 1; i < this.actionValues.length; i += 1) {
if (this.actionValues[i] > bestValue) {
bestValue = this.actionValues[i];
bestAction = i;
}
}
return bestAction;
}
}
update(action: number, reward: number): void {
this.actionCounts[action] += 1;
const n = this.actionCounts[action];
// 增量更新平均奖励
this.actionValues[action] += (reward - this.actionValues[action]) / n;
}
getActionValues(): number[] {
return [...this.actionValues];
}
}6. 演员-评论家算法 (Actor-Critic)
ts
class ActorCritic {
private actor: PolicyGradient;
private critic: Map<string, number> = new Map(); // 状态价值函数
private learningRateActor: number;
private learningRateCritic: number;
private discountFactor: number;
constructor(learningRateActor: number = 0.01, learningRateCritic: number = 0.1, discountFactor: number = 0.9) {
this.actor = new PolicyGradient(learningRateActor);
this.critic = new Map();
this.learningRateActor = learningRateActor;
this.learningRateCritic = learningRateCritic;
this.discountFactor = discountFactor;
}
chooseAction(state: string, actions: string[]): string {
return this.actor.chooseAction(state, actions);
}
update(state: string, action: string, reward: number, nextState: string): void {
// 获取当前状态价值
const currentValue = this.critic.get(state) || 0;
const nextValue = this.critic.get(nextState) || 0;
// 计算 TD 误差
const tdError = reward + this.discountFactor * nextValue - currentValue;
// 更新评论家 (价值函数)
this.critic.set(state, currentValue + this.learningRateCritic * tdError);
// 更新演员 (策略)
const probabilities = this.actor.getActionProbabilities(state, [action]);
const currentProb = probabilities.get(action)!;
// 策略梯度更新
const gradient = (tdError * (1 - currentProb)) / currentProb;
probabilities.set(action, currentProb + this.learningRateActor * gradient);
// 归一化概率 (简化版)
const total = Array.from(probabilities.values()).reduce((sum, prob) => sum + prob, 0);
for (const [act, prob] of probabilities) {
probabilities.set(act, prob / total);
}
}
}7. 实现要点
- Q-Learning 使用 Q 表学习最优策略。
- SARSA 是同策略的时序差分学习。
- 策略梯度直接优化策略。
- 经验回放提高样本效率。
- Actor-Critic 结合策略和价值学习。
算法强化学习JavaScript