From c00a468249631fa5a7f41763066dfc02f43fd72e Mon Sep 17 00:00:00 2001 From: "Luke I. Wilson" Date: Sun, 14 May 2023 15:28:02 -0500 Subject: [PATCH] Update everything to use new DataFrame wrapper --- backtesting.go | 21 ++++++++------------- backtesting_test.go | 29 +++++++++++++++-------------- broker.go | 4 +--- cmd/sma_crossover.go | 2 +- data.go | 14 ++++++++++++-- trader.go | 5 ++--- 6 files changed, 39 insertions(+), 36 deletions(-) diff --git a/backtesting.go b/backtesting.go index 5d136c4..cea1881 100644 --- a/backtesting.go +++ b/backtesting.go @@ -5,7 +5,6 @@ import ( "strconv" "time" - df "github.com/rocketlaunchr/dataframe-go" "golang.org/x/exp/rand" ) @@ -31,7 +30,7 @@ func Backtest(trader *Trader) { type TestBroker struct { SignalManager DataBroker Broker - Data *df.DataFrame + Data *DataFrame Cash float64 Leverage float64 Spread float64 // Number of pips to add to the price when buying and subtract when selling. (Forex) @@ -42,10 +41,10 @@ type TestBroker struct { positions []Position } -func (b *TestBroker) Candles(symbol string, frequency string, count int) (*df.DataFrame, error) { +func (b *TestBroker) Candles(symbol string, frequency string, count int) (*DataFrame, error) { // Check if we reached the end of the existing data. - if b.Data != nil && b.candleCount >= b.Data.NRows() { - return b.Data.Copy(), ErrEOF + if b.Data != nil && b.candleCount >= b.Data.Len() { + return b.Data.Copy(0, -1), ErrEOF } // Catch up to the start candles. @@ -59,7 +58,7 @@ func (b *TestBroker) Candles(symbol string, frequency string, count int) (*df.Da // candles does the same as the public Candles except it doesn't increment b.candleCount so that it can be used // internally to fetch candles without incrementing the count. -func (b *TestBroker) candles(symbol string, frequency string, count int) (*df.DataFrame, error) { +func (b *TestBroker) candles(symbol string, frequency string, count int) (*DataFrame, error) { if b.DataBroker != nil && b.Data == nil { // Fetch a lot of candles from the broker so we don't keep asking. candles, err := b.DataBroker.Candles(symbol, frequency, Max(count, 1000)) @@ -83,7 +82,7 @@ func (b *TestBroker) candles(symbol string, frequency string, count int) (*df.Da end := Max(b.candleCount, 1) - 1 start := Max(Max(b.candleCount, 1)-count, 0) - return b.Data.Copy(df.Range{Start: &start, End: &end}), nil + return b.Data.Copy(start, end), nil } func (b *TestBroker) MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error) { @@ -96,11 +95,7 @@ func (b *TestBroker) MarketOrder(symbol string, units float64, stopLoss, takePro return nil, err } } - closeIdx, err := b.Data.NameToColumn("Close") - if err != nil { - return nil, err - } - price := b.Data.Series[closeIdx].Value(Max(b.candleCount-1, 0)).(float64) // Get the last close price. + price := b.Data.Close(Max(b.candleCount-1, 0)) // Get the last close price. // Instantly fulfill the order. b.Cash -= price * units * LeverageToMargin(b.Leverage) @@ -138,7 +133,7 @@ func (b *TestBroker) Positions() []Position { return b.positions } -func NewTestBroker(dataBroker Broker, data *df.DataFrame, cash, leverage, spread float64, startCandles int) *TestBroker { +func NewTestBroker(dataBroker Broker, data *DataFrame, cash, leverage, spread float64, startCandles int) *TestBroker { return &TestBroker{ DataBroker: dataBroker, Data: data, diff --git a/backtesting_test.go b/backtesting_test.go index 468090e..1cd7801 100644 --- a/backtesting_test.go +++ b/backtesting_test.go @@ -37,29 +37,29 @@ func newTestingDataframe() *df.DataFrame { } func TestBacktestingBrokerCandles(t *testing.T) { - data := newTestingDataframe() + data := NewDataFrame(newTestingDataframe()) broker := NewTestBroker(nil, data, 0, 0, 0, 0) candles, err := broker.Candles("EUR_USD", "D", 3) if err != nil { t.Fatal(err) } - if candles.NRows() != 1 { - t.Errorf("Expected 1 candle, got %d", candles.NRows()) + if candles.Len() != 1 { + t.Errorf("Expected 1 candle, got %d", candles.Len()) } - if candles.Series[0].Value(0).(time.Time) != time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) { - t.Errorf("Expected first candle to be 2022-01-01, got %s", candles.Series[0].Value(0)) + if candles.Date(0) != time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) { + t.Errorf("Expected first candle to be 2022-01-01, got %s", candles.Date(0)) } candles, err = broker.Candles("EUR_USD", "D", 3) if err != nil { t.Fatal(err) } - if candles.NRows() != 2 { - t.Errorf("Expected 2 candles, got %d", candles.NRows()) + if candles.Len() != 2 { + t.Errorf("Expected 2 candles, got %d", candles.Len()) } - if candles.Series[0].Value(1).(time.Time) != time.Date(2022, 1, 2, 0, 0, 0, 0, time.UTC) { - t.Errorf("Expected second candle to be 2022-01-02, got %s", candles.Series[0].Value(1)) + if candles.Date(1) != time.Date(2022, 1, 2, 0, 0, 0, 0, time.UTC) { + t.Errorf("Expected second candle to be 2022-01-02, got %s", candles.Date(1)) } for i := 0; i < 7; i++ { // 7 because we want to call broker.Candles 9 times total @@ -71,11 +71,11 @@ func TestBacktestingBrokerCandles(t *testing.T) { t.Errorf("Candles is nil on iteration %d", i+1) } } - if candles.NRows() != 5 { - t.Errorf("Expected 5 candles, got %d", candles.NRows()) + if candles.Len() != 5 { + t.Errorf("Expected 5 candles, got %d", candles.Len()) } - if candles.Series[4].Value(4).(float64) != 1.3 { - t.Errorf("Expected the last closing price to be 1.3, got %f", candles.Series[4].Value(4)) + if candles.Close(4) != 1.3 { + t.Errorf("Expected the last closing price to be 1.3, got %f", candles.Close(4)) } } @@ -88,7 +88,8 @@ func TestBacktestingBrokerFunctions(t *testing.T) { } func TestBacktestingBrokerOrders(t *testing.T) { - broker := NewTestBroker(nil, newTestingDataframe(), 100_000, 50, 0, 0) + data := NewDataFrame(newTestingDataframe()) + broker := NewTestBroker(nil, data, 100_000, 50, 0, 0) timeBeforeOrder := time.Now() order, err := broker.MarketOrder("EUR_USD", 50_000, 0, 0) // Buy 50,000 USD for 1000 EUR with no stop loss or take profit if err != nil { diff --git a/broker.go b/broker.go index 00014fa..cacc061 100644 --- a/broker.go +++ b/broker.go @@ -3,8 +3,6 @@ package autotrader import ( "errors" "time" - - df "github.com/rocketlaunchr/dataframe-go" ) type OrderType string @@ -53,7 +51,7 @@ type Position interface { type Broker interface { // Candles returns a dataframe of candles for the given symbol, frequency, and count by querying the broker. - Candles(symbol string, frequency string, count int) (*df.DataFrame, error) + Candles(symbol string, frequency string, count int) (*DataFrame, error) MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error) NAV() float64 // NAV returns the net asset value of the account. // Orders returns a slice of orders that have been placed with the broker. If an order has been canceled or diff --git a/cmd/sma_crossover.go b/cmd/sma_crossover.go index 3ce884f..2c1a3bb 100644 --- a/cmd/sma_crossover.go +++ b/cmd/sma_crossover.go @@ -49,7 +49,7 @@ func main() { // AccountID: "101-001-14983263-001", // DemoAccount: true, // }), - Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0), + Broker: auto.NewTestBroker(nil, auto.NewDataFrame(data), 10000, 50, 0.0002, 0), Strategy: &SMAStrategy{}, Symbol: "EUR_USD", Frequency: "D", diff --git a/data.go b/data.go index cc2afe6..53c2e54 100644 --- a/data.go +++ b/data.go @@ -252,8 +252,18 @@ type DataFrame struct { data *df.DataFrame // DataFrame with a Date, Open, High, Low, Close, and Volume column. } -func (o *DataFrame) Copy() *DataFrame { - return &DataFrame{o.data.Copy()} +// Copy copies the DataFrame from start to end (inclusive). If end is -1, it will copy to the end of the DataFrame. If start is out of bounds, nil is returned. +func (o *DataFrame) Copy(start, end int) *DataFrame { + var _end *int + if start < 0 || start >= o.Len() { + return nil + } else if end >= 0 { + if end < start { + return nil + } + _end = &end + } + return &DataFrame{o.data.Copy(df.Range{Start: &start, End: _end})} } // Len returns the number of rows in the DataFrame or 0 if the DataFrame is nil. diff --git a/trader.go b/trader.go index 10b8028..632436f 100644 --- a/trader.go +++ b/trader.go @@ -9,7 +9,6 @@ import ( "time" "github.com/go-co-op/gocron" - df "github.com/rocketlaunchr/dataframe-go" ) // Trader acts as the primary interface to the broker and strategy. To the strategy, it provides all the information @@ -23,12 +22,12 @@ type Trader struct { CandlesToKeep int Log *log.Logger - data *df.DataFrame + data *DataFrame sched *gocron.Scheduler idx int } -func (t *Trader) Data() *df.DataFrame { +func (t *Trader) Data() *DataFrame { return t.data }