mirror of
https://github.com/lukewilson2002/autotrader.git
synced 2025-06-14 16:03:51 +00:00
Use type safety for RSI indicator
This commit is contained in:
parent
1516604889
commit
ed08bd275d
@ -7,28 +7,28 @@ import "math"
|
||||
// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition.
|
||||
//
|
||||
// Typically, the RSI is calculated with a period of 14 days.
|
||||
func RSI(series *Series, periods int) *Series {
|
||||
func RSI(series *FloatSeries, periods int) *FloatSeries {
|
||||
// Calculate the difference between each day's close and the previous day's close.
|
||||
delta := series.Copy().Map(func(i int, v interface{}) interface{} {
|
||||
delta := series.Copy().MapReverse(func(i int, v float64) float64 {
|
||||
if i == 0 {
|
||||
return float64(0)
|
||||
return 0
|
||||
}
|
||||
return v.(float64) - series.Value(i-1).(float64)
|
||||
return v - series.Value(i-1)
|
||||
})
|
||||
// Calculate the average gain and average loss.
|
||||
avgGain := delta.Copy().
|
||||
Map(func(i int, val interface{}) interface{} { return math.Max(val.(float64), 0) }).
|
||||
Rolling(periods).Average()
|
||||
avgLoss := delta.Copy().
|
||||
Map(func(i int, val interface{}) interface{} { return math.Abs(math.Min(val.(float64), 0)) }).
|
||||
Rolling(periods).Average()
|
||||
avgGain := &FloatSeries{delta.Copy().
|
||||
Map(func(i int, val float64) float64 { return math.Max(val, 0) }).
|
||||
Rolling(periods).Average()}
|
||||
avgLoss := &FloatSeries{delta.Copy().
|
||||
Map(func(i int, val float64) float64 { return math.Abs(math.Min(val, 0)) }).
|
||||
Rolling(periods).Average()}
|
||||
// Calculate the RSI.
|
||||
return avgGain.Map(func(i int, val interface{}) interface{} {
|
||||
return avgGain.Map(func(i int, val float64) float64 {
|
||||
loss := avgLoss.Float(i)
|
||||
if loss == 0 {
|
||||
return float64(100)
|
||||
}
|
||||
return float64(100. - 100./(1.+val.(float64)/loss))
|
||||
return float64(100. - 100./(1.+val/loss))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -5,16 +5,16 @@ import (
|
||||
)
|
||||
|
||||
func TestRSI(t *testing.T) {
|
||||
prices := NewSeries("Prices", 1., 0., 2., 1., 3., 2., 4., 3., 5., 4., 6., 5., 7., 6.)
|
||||
prices := NewFloatSeries("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6)
|
||||
rsi := RSI(prices, 14)
|
||||
if rsi.Len() != 14 {
|
||||
t.Errorf("RSI length is %d, expected 14", rsi.Len())
|
||||
}
|
||||
if !EqualApprox(rsi.Float(0), 100) {
|
||||
t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0))
|
||||
if !EqualApprox(rsi.Value(0), 100) {
|
||||
t.Errorf("RSI[0] is %f, expected 100", rsi.Value(0))
|
||||
}
|
||||
// TODO: check the expected RSI
|
||||
// if !EqualApprox(rsi.Float(-1), 61.02423) {
|
||||
// t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1))
|
||||
// }
|
||||
if !EqualApprox(rsi.Value(-1), 63.157895) {
|
||||
t.Errorf("RSI[-1] is %f, expected 63.157895", rsi.Value(-1))
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user