SorryToPerson logo
返回
算法2026-05-04·12 分钟

算法知识库:强化学习算法实现

JavaScript/TypeScript 实现强化学习基础算法,如Q-Learning、SARSA、策略梯度等。

强化学习算法实现

1. Q-Learning 算法

ts
class QLearning {
  private qTable: Map<string, Map<string, number>> = new Map();
  private learningRate: number;
  private discountFactor: number;
  private explorationRate: number;
  private minExplorationRate: number;
  private explorationDecay: number;

  constructor(learningRate: number = 0.1, discountFactor: number = 0.9, explorationRate: number = 1.0, minExplorationRate: number = 0.01, explorationDecay: number = 0.995) {
    this.learningRate = learningRate;
    this.discountFactor = discountFactor;
    this.explorationRate = explorationRate;
    this.minExplorationRate = minExplorationRate;
    this.explorationDecay = explorationDecay;
  }

  getQValue(state: string, action: string): number {
    if (!this.qTable.has(state)) {
      this.qTable.set(state, new Map());
    }
    return this.qTable.get(state)!.get(action) || 0;
  }

  setQValue(state: string, action: string, value: number): void {
    if (!this.qTable.has(state)) {
      this.qTable.set(state, new Map());
    }
    this.qTable.get(state)!.set(action, value);
  }

  chooseAction(state: string, actions: string[]): string {
    if (Math.random() < this.explorationRate) {
      // 探索:随机选择动作
      return actions[Math.floor(Math.random() * actions.length)];
    } else {
      // 利用:选择 Q 值最大的动作
      let bestAction = actions[0];
      let bestQValue = this.getQValue(state, bestAction);

      for (const action of actions) {
        const qValue = this.getQValue(state, action);
        if (qValue > bestQValue) {
          bestQValue = qValue;
          bestAction = action;
        }
      }

      return bestAction;
    }
  }

  update(state: string, action: string, reward: number, nextState: string, nextActions: string[]): void {
    const currentQ = this.getQValue(state, action);

    // 计算最大未来 Q 值
    let maxNextQ = -Infinity;
    for (const nextAction of nextActions) {
      const nextQ = this.getQValue(nextState, nextAction);
      if (nextQ > maxNextQ) {
        maxNextQ = nextQ;
      }
    }
    if (maxNextQ === -Infinity) maxNextQ = 0;

    // Q-Learning 更新公式
    const newQ = currentQ + this.learningRate * (reward + this.discountFactor * maxNextQ - currentQ);
    this.setQValue(state, action, newQ);

    // 衰减探索率
    this.explorationRate = Math.max(this.minExplorationRate, this.explorationRate * this.explorationDecay);
  }

  getPolicy(): Map<string, string> {
    const policy = new Map<string, string>();

    for (const [state, actions] of this.qTable) {
      let bestAction = '';
      let bestQValue = -Infinity;

      for (const [action, qValue] of actions) {
        if (qValue > bestQValue) {
          bestQValue = qValue;
          bestAction = action;
        }
      }

      if (bestAction) {
        policy.set(state, bestAction);
      }
    }

    return policy;
  }
}

2. SARSA 算法

ts
class SARSA extends QLearning {
  update(state: string, action: string, reward: number, nextState: string, nextActions: string[]): void {
    const currentQ = this.getQValue(state, action);

    // 选择下一个动作 (使用 ε-贪婪策略)
    const nextAction = this.chooseAction(nextState, nextActions);
    const nextQ = this.getQValue(nextState, nextAction);

    // SARSA 更新公式
    const newQ = currentQ + this.learningRate * (reward + this.discountFactor * nextQ - currentQ);
    this.setQValue(state, action, newQ);

    // 衰减探索率
    this.explorationRate = Math.max(this.minExplorationRate, this.explorationRate * this.explorationDecay);
  }
}

3. 策略梯度算法 (REINFORCE)

ts
class PolicyGradient {
  private policy: Map<string, Map<string, number>> = new Map(); // 状态 -> 动作 -> 概率
  private learningRate: number;
  private baseline: number = 0;

  constructor(learningRate: number = 0.01) {
    this.learningRate = learningRate;
  }

  getActionProbabilities(state: string, actions: string[]): Map<string, number> {
    if (!this.policy.has(state)) {
      // 初始化均匀分布
      const probabilities = new Map<string, number>();
      const prob = 1 / actions.length;
      for (const action of actions) {
        probabilities.set(action, prob);
      }
      this.policy.set(state, probabilities);
    }

    return new Map(this.policy.get(state)!);
  }

  chooseAction(state: string, actions: string[]): string {
    const probabilities = this.getActionProbabilities(state, actions);
    const random = Math.random();
    let cumulative = 0;

    for (const [action, prob] of probabilities) {
      cumulative += prob;
      if (random <= cumulative) {
        return action;
      }
    }

    return actions[0];
  }

  update(trajectory: Array<{ state: string; action: string; reward: number }>): void {
    // 计算每个时间步的回报
    const returns: number[] = [];
    let G = 0;

    for (let t = trajectory.length - 1; t >= 0; t -= 1) {
      G = trajectory[t].reward + 0.99 * G; // 折扣因子 0.99
      returns[t] = G;
    }

    // 更新基线
    const avgReturn = returns.reduce((sum, r) => sum + r, 0) / returns.length;
    this.baseline = 0.9 * this.baseline + 0.1 * avgReturn;

    // 更新策略
    for (let t = 0; t < trajectory.length; t += 1) {
      const { state, action } = trajectory[t];
      const advantage = returns[t] - this.baseline;

      const probabilities = this.policy.get(state)!;
      const currentProb = probabilities.get(action)!;

      // 增加优势动作的概率,减少劣势动作的概率
      const gradient = (advantage * (1 - currentProb)) / currentProb;

      for (const [act, prob] of probabilities) {
        if (act === action) {
          probabilities.set(act, prob + this.learningRate * gradient);
        } else {
          probabilities.set(act, prob - (this.learningRate * gradient) / (probabilities.size - 1));
        }
      }

      // 归一化概率
      this.normalizeProbabilities(probabilities);
    }
  }

  private normalizeProbabilities(probabilities: Map<string, number>): void {
    const total = Array.from(probabilities.values()).reduce((sum, prob) => sum + prob, 0);
    for (const [action, prob] of probabilities) {
      probabilities.set(action, prob / total);
    }
  }
}

4. Q-Learning with 经验回放

ts
interface Experience {
  state: string;
  action: string;
  reward: number;
  nextState: string;
  nextActions: string[];
}

class QLearningWithReplay extends QLearning {
  private replayBuffer: Experience[] = [];
  private bufferSize: number;
  private batchSize: number;

  constructor(
    learningRate: number = 0.1,
    discountFactor: number = 0.9,
    explorationRate: number = 1.0,
    minExplorationRate: number = 0.01,
    explorationDecay: number = 0.995,
    bufferSize: number = 10000,
    batchSize: number = 32,
  ) {
    super(learningRate, discountFactor, explorationRate, minExplorationRate, explorationDecay);
    this.bufferSize = bufferSize;
    this.batchSize = batchSize;
  }

  addExperience(experience: Experience): void {
    this.replayBuffer.push(experience);
    if (this.replayBuffer.length > this.bufferSize) {
      this.replayBuffer.shift();
    }
  }

  trainOnBatch(): void {
    if (this.replayBuffer.length < this.batchSize) return;

    // 随机采样批次
    const batch = this.sampleBatch();

    for (const experience of batch) {
      const { state, action, reward, nextState, nextActions } = experience;
      this.update(state, action, reward, nextState, nextActions);
    }
  }

  private sampleBatch(): Experience[] {
    const batch: Experience[] = [];
    const indices = new Set<number>();

    while (batch.length < this.batchSize) {
      const index = Math.floor(Math.random() * this.replayBuffer.length);
      if (!indices.has(index)) {
        indices.add(index);
        batch.push(this.replayBuffer[index]);
      }
    }

    return batch;
  }
}

5. 多臂赌博机 (ε-贪婪算法)

ts
class EpsilonGreedyBandit {
  private actionValues: number[];
  private actionCounts: number[];
  private epsilon: number;

  constructor(numActions: number, epsilon: number = 0.1) {
    this.actionValues = new Array(numActions).fill(0);
    this.actionCounts = new Array(numActions).fill(0);
    this.epsilon = epsilon;
  }

  chooseAction(): number {
    if (Math.random() < this.epsilon) {
      // 探索:随机选择
      return Math.floor(Math.random() * this.actionValues.length);
    } else {
      // 利用:选择价值最高的动作
      let bestAction = 0;
      let bestValue = this.actionValues[0];

      for (let i = 1; i < this.actionValues.length; i += 1) {
        if (this.actionValues[i] > bestValue) {
          bestValue = this.actionValues[i];
          bestAction = i;
        }
      }

      return bestAction;
    }
  }

  update(action: number, reward: number): void {
    this.actionCounts[action] += 1;
    const n = this.actionCounts[action];

    // 增量更新平均奖励
    this.actionValues[action] += (reward - this.actionValues[action]) / n;
  }

  getActionValues(): number[] {
    return [...this.actionValues];
  }
}

6. 演员-评论家算法 (Actor-Critic)

ts
class ActorCritic {
  private actor: PolicyGradient;
  private critic: Map<string, number> = new Map(); // 状态价值函数
  private learningRateActor: number;
  private learningRateCritic: number;
  private discountFactor: number;

  constructor(learningRateActor: number = 0.01, learningRateCritic: number = 0.1, discountFactor: number = 0.9) {
    this.actor = new PolicyGradient(learningRateActor);
    this.critic = new Map();
    this.learningRateActor = learningRateActor;
    this.learningRateCritic = learningRateCritic;
    this.discountFactor = discountFactor;
  }

  chooseAction(state: string, actions: string[]): string {
    return this.actor.chooseAction(state, actions);
  }

  update(state: string, action: string, reward: number, nextState: string): void {
    // 获取当前状态价值
    const currentValue = this.critic.get(state) || 0;
    const nextValue = this.critic.get(nextState) || 0;

    // 计算 TD 误差
    const tdError = reward + this.discountFactor * nextValue - currentValue;

    // 更新评论家 (价值函数)
    this.critic.set(state, currentValue + this.learningRateCritic * tdError);

    // 更新演员 (策略)
    const probabilities = this.actor.getActionProbabilities(state, [action]);
    const currentProb = probabilities.get(action)!;

    // 策略梯度更新
    const gradient = (tdError * (1 - currentProb)) / currentProb;
    probabilities.set(action, currentProb + this.learningRateActor * gradient);

    // 归一化概率 (简化版)
    const total = Array.from(probabilities.values()).reduce((sum, prob) => sum + prob, 0);
    for (const [act, prob] of probabilities) {
      probabilities.set(act, prob / total);
    }
  }
}

7. 实现要点

  • Q-Learning 使用 Q 表学习最优策略。
  • SARSA 是同策略的时序差分学习。
  • 策略梯度直接优化策略。
  • 经验回放提高样本效率。
  • Actor-Critic 结合策略和价值学习。
算法强化学习JavaScript