SorryToPerson logo
返回
算法2026-04-22·10 分钟

算法知识库:机器学习基础算法实现

JavaScript/TypeScript 实现机器学习基础算法,如线性回归、K-Means 聚类等。

机器学习基础算法实现

1. 线性回归

ts
class LinearRegression {
  private weights: number[] = [];
  private bias = 0;
  private learningRate: number;
  private epochs: number;

  constructor(learningRate: number = 0.01, epochs: number = 1000) {
    this.learningRate = learningRate;
    this.epochs = epochs;
  }

  fit(X: number[][], y: number[]): void {
    const n = X.length;
    const m = X[0].length;
    this.weights = new Array(m).fill(0);

    for (let epoch = 0; epoch < this.epochs; epoch += 1) {
      let dw = new Array(m).fill(0);
      let db = 0;

      for (let i = 0; i < n; i += 1) {
        const prediction = this.predictSingle(X[i]);
        const error = prediction - y[i];

        for (let j = 0; j < m; j += 1) {
          dw[j] += error * X[i][j];
        }
        db += error;
      }

      for (let j = 0; j < m; j += 1) {
        this.weights[j] -= (this.learningRate * dw[j]) / n;
      }
      this.bias -= (this.learningRate * db) / n;
    }
  }

  predict(X: number[][]): number[] {
    return X.map((row) => this.predictSingle(row));
  }

  private predictSingle(x: number[]): number {
    let result = this.bias;
    for (let i = 0; i < x.length; i += 1) {
      result += this.weights[i] * x[i];
    }
    return result;
  }
}

2. K-Means 聚类

ts
class KMeans {
  private centroids: number[][] = [];
  private k: number;
  private maxIterations: number;

  constructor(k: number, maxIterations: number = 100) {
    this.k = k;
    this.maxIterations = maxIterations;
  }

  fit(X: number[][]): number[] {
    const n = X.length;
    const m = X[0].length;

    // 随机初始化质心
    this.centroids = [];
    const indices = new Set<number>();
    while (this.centroids.length < this.k) {
      const idx = Math.floor(Math.random() * n);
      if (!indices.has(idx)) {
        indices.add(idx);
        this.centroids.push([...X[idx]]);
      }
    }

    let labels = new Array(n).fill(0);

    for (let iter = 0; iter < this.maxIterations; iter += 1) {
      // 分配点到最近的质心
      const newLabels = X.map((point) => this.closestCentroid(point));

      // 检查收敛
      if (this.arraysEqual(labels, newLabels)) break;
      labels = newLabels;

      // 更新质心
      this.updateCentroids(X, labels);
    }

    return labels;
  }

  private closestCentroid(point: number[]): number {
    let minDist = Infinity;
    let closest = 0;

    for (let i = 0; i < this.centroids.length; i += 1) {
      const dist = this.euclideanDistance(point, this.centroids[i]);
      if (dist < minDist) {
        minDist = dist;
        closest = i;
      }
    }

    return closest;
  }

  private updateCentroids(X: number[][], labels: number[]): void {
    const counts = new Array(this.k).fill(0);
    const sums = this.centroids.map(() => new Array(X[0].length).fill(0));

    for (let i = 0; i < X.length; i += 1) {
      const label = labels[i];
      counts[label] += 1;
      for (let j = 0; j < X[i].length; j += 1) {
        sums[label][j] += X[i][j];
      }
    }

    for (let i = 0; i < this.k; i += 1) {
      if (counts[i] > 0) {
        for (let j = 0; j < sums[i].length; j += 1) {
          this.centroids[i][j] = sums[i][j] / counts[i];
        }
      }
    }
  }

  private euclideanDistance(a: number[], b: number[]): number {
    let sum = 0;
    for (let i = 0; i < a.length; i += 1) {
      sum += (a[i] - b[i]) ** 2;
    }
    return Math.sqrt(sum);
  }

  private arraysEqual(a: number[], b: number[]): boolean {
    return a.length === b.length && a.every((val, idx) => val === b[idx]);
  }
}

3. 决策树

ts
interface DataPoint {
  features: number[];
  label: number;
}

class DecisionTreeNode {
  feature?: number;
  threshold?: number;
  left?: DecisionTreeNode;
  right?: DecisionTreeNode;
  value?: number;
  isLeaf = false;
}

class DecisionTree {
  private root?: DecisionTreeNode;
  private maxDepth: number;

  constructor(maxDepth: number = 10) {
    this.maxDepth = maxDepth;
  }

  fit(X: number[][], y: number[]): void {
    const data: DataPoint[] = X.map((features, i) => ({ features, label: y[i] }));
    this.root = this.buildTree(data, 0);
  }

  predict(X: number[][]): number[] {
    return X.map((row) => this.predictSingle(row));
  }

  private buildTree(data: DataPoint[], depth: number): DecisionTreeNode {
    if (depth >= this.maxDepth || this.isPure(data)) {
      const node = new DecisionTreeNode();
      node.isLeaf = true;
      node.value = this.majorityVote(data);
      return node;
    }

    const { feature, threshold } = this.findBestSplit(data);
    const node = new DecisionTreeNode();
    node.feature = feature;
    node.threshold = threshold;

    const leftData = data.filter((point) => point.features[feature] <= threshold);
    const rightData = data.filter((point) => point.features[feature] > threshold);

    node.left = this.buildTree(leftData, depth + 1);
    node.right = this.buildTree(rightData, depth + 1);

    return node;
  }

  private findBestSplit(data: DataPoint[]): { feature: number; threshold: number } {
    let bestGini = Infinity;
    let bestFeature = 0;
    let bestThreshold = 0;

    for (let feature = 0; feature < data[0].features.length; feature += 1) {
      const values = data.map((point) => point.features[feature]).sort((a, b) => a - b);

      for (let i = 1; i < values.length; i += 1) {
        const threshold = (values[i - 1] + values[i]) / 2;
        const gini = this.calculateGini(data, feature, threshold);

        if (gini < bestGini) {
          bestGini = gini;
          bestFeature = feature;
          bestThreshold = threshold;
        }
      }
    }

    return { feature: bestFeature, threshold: bestThreshold };
  }

  private calculateGini(data: DataPoint[], feature: number, threshold: number): number {
    const left = data.filter((point) => point.features[feature] <= threshold);
    const right = data.filter((point) => point.features[feature] > threshold);

    const leftGini = this.giniImpurity(left);
    const rightGini = this.giniImpurity(right);

    return (left.length * leftGini + right.length * rightGini) / data.length;
  }

  private giniImpurity(data: DataPoint[]): number {
    if (data.length === 0) return 0;

    const labelCounts = new Map<number, number>();
    for (const point of data) {
      labelCounts.set(point.label, (labelCounts.get(point.label) || 0) + 1);
    }

    let impurity = 1;
    for (const count of labelCounts.values()) {
      const p = count / data.length;
      impurity -= p * p;
    }

    return impurity;
  }

  private isPure(data: DataPoint[]): boolean {
    const firstLabel = data[0].label;
    return data.every((point) => point.label === firstLabel);
  }

  private majorityVote(data: DataPoint[]): number {
    const counts = new Map<number, number>();
    for (const point of data) {
      counts.set(point.label, (counts.get(point.label) || 0) + 1);
    }

    let maxCount = 0;
    let majorityLabel = 0;
    for (const [label, count] of counts) {
      if (count > maxCount) {
        maxCount = count;
        majorityLabel = label;
      }
    }

    return majorityLabel;
  }

  private predictSingle(features: number[]): number {
    let node = this.root;
    while (node && !node.isLeaf) {
      if (features[node.feature!] <= node.threshold!) {
        node = node.left;
      } else {
        node = node.right;
      }
    }
    return node?.value || 0;
  }
}

4. 实现要点

  • 线性回归使用梯度下降优化。
  • K-Means 迭代更新质心。
  • 决策树使用基尼不纯度选择分割。
  • 这些是简化实现,实际应用中需要更多优化。
算法机器学习JavaScript