Newer
Older
CGTrack / Assets / Scripts / Bezier.cs
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

/// <summary>
/// Calculate points on a bezier curve.
/// </summary>
/// <see href="https://pomax.github.io/bezierinfo"/>
public class Bezier
{
        private Vector3[] w;
        private Bezier _derivate;

        public Bezier(Vector3[] w)
        {
                this.w = w;
        }
        
        public Bezier(Vector3[] w, Vector3 end)
        {
                var _w = w.ToList();
                _w.Insert(0, Vector3.zero);
                _w.Add(end);
                this.w = _w.ToArray();
        }

        public Bezier GetDerivate()
        {
                if (_derivate == null)
                {
                        var length = w.Length;
                        if (length <= 1)
                                return this;

                        var newW = new Vector3[--length];
                        for (var i = 0; i < length; i++)
                        {
                                newW[i] = length * (w[i + 1] - w[i]);
                        }

                        _derivate = new Bezier(newW);
                }

                return _derivate;
        }

        public void CalculateFrenet(float t, float rotate, out Vector3 up, out Vector3 right, out Vector3 forward)
        {
                if (w.Length <= 2)
                {
                        forward = Vector3.back;
                        right = Vector3.left;
                        up = Vector3.up;
                } else {
                        var a = GetDerivate().DeCasteljau(t).normalized;
                        var b = (a + GetDerivate().GetDerivate().DeCasteljau(t)).normalized;
                        up = Vector3.Cross(b, a).normalized;
                        right = Vector3.Cross(up, a).normalized;
                        forward = Vector3.Cross(right, up).normalized;

                }
                
                up = Quaternion.AngleAxis(rotate, forward) * up;
                right = Quaternion.AngleAxis(rotate, forward) * right;
        }
        
        public Vector3 DeCasteljau(float t)
        {
                return DeCasteljau(w, t);
        }
        
        public static Vector3 DeCasteljau(Vector3[] w, Vector3 end, float t)
        {
                var curve = new Bezier(w, end);

                return curve.DeCasteljau(t);
        }
        public static Vector3 DeCasteljau(Vector3[] w, float t)
        {
                var n = w.Length;
                switch (n)
                {
                        case 1:
                                return w[0];
                        case 2:
                                return Vector3.Lerp(w[0], w[1], t);
                }

                var newW = new Vector3[--n];
                for (var i = 0; i < n; i++)
                {
                        newW[i] = (1f - t) * w[i] + t * w[i + 1];
                }

                return DeCasteljau(newW, t);
        }

        public float ApproximateLength(float error)
        {
                var diff = float.MaxValue;
                var i = 0;
                var length = 0f;
                while (diff > error)
                {
                        var nextLength = CalculateLength(SampleCurve(i++));
                        diff = Mathf.Abs(length - nextLength);
                        length = nextLength;
                }

                return length;
        }

        public static float CalculateLength(Vector3[] w)
        {
                var sum = 0f;
                if (w.Length < 2)
                        return sum;
                
                for (var i = 1; i < w.Length; i++)
                {
                        sum += Vector3.Distance(w[i - 1], w[i]);
                }

                return sum;
        }

        public static Vector3[] SampleCurve(Vector3[] w, Vector3 end, int sampleCount)
        {
                var curve = new Bezier(w, end);

                return curve.SampleCurve(sampleCount);
        }

        public Vector3[] SampleCurve(int sampleCount)
        {
                if (sampleCount < 1)
                        return w;

                var coords = new List<Vector3>();
                for (int i = 0; i <= sampleCount; i++)
                {
                        coords.Add(DeCasteljau(w, i / (float) sampleCount));
                }
                return coords.ToArray();
        }
        
        
        
        
        public static Vector3 bezier(float t, Vector3[] w)
        {
                switch (w.Length)
                {
                        case 3:
                                return bezier2(t, w);
                        case 4:
                                return bezier3(t, w);
                        default:
                                var sum = Vector3.zero;
                                var n = w.Length - 1;
                                for (var k = 0; k < n; k++)
                                {
                                        sum += w[k] * nchoosek(n, k) * Mathf.Pow(1 - t, n - k) * Mathf.Pow(t, k);
                                }

                                return sum;
                }
                return Vector3.zero;
        }

        private static Vector3 bezier2(float t, Vector3[] w)
        {
                var t2 = t * t;
                var mt = 1 - t;
                var mt2 = mt * mt;

                return w[0] * mt2 + w[1] * 2 * mt * t + w[2] * t2;
        }
        
        private static Vector3 bezier3(float t, Vector3[] w)
        {
                var t2 = t * t;
                var t3 = t2 * t;
                var mt = 1-t;
                var mt2 = mt * mt;
                var mt3 = mt2 * mt;

                return w[0] * mt3 + 3 * w[1] * mt2 * t + 3 * w[2] * mt * t2 + w[3] * t3;
        }
        
        /// <summary>
        /// Calculate the binomial coefficient.
        /// </summary>
        /// <seealso href="http://csharphelper.com/blog/2014/08/calculate-the-binomial-coefficient-n-choose-k-efficiently-in-c/"/>
        /// <param name="n">n</param>
        /// <param name="k">k</param>
        /// <returns></returns>
        public static int nchoosek(int n, int k)
        {
                decimal result = 1;
                for (int i = 1; i <= k; i++)
                {
                        result *= n - (k - i);
                        result /= i;
                }
                return (int) result;
        }
}