mirror of
https://github.com/lukewilson2002/autotrader.git
synced 2025-08-02 05:09:33 +00:00
Working on trading and backtesting loop
This commit is contained in:
@@ -16,12 +16,19 @@ var (
|
||||
)
|
||||
|
||||
func Backtest(trader *Trader) {
|
||||
switch broker := trader.Broker.(type) {
|
||||
case *TestBroker:
|
||||
trader.Init() // Initialize the trader and strategy.
|
||||
for !trader.EOF {
|
||||
trader.Tick()
|
||||
trader.Tick() // Allow the trader to process the current candlesticks.
|
||||
broker.Advance() // Give the trader access to the next candlestick.
|
||||
}
|
||||
log.Println("Backtest complete.")
|
||||
log.Println("Stats:")
|
||||
log.Println(trader.Stats())
|
||||
log.Println(trader.Stats().String())
|
||||
default:
|
||||
log.Fatalf("Backtesting is only supported with a TestBroker. Got %T", broker)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroker is a broker that can be used for testing. It implements the Broker interface and fulfills orders
|
||||
@@ -54,11 +61,13 @@ func (b *TestBroker) CandleIndex() int {
|
||||
// Advance advances the test broker to the next candle in the input data. This should be done at the end of the
|
||||
// strategy loop.
|
||||
func (b *TestBroker) Advance() {
|
||||
if b.candleCount < b.Data.Len() {
|
||||
b.candleCount++
|
||||
}
|
||||
}
|
||||
|
||||
func (b *TestBroker) Candles(symbol string, frequency string, count int) (*DataFrame, error) {
|
||||
if b.Data != nil && b.candleCount > b.Data.Len() { // We have data and we are at the end of it.
|
||||
if b.Data != nil && b.candleCount >= b.Data.Len() { // We have data and we are at the end of it.
|
||||
return b.Data.Copy(0, -1).(*DataFrame), ErrEOF
|
||||
} else if b.DataBroker != nil && b.Data == nil { // We have a data broker but no data.
|
||||
// Fetch a lot of candles from the broker so we don't keep asking.
|
||||
@@ -137,6 +146,26 @@ func (b *TestBroker) NAV() float64 {
|
||||
return nav
|
||||
}
|
||||
|
||||
func (b *TestBroker) OpenOrders() []Order {
|
||||
orders := make([]Order, 0, len(b.orders))
|
||||
for _, order := range b.orders {
|
||||
if !order.Fulfilled() {
|
||||
orders = append(orders, order)
|
||||
}
|
||||
}
|
||||
return orders
|
||||
}
|
||||
|
||||
func (b *TestBroker) OpenPositions() []Position {
|
||||
positions := make([]Position, 0, len(b.positions))
|
||||
for _, position := range b.positions {
|
||||
if !position.Closed() {
|
||||
positions = append(positions, position)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func (b *TestBroker) Orders() []Order {
|
||||
return b.orders
|
||||
}
|
||||
|
@@ -57,6 +57,8 @@ type Broker interface {
|
||||
Candles(symbol string, frequency string, count int) (*DataFrame, error)
|
||||
MarketOrder(symbol string, units float64, stopLoss, takeProfit float64) (Order, error)
|
||||
NAV() float64 // NAV returns the net asset value of the account.
|
||||
OpenOrders() []Order
|
||||
OpenPositions() []Position
|
||||
// Orders returns a slice of orders that have been placed with the broker. If an order has been canceled or
|
||||
// filled, it will not be returned.
|
||||
Orders() []Order
|
||||
|
@@ -1,49 +1,46 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
auto "github.com/fivemoreminix/autotrader"
|
||||
)
|
||||
|
||||
type SMAStrategy struct {
|
||||
i int
|
||||
period1, period2 int
|
||||
}
|
||||
|
||||
func (s *SMAStrategy) Init(_trader *auto.Trader) {
|
||||
fmt.Println("Init")
|
||||
s.i = 0
|
||||
func (s *SMAStrategy) Init(_ *auto.Trader) {
|
||||
}
|
||||
|
||||
func (s *SMAStrategy) Next(_trader *auto.Trader) {
|
||||
fmt.Println("Next " + fmt.Sprint(s.i))
|
||||
s.i++
|
||||
func (s *SMAStrategy) Next(t *auto.Trader) {
|
||||
sma1 := t.Data().Closes().Rolling(s.period1).Mean()
|
||||
sma2 := t.Data().Closes().Rolling(s.period2).Mean()
|
||||
log.Println(t.Data().Close(-1) - sma1.Float(-1))
|
||||
// If the shorter SMA crosses above the longer SMA, buy.
|
||||
if crossover(sma1, sma2) {
|
||||
t.Buy(1000)
|
||||
} else if crossover(sma2, sma1) {
|
||||
t.Sell(1000)
|
||||
}
|
||||
}
|
||||
|
||||
// crossover returns true if s1 crosses above s2 at the latest float.
|
||||
func crossover(s1, s2 auto.Series) bool {
|
||||
return s1.Float(-1) > s2.Float(-1) && s1.Float(-2) <= s2.Float(-2)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// token := os.Environ["OANDA_TOKEN"]
|
||||
// accountId := os.Environ["OANDA_ACCOUNT_ID"]
|
||||
|
||||
// if token == "" || accountId == "" {
|
||||
// fmt.Println("Please set OANDA_TOKEN and OANDA_ACCOUNT_ID environment variables")
|
||||
// os.Exit(1)
|
||||
// }
|
||||
|
||||
data, err := auto.EURUSD()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
auto.Backtest(auto.NewTrader(auto.TraderConfig{
|
||||
// auto.NewOandaBroker(auto.OandaConfig{
|
||||
// Token: "YOUR_TOKEN",
|
||||
// AccountID: "101-001-14983263-001",
|
||||
// DemoAccount: true,
|
||||
// }),
|
||||
Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0),
|
||||
Strategy: &SMAStrategy{},
|
||||
Strategy: &SMAStrategy{period1: 20, period2: 40},
|
||||
Symbol: "EUR_USD",
|
||||
Frequency: "D",
|
||||
CandlesToKeep: 100,
|
||||
CandlesToKeep: 1000,
|
||||
}))
|
||||
}
|
||||
|
98
data.go
98
data.go
@@ -1,6 +1,7 @@
|
||||
package autotrader
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
df "github.com/rocketlaunchr/dataframe-go"
|
||||
@@ -42,13 +45,14 @@ type Series interface {
|
||||
Values() []interface{} // Values is the same as ValueRange(0, -1).
|
||||
Float(i int) float64
|
||||
Int(i int) int64
|
||||
String(i int) string
|
||||
Str(i int) string
|
||||
Time(i int) time.Time
|
||||
}
|
||||
|
||||
type Frame interface {
|
||||
Copy(start, end int) Frame
|
||||
Len() int
|
||||
String() string
|
||||
|
||||
// Easy access functions.
|
||||
Date(i int) time.Time
|
||||
@@ -67,6 +71,7 @@ type Frame interface {
|
||||
ContainsDOHLCV() bool // ContainsDOHLCV returns true if the frame contains the columns: Date, Open, High, Low, Close, Volume.
|
||||
|
||||
PushCandle(date time.Time, open, high, low, close, volume float64) error
|
||||
PushValues(values map[string]interface{}) error
|
||||
PushSeries(s ...Series) error
|
||||
RemoveSeries(name string)
|
||||
|
||||
@@ -75,7 +80,7 @@ type Frame interface {
|
||||
Value(column string, i int) interface{}
|
||||
Float(column string, i int) float64
|
||||
Int(column string, i int) int64
|
||||
String(column string, i int) string
|
||||
Str(column string, i int) string
|
||||
// Time returns the value of the column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned.
|
||||
Time(column string, i int) time.Time
|
||||
}
|
||||
@@ -374,7 +379,7 @@ func (s *DataSeries) Int(i int) int64 {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DataSeries) String(i int) string {
|
||||
func (s *DataSeries) Str(i int) string {
|
||||
val := s.Value(i)
|
||||
if val == nil {
|
||||
return ""
|
||||
@@ -438,6 +443,56 @@ func (d *DataFrame) Len() int {
|
||||
return length
|
||||
}
|
||||
|
||||
func (d *DataFrame) String() string {
|
||||
names := d.Names() // Defines the order of the columns.
|
||||
series := make([]Series, len(names))
|
||||
for i, name := range names {
|
||||
series[i] = d.Series(name)
|
||||
}
|
||||
|
||||
buffer := new(bytes.Buffer)
|
||||
t := tabwriter.NewWriter(buffer, 0, 0, 1, ' ', 0)
|
||||
fmt.Fprintf(t, "%T[%dx%d]\n", d, d.Len(), len(d.series))
|
||||
fmt.Fprintln(t, "\t", strings.Join(names, "\t"), "\t")
|
||||
|
||||
printRow := func(i int) {
|
||||
row := make([]string, len(series))
|
||||
for j, s := range series {
|
||||
switch typ := s.Value(i).(type) {
|
||||
case time.Time:
|
||||
row[j] = typ.Format("2006-01-02 15:04:05")
|
||||
case string:
|
||||
row[j] = fmt.Sprintf("%q", typ)
|
||||
default:
|
||||
row[j] = fmt.Sprintf("%v", typ)
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(t, strconv.Itoa(i), "\t", strings.Join(row, "\t"), "\t")
|
||||
}
|
||||
|
||||
// Print the first ten rows and the last ten rows if the DataFrame has more than 20 rows.
|
||||
if d.Len() > 20 {
|
||||
for i := 0; i < 10; i++ {
|
||||
printRow(i)
|
||||
}
|
||||
fmt.Fprintf(t, "...\t")
|
||||
for range names {
|
||||
fmt.Fprint(t, "\t") // Keeps alignment.
|
||||
}
|
||||
fmt.Fprintln(t) // Print new line character.
|
||||
for i := 10; i > 0; i-- {
|
||||
printRow(d.Len() - i)
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < d.Len(); i++ {
|
||||
printRow(i)
|
||||
}
|
||||
}
|
||||
|
||||
t.Flush()
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned.
|
||||
// This is the equivalent to calling Time("Date", i).
|
||||
func (d *DataFrame) Date(i int) time.Time {
|
||||
@@ -541,6 +596,19 @@ func (d *DataFrame) PushCandle(date time.Time, open, high, low, close, volume fl
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DataFrame) PushValues(values map[string]interface{}) error {
|
||||
if len(d.series) == 0 {
|
||||
return fmt.Errorf("DataFrame has no columns") // TODO: could create the columns here.
|
||||
}
|
||||
for name, value := range values {
|
||||
if _, ok := d.series[name]; !ok {
|
||||
return fmt.Errorf("DataFrame does not contain column %q", name)
|
||||
}
|
||||
d.series[name].Push(value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DataFrame) PushSeries(series ...Series) error {
|
||||
if d.series == nil {
|
||||
d.series = make(map[string]Series, len(series))
|
||||
@@ -558,6 +626,17 @@ func (d *DataFrame) PushSeries(series ...Series) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DataFrame) RemoveSeries(name string) {
|
||||
s, ok := d.series[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged)
|
||||
s.SignalDisconnect("NameChanged", d.onSeriesNameChanged)
|
||||
delete(d.series, name)
|
||||
delete(d.rowCounts, name)
|
||||
}
|
||||
|
||||
func (d *DataFrame) onSeriesLengthChanged(args ...interface{}) {
|
||||
if len(args) != 2 {
|
||||
panic(fmt.Sprintf("expected two arguments, got %d", len(args)))
|
||||
@@ -586,17 +665,6 @@ func (d *DataFrame) onSeriesNameChanged(args ...interface{}) {
|
||||
d.series[newName].SignalConnect("NameChanged", d.onSeriesNameChanged, newName)
|
||||
}
|
||||
|
||||
func (d *DataFrame) RemoveSeries(name string) {
|
||||
s, ok := d.series[name]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged)
|
||||
s.SignalDisconnect("NameChanged", d.onSeriesNameChanged)
|
||||
delete(d.series, name)
|
||||
delete(d.rowCounts, name)
|
||||
}
|
||||
|
||||
func (d *DataFrame) Names() []string {
|
||||
return maps.Keys(d.series)
|
||||
}
|
||||
@@ -654,7 +722,7 @@ func (d *DataFrame) Int(column string, i int) int64 {
|
||||
}
|
||||
|
||||
// String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned.
|
||||
func (d *DataFrame) String(column string, i int) string {
|
||||
func (d *DataFrame) Str(column string, i int) string {
|
||||
val := d.Value(column, i)
|
||||
if val == nil {
|
||||
return ""
|
||||
|
57
trader.go
57
trader.go
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/rocketlaunchr/dataframe-go"
|
||||
)
|
||||
|
||||
// Trader acts as the primary interface to the broker and strategy. To the strategy, it provides all the information
|
||||
@@ -25,7 +26,6 @@ type Trader struct {
|
||||
|
||||
data *DataFrame
|
||||
sched *gocron.Scheduler
|
||||
idx int
|
||||
stats *DataFrame // Performance (financial) reporting and statistics.
|
||||
}
|
||||
|
||||
@@ -72,30 +72,71 @@ func (t *Trader) Run() {
|
||||
}
|
||||
}
|
||||
t.sched.Do(t.Tick) // Set the function to be run when the interval repeats.
|
||||
|
||||
t.Init()
|
||||
t.sched.StartBlocking()
|
||||
}
|
||||
|
||||
func (t *Trader) Init() {
|
||||
t.Strategy.Init(t)
|
||||
t.stats = NewDataFrame(
|
||||
NewDataSeries(dataframe.NewSeriesTime("Date", nil)),
|
||||
NewDataSeries(dataframe.NewSeriesFloat64("Equity", nil)),
|
||||
)
|
||||
}
|
||||
|
||||
// Tick updates the current state of the market and runs the strategy.
|
||||
func (t *Trader) Tick() {
|
||||
t.Log.Println("Tick")
|
||||
if t.idx == 0 {
|
||||
t.Strategy.Init(t)
|
||||
}
|
||||
t.fetchData()
|
||||
t.Strategy.Next(t)
|
||||
t.fetchData() // Fetch the latest candlesticks from the broker.
|
||||
// t.Log.Println(t.data.Close(-1))
|
||||
t.Strategy.Next(t) // Run the strategy.
|
||||
|
||||
// Update the stats.
|
||||
t.stats.PushValues(map[string]interface{}{
|
||||
"Date": t.data.Date(-1),
|
||||
"Equity": t.Broker.NAV(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Trader) fetchData() {
|
||||
var err error
|
||||
t.data, err = t.Broker.Candles(t.Symbol, t.Frequency, t.CandlesToKeep)
|
||||
if err == ErrEOF {
|
||||
t.EOF = true
|
||||
t.Log.Println("End of data")
|
||||
if t.sched != nil && t.sched.IsRunning() {
|
||||
t.sched.Clear()
|
||||
}
|
||||
} else if err != nil {
|
||||
panic(err) // TODO: implement safe shutdown procedure
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Trader) Buy(units float64) {
|
||||
t.Log.Printf("Buy %f units", units)
|
||||
t.closeOrdersAndPositions()
|
||||
t.Broker.MarketOrder(t.Symbol, units, 0.0, 0.0)
|
||||
}
|
||||
|
||||
func (t *Trader) Sell(units float64) {
|
||||
t.Log.Printf("Sell %f units", units)
|
||||
t.closeOrdersAndPositions()
|
||||
t.Broker.MarketOrder(t.Symbol, -units, 0.0, 0.0)
|
||||
}
|
||||
|
||||
func (t *Trader) closeOrdersAndPositions() {
|
||||
for _, order := range t.Broker.OpenOrders() {
|
||||
if order.Symbol() == t.Symbol {
|
||||
order.Cancel()
|
||||
}
|
||||
}
|
||||
for _, position := range t.Broker.OpenPositions() {
|
||||
if position.Symbol() == t.Symbol {
|
||||
position.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TraderConfig struct {
|
||||
Broker Broker
|
||||
Strategy Strategy
|
||||
@@ -114,6 +155,6 @@ func NewTrader(config TraderConfig) *Trader {
|
||||
Frequency: config.Frequency,
|
||||
CandlesToKeep: config.CandlesToKeep,
|
||||
Log: logger,
|
||||
stats: NewDataFrame(nil),
|
||||
stats: NewDataFrame(),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user