diff --git a/backtesting.go b/backtesting.go index 1a5cdce..053390d 100644 --- a/backtesting.go +++ b/backtesting.go @@ -16,12 +16,19 @@ var ( ) func Backtest(trader *Trader) { - for !trader.EOF { - trader.Tick() + switch broker := trader.Broker.(type) { + case *TestBroker: + trader.Init() // Initialize the trader and strategy. + for !trader.EOF { + trader.Tick() // Allow the trader to process the current candlesticks. + broker.Advance() // Give the trader access to the next candlestick. + } + log.Println("Backtest complete.") + log.Println("Stats:") + log.Println(trader.Stats().String()) + default: + log.Fatalf("Backtesting is only supported with a TestBroker. Got %T", broker) } - log.Println("Backtest complete.") - log.Println("Stats:") - log.Println(trader.Stats()) } // TestBroker is a broker that can be used for testing. It implements the Broker interface and fulfills orders @@ -54,11 +61,13 @@ func (b *TestBroker) CandleIndex() int { // Advance advances the test broker to the next candle in the input data. This should be done at the end of the // strategy loop. func (b *TestBroker) Advance() { - b.candleCount++ + if b.candleCount < b.Data.Len() { + b.candleCount++ + } } func (b *TestBroker) Candles(symbol string, frequency string, count int) (*DataFrame, error) { - if b.Data != nil && b.candleCount > b.Data.Len() { // We have data and we are at the end of it. + if b.Data != nil && b.candleCount >= b.Data.Len() { // We have data and we are at the end of it. return b.Data.Copy(0, -1).(*DataFrame), ErrEOF } else if b.DataBroker != nil && b.Data == nil { // We have a data broker but no data. // Fetch a lot of candles from the broker so we don't keep asking. @@ -137,6 +146,26 @@ func (b *TestBroker) NAV() float64 { return nav } +func (b *TestBroker) OpenOrders() []Order { + orders := make([]Order, 0, len(b.orders)) + for _, order := range b.orders { + if !order.Fulfilled() { + orders = append(orders, order) + } + } + return orders +} + +func (b *TestBroker) OpenPositions() []Position { + positions := make([]Position, 0, len(b.positions)) + for _, position := range b.positions { + if !position.Closed() { + positions = append(positions, position) + } + } + return positions +} + func (b *TestBroker) Orders() []Order { return b.orders } diff --git a/broker.go b/broker.go index 3b704f9..2ff7f99 100644 --- a/broker.go +++ b/broker.go @@ -57,6 +57,8 @@ type Broker interface { Candles(symbol string, frequency string, count int) (*DataFrame, error) MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error) NAV() float64 // NAV returns the net asset value of the account. + OpenOrders() []Order + OpenPositions() []Position // Orders returns a slice of orders that have been placed with the broker. If an order has been canceled or // filled, it will not be returned. Orders() []Order diff --git a/cmd/sma_crossover.go b/cmd/sma_crossover.go index b241dcf..75e4a4b 100644 --- a/cmd/sma_crossover.go +++ b/cmd/sma_crossover.go @@ -1,49 +1,46 @@ package main import ( - "fmt" + "log" auto "github.com/fivemoreminix/autotrader" ) type SMAStrategy struct { - i int + period1, period2 int } -func (s *SMAStrategy) Init(_trader *auto.Trader) { - fmt.Println("Init") - s.i = 0 +func (s *SMAStrategy) Init(_ *auto.Trader) { } -func (s *SMAStrategy) Next(_trader *auto.Trader) { - fmt.Println("Next " + fmt.Sprint(s.i)) - s.i++ +func (s *SMAStrategy) Next(t *auto.Trader) { + sma1 := t.Data().Closes().Rolling(s.period1).Mean() + sma2 := t.Data().Closes().Rolling(s.period2).Mean() + log.Println(t.Data().Close(-1) - sma1.Float(-1)) + // If the shorter SMA crosses above the longer SMA, buy. + if crossover(sma1, sma2) { + t.Buy(1000) + } else if crossover(sma2, sma1) { + t.Sell(1000) + } +} + +// crossover returns true if s1 crosses above s2 at the latest float. +func crossover(s1, s2 auto.Series) bool { + return s1.Float(-1) > s2.Float(-1) && s1.Float(-2) <= s2.Float(-2) } func main() { - // token := os.Environ["OANDA_TOKEN"] - // accountId := os.Environ["OANDA_ACCOUNT_ID"] - - // if token == "" || accountId == "" { - // fmt.Println("Please set OANDA_TOKEN and OANDA_ACCOUNT_ID environment variables") - // os.Exit(1) - // } - data, err := auto.EURUSD() if err != nil { panic(err) } auto.Backtest(auto.NewTrader(auto.TraderConfig{ - // auto.NewOandaBroker(auto.OandaConfig{ - // Token: "YOUR_TOKEN", - // AccountID: "101-001-14983263-001", - // DemoAccount: true, - // }), Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0), - Strategy: &SMAStrategy{}, + Strategy: &SMAStrategy{period1: 20, period2: 40}, Symbol: "EUR_USD", Frequency: "D", - CandlesToKeep: 100, + CandlesToKeep: 1000, })) } diff --git a/data.go b/data.go index 9caf250..b715103 100644 --- a/data.go +++ b/data.go @@ -1,6 +1,7 @@ package autotrader import ( + "bytes" "encoding/csv" "errors" "fmt" @@ -8,6 +9,8 @@ import ( "math" "os" "strconv" + "strings" + "text/tabwriter" "time" df "github.com/rocketlaunchr/dataframe-go" @@ -42,13 +45,14 @@ type Series interface { Values() []interface{} // Values is the same as ValueRange(0, -1). Float(i int) float64 Int(i int) int64 - String(i int) string + Str(i int) string Time(i int) time.Time } type Frame interface { Copy(start, end int) Frame Len() int + String() string // Easy access functions. Date(i int) time.Time @@ -67,6 +71,7 @@ type Frame interface { ContainsDOHLCV() bool // ContainsDOHLCV returns true if the frame contains the columns: Date, Open, High, Low, Close, Volume. PushCandle(date time.Time, open, high, low, close, volume float64) error + PushValues(values map[string]interface{}) error PushSeries(s ...Series) error RemoveSeries(name string) @@ -75,7 +80,7 @@ type Frame interface { Value(column string, i int) interface{} Float(column string, i int) float64 Int(column string, i int) int64 - String(column string, i int) string + Str(column string, i int) string // Time returns the value of the column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. Time(column string, i int) time.Time } @@ -374,7 +379,7 @@ func (s *DataSeries) Int(i int) int64 { } } -func (s *DataSeries) String(i int) string { +func (s *DataSeries) Str(i int) string { val := s.Value(i) if val == nil { return "" @@ -438,6 +443,56 @@ func (d *DataFrame) Len() int { return length } +func (d *DataFrame) String() string { + names := d.Names() // Defines the order of the columns. + series := make([]Series, len(names)) + for i, name := range names { + series[i] = d.Series(name) + } + + buffer := new(bytes.Buffer) + t := tabwriter.NewWriter(buffer, 0, 0, 1, ' ', 0) + fmt.Fprintf(t, "%T[%dx%d]\n", d, d.Len(), len(d.series)) + fmt.Fprintln(t, "\t", strings.Join(names, "\t"), "\t") + + printRow := func(i int) { + row := make([]string, len(series)) + for j, s := range series { + switch typ := s.Value(i).(type) { + case time.Time: + row[j] = typ.Format("2006-01-02 15:04:05") + case string: + row[j] = fmt.Sprintf("%q", typ) + default: + row[j] = fmt.Sprintf("%v", typ) + } + } + fmt.Fprintln(t, strconv.Itoa(i), "\t", strings.Join(row, "\t"), "\t") + } + + // Print the first ten rows and the last ten rows if the DataFrame has more than 20 rows. + if d.Len() > 20 { + for i := 0; i < 10; i++ { + printRow(i) + } + fmt.Fprintf(t, "...\t") + for range names { + fmt.Fprint(t, "\t") // Keeps alignment. + } + fmt.Fprintln(t) // Print new line character. + for i := 10; i > 0; i-- { + printRow(d.Len() - i) + } + } else { + for i := 0; i < d.Len(); i++ { + printRow(i) + } + } + + t.Flush() + return buffer.String() +} + // Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. // This is the equivalent to calling Time("Date", i). func (d *DataFrame) Date(i int) time.Time { @@ -541,6 +596,19 @@ func (d *DataFrame) PushCandle(date time.Time, open, high, low, close, volume fl return nil } +func (d *DataFrame) PushValues(values map[string]interface{}) error { + if len(d.series) == 0 { + return fmt.Errorf("DataFrame has no columns") // TODO: could create the columns here. + } + for name, value := range values { + if _, ok := d.series[name]; !ok { + return fmt.Errorf("DataFrame does not contain column %q", name) + } + d.series[name].Push(value) + } + return nil +} + func (d *DataFrame) PushSeries(series ...Series) error { if d.series == nil { d.series = make(map[string]Series, len(series)) @@ -558,6 +626,17 @@ func (d *DataFrame) PushSeries(series ...Series) error { return nil } +func (d *DataFrame) RemoveSeries(name string) { + s, ok := d.series[name] + if !ok { + return + } + s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged) + s.SignalDisconnect("NameChanged", d.onSeriesNameChanged) + delete(d.series, name) + delete(d.rowCounts, name) +} + func (d *DataFrame) onSeriesLengthChanged(args ...interface{}) { if len(args) != 2 { panic(fmt.Sprintf("expected two arguments, got %d", len(args))) @@ -586,17 +665,6 @@ func (d *DataFrame) onSeriesNameChanged(args ...interface{}) { d.series[newName].SignalConnect("NameChanged", d.onSeriesNameChanged, newName) } -func (d *DataFrame) RemoveSeries(name string) { - s, ok := d.series[name] - if !ok { - return - } - s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged) - s.SignalDisconnect("NameChanged", d.onSeriesNameChanged) - delete(d.series, name) - delete(d.rowCounts, name) -} - func (d *DataFrame) Names() []string { return maps.Keys(d.series) } @@ -654,7 +722,7 @@ func (d *DataFrame) Int(column string, i int) int64 { } // String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned. -func (d *DataFrame) String(column string, i int) string { +func (d *DataFrame) Str(column string, i int) string { val := d.Value(column, i) if val == nil { return "" diff --git a/trader.go b/trader.go index aa65094..1bb901b 100644 --- a/trader.go +++ b/trader.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-co-op/gocron" + "github.com/rocketlaunchr/dataframe-go" ) // Trader acts as the primary interface to the broker and strategy. To the strategy, it provides all the information @@ -25,7 +26,6 @@ type Trader struct { data *DataFrame sched *gocron.Scheduler - idx int stats *DataFrame // Performance (financial) reporting and statistics. } @@ -72,30 +72,71 @@ func (t *Trader) Run() { } } t.sched.Do(t.Tick) // Set the function to be run when the interval repeats. + + t.Init() t.sched.StartBlocking() } +func (t *Trader) Init() { + t.Strategy.Init(t) + t.stats = NewDataFrame( + NewDataSeries(dataframe.NewSeriesTime("Date", nil)), + NewDataSeries(dataframe.NewSeriesFloat64("Equity", nil)), + ) +} + // Tick updates the current state of the market and runs the strategy. func (t *Trader) Tick() { - t.Log.Println("Tick") - if t.idx == 0 { - t.Strategy.Init(t) - } - t.fetchData() - t.Strategy.Next(t) + t.fetchData() // Fetch the latest candlesticks from the broker. + // t.Log.Println(t.data.Close(-1)) + t.Strategy.Next(t) // Run the strategy. + + // Update the stats. + t.stats.PushValues(map[string]interface{}{ + "Date": t.data.Date(-1), + "Equity": t.Broker.NAV(), + }) } func (t *Trader) fetchData() { var err error t.data, err = t.Broker.Candles(t.Symbol, t.Frequency, t.CandlesToKeep) if err == ErrEOF { + t.EOF = true t.Log.Println("End of data") - t.sched.Clear() + if t.sched != nil && t.sched.IsRunning() { + t.sched.Clear() + } } else if err != nil { panic(err) // TODO: implement safe shutdown procedure } } +func (t *Trader) Buy(units float64) { + t.Log.Printf("Buy %f units", units) + t.closeOrdersAndPositions() + t.Broker.MarketOrder(t.Symbol, units, 0.0, 0.0) +} + +func (t *Trader) Sell(units float64) { + t.Log.Printf("Sell %f units", units) + t.closeOrdersAndPositions() + t.Broker.MarketOrder(t.Symbol, -units, 0.0, 0.0) +} + +func (t *Trader) closeOrdersAndPositions() { + for _, order := range t.Broker.OpenOrders() { + if order.Symbol() == t.Symbol { + order.Cancel() + } + } + for _, position := range t.Broker.OpenPositions() { + if position.Symbol() == t.Symbol { + position.Close() + } + } +} + type TraderConfig struct { Broker Broker Strategy Strategy @@ -114,6 +155,6 @@ func NewTrader(config TraderConfig) *Trader { Frequency: config.Frequency, CandlesToKeep: config.CandlesToKeep, Log: logger, - stats: NewDataFrame(nil), + stats: NewDataFrame(), } }