diff --git a/fixed.go b/fixed.go index ff26b03..e791a0e 100644 --- a/fixed.go +++ b/fixed.go @@ -244,6 +244,39 @@ func sign(fp int64) int64 { return 1 } +func (f Fixed) Sqrt() Fixed { + if f.IsNaN() || f.LessThan(ZERO) { + return NaN + } + if f.Equal(ZERO) { + return ZERO + } + + // Use Newton's method. Use an epsilon of 0.0000001 to prevent issues with + // rounding. On each iteration, next guess is the average of f/f0 and f0 + // (where f0 is the last guess). + f0 := NewI(1, 0) + eps := NewI(1, 7) + TWO := NewI(2, 0) + + // Calculate the first guess using f/f0/2 + f0/2 to avoid overflowing + f0 = f.Div(f0).Div(TWO).Add(f0.Div(TWO)) + + // Bail out if we get NaN anyways (otherwise it loops forever) + if f0.IsNaN() { + return NaN + } + + // Keep iterating until it converges + for { + f1 := f.Div(f0).Add(f0).Div(TWO) + if f0.Sub(f1).Abs().LessThanOrEqual(eps) { + return f1 + } + f0 = f1 + } +} + // Round returns a rounded (half-up, away from zero) to n decimal places func (f Fixed) Round(n int) Fixed { if f.IsNaN() { diff --git a/fixed_test.go b/fixed_test.go index 4c3df57..7e56427 100644 --- a/fixed_test.go +++ b/fixed_test.go @@ -367,6 +367,92 @@ func TestMulDiv(t *testing.T) { } +func TestSqrt(t *testing.T) { + f := NewS("0") + eps := NewI(1, 7) + if f.Sqrt().Sub(NewS("0")).Abs().GreaterThan(eps) { + t.Error("Sqrt(0) should be 0, got", f.Sqrt()) + } + + f = NewS("1") + if f.Sqrt().Sub(NewS("1")).Abs().GreaterThan(eps) { + t.Error("Sqrt(1) should be 1, got", f.Sqrt()) + } + + f = NewS("4") + if f.Sqrt().Sub(NewS("2")).Abs().GreaterThan(eps) { + t.Error("Sqrt(4) should be 2, got", f.Sqrt()) + } + + f = NewS("9") + if f.Sqrt().Sub(NewS("3")).Abs().GreaterThan(eps) { + t.Error("Sqrt(9) should be 3, got", f.Sqrt()) + } + + f = NewS("16") + if f.Sqrt().Sub(NewS("4")).Abs().GreaterThan(eps) { + t.Error("Sqrt(16) should be 4, got", f.Sqrt()) + } + + f = NewS("2") + if f.Sqrt().Sub(NewS("1.4142136")).Abs().GreaterThan(eps) { + t.Error("Sqrt(2) should be 1.4142136, got", f.Sqrt()) + } + + f = NewS("0.25") + if f.Sqrt().Sub(NewS("0.5")).Abs().GreaterThan(eps) { + t.Error("Sqrt(0.25) should be 0.5, got", f.Sqrt()) + } + + f = NewS("0.0625") + if f.Sqrt().Sub(NewS("0.25")).Abs().GreaterThan(eps) { + t.Error("Sqrt(0.0625) should be 0.25, got", f.Sqrt()) + } + + f = NewS("100") + if f.Sqrt().Sub(NewS("10")).Abs().GreaterThan(eps) { + t.Error("Sqrt(100) should be 10, got", f.Sqrt()) + } + + f = NewS("99999999999.9999900") + if f.Sqrt().Sub(NewS("316227.766017")).Abs().GreaterThan(eps) { + t.Error("Sqrt(99999999999.9999900) should be 316227.766017, got", f.Sqrt()) + } + + f = NewS("0.0000001") + if f.Sqrt().Sub(NewS("0.0003162")).Abs().GreaterThan(eps) { + t.Error("Sqrt(0.0000001) should be 0.0003162, got", f.Sqrt()) + } + + f = NewS("123.456") + if f.Sqrt().Sub(NewS("11.1110756")).Abs().GreaterThan(eps) { + t.Error("Sqrt(123.456) should be 11.1110756, got", f.Sqrt()) + } + + f = NewS("NaN") + if !f.Sqrt().IsNaN() { + t.Error("Sqrt(NaN) should be NaN, got", f.Sqrt()) + } + + f = NewS("-1") + if !f.Sqrt().IsNaN() { + t.Error("Sqrt(-1) should be NaN, got", f.Sqrt()) + } + + f = NewS("-123.456") + if !f.Sqrt().IsNaN() { + t.Error("Sqrt(-123.456) should be NaN, got", f.Sqrt()) + } + + f = NewS("2") + sqrt := f.Sqrt() + squared := sqrt.Mul(sqrt) + diff := squared.Sub(f).Abs() + if diff.GreaterThan(eps) { + t.Error("Sqrt(2) accuracy check failed: sqrt^2 =", squared, "diff from 2 =", diff) + } +} + func TestNegatives(t *testing.T) { f0 := NewS("99") f1 := NewS("100")