"""Grid search result analysis.
- Breaddown of performance of different grid search combinations
- Heatmap and other comparison methods
"""
import textwrap
import warnings
from dataclasses import dataclass
from typing import List, TypeAlias
import numpy as np
import pandas as pd
from IPython.core.display_functions import display
import plotly.express as px
from plotly.graph_objs import Figure, Scatter
from tradeexecutor.backtest.grid_search import GridSearchResult
from tradeexecutor.state.types import USDollarAmount
from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse
from tradeexecutor.utils.sort import unique_sort
from tradeexecutor.visual.benchmark import visualise_all_cash, visualise_portfolio_equity_curve
VALUE_COLS = ["CAGR", "Max drawdown", "Sharpe", "Sortino", "Average position", "Median position"]
PERCENT_COLS = ["CAGR", "Max drawdown", "Average position", "Median position", "Time in market"]
DATA_COLS = ["Positions", "Trades"]
[docs]def analyse_combination(
r: GridSearchResult,
min_positions_threshold: int,
) -> dict:
"""Create a grid search result table row.
- Create columns we can use to compare different grid search combinations
:param min_positions_threshold:
If we did less positions than this amount, do not consider this a proper strategy.
Filter out one position outliers.
"""
row = {}
param_names = []
for param in r.combination.parameters:
# Skip parameters that are single fixed value
# and do not affect the grid search results
if param.single:
continue
row[param.name] = param.value
param_names.append(param.name)
def clean(x):
if x == "-":
return np.NaN
elif x == "":
return np.NaN
return x
# import ipdb ; ipdb.set_trace()
row.update({
# "Combination": r.combination.get_label(),
"Positions": r.summary.total_positions,
"Trades": r.summary.total_trades,
"Time in market": clean(r.metrics.loc["Time in Market"][0]),
# "Return": r.summary.return_percent,
# "Return2": r.summary.annualised_return_percent,
#"Annualised profit": clean(r.metrics.loc["Expected Yearly"][0]),
"CAGR": clean(r.metrics.loc["Annualised return (raw)"][0]),
"Max drawdown": clean(r.metrics.loc["Max Drawdown"][0]),
"Sharpe": clean(r.metrics.loc["Sharpe"][0]),
"Sortino": clean(r.metrics.loc["Sortino"][0]),
"Average position": r.summary.average_trade,
"Median position": r.summary.median_trade,
})
# Clear all values except position count if this is not a good trade series
if r.summary.total_positions < min_positions_threshold:
for k in row.keys():
if k != "Positions" and k not in param_names:
row[k] = np.NaN
return row
[docs]def analyse_grid_search_result(
results: List[GridSearchResult],
min_positions_threshold: int = 5,
) -> pd.DataFrame:
"""Create aa table showing grid search result of each combination.
- Each row have labeled parameters of its combination
- Each row has some metrics extracted from the results by :py:func:`analyse_combination`
The output has the following row for each parameter combination:
- Combination parameters
- Positions and trade counts
- CAGR (Communicative annualized growth return, compounding)
- Max drawdown
- Sharpe
- Sortino
See also :py:func:`analyse_combination`.
:param results:
Output from :py:meth:`tradeexecutor.backtest.grid_search.perform_grid_search`.
:param min_positions_threshold:
If we did less positions than this amount, do not consider this a proper strategy.
Filter out one position outliers.
:return:
Table of grid search combinations
"""
assert len(results) > 0, "No results"
rows = [analyse_combination(r, min_positions_threshold) for r in results]
df = pd.DataFrame(rows)
r = results[0]
param_names = [p.name for p in r.combination.searchable_parameters]
df = df.set_index(param_names)
df = df.sort_index()
return df
[docs]def visualise_table(*args, **kwargs):
warnings.warn('This function is deprecated. Use render_grid_search_result_table() instead', DeprecationWarning, stacklevel=2)
return render_grid_search_result_table(*args, **kwargs)
[docs]def render_grid_search_result_table(results: pd.DataFrame | list[GridSearchResult]) -> pd.DataFrame:
"""Render a grid search combination table to notebook output.
- Highlight winners and losers
- Gradient based on the performance of a metric
- Stripes for the input
Example:
.. code-block:: python
grid_search_results = perform_grid_search(
decide_trades,
strategy_universe,
combinations,
max_workers=get_safe_max_workers_count(),
trading_strategy_engine_version="0.5",
multiprocess=True,
)
render_grid_search_result_table(grid_search_results)
:param results:
Output from :py:func:`perform_grid_search`.
:return:
Styled DataFrame for the notebook output
"""
if isinstance(results, pd.DataFrame):
df = results
else:
df = analyse_grid_search_result(results)
# https://stackoverflow.com/a/57152529/315168
# TODO:
# Diverge color gradient around zero
# https://stackoverflow.com/a/60654669/315168
formatted = df.style.background_gradient(
axis = 0,
subset = VALUE_COLS,
).highlight_min(
color = 'pink',
axis = 0,
subset = VALUE_COLS,
).highlight_max(
color = 'darkgreen',
axis = 0,
subset = VALUE_COLS,
).format(
formatter="{:.2%}",
subset = PERCENT_COLS,
).format(
# https://stackoverflow.com/a/12080042/315168
subset=DATA_COLS,
formatter="{0:g}",
)
return formatted
[docs]def visualise_heatmap_2d(
result: pd.DataFrame,
parameter_1: str,
parameter_2: str,
metric: str,
color_continuous_scale='Bluered_r',
continuous_scale: bool | None = None,
) -> Figure:
"""Draw a heatmap square comparing two different parameters.
Directly shows the resulting matplotlib figure.
:param parameter_1:
Y axis
:param parameter_2:
X axis
:param metric:
Value to examine
:param result:
Grid search results as a DataFrame.
Created by :py:func:`analyse_grid_search_result`.
:param color_continuous_scale:
The name of Plotly gradient used for the colour scale.
:param continuous_scale:
Are the X and Y scales continuous.
X and Y scales cannot be continuous if they contain values like None or NaN.
This will stretch the scale to infinity or zero.
Set `True` to force continuous, `False` to force discreet steps, `None` to autodetect.
:return:
Plotly Figure object
"""
# Reset multi-index so we can work with parameter 1 and 2 as series
df = result.reset_index()
# Backwards compatibiltiy
if metric == "Annualised return" and ("Annualised return" not in df.columns) and "CAGR" in df.columns:
metric = "CAGR"
# Detect any non-number values on axes
if continuous_scale is None:
continuous_scale = not(df[parameter_1].isna().any() or df[parameter_2].isna().any())
# setting all column values to string will hint
# Plotly to make all boxes same size regardless of value
if not continuous_scale:
df[parameter_1] = df[parameter_1].astype(str)
df[parameter_2] = df[parameter_2].astype(str)
df = df.pivot(index=parameter_1, columns=parameter_2, values=metric)
# Format percents inside the cells and mouse hovers
if metric in PERCENT_COLS:
text = df.applymap(lambda x: f"{x * 100:,.2f}%")
else:
text = df.applymap(lambda x: f"{x:,.2f}")
fig = px.imshow(
df,
labels=dict(x=parameter_2, y=parameter_1, color=metric),
aspect="auto",
title=metric,
color_continuous_scale=color_continuous_scale,
)
fig.update_traces(text=text, texttemplate="%{text}")
fig.update_layout(
title={"text": metric},
height=600,
)
return fig
[docs]def visualise_3d_scatter(
flattened_result: pd.DataFrame,
parameter_x: str,
parameter_y: str,
parameter_z: str,
measured_metric: str,
color_continuous_scale="Bluered_r", # Reversed, blue = best
height=600,
) -> Figure:
"""Draw a 3D scatter plot for grid search results.
Create an interactive 3d chart to explore three different parameters and one performance measurement
of the grid search results.
Example:
.. code-block:: python
from tradeexecutor.analysis.grid_search import analyse_grid_search_result
table = analyse_grid_search_result(grid_search_results)
flattened_results = table.reset_index()
flattened_results["Annualised return %"] = flattened_results["Annualised return"] * 100
fig = visualise_3d_scatter(
flattened_results,
parameter_x="rsi_days",
parameter_y="rsi_high",
parameter_z="rsi_low",
measured_metric="Annualised return %"
)
fig.show()
:param flattened_result:
Grid search results as a DataFrame.
Created by :py:func:`analyse_grid_search_result`.
:param parameter_x:
X axis
:param parameter_y:
Y axis
:param parameter_z:
Z axis
:param parameter_colour:
Output we compare.
E.g. `Annualised return`
:param color_continuous_scale:
The name of Plotly gradient used for the colour scale.
`See the Plotly continuos scale color gradient options <https://plotly.com/python/builtin-colorscales/>`__.
:return:
Plotly figure to display
"""
assert isinstance(flattened_result, pd.DataFrame)
assert type(parameter_x) == str
assert type(parameter_y) == str
assert type(parameter_z) == str
assert type(measured_metric) == str
fig = px.scatter_3d(
flattened_result,
x=parameter_x,
y=parameter_y,
z=parameter_z,
color=measured_metric,
color_continuous_scale=color_continuous_scale,
height=height,
)
return fig
def _get_hover_template(
result: GridSearchResult,
key_metrics = ("CAGR﹪", "Max Drawdown", "Time in Market", "Sharpe", "Sortino"), # See quantstats
percent_metrics = ("CAGR﹪", "Max Drawdown", "Time in Market"),
):
# Get metrics calculated with QuantStats
data = result.metrics["Strategy"]
metrics = {}
for name in key_metrics:
metrics[name] = data[name]
template = textwrap.dedent(f"""<b>{result.get_label()}</b><br><br>""")
for k, v in metrics.items():
if type(v) == int:
v = float(v)
if v in ("", None, "-"): # Messy third party code does not know how to mark no value
template += f"{k}: -<br>"
elif k in percent_metrics:
assert type(v) == float, f"Got unknown type: {k}: {v} ({type(v)}"
v *= 100
template += f"{k}: {v:.2f}%<br>"
else:
assert type(v) == float, f"Got unknown type: {k}: {v} ({type(v)}"
template += f"{k}: {v:.2f}<br>"
# Get trade metrics
for k, v in result.summary.get_trading_core_metrics().items():
template += f"{k}: {v}<br>"
return template
[docs]@dataclass(slots=True)
class TopGridSearchResult:
"""Sorted best grid search results."""
#: Top returns
cagr: list[GridSearchResult]
#: Top Sharpe
sharpe: list[GridSearchResult]
[docs]def find_best_grid_search_results(grid_search_results: list[GridSearchResult], count=20, unique_only=True) -> TopGridSearchResult:
"""From all grid search results, filter out the best one to be displayed.
:param unique_only:
Return unique value matches only.
If multiple grid search results share the same metric (CAGR),
filter out duplicates. Otherwise the table will be littered with duplicates.
:return:
Top lists
"""
if unique_only:
sorter = unique_sort
else:
sorter = sorted
result = TopGridSearchResult(
cagr=sorter(grid_search_results, key=lambda r: r.get_cagr(), reverse=True)[0: count],
sharpe=sorter(grid_search_results, key=lambda r: r.get_sharpe(), reverse=True)[0: count],
)
return result
[docs]def visualise_grid_search_equity_curves(*args, **kwags):
"""Deprecated."""
warnings.warn("use tradeexecutor.visual.grid_search.visualise_grid_search_equity_curves instead", DeprecationWarning, stacklevel=2)
from tradeexecutor.visual.grid_search import visualise_grid_search_equity_curves
return visualise_grid_search_equity_curves(*args, **kwags)