diff --git a/backtesting.go b/backtesting.go index db8c57b..3ce7678 100644 --- a/backtesting.go +++ b/backtesting.go @@ -2,6 +2,7 @@ package autotrader import ( "errors" + "log" "strconv" "time" @@ -18,6 +19,9 @@ func Backtest(trader *Trader) { for !trader.EOF { trader.Tick() } + log.Println("Backtest complete.") + log.Println("Stats:") + log.Println(trader.Stats()) } // TestBroker is a broker that can be used for testing. It implements the Broker interface and fulfills orders diff --git a/backtesting_test.go b/backtesting_test.go index 1cd7801..255e192 100644 --- a/backtesting_test.go +++ b/backtesting_test.go @@ -4,8 +4,6 @@ import ( "strings" "testing" "time" - - df "github.com/rocketlaunchr/dataframe-go" ) const testDataCSV = `date,open,high,low,close,volume @@ -19,8 +17,8 @@ const testDataCSV = `date,open,high,low,close,volume 2022-01-08,1.25,1.3,1.2,1.1,150 2022-01-09,1.1,1.4,1.0,1.3,220` -func newTestingDataframe() *df.DataFrame { - data, err := ReadDataCSVFromReader(strings.NewReader(testDataCSV), DataCSVLayout{ +func newTestingDataframe() *DataFrame { + data, err := DataFrameFromCSVReaderLayout(strings.NewReader(testDataCSV), DataCSVLayout{ LatestFirst: false, DateFormat: "2006-01-02", Date: "date", @@ -37,7 +35,7 @@ func newTestingDataframe() *df.DataFrame { } func TestBacktestingBrokerCandles(t *testing.T) { - data := NewDataFrame(newTestingDataframe()) + data := newTestingDataframe() broker := NewTestBroker(nil, data, 0, 0, 0, 0) candles, err := broker.Candles("EUR_USD", "D", 3) @@ -88,7 +86,7 @@ func TestBacktestingBrokerFunctions(t *testing.T) { } func TestBacktestingBrokerOrders(t *testing.T) { - data := NewDataFrame(newTestingDataframe()) + data := newTestingDataframe() broker := NewTestBroker(nil, data, 100_000, 50, 0, 0) timeBeforeOrder := time.Now() order, err := broker.MarketOrder("EUR_USD", 50_000, 0, 0) // Buy 50,000 USD for 1000 EUR with no stop loss or take profit diff --git a/cmd/sma_crossover.go b/cmd/sma_crossover.go index 2c1a3bb..b241dcf 100644 --- a/cmd/sma_crossover.go +++ b/cmd/sma_crossover.go @@ -29,16 +29,7 @@ func main() { // os.Exit(1) // } - data, err := auto.ReadDataCSV("./EUR_USD Historical Data.csv", auto.DataCSVLayout{ - LatestFirst: true, - DateFormat: "01/02/2006", - Date: "\ufeff\"Date\"", - Open: "Open", - High: "High", - Low: "Low", - Close: "Price", - Volume: "Vol.", - }) + data, err := auto.EURUSD() if err != nil { panic(err) } @@ -49,7 +40,7 @@ func main() { // AccountID: "101-001-14983263-001", // DemoAccount: true, // }), - Broker: auto.NewTestBroker(nil, auto.NewDataFrame(data), 10000, 50, 0.0002, 0), + Broker: auto.NewTestBroker(nil, data, 10000, 50, 0.0002, 0), Strategy: &SMAStrategy{}, Symbol: "EUR_USD", Frequency: "D", diff --git a/data.go b/data.go index 3d3a629..9caf250 100644 --- a/data.go +++ b/data.go @@ -11,6 +11,7 @@ import ( "time" df "github.com/rocketlaunchr/dataframe-go" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) @@ -23,7 +24,11 @@ func EasyIndex(i, n int) int { } type Series interface { + Signaler + Copy(start, end int) Series + Name() string // Name returns the immutable name of the Series. + SetName(name string) Series Len() int // Statistical functions. @@ -58,10 +63,14 @@ type Frame interface { Lows() Series Closes() Series Volumes() Series + Contains(names ...string) bool // Contains returns true if the frame contains all the columns specified. + ContainsDOHLCV() bool // ContainsDOHLCV returns true if the frame contains the columns: Date, Open, High, Low, Close, Volume. - PushCandle(date time.Time, open, high, low, close, volume float64) Frame - // AddSeries(name string, s Series) error + PushCandle(date time.Time, open, high, low, close, volume float64) error + PushSeries(s ...Series) error + RemoveSeries(name string) + Names() []string Series(name string) Series Value(column string, i int) interface{} Float(column string, i int) float64 @@ -250,203 +259,15 @@ func (s *RollingSeries) Value(i int) interface{} { } // DataSeries is a Series that wraps a column of data. The data can be of the following types: float64, int64, string, or time.Time. +// +// Signals: +// - LengthChanged(int) - when the data is appended or an item is removed. +// - NameChanged(string) - when the name is changed. type DataSeries struct { + SignalManager data df.Series } -type DataFrame struct { - data *df.DataFrame // DataFrame with a Date, Open, High, Low, Close, and Volume column. -} - -// Copy copies the DataFrame from start to end (inclusive). If end is -1, it will copy to the end of the DataFrame. If start is out of bounds, nil is returned. -func (d *DataFrame) Copy(start, end int) Frame { - var _end *int - if start < 0 || start >= d.Len() { - return nil - } else if end >= 0 { - if end < start { - return nil - } - _end = &end - } - return &DataFrame{d.data.Copy(df.Range{Start: &start, End: _end})} -} - -// Len returns the number of rows in the DataFrame or 0 if the DataFrame is nil. -func (d *DataFrame) Len() int { - if d.data == nil { - return 0 - } - return d.data.NRows() -} - -// Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Time("Date", i). -func (d *DataFrame) Date(i int) time.Time { - return d.Time("Date", i) -} - -// Open returns the open price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Float("Open", i). -func (d *DataFrame) Open(i int) float64 { - return d.Float("Open", i) -} - -// High returns the high price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Float("High", i). -func (d *DataFrame) High(i int) float64 { - return d.Float("High", i) -} - -// Low returns the low price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Float("Low", i). -func (d *DataFrame) Low(i int) float64 { - return d.Float("Low", i) -} - -// Close returns the close price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Float("Close", i). -func (d *DataFrame) Close(i int) float64 { - return d.Float("Close", i) -} - -// Volume returns the volume of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -// This is the equivalent to calling Float("Volume", i). -func (d *DataFrame) Volume(i int) float64 { - return d.Float("Volume", i) -} - -// Dates returns a Series of all the dates in the DataFrame. -func (d *DataFrame) Dates() Series { - return d.Series("Date") -} - -// Opens returns a Series of all the open prices in the DataFrame. -func (d *DataFrame) Opens() Series { - return d.Series("Open") -} - -// Highs returns a Series of all the high prices in the DataFrame. -func (d *DataFrame) Highs() Series { - return d.Series("High") -} - -// Lows returns a Series of all the low prices in the DataFrame. -func (d *DataFrame) Lows() Series { - return d.Series("Low") -} - -// Closes returns a Series of all the close prices in the DataFrame. -func (d *DataFrame) Closes() Series { - return d.Series("Close") -} - -// Volumes returns a Series of all the volumes in the DataFrame. -func (d *DataFrame) Volumes() Series { - return d.Series("Volume") -} - -func (d *DataFrame) PushCandle(date time.Time, open, high, low, close, volume float64) Frame { - if d.data == nil { - d.data = df.NewDataFrame([]df.Series{ - df.NewSeriesTime("Date", nil, date), - df.NewSeriesFloat64("Open", nil, open), - df.NewSeriesFloat64("High", nil, high), - df.NewSeriesFloat64("Low", nil, low), - df.NewSeriesFloat64("Close", nil, close), - df.NewSeriesFloat64("Volume", nil, volume), - }...) - return d - } - d.data.Append(nil, date, open, high, low, close, volume) - return d -} - -// Series returns a Series of the column with the given name. If the column does not exist, nil is returned. -func (d *DataFrame) Series(name string) Series { - if d.data == nil { - return nil - } - colIdx, err := d.data.NameToColumn(name) - if err != nil { - return nil - } - return &DataSeries{d.data.Series[colIdx]} -} - -// Value returns the value of the column at index i. The first value is at index 0. A negative value for i can be used to get i values from the latest, like Python's negative indexing. If i is out of bounds, nil is returned. -func (d *DataFrame) Value(column string, i int) interface{} { - if d.data == nil { - return nil - } - i = EasyIndex(i, d.Len()) // Allow for negative indexing. - colIdx, err := d.data.NameToColumn(column) - if err != nil || i < 0 || i >= d.Len() { // Prevent out of bounds access. - return nil - } - return d.data.Series[colIdx].Value(i) -} - -// Float returns the value of the column at index i casted to float64. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -func (d *DataFrame) Float(column string, i int) float64 { - val := d.Value(column, i) - if val == nil { - return 0 - } - switch val := val.(type) { - case float64: - return val - default: - return 0 - } -} - -// Int returns the value of the column at index i casted to int. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. -func (d *DataFrame) Int(column string, i int) int64 { - val := d.Value(column, i) - if val == nil { - return 0 - } - switch val := val.(type) { - case int64: - return val - default: - return 0 - } -} - -// String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned. -func (d *DataFrame) String(column string, i int) string { - val := d.Value(column, i) - if val == nil { - return "" - } - switch val := val.(type) { - case string: - return val - default: - return "" - } -} - -// Time returns the value of the column at index i casted to time.Time. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, time.Time{} is returned. -func (d *DataFrame) Time(column string, i int) time.Time { - val := d.Value(column, i) - if val == nil { - return time.Time{} - } - switch val := val.(type) { - case time.Time: - return val - default: - return time.Time{} - } -} - -func NewDataFrame(data *df.DataFrame) *DataFrame { - return &DataFrame{data} -} - // Copy copies the Series from start to end (inclusive). If end is -1, it will copy to the end of the Series. If start is out of bounds, nil is returned. func (s *DataSeries) Copy(start, end int) Series { var _end *int @@ -458,7 +279,20 @@ func (s *DataSeries) Copy(start, end int) Series { } _end = &end } - return &DataSeries{s.data.Copy(df.Range{Start: &start, End: _end})} + return &DataSeries{SignalManager{}, s.data.Copy(df.Range{Start: &start, End: _end})} +} + +func (s *DataSeries) Name() string { + return s.data.Name() +} + +func (s *DataSeries) SetName(name string) Series { + if name == s.Name() { + return s + } + s.data.Rename(name) + s.SignalEmit("NameChanged", name) + return s } func (s *DataSeries) Len() int { @@ -475,6 +309,7 @@ func (s *DataSeries) Rolling(period int) *RollingSeries { func (s *DataSeries) Push(value interface{}) Series { if s.data != nil { s.data.Append(value) + s.SignalEmit("LengthChanged", s.Len()) } return s } @@ -493,12 +328,10 @@ func (s *DataSeries) ValueRange(start, end int) []interface{} { return nil } start = EasyIndex(start, s.Len()) - if start < 0 || start >= s.Len() || end >= s.Len() { + if start < 0 || start >= s.Len() || end >= s.Len() || start > end { return nil } else if end < 0 { end = s.Len() - 1 - } else if start > end { - return nil } items := make([]interface{}, end-start+1) @@ -567,6 +400,293 @@ func (s *DataSeries) Time(i int) time.Time { } } +func NewDataSeries(data df.Series) *DataSeries { + return &DataSeries{SignalManager{}, data} +} + +type DataFrame struct { + series map[string]Series + rowCounts map[string]int + // data *df.DataFrame // DataFrame with a Date, Open, High, Low, Close, and Volume column. +} + +// Copy copies the DataFrame from start to end (inclusive). If end is -1, it will copy to the end of the DataFrame. If start is out of bounds, nil is returned. +func (d *DataFrame) Copy(start, end int) Frame { + out := &DataFrame{} + for _, v := range d.series { + newSeries := v.Copy(start, end) + out.PushSeries(newSeries) + } + return out +} + +// Len returns the number of rows in the DataFrame or 0 if the DataFrame is nil. A value less than zero means the +// DataFrame has Series of varying lengths. +func (d *DataFrame) Len() int { + if len(d.series) == 0 { + return 0 + } + // Check if all the Series have the same length. + var length int + for _, v := range d.rowCounts { + if length == 0 { + length = v + } else if length != v { + return -1 + } + } + return length +} + +// Date returns the value of the Date column at index i. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Time("Date", i). +func (d *DataFrame) Date(i int) time.Time { + return d.Time("Date", i) +} + +// Open returns the open price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Float("Open", i). +func (d *DataFrame) Open(i int) float64 { + return d.Float("Open", i) +} + +// High returns the high price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Float("High", i). +func (d *DataFrame) High(i int) float64 { + return d.Float("High", i) +} + +// Low returns the low price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Float("Low", i). +func (d *DataFrame) Low(i int) float64 { + return d.Float("Low", i) +} + +// Close returns the close price of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Float("Close", i). +func (d *DataFrame) Close(i int) float64 { + return d.Float("Close", i) +} + +// Volume returns the volume of the candle at index i. The first candle is at index 0. A negative value for i (-n) can be used to get n candles from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +// This is the equivalent to calling Float("Volume", i). +func (d *DataFrame) Volume(i int) float64 { + return d.Float("Volume", i) +} + +// Dates returns a Series of all the dates in the DataFrame. +func (d *DataFrame) Dates() Series { + return d.Series("Date") +} + +// Opens returns a Series of all the open prices in the DataFrame. +func (d *DataFrame) Opens() Series { + return d.Series("Open") +} + +// Highs returns a Series of all the high prices in the DataFrame. +func (d *DataFrame) Highs() Series { + return d.Series("High") +} + +// Lows returns a Series of all the low prices in the DataFrame. +func (d *DataFrame) Lows() Series { + return d.Series("Low") +} + +// Closes returns a Series of all the close prices in the DataFrame. +func (d *DataFrame) Closes() Series { + return d.Series("Close") +} + +// Volumes returns a Series of all the volumes in the DataFrame. +func (d *DataFrame) Volumes() Series { + return d.Series("Volume") +} + +func (d *DataFrame) Contains(names ...string) bool { + for _, name := range names { + if _, ok := d.series[name]; !ok { + return false + } + } + return true +} + +func (d *DataFrame) ContainsDOHLCV() bool { + return d.Contains("Date", "Open", "High", "Low", "Close", "Volume") +} + +func (d *DataFrame) PushCandle(date time.Time, open, high, low, close, volume float64) error { + if len(d.series) == 0 { + d.PushSeries([]Series{ + NewDataSeries(df.NewSeriesTime("Date", nil, date)), + NewDataSeries(df.NewSeriesFloat64("Open", nil, open)), + NewDataSeries(df.NewSeriesFloat64("High", nil, high)), + NewDataSeries(df.NewSeriesFloat64("Low", nil, low)), + NewDataSeries(df.NewSeriesFloat64("Close", nil, close)), + NewDataSeries(df.NewSeriesFloat64("Volume", nil, volume)), + }...) + return nil + } + if !d.ContainsDOHLCV() { + return fmt.Errorf("DataFrame does not contain Date, Open, High, Low, Close, Volume columns") + } + d.series["Date"].Push(date) + d.series["Open"].Push(open) + d.series["High"].Push(high) + d.series["Low"].Push(low) + d.series["Close"].Push(close) + d.series["Volume"].Push(volume) + return nil +} + +func (d *DataFrame) PushSeries(series ...Series) error { + if d.series == nil { + d.series = make(map[string]Series, len(series)) + d.rowCounts = make(map[string]int, len(series)) + } + + for _, s := range series { + name := s.Name() + s.SignalConnect("LengthChanged", d.onSeriesLengthChanged, name) + s.SignalConnect("NameChanged", d.onSeriesNameChanged, name) + d.series[name] = s + d.rowCounts[name] = s.Len() + } + + return nil +} + +func (d *DataFrame) onSeriesLengthChanged(args ...interface{}) { + if len(args) != 2 { + panic(fmt.Sprintf("expected two arguments, got %d", len(args))) + } + newLen := args[0].(int) + name := args[1].(string) + d.rowCounts[name] = newLen +} + +func (d *DataFrame) onSeriesNameChanged(args ...interface{}) { + if len(args) != 2 { + panic(fmt.Sprintf("expected two arguments, got %d", len(args))) + } + newName := args[0].(string) + oldName := args[1].(string) + + d.series[newName] = d.series[oldName] + d.rowCounts[newName] = d.rowCounts[oldName] + delete(d.series, oldName) + delete(d.rowCounts, oldName) + + // Reconnect our signal handlers to update the name we use in the handlers. + d.series[newName].SignalDisconnect("LengthChanged", d.onSeriesLengthChanged) + d.series[newName].SignalDisconnect("NameChanged", d.onSeriesNameChanged) + d.series[newName].SignalConnect("LengthChanged", d.onSeriesLengthChanged, newName) + d.series[newName].SignalConnect("NameChanged", d.onSeriesNameChanged, newName) +} + +func (d *DataFrame) RemoveSeries(name string) { + s, ok := d.series[name] + if !ok { + return + } + s.SignalDisconnect("LengthChanged", d.onSeriesLengthChanged) + s.SignalDisconnect("NameChanged", d.onSeriesNameChanged) + delete(d.series, name) + delete(d.rowCounts, name) +} + +func (d *DataFrame) Names() []string { + return maps.Keys(d.series) +} + +// Series returns a Series of the column with the given name. If the column does not exist, nil is returned. +func (d *DataFrame) Series(name string) Series { + if len(d.series) == 0 { + return nil + } + v, ok := d.series[name] + if !ok { + return nil + } + return v +} + +// Value returns the value of the column at index i. The first value is at index 0. A negative value for i can be used to get i values from the latest, like Python's negative indexing. If i is out of bounds, nil is returned. +func (d *DataFrame) Value(column string, i int) interface{} { + if len(d.series) == 0 { + return nil + } + i = EasyIndex(i, d.Len()) // Allow for negative indexing. + if i < 0 || i >= d.Len() { // Prevent out of bounds access. + return nil + } + return d.series[column].Value(i) +} + +// Float returns the value of the column at index i casted to float64. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +func (d *DataFrame) Float(column string, i int) float64 { + val := d.Value(column, i) + if val == nil { + return 0 + } + switch val := val.(type) { + case float64: + return val + default: + return 0 + } +} + +// Int returns the value of the column at index i casted to int. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, 0 is returned. +func (d *DataFrame) Int(column string, i int) int64 { + val := d.Value(column, i) + if val == nil { + return 0 + } + switch val := val.(type) { + case int64: + return val + default: + return 0 + } +} + +// String returns the value of the column at index i casted to string. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, "" is returned. +func (d *DataFrame) String(column string, i int) string { + val := d.Value(column, i) + if val == nil { + return "" + } + switch val := val.(type) { + case string: + return val + default: + return "" + } +} + +// Time returns the value of the column at index i casted to time.Time. The first value is at index 0. A negative value for i (-n) can be used to get n values from the latest, like Python's negative indexing. If i is out of bounds, time.Time{} is returned. +func (d *DataFrame) Time(column string, i int) time.Time { + val := d.Value(column, i) + if val == nil { + return time.Time{} + } + switch val := val.(type) { + case time.Time: + return val + default: + return time.Time{} + } +} + +func NewDataFrame(series ...Series) *DataFrame { + d := &DataFrame{} + d.PushSeries(series...) + return d +} + type DataCSVLayout struct { LatestFirst bool // Whether the latest data is first in the dataframe. If false, the latest data is last. DateFormat string // The format of the date column. Example: "03/22/2006". See https://pkg.go.dev/time#pkg-constants for more information. @@ -578,17 +698,8 @@ type DataCSVLayout struct { Volume string } -func ReadDataCSV(path string, layout DataCSVLayout) (*df.DataFrame, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - return ReadDataCSVFromReader(f, layout) -} - -func ReadEURUSDDataCSV() (*df.DataFrame, error) { - return ReadDataCSV("./EUR_USD Historical Data.csv", DataCSVLayout{ +func EURUSD() (*DataFrame, error) { + return DataFrameFromCSVLayout("./EUR_USD Historical Data.csv", DataCSVLayout{ LatestFirst: true, DateFormat: "01/02/2006", Date: "\ufeff\"Date\"", @@ -600,8 +711,17 @@ func ReadEURUSDDataCSV() (*df.DataFrame, error) { }) } -func ReadDataCSVFromReader(r io.Reader, layout DataCSVLayout) (*df.DataFrame, error) { - data, err := ReadCSVFromReader(r, layout.DateFormat, layout.LatestFirst) +func DataFrameFromCSVLayout(path string, layout DataCSVLayout) (*DataFrame, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + return DataFrameFromCSVReaderLayout(f, layout) +} + +func DataFrameFromCSVReaderLayout(r io.Reader, layout DataCSVLayout) (*DataFrame, error) { + data, err := DataFrameFromCSVReader(r, layout.DateFormat, layout.LatestFirst) if err != nil { return data, err } @@ -626,23 +746,19 @@ func ReadDataCSVFromReader(r io.Reader, layout DataCSVLayout) (*df.DataFrame, er data.RemoveSeries(name) continue } - idx, err := data.NameToColumn(name) - if err != nil { - panic(err) - } - data.Series[idx].Rename(newName) + data.Series(name).SetName(newName) } - err = data.ReorderColumns([]string{"Date", "Open", "High", "Low", "Close", "Volume"}) - if err != nil { - return data, err - } + // err = data.ReorderColumns([]string{"Date", "Open", "High", "Low", "Close", "Volume"}) + // if err != nil { + // return data, err + // } // TODO: Reverse the dataframe if the latest data is first. return data, nil } -func ReadCSVFromReader(r io.Reader, dateLayout string, readReversed bool) (*df.DataFrame, error) { +func DataFrameFromCSVReader(r io.Reader, dateLayout string, readReversed bool) (*DataFrame, error) { csv := csv.NewReader(r) csv.LazyQuotes = true records, err := csv.ReadAll() @@ -654,7 +770,7 @@ func ReadCSVFromReader(r io.Reader, dateLayout string, readReversed bool) (*df.D return nil, errors.New("csv file must have at least 2 rows") } - seriesSlice := make([]df.Series, 0, 12) + dfSeriesSlice := make([]df.Series, 0, 12) // TODO: change Capacity to Size. initOptions := &df.SeriesInit{Capacity: len(records) - 1} @@ -674,7 +790,7 @@ func ReadCSVFromReader(r io.Reader, dateLayout string, readReversed bool) (*df.D } // Create the series columns and label them. - seriesSlice = append(seriesSlice, series) + dfSeriesSlice = append(dfSeriesSlice, series) } // Set the direction to iterate the records. @@ -694,7 +810,7 @@ func ReadCSVFromReader(r io.Reader, dateLayout string, readReversed bool) (*df.D // Add rows to the series. for j, val := range rec { - series := seriesSlice[j] + series := dfSeriesSlice[j] switch series.Type() { case "float64": val, err := strconv.ParseFloat(val, 64) @@ -720,11 +836,15 @@ func ReadCSVFromReader(r io.Reader, dateLayout string, readReversed bool) (*df.D case "string": series.Append(val) } - seriesSlice[j] = series + dfSeriesSlice[j] = series } } // NOTE: we specifically construct the DataFrame at the end of the function because it likes to set // state like number of rows and columns at initialization and won't let you change it later. - return df.NewDataFrame(seriesSlice...), nil + seriesSlice := make([]Series, len(dfSeriesSlice)) + for i, series := range dfSeriesSlice { + seriesSlice[i] = NewDataSeries(series) + } + return NewDataFrame(seriesSlice...), nil } diff --git a/data_test.go b/data_test.go index fa243dd..0a96fce 100644 --- a/data_test.go +++ b/data_test.go @@ -6,18 +6,15 @@ import ( ) func newTestingDataFrame() *DataFrame { - _dataframe, err := ReadEURUSDDataCSV() + data, err := EURUSD() if err != nil { - return nil + panic(err) } - return NewDataFrame(_dataframe) + return data } func TestDataSeries(t *testing.T) { data := newTestingDataFrame() - if data == nil { - t.Fatal("Could not create DataFrame") - } dates, closes := data.Dates(), data.Closes() @@ -39,9 +36,6 @@ func TestDataSeries(t *testing.T) { func TestDataFrame(t *testing.T) { data := newTestingDataFrame() - if data == nil { - t.Fatal("Could not create DataFrame") - } if data.Len() != 2610 { t.Fatalf("Expected 2610 rows, got %d", data.Len()) @@ -55,7 +49,11 @@ func TestDataFrame(t *testing.T) { t.Fatalf("Expected 2013-05-13, got %s", date.Format(time.DateOnly)) } - data.PushCandle(time.Date(2023, 5, 14, 0, 0, 0, 0, time.UTC), 1.0, 1.0, 1.0, 1.0, 1) + err := data.PushCandle(time.Date(2023, 5, 14, 0, 0, 0, 0, time.UTC), 1.0, 1.0, 1.0, 1.0, 1) + if err != nil { + t.Log(data.Names()) + t.Fatalf("Expected no error, got %s", err) + } if data.Len() != 2611 { t.Fatalf("Expected 2611 rows, got %d", data.Len()) } @@ -65,37 +63,34 @@ func TestDataFrame(t *testing.T) { } func TestReadDataCSV(t *testing.T) { - data, err := ReadEURUSDDataCSV() - if err != nil { - t.Fatal(err) - } + data := newTestingDataFrame() - if data.NRows() != 2610 { - t.Fatalf("Expected 2610 rows, got %d", data.NRows()) + if data.Len() != 2610 { + t.Fatalf("Expected 2610 rows, got %d", data.Len()) } if len(data.Names()) != 6 { t.Fatalf("Expected 6 columns, got %d", len(data.Names())) } - if data.Series[0].Name() != "Date" { - t.Fatalf("Expected Date column, got %s", data.Series[0].Name()) + if data.Series("Date") == nil { + t.Fatalf("Expected Date column, got nil") } - if data.Series[1].Name() != "Open" { - t.Fatalf("Expected Open column, got %s", data.Series[1].Name()) + if data.Series("Open") == nil { + t.Fatalf("Expected Open column, got nil") } - if data.Series[2].Name() != "High" { - t.Fatalf("Expected High column, got %s", data.Series[2].Name()) + if data.Series("High") == nil { + t.Fatalf("Expected High column, got nil") } - if data.Series[3].Name() != "Low" { - t.Fatalf("Expected Low column, got %s", data.Series[3].Name()) + if data.Series("Low") == nil { + t.Fatalf("Expected Low column, got nil") } - if data.Series[4].Name() != "Close" { - t.Fatalf("Expected Close column, got %s", data.Series[4].Name()) + if data.Series("Close") == nil { + t.Fatalf("Expected Close column, got nil") } - if data.Series[5].Name() != "Volume" { - t.Fatalf("Expected Volume column, got %s", data.Series[5].Name()) + if data.Series("Volume") == nil { + t.Fatalf("Expected Volume column, got nil") } - if data.Series[0].Type() != "time" { - t.Fatalf("Expected Date column type time, got %s", data.Series[0].Type()) + if data.Series("Date").Time(0).Equal(time.Time{}) { + t.Fatalf("Expected Date column to have type time.Time, got %s", data.Value("Date", 0)) } } diff --git a/signals.go b/signals.go index 1dd3daa..dcd049c 100644 --- a/signals.go +++ b/signals.go @@ -3,60 +3,68 @@ package autotrader import "reflect" type Signaler interface { - SignalConnect(signal string, handler func(interface{})) error // SignalConnect connects the handler to the signal. - SignalConnected(signal string, handler func(interface{})) bool // SignalConnected returns true if the handler is connected to the signal. - SignalConnections(signal string) []func(interface{}) // SignalConnections returns a slice of handlers connected to the signal. - SignalDisconnect(signal string, handler func(interface{})) // SignalDisconnect removes the handler from the signal. - SignalEmit(signal string, data interface{}) // SignalEmit emits the signal with the data. + SignalConnect(signal string, handler func(...interface{}), bindings ...interface{}) error // SignalConnect connects the handler to the signal. + SignalConnected(signal string, handler func(...interface{})) bool // SignalConnected returns true if the handler is connected to the signal. + SignalConnections(signal string) []SignalHandler // SignalConnections returns a slice of handlers connected to the signal. + SignalDisconnect(signal string, handler func(...interface{})) // SignalDisconnect removes the handler from the signal. + SignalEmit(signal string, data ...interface{}) // SignalEmit emits the signal with the data. +} + +type SignalHandler struct { + Callback func(...interface{}) + Bindings []interface{} } type SignalManager struct { - signalConnections map[string][]func(interface{}) + signalConnections map[string][]SignalHandler } -func (s *SignalManager) SignalConnect(signal string, handler func(interface{})) error { +func (s *SignalManager) SignalConnect(signal string, callback func(...interface{}), bindings ...interface{}) error { if s.signalConnections == nil { - s.signalConnections = make(map[string][]func(interface{})) + s.signalConnections = make(map[string][]SignalHandler) } - s.signalConnections[signal] = append(s.signalConnections[signal], handler) + s.signalConnections[signal] = append(s.signalConnections[signal], SignalHandler{callback, bindings}) return nil } -func (s *SignalManager) SignalConnected(signal string, handler func(interface{})) bool { +func (s *SignalManager) SignalConnected(signal string, callback func(...interface{})) bool { if s.signalConnections == nil { return false } for _, h := range s.signalConnections[signal] { - if reflect.ValueOf(h).Pointer() == reflect.ValueOf(handler).Pointer() { + if reflect.ValueOf(h.Callback).Pointer() == reflect.ValueOf(callback).Pointer() { return true } } return false } -func (s *SignalManager) SignalConnections(signal string) []func(interface{}) { +func (s *SignalManager) SignalConnections(signal string) []SignalHandler { if s.signalConnections == nil { return nil } return s.signalConnections[signal] } -func (s *SignalManager) SignalDisconnect(signal string, handler func(interface{})) { +func (s *SignalManager) SignalDisconnect(signal string, callback func(...interface{})) { if s.signalConnections == nil { return } for i, h := range s.signalConnections[signal] { - if reflect.ValueOf(h).Pointer() == reflect.ValueOf(handler).Pointer() { + if reflect.ValueOf(h.Callback).Pointer() == reflect.ValueOf(callback).Pointer() { s.signalConnections[signal] = append(s.signalConnections[signal][:i], s.signalConnections[signal][i+1:]...) } } } -func (s *SignalManager) SignalEmit(signal string, data interface{}) { +func (s *SignalManager) SignalEmit(signal string, data ...interface{}) { if s.signalConnections == nil { return } for _, handler := range s.signalConnections[signal] { - handler(data) + args := make([]interface{}, len(data)+len(handler.Bindings)) + copy(args, data) + copy(args[len(data):], handler.Bindings) + handler.Callback(args...) } }