跳转至

三分搜索

问题

已知函数 f(x) 在区间 [l, r] 上单峰且连续,求 f(x)[l, r]上的极值。

算法

使用标准三分搜索算法,每次迭代将当前区间的长度缩小 \frac{1}{3}

理论上最终收敛到极值点或极值区间的左端,但实际中由于复点误差导致的判等错误,当 f(x) 在区间 [l_m, r_m] 上取极值时,最终定位往往在线段内的随机点,但返回的函数值仍可视为 f(x)[l, r] 上的极值。

代码

[-1000, 1000] 范围内搜索单峰连续函数 f(x) 极小值

double ternary_search() {
    double l = -1000, r = 1000;
    for (int i = 0; i < 300; i++) {
        double m1 = (2 * l + r) / 3;
        double m2 = (l + 2 * r) / 3;
        if (f(m1) >= f(m2)) l = m1;
        else r = m2;
    }
    return f(l);
}

注: 迭代 100 次后区间长度为先前的 2.4597 \times 10^{-18} ,迭代 1000 次之后区间长度为缩减到原来的 1.4881 \times 10^{-53} ,故一般数百次的迭代能够满足算法竞赛中的精度要求。

变体 - 平面三分搜索

当需要求解的单峰连续二元函数 f_{x, y}(x, y) 在矩形区域 x \in [l_x, r_x], y \in [l_y, r_y] 上的极值时,可使用平面三分搜索算法。具体地,分两层调用一维三分搜索算法。

从几何意义来看,平面三分搜索首先利用一维三分搜索求解给定 x 时函数在 y 方向上的极值 f_x(x) = \max_{y \in [l_y, r_y]}f(x, y),然后利用三分搜索求解 f_x(x)x 方向上的极值。

代码

[-1000, 1000] \times [-1000, 1000] 上搜索单峰函数 f(x) 的极小值

double t_search_y(double x) {
    double l = -1000, r = 1000;
    for (int i = 0; i < 300; i++) {
        double m1 = (2 * l + r) / 3;
        double m2 = (l + 2 * r) / 3;
        if (f(x, m1) >= f(x, m2)) l = m1;
        else r = m2;
    }
    return f(x, l);
}
double t_search_x() {
    double l = -1000, r = 1000;
    for (int i = 0; i < 300; i++) {
        double m1 = (2 * l + r) / 3;
        double m2 = (l + 2 * r) / 3;
        if (t_search_y(m1) >= t_search_y(m2)) l = m1;
        else r = m2;
    }
    return t_search_y(l);
}

模板封装

一维三分搜索

class TernarySearch {
   public:
    TernarySearch(function<double(double)> _f) : f(_f) {
        l = -1000, r = 1000;
    }

    double MaxValue() {
        for (int i = 0; i < iter_num; i++) {
            double m1 = (2 * l + r) / 3;
            double m2 = (l + 2 * r) / 3;
            if (f(m1) >= f(m2)) l = m1;
            else r = m2;
        }
        return f(l);
    }

   private:
    int iter_num = 300;
    double l, r;
    function<double(double)> f;
};

平面三分搜索

class TernarySearch2D {
   public:
    TernarySearch2D(function<double(double, double)> _f) : f(_f) {
        lx = -1000, rx = 1000, ly = -1000, ry = -1000;
    }

    double MaxValue() {
        double l = lx, r = rx;
        for (int i = 0; i < iter_num; i++) {
            double m1 = (2 * l + r) / 3;
            double m2 = (l + 2 * r) / 3;
            if (MaxValueY(m1) >= MaxValueY(m2)) l = m1;
            else r = m2;
        }
        return MaxValueY(l);
    }

   private:
    double MaxValueY(double x) {
        double l = ly, r = ry;
        for (int i = 0; i < iter_num; i++) {
            double m1 = (2 * l + r) / 3;
            double m2 = (l + 2 * r) / 3;
            if (f(x, m1) >= f(x, m2)) l = m1;
            else r = m2;
        }
        return f(x, l);
    }

    int iter_num = 300;
    double lx, rx, ly, ry;
    function<double(double, double)> f;
};