Pythonでチャートに回帰直線を引く(トレンドを把握する)

今回は下図のようにチャートに回帰直線を引いてみます。

回帰直線

回帰とは

回帰とは得られたデータxを \(Y=f(X)\)という関数に当てはめることを言います。
この関数のことをモデルとも言います。

例えば、以下のようなデータ(点)があった場合、

f:id:ttt242242:20190905145606g:plain

以下のように\(Y=f(X)\)に当てはめることです。

f:id:ttt242242:20190905145716g:plain

で、この直線を回帰直線と言います。

今回はチャートに回帰直線を引いてみます。
この回帰によって簡易的にチャートのトレンドを把握するのに役立つかも知れません(主にトレンド把握には移動平均を使いますが)。

実際にはこの回帰直線の傾きを使って、トレンド具合を判断したほうが良いと思いますが、今回は単純に線を引いてみます。

実装

では、実装してみます。

データの取得

まずチャートデータを取得します。
例によってOANDA APIを用います。

OANDA APIを使うためには以下の記事を参考にしてください。

以下のコードでデータを取得できます。

def get_candles(instrument="USD_JPY", params=None, api=None):
    """
        足データを取得してDataFrameに変換
    """
    instruments_candles = instruments.InstrumentsCandles(
        instrument=instrument, params=params)

    api.request(instruments_candles)
    response = instruments_candles.response

    df = pd.DataFrame(response["candles"])

    return df


def get_prices():
    """
        指定した数だけデータの取得
    """

    access_token = "your access_token"
    api = API(access_token=access_token, environment="practice")

    params = {
        "granularity": GRANULARITY,
        "count": COUNT,
        "price": "B",
    }

    # 足データの取得
    candles = None
    for i in range(NB_ITR):
        print(i)
        new_candles = get_candles(
            instrument=INSTRUMENT, params=params, api=api)
        params["to"] = new_candles["time"].iloc[0]
        print(params["to"])
        candles = pd.concat([new_candles, candles])

    prices = np.array([x["c"] for x in candles["bid"].values])
    prices = prices.astype(np.float64)
    return prices, candles


# 為替データの取得とプロット
prices, _ = get_prices()
x = np.arange(len(prices))
plt.plot(x, prices)

上記のコードを実行すると、下記のようなグラフをプロットします。

回帰直線を引く

では、先程のチャートに回帰直線を引いてみます。

回帰直線はscikit-learnというライブラリを使えば簡単に引くことができます。

scikit-learnをインストールしていない人は下記のコマンドでインストールしてください。

pip install scikit-learn

回帰直線を求めるためにlinear_modelモジュールを用います。

では、回帰直線を求めてみましょう。

# 回帰直線の計算
reg = linear_model.LinearRegression()
x_ = x.reshape(-1, 1)
prices_ = prices.reshape(-1, 1)
result = reg.fit(x_, prices_)
a = result.coef_[0][0]  # 傾き
b = result.intercept_[0]    # 切片

回帰直線はfitメソッドを使って求めることができます。

回帰直線の傾きはcoef_に、切片はintercept_に格納されています。

では、プロットしてみます。

# 回帰直線プロット
y = []
for xi in x:
    y_ = a*xi + b
    y.append(y_)
plt.plot(x, y)
plt.show()

うまく引けていますね!傾きだけ知りたければ、aを標準出力してみましょう。

タイトルとURLをコピーしました