算法2026-04-22·10 分钟
算法知识库:机器学习基础算法实现
JavaScript/TypeScript 实现机器学习基础算法,如线性回归、K-Means 聚类等。
机器学习基础算法实现
1. 线性回归
ts
class LinearRegression {
private weights: number[] = [];
private bias = 0;
private learningRate: number;
private epochs: number;
constructor(learningRate: number = 0.01, epochs: number = 1000) {
this.learningRate = learningRate;
this.epochs = epochs;
}
fit(X: number[][], y: number[]): void {
const n = X.length;
const m = X[0].length;
this.weights = new Array(m).fill(0);
for (let epoch = 0; epoch < this.epochs; epoch += 1) {
let dw = new Array(m).fill(0);
let db = 0;
for (let i = 0; i < n; i += 1) {
const prediction = this.predictSingle(X[i]);
const error = prediction - y[i];
for (let j = 0; j < m; j += 1) {
dw[j] += error * X[i][j];
}
db += error;
}
for (let j = 0; j < m; j += 1) {
this.weights[j] -= (this.learningRate * dw[j]) / n;
}
this.bias -= (this.learningRate * db) / n;
}
}
predict(X: number[][]): number[] {
return X.map((row) => this.predictSingle(row));
}
private predictSingle(x: number[]): number {
let result = this.bias;
for (let i = 0; i < x.length; i += 1) {
result += this.weights[i] * x[i];
}
return result;
}
}2. K-Means 聚类
ts
class KMeans {
private centroids: number[][] = [];
private k: number;
private maxIterations: number;
constructor(k: number, maxIterations: number = 100) {
this.k = k;
this.maxIterations = maxIterations;
}
fit(X: number[][]): number[] {
const n = X.length;
const m = X[0].length;
// 随机初始化质心
this.centroids = [];
const indices = new Set<number>();
while (this.centroids.length < this.k) {
const idx = Math.floor(Math.random() * n);
if (!indices.has(idx)) {
indices.add(idx);
this.centroids.push([...X[idx]]);
}
}
let labels = new Array(n).fill(0);
for (let iter = 0; iter < this.maxIterations; iter += 1) {
// 分配点到最近的质心
const newLabels = X.map((point) => this.closestCentroid(point));
// 检查收敛
if (this.arraysEqual(labels, newLabels)) break;
labels = newLabels;
// 更新质心
this.updateCentroids(X, labels);
}
return labels;
}
private closestCentroid(point: number[]): number {
let minDist = Infinity;
let closest = 0;
for (let i = 0; i < this.centroids.length; i += 1) {
const dist = this.euclideanDistance(point, this.centroids[i]);
if (dist < minDist) {
minDist = dist;
closest = i;
}
}
return closest;
}
private updateCentroids(X: number[][], labels: number[]): void {
const counts = new Array(this.k).fill(0);
const sums = this.centroids.map(() => new Array(X[0].length).fill(0));
for (let i = 0; i < X.length; i += 1) {
const label = labels[i];
counts[label] += 1;
for (let j = 0; j < X[i].length; j += 1) {
sums[label][j] += X[i][j];
}
}
for (let i = 0; i < this.k; i += 1) {
if (counts[i] > 0) {
for (let j = 0; j < sums[i].length; j += 1) {
this.centroids[i][j] = sums[i][j] / counts[i];
}
}
}
}
private euclideanDistance(a: number[], b: number[]): number {
let sum = 0;
for (let i = 0; i < a.length; i += 1) {
sum += (a[i] - b[i]) ** 2;
}
return Math.sqrt(sum);
}
private arraysEqual(a: number[], b: number[]): boolean {
return a.length === b.length && a.every((val, idx) => val === b[idx]);
}
}3. 决策树
ts
interface DataPoint {
features: number[];
label: number;
}
class DecisionTreeNode {
feature?: number;
threshold?: number;
left?: DecisionTreeNode;
right?: DecisionTreeNode;
value?: number;
isLeaf = false;
}
class DecisionTree {
private root?: DecisionTreeNode;
private maxDepth: number;
constructor(maxDepth: number = 10) {
this.maxDepth = maxDepth;
}
fit(X: number[][], y: number[]): void {
const data: DataPoint[] = X.map((features, i) => ({ features, label: y[i] }));
this.root = this.buildTree(data, 0);
}
predict(X: number[][]): number[] {
return X.map((row) => this.predictSingle(row));
}
private buildTree(data: DataPoint[], depth: number): DecisionTreeNode {
if (depth >= this.maxDepth || this.isPure(data)) {
const node = new DecisionTreeNode();
node.isLeaf = true;
node.value = this.majorityVote(data);
return node;
}
const { feature, threshold } = this.findBestSplit(data);
const node = new DecisionTreeNode();
node.feature = feature;
node.threshold = threshold;
const leftData = data.filter((point) => point.features[feature] <= threshold);
const rightData = data.filter((point) => point.features[feature] > threshold);
node.left = this.buildTree(leftData, depth + 1);
node.right = this.buildTree(rightData, depth + 1);
return node;
}
private findBestSplit(data: DataPoint[]): { feature: number; threshold: number } {
let bestGini = Infinity;
let bestFeature = 0;
let bestThreshold = 0;
for (let feature = 0; feature < data[0].features.length; feature += 1) {
const values = data.map((point) => point.features[feature]).sort((a, b) => a - b);
for (let i = 1; i < values.length; i += 1) {
const threshold = (values[i - 1] + values[i]) / 2;
const gini = this.calculateGini(data, feature, threshold);
if (gini < bestGini) {
bestGini = gini;
bestFeature = feature;
bestThreshold = threshold;
}
}
}
return { feature: bestFeature, threshold: bestThreshold };
}
private calculateGini(data: DataPoint[], feature: number, threshold: number): number {
const left = data.filter((point) => point.features[feature] <= threshold);
const right = data.filter((point) => point.features[feature] > threshold);
const leftGini = this.giniImpurity(left);
const rightGini = this.giniImpurity(right);
return (left.length * leftGini + right.length * rightGini) / data.length;
}
private giniImpurity(data: DataPoint[]): number {
if (data.length === 0) return 0;
const labelCounts = new Map<number, number>();
for (const point of data) {
labelCounts.set(point.label, (labelCounts.get(point.label) || 0) + 1);
}
let impurity = 1;
for (const count of labelCounts.values()) {
const p = count / data.length;
impurity -= p * p;
}
return impurity;
}
private isPure(data: DataPoint[]): boolean {
const firstLabel = data[0].label;
return data.every((point) => point.label === firstLabel);
}
private majorityVote(data: DataPoint[]): number {
const counts = new Map<number, number>();
for (const point of data) {
counts.set(point.label, (counts.get(point.label) || 0) + 1);
}
let maxCount = 0;
let majorityLabel = 0;
for (const [label, count] of counts) {
if (count > maxCount) {
maxCount = count;
majorityLabel = label;
}
}
return majorityLabel;
}
private predictSingle(features: number[]): number {
let node = this.root;
while (node && !node.isLeaf) {
if (features[node.feature!] <= node.threshold!) {
node = node.left;
} else {
node = node.right;
}
}
return node?.value || 0;
}
}4. 实现要点
- 线性回归使用梯度下降优化。
- K-Means 迭代更新质心。
- 决策树使用基尼不纯度选择分割。
- 这些是简化实现,实际应用中需要更多优化。
算法机器学习JavaScript