Working on trading and backtesting loop

This commit is contained in:
Luke I. Wilson
2023-05-15 19:52:40 -05:00
parent b467d03c73
commit 9b6f7962f2
5 changed files with 191 additions and 54 deletions

View File

@@ -16,12 +16,19 @@ var (
) )
func Backtest(trader *Trader) { func Backtest(trader *Trader) {
switch broker := trader.Broker.(type) {
case *TestBroker:
trader.Init() // Initialize the trader and strategy.
for !trader.EOF { for !trader.EOF {
trader.Tick() trader.Tick() // Allow the trader to process the current candlesticks.
broker.Advance() // Give the trader access to the next candlestick.
} }
log.Println("Backtest complete.") log.Println("Backtest complete.")
log.Println("Stats:") log.Println("Stats:")
log.Println(trader.Stats()) log.Println(trader.Stats().String())
default:
log.Fatalf("Backtesting is only supported with a TestBroker. Got %T", broker)
}
} }
// TestBroker is a broker that can be used for testing. It implements the Broker interface and fulfills orders // TestBroker is a broker that can be used for testing. It implements the Broker interface and fulfills orders
@@ -54,11 +61,13 @@ func (b *TestBroker) CandleIndex() int {
// Advance advances the test broker to the next candle in the input data. This should be done at the end of the // Advance advances the test broker to the next candle in the input data. This should be done at the end of the
// strategy loop. // strategy loop.
func (b *TestBroker) Advance() { func (b *TestBroker) Advance() {
if b.candleCount < b.Data.Len() {
b.candleCount++ b.candleCount++
}
} }
func (b *TestBroker) Candles(symbol string, frequency string, count int) (*DataFrame, error) { func (b *TestBroker) Candles(symbol string, frequency string, count int) (*DataFrame, error) {
if b.Data != nil && b.candleCount > b.Data.Len() { // We have data and we are at the end of it. if b.Data != nil && b.candleCount >= b.Data.Len() { // We have data and we are at the end of it.
return b.Data.Copy(0, -1).(*DataFrame), ErrEOF return b.Data.Copy(0, -1).(*DataFrame), ErrEOF
} else if b.DataBroker != nil && b.Data == nil { // We have a data broker but no data. } else if b.DataBroker != nil && b.Data == nil { // We have a data broker but no data.
// Fetch a lot of candles from the broker so we don't keep asking. // Fetch a lot of candles from the broker so we don't keep asking.
@@ -137,6 +146,26 @@ func (b *TestBroker) NAV() float64 {
return nav return nav
} }
func (b *TestBroker) OpenOrders() []Order {
orders := make([]Order, 0, len(b.orders))
for _, order := range b.orders {
if !order.Fulfilled() {
orders = append(orders, order)
}
}
return orders
}
func (b *TestBroker) OpenPositions() []Position {
positions := make([]Position, 0, len(b.positions))
for _, position := range b.positions {
if !position.Closed() {
positions = append(positions, position)
}
}
return positions
}
func (b *TestBroker) Orders() []Order { func (b *TestBroker) Orders() []Order {
return b.orders return b.orders
} }

View File

@@ -57,6 +57,8 @@ type Broker interface {
Candles(symbol string, frequency string, count int) (*DataFrame, error) Candles(symbol string, frequency string, count int) (*DataFrame, error)
MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error) MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error)
NAV() float64 // NAV returns the net asset value of the account. NAV() float64 // NAV returns the net asset value of the account.
OpenOrders() []Order
OpenPositions() []Position
// Orders returns a slice of orders that have been placed with the broker. If an order has been canceled or // Orders returns a slice of orders that have been placed with the broker. If an order has been canceled or
// filled, it will not be returned. // filled, it will not be returned.
Orders() []Order Orders() []Order

View File

@@ -1,49 +1,46 @@
package main package main
import ( import (
"fmt" "log"
auto "github.com/fivemoreminix/autotrader" auto "github.com/fivemoreminix/autotrader"
) )
type SMAStrategy struct { type SMAStrategy struct {
i int period1, period2 int
} }
func (s *SMAStrategy) Init(_trader *auto.Trader) { func (s *SMAStrategy) Init(_ *auto.Trader) {
fmt.Println("Init")
s.i = 0
} }
func (s *SMAStrategy) Next(_trader *auto.Trader) { func (s *SMAStrategy) Next(t *auto.Trader) {
fmt.Println("Next " + fmt.Sprint(s.i)) sma1 := t.Data().Closes().Rolling(s.period1).Mean()
s.i++ sma2 := t.Data().Closes().Rolling(s.period2).Mean()
log.Println(t.Data().Close(-1) - sma1.Float(-1))
// If the shorter SMA crosses above the longer SMA, buy.
if crossover(sma1, sma2) {
t.Buy(1000)
} else if crossover(sma2, sma1) {
t.Sell(1000)
}
}
// crossover returns true if s1 crosses above s2 at the latest float.
func crossover(s1, s2 auto.Series) bool {
return s1.Float(-1) > s2.Float(-1) && s1.Float(-2) <= s2.Float(-2)
} }
func main() { func main() {
// token := os.Environ["OANDA_TOKEN"]
// accountId := os.Environ["OANDA_ACCOUNT_ID"]
// if token == "" || accountId == "" {
// fmt.Println("Please set OANDA_TOKEN and OANDA_ACCOUNT_ID environment variables")
// os.Exit(1)
// }
data, err := auto.EURUSD() data, err := auto.EURUSD()
if err != nil { if err != nil {
panic(err) panic(err)
} }
auto.Backtest(auto.NewTrader(auto.TraderConfig{ auto.Backtest(auto.NewTrader(auto.TraderConfig{
// auto.NewOandaBroker(auto.OandaConfig{
// Token: "YOUR_TOKEN",
// AccountID: "101-001-14983263-001",
// DemoAccount: true,
// }),
Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0), Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0),
Strategy: &SMAStrategy{}, Strategy: &SMAStrategy{period1: 20, period2: 40},
Symbol: "EUR_USD", Symbol: "EUR_USD",
Frequency: "D", Frequency: "D",
CandlesToKeep: 100, CandlesToKeep: 1000,
})) }))
} }

98
data.go
View File

@@ -1,6 +1,7 @@
package autotrader package autotrader
import ( import (
"bytes"
"encoding/csv" "encoding/csv"
"errors" "errors"
"fmt" "fmt"
@@ -8,6 +9,8 @@ import (
"math" "math"
"os" "os"
"strconv" "strconv"
"strings"
"text/tabwriter"
"time" "time"
df "github.com/rocketlaunchr/dataframe-go" df "github.com/rocketlaunchr/dataframe-go"
@@ -42,13 +45,14 @@ type Series interface {
Values() []interface{} // Values is the same as ValueRange(0, -1). Values() []interface{} // Values is the same as ValueRange(0, -1).
Float(i int) float64 Float(i int) float64
Int(i int) int64 Int(i int) int64
String(i int) string Str(i int) string
Time(i int) time.Time Time(i int) time.Time
} }
type Frame interface { type Frame interface {
Copy(start, end int) Frame Copy(start, end int) Frame
Len() int Len() int
String() string
// Easy access functions. // Easy access functions.
Date(i int) time.Time Date(i int) time.Time
@@ -67,6 +71,7 @@ type Frame interface {
ContainsDOHLCV() bool // ContainsDOHLCV returns true if the frame contains the columns: Date, Open, High, Low, Close, Volume. ContainsDOHLCV() bool // ContainsDOHLCV returns true if the frame contains the columns: Date, Open, High, Low, Close, Volume.
PushCandle(date time.Time, open, high, low, close, volume float64) error PushCandle(date time.Time, open, high, low, close, volume float64) error
PushValues(values map[string]interface{}) error
PushSeries(s ...Series) error PushSeries(s ...Series) error
RemoveSeries(name string) RemoveSeries(name string)
@@ -75,7 +80,7 @@ type Frame interface {
Value(column string, i int) interface{} Value(column string, i int) interface{}
Float(column string, i int) float64 Float(column string, i int) float64
Int(column string, i int) int64 Int(column string, i int) int64
String(column string, i int) string Str(column string, i int) string
// Time returns the value of the column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. // Time returns the value of the column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned.
Time(column string, i int) time.Time Time(column string, i int) time.Time
} }
@@ -374,7 +379,7 @@ func (s *DataSeries) Int(i int) int64 {
} }
} }
func (s *DataSeries) String(i int) string { func (s *DataSeries) Str(i int) string {
val := s.Value(i) val := s.Value(i)
if val == nil { if val == nil {
return "" return ""
@@ -438,6 +443,56 @@ func (d *DataFrame) Len() int {
return length return length
} }
func (d *DataFrame) String() string {
names := d.Names() // Defines the order of the columns.
series := make([]Series, len(names))
for i, name := range names {
series[i] = d.Series(name)
}
buffer := new(bytes.Buffer)
t := tabwriter.NewWriter(buffer, 0, 0, 1, ' ', 0)
fmt.Fprintf(t, "%T[%dx%d]\n", d, d.Len(), len(d.series))
fmt.Fprintln(t, "\t", strings.Join(names, "\t"), "\t")
printRow := func(i int) {
row := make([]string, len(series))
for j, s := range series {
switch typ := s.Value(i).(type) {
case time.Time:
row[j] = typ.Format("2006-01-02 15:04:05")
case string:
row[j] = fmt.Sprintf("%q", typ)
default:
row[j] = fmt.Sprintf("%v", typ)
}
}
fmt.Fprintln(t, strconv.Itoa(i), "\t", strings.Join(row, "\t"), "\t")
}
// Print the first ten rows and the last ten rows if the DataFrame has more than 20 rows.
if d.Len() > 20 {
for i := 0; i < 10; i++ {
printRow(i)
}
fmt.Fprintf(t, "...\t")
for range names {
fmt.Fprint(t, "\t") // Keeps alignment.
}
fmt.Fprintln(t) // Print new line character.
for i := 10; i > 0; i-- {
printRow(d.Len() - i)
}
} else {
for i := 0; i < d.Len(); i++ {
printRow(i)
}
}
t.Flush()
return buffer.String()
}
// Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. // Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned.
// This is the equivalent to calling Time("Date", i). // This is the equivalent to calling Time("Date", i).
func (d *DataFrame) Date(i int) time.Time { func (d *DataFrame) Date(i int) time.Time {
@@ -541,6 +596,19 @@ func (d *DataFrame) PushCandle(date time.Time, open, high, low, close, volume fl
return nil return nil
} }
func (d *DataFrame) PushValues(values map[string]interface{}) error {
if len(d.series) == 0 {
return fmt.Errorf("DataFrame has no columns") // TODO: could create the columns here.
}
for name, value := range values {
if _, ok := d.series[name]; !ok {
return fmt.Errorf("DataFrame does not contain column %q", name)
}
d.series[name].Push(value)
}
return nil
}
func (d *DataFrame) PushSeries(series ...Series) error { func (d *DataFrame) PushSeries(series ...Series) error {
if d.series == nil { if d.series == nil {
d.series = make(map[string]Series, len(series)) d.series = make(map[string]Series, len(series))
@@ -558,6 +626,17 @@ func (d *DataFrame) PushSeries(series ...Series) error {
return nil return nil
} }
func (d *DataFrame) RemoveSeries(name string) {
s, ok := d.series[name]
if !ok {
return
}
s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged)
s.SignalDisconnect("NameChanged", d.onSeriesNameChanged)
delete(d.series, name)
delete(d.rowCounts, name)
}
func (d *DataFrame) onSeriesLengthChanged(args ...interface{}) { func (d *DataFrame) onSeriesLengthChanged(args ...interface{}) {
if len(args) != 2 { if len(args) != 2 {
panic(fmt.Sprintf("expected two arguments, got %d", len(args))) panic(fmt.Sprintf("expected two arguments, got %d", len(args)))
@@ -586,17 +665,6 @@ func (d *DataFrame) onSeriesNameChanged(args ...interface{}) {
d.series[newName].SignalConnect("NameChanged", d.onSeriesNameChanged, newName) d.series[newName].SignalConnect("NameChanged", d.onSeriesNameChanged, newName)
} }
func (d *DataFrame) RemoveSeries(name string) {
s, ok := d.series[name]
if !ok {
return
}
s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged)
s.SignalDisconnect("NameChanged", d.onSeriesNameChanged)
delete(d.series, name)
delete(d.rowCounts, name)
}
func (d *DataFrame) Names() []string { func (d *DataFrame) Names() []string {
return maps.Keys(d.series) return maps.Keys(d.series)
} }
@@ -654,7 +722,7 @@ func (d *DataFrame) Int(column string, i int) int64 {
} }
// String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned. // String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned.
func (d *DataFrame) String(column string, i int) string { func (d *DataFrame) Str(column string, i int) string {
val := d.Value(column, i) val := d.Value(column, i)
if val == nil { if val == nil {
return "" return ""

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
"github.com/rocketlaunchr/dataframe-go"
) )
// Trader acts as the primary interface to the broker and strategy. To the strategy, it provides all the information // Trader acts as the primary interface to the broker and strategy. To the strategy, it provides all the information
@@ -25,7 +26,6 @@ type Trader struct {
data *DataFrame data *DataFrame
sched *gocron.Scheduler sched *gocron.Scheduler
idx int
stats *DataFrame // Performance (financial) reporting and statistics. stats *DataFrame // Performance (financial) reporting and statistics.
} }
@@ -72,30 +72,71 @@ func (t *Trader) Run() {
} }
} }
t.sched.Do(t.Tick) // Set the function to be run when the interval repeats. t.sched.Do(t.Tick) // Set the function to be run when the interval repeats.
t.Init()
t.sched.StartBlocking() t.sched.StartBlocking()
} }
func (t *Trader) Init() {
t.Strategy.Init(t)
t.stats = NewDataFrame(
NewDataSeries(dataframe.NewSeriesTime("Date", nil)),
NewDataSeries(dataframe.NewSeriesFloat64("Equity", nil)),
)
}
// Tick updates the current state of the market and runs the strategy. // Tick updates the current state of the market and runs the strategy.
func (t *Trader) Tick() { func (t *Trader) Tick() {
t.Log.Println("Tick") t.fetchData() // Fetch the latest candlesticks from the broker.
if t.idx == 0 { // t.Log.Println(t.data.Close(-1))
t.Strategy.Init(t) t.Strategy.Next(t) // Run the strategy.
}
t.fetchData() // Update the stats.
t.Strategy.Next(t) t.stats.PushValues(map[string]interface{}{
"Date": t.data.Date(-1),
"Equity": t.Broker.NAV(),
})
} }
func (t *Trader) fetchData() { func (t *Trader) fetchData() {
var err error var err error
t.data, err = t.Broker.Candles(t.Symbol, t.Frequency, t.CandlesToKeep) t.data, err = t.Broker.Candles(t.Symbol, t.Frequency, t.CandlesToKeep)
if err == ErrEOF { if err == ErrEOF {
t.EOF = true
t.Log.Println("End of data") t.Log.Println("End of data")
if t.sched != nil && t.sched.IsRunning() {
t.sched.Clear() t.sched.Clear()
}
} else if err != nil { } else if err != nil {
panic(err) // TODO: implement safe shutdown procedure panic(err) // TODO: implement safe shutdown procedure
} }
} }
func (t *Trader) Buy(units float64) {
t.Log.Printf("Buy %f units", units)
t.closeOrdersAndPositions()
t.Broker.MarketOrder(t.Symbol, units, 0.0, 0.0)
}
func (t *Trader) Sell(units float64) {
t.Log.Printf("Sell %f units", units)
t.closeOrdersAndPositions()
t.Broker.MarketOrder(t.Symbol, -units, 0.0, 0.0)
}
func (t *Trader) closeOrdersAndPositions() {
for _, order := range t.Broker.OpenOrders() {
if order.Symbol() == t.Symbol {
order.Cancel()
}
}
for _, position := range t.Broker.OpenPositions() {
if position.Symbol() == t.Symbol {
position.Close()
}
}
}
type TraderConfig struct { type TraderConfig struct {
Broker Broker Broker Broker
Strategy Strategy Strategy Strategy
@@ -114,6 +155,6 @@ func NewTrader(config TraderConfig) *Trader {
Frequency: config.Frequency, Frequency: config.Frequency,
CandlesToKeep: config.CandlesToKeep, CandlesToKeep: config.CandlesToKeep,
Log: logger, Log: logger,
stats: NewDataFrame(nil), stats: NewDataFrame(),
} }
} }