• 回测研究
    • 加载策略
    • 载入历史数据
    • 撮合成交
    • 计算策略盈亏情况
    • 计算策略统计指标
    • 统计指标绘图
    • 回测引擎使用示例

    回测研究

    backtesting.py定义了回测引擎,下面主要介绍相关功能函数,以及回测引擎应用示例:

    加载策略

    把CTA策略逻辑,对应合约品种,以及参数设置(可在策略文件外修改)载入到回测引擎中。

    1. def add_strategy(self, strategy_class: type, setting: dict):
    2. """"""
    3. self.strategy_class = strategy_class
    4. self.strategy = strategy_class(
    5. self, strategy_class.__name__, self.vt_symbol, setting
    6. )

    载入历史数据

    负责载入对应品种的历史数据,大概有4个步骤:

    • 根据数据类型不同,分成K线模式和Tick模式;
    • 通过select().where()方法,有条件地从数据库中选取数据,其筛选标准包括:vt_symbol、 回测开始日期、回测结束日期、K线周期(K线模式下);
    • order_by(DbBarData.datetime)表示需要按照时间顺序载入数据;
    • 载入数据是以迭代方式进行的,数据最终存入self.history_data。
    1. def load_data(self):
    2. """"""
    3. self.output("开始加载历史数据")
    4.  
    5. if self.mode == BacktestingMode.BAR:
    6. s = (
    7. DbBarData.select()
    8. .where(
    9. (DbBarData.vt_symbol == self.vt_symbol)
    10. & (DbBarData.interval == self.interval)
    11. & (DbBarData.datetime >= self.start)
    12. & (DbBarData.datetime <= self.end)
    13. )
    14. .order_by(DbBarData.datetime)
    15. )
    16. self.history_data = [db_bar.to_bar() for db_bar in s]
    17. else:
    18. s = (
    19. DbTickData.select()
    20. .where(
    21. (DbTickData.vt_symbol == self.vt_symbol)
    22. & (DbTickData.datetime >= self.start)
    23. & (DbTickData.datetime <= self.end)
    24. )
    25. .order_by(DbTickData.datetime)
    26. )
    27. self.history_data = [db_tick.to_tick() for db_tick in s]
    28.  
    29. self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")

    撮合成交

    载入CTA策略以及相关历史数据后,策略会根据最新的数据来计算相关指标。若符合条件会生成交易信号,发出具体委托(buy/sell/short/cover),并且在下一根K线成交。

    根据委托类型的不同,回测引擎提供2种撮合成交机制来尽量模仿真实交易环节:

    • 限价单撮合成交:(以买入方向为例)先确定是否发生成交,成交标准为委托价>= 下一根K线的最低价;然后确定成交价格,成交价格为委托价与下一根K线开盘价的最小值。
    • 停止单撮合成交:(以买入方向为例)先确定是否发生成交,成交标准为委托价<= 下一根K线的最高价;然后确定成交价格,成交价格为委托价与下一根K线开盘价的最大值。

    下面展示在引擎中限价单撮合成交的流程:

    • 确定会撮合成交的价格;
    • 遍历限价单字典中的所有限价单,推送委托进入未成交队列的更新状态;
    • 判断成交状态,若出现成交,推送成交数据和委托数据;
    • 从字典中删除已成交的限价单。
    1. def cross_limit_order(self):
    2. """
    3. Cross limit order with last bar/tick data.
    4. """
    5. if self.mode == BacktestingMode.BAR:
    6. long_cross_price = self.bar.low_price
    7. short_cross_price = self.bar.high_price
    8. long_best_price = self.bar.open_price
    9. short_best_price = self.bar.open_price
    10. else:
    11. long_cross_price = self.tick.ask_price_1
    12. short_cross_price = self.tick.bid_price_1
    13. long_best_price = long_cross_price
    14. short_best_price = short_cross_price
    15.  
    16. for order in list(self.active_limit_orders.values()):
    17. # Push order update with status "not traded" (pending)
    18. if order.status == Status.SUBMITTING:
    19. order.status = Status.NOTTRADED
    20. self.strategy.on_order(order)
    21.  
    22. # Check whether limit orders can be filled.
    23. long_cross = (
    24. order.direction == Direction.LONG
    25. and order.price >= long_cross_price
    26. and long_cross_price > 0
    27. )
    28.  
    29. short_cross = (
    30. order.direction == Direction.SHORT
    31. and order.price <= short_cross_price
    32. and short_cross_price > 0
    33. )
    34.  
    35. if not long_cross and not short_cross:
    36. continue
    37.  
    38. # Push order udpate with status "all traded" (filled).
    39. order.traded = order.volume
    40. order.status = Status.ALLTRADED
    41. self.strategy.on_order(order)
    42.  
    43. self.active_limit_orders.pop(order.vt_orderid)
    44.  
    45. # Push trade update
    46. self.trade_count += 1
    47.  
    48. if long_cross:
    49. trade_price = min(order.price, long_best_price)
    50. pos_change = order.volume
    51. else:
    52. trade_price = max(order.price, short_best_price)
    53. pos_change = -order.volume
    54.  
    55. trade = TradeData(
    56. symbol=order.symbol,
    57. exchange=order.exchange,
    58. orderid=order.orderid,
    59. tradeid=str(self.trade_count),
    60. direction=order.direction,
    61. offset=order.offset,
    62. price=trade_price,
    63. volume=order.volume,
    64. time=self.datetime.strftime("%H:%M:%S"),
    65. gateway_name=self.gateway_name,
    66. )
    67. trade.datetime = self.datetime
    68.  
    69. self.strategy.pos += pos_change
    70. self.strategy.on_trade(trade)
    71.  
    72. self.trades[trade.vt_tradeid] = trade

    计算策略盈亏情况

    基于收盘价、当日持仓量、合约规模、滑点、手续费率等计算总盈亏与净盈亏,并且其计算结果以DataFrame格式输出,完成基于逐日盯市盈亏统计。

    下面展示盈亏情况的计算过程

    • 浮动盈亏 = 持仓量 (当日收盘价 - 昨日收盘价) 合约规模
    • 实际盈亏 = 持仓变化量 (当时收盘价 - 开仓成交价) 合约规模
    • 总盈亏 = 浮动盈亏 + 实际盈亏
    • 净盈亏 = 总盈亏 - 总手续费 - 总滑点
    1. def calculate_pnl(
    2. self,
    3. pre_close: float,
    4. start_pos: float,
    5. size: int,
    6. rate: float,
    7. slippage: float,
    8. ):
    9. """"""
    10. self.pre_close = pre_close
    11.  
    12. # Holding pnl is the pnl from holding position at day start
    13. self.start_pos = start_pos
    14. self.end_pos = start_pos
    15. self.holding_pnl = self.start_pos * \
    16. (self.close_price - self.pre_close) * size
    17.  
    18. # Trading pnl is the pnl from new trade during the day
    19. self.trade_count = len(self.trades)
    20.  
    21. for trade in self.trades:
    22. if trade.direction == Direction.LONG:
    23. pos_change = trade.volume
    24. else:
    25. pos_change = -trade.volume
    26.  
    27. turnover = trade.price * trade.volume * size
    28.  
    29. self.trading_pnl += pos_change * \
    30. (self.close_price - trade.price) * size
    31. self.end_pos += pos_change
    32. self.turnover += turnover
    33. self.commission += turnover * rate
    34. self.slippage += trade.volume * size * slippage
    35.  
    36. # Net pnl takes account of commission and slippage cost
    37. self.total_pnl = self.trading_pnl + self.holding_pnl
    38. self.net_pnl = self.total_pnl - self.commission - self.slippage

    计算策略统计指标

    calculate_statistics函数是基于逐日盯市盈亏情况(DateFrame格式)来计算衍生指标,如最大回撤、年化收益、盈亏比、夏普比率等。

    1. df["balance"] = df["net_pnl"].cumsum() + self.capital
    2. df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
    3. df["highlevel"] = (
    4. df["balance"].rolling(
    5. min_periods=1, window=len(df), center=False).max()
    6. )
    7. df["drawdown"] = df["balance"] - df["highlevel"]
    8. df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
    9.  
    10. # Calculate statistics value
    11. start_date = df.index[0]
    12. end_date = df.index[-1]
    13.  
    14. total_days = len(df)
    15. profit_days = len(df[df["net_pnl"] > 0])
    16. loss_days = len(df[df["net_pnl"] < 0])
    17.  
    18. end_balance = df["balance"].iloc[-1]
    19. max_drawdown = df["drawdown"].min()
    20. max_ddpercent = df["ddpercent"].min()
    21.  
    22. total_net_pnl = df["net_pnl"].sum()
    23. daily_net_pnl = total_net_pnl / total_days
    24.  
    25. total_commission = df["commission"].sum()
    26. daily_commission = total_commission / total_days
    27.  
    28. total_slippage = df["slippage"].sum()
    29. daily_slippage = total_slippage / total_days
    30.  
    31. total_turnover = df["turnover"].sum()
    32. daily_turnover = total_turnover / total_days
    33.  
    34. total_trade_count = df["trade_count"].sum()
    35. daily_trade_count = total_trade_count / total_days
    36.  
    37. total_return = (end_balance / self.capital - 1) * 100
    38. annual_return = total_return / total_days * 240
    39. daily_return = df["return"].mean() * 100
    40. return_std = df["return"].std() * 100
    41.  
    42. if return_std:
    43. sharpe_ratio = daily_return / return_std * np.sqrt(240)
    44. else:
    45. sharpe_ratio = 0

    统计指标绘图

    通过matplotlib绘制4幅图:

    • 资金曲线图
    • 资金回撤图
    • 每日盈亏图
    • 每日盈亏分布图
    1. def show_chart(self, df: DataFrame = None):
    2. """"""
    3. if not df:
    4. df = self.daily_df
    5.  
    6. if df is None:
    7. return
    8.  
    9. plt.figure(figsize=(10, 16))
    10.  
    11. balance_plot = plt.subplot(4, 1, 1)
    12. balance_plot.set_title("Balance")
    13. df["balance"].plot(legend=True)
    14.  
    15. drawdown_plot = plt.subplot(4, 1, 2)
    16. drawdown_plot.set_title("Drawdown")
    17. drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
    18.  
    19. pnl_plot = plt.subplot(4, 1, 3)
    20. pnl_plot.set_title("Daily Pnl")
    21. df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
    22.  
    23. distribution_plot = plt.subplot(4, 1, 4)
    24. distribution_plot.set_title("Daily Pnl Distribution")
    25. df["net_pnl"].hist(bins=50)
    26.  
    27. plt.show()

    回测引擎使用示例

    • 导入回测引擎和CTA策略
    • 设置回测相关参数,如:品种、K线周期、回测开始和结束日期、手续费、滑点、合约规模、起始资金
    • 载入策略和数据到引擎中,运行回测。
    • 计算基于逐日统计盈利情况,计算统计指标,统计指标绘图。
    1. from vnpy.app.cta_strategy.backtesting import BacktestingEngine
    2. from vnpy.app.cta_strategy.strategies.boll_channel_strategy import (
    3. BollChannelStrategy,
    4. )
    5. from datetime import datetime
    6.  
    7. engine = BacktestingEngine()
    8. engine.set_parameters(
    9. vt_symbol="IF88.CFFEX",
    10. interval="1m",
    11. start=datetime(2018, 1, 1),
    12. end=datetime(2019, 1, 1),
    13. rate=3.0/10000,
    14. slippage=0.2,
    15. size=300,
    16. pricetick=0.2,
    17. capital=1_000_000,
    18. )
    19.  
    20. engine.add_strategy(AtrRsiStrategy, {})
    21. engine.load_data()
    22. engine.run_backtesting()
    23. df = engine.calculate_result()
    24. engine.calculate_statistics()
    25. engine.show_chart()