Files
OpenRA/OpenRA.Game/Exts.cs
RoosterDragon b4882a8b03 Avoid some allocations in MiniYaml.Merge.
During the merge operation, it is quite common to be dealing with a node that has no child nodes. When there are no such nodes, we can return early from some functions to avoid allocating new collections that will not be used.

In the MergePartial operation, reuse a dictionary as scratch space when checking for conflicts. We introduce a IntoDictionaryWithConflictLog helper to allow this. This avoids allocating a new dictionary for the conflict log that gets thrown away at each check.
2024-09-23 16:18:00 +02:00

704 lines
18 KiB
C#

#region Copyright & License Information
/*
* Copyright (c) The OpenRA Developers and Contributors
* This file is part of OpenRA, which is free software. It is made
* available to you under the terms of the GNU General Public License
* as published by the Free Software Foundation, either version 3 of
* the License, or (at your option) any later version. For more
* information, see COPYING.
*/
#endregion
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Text;
using OpenRA.Primitives;
using OpenRA.Support;
using OpenRA.Traits;
namespace OpenRA
{
public static class Exts
{
/// <summary>Returns <see cref="Color"/> of the <paramref name="actor"/>, taking <see cref="Actor.EffectiveOwner"/> into account.</summary>
public static Color OwnerColor(this Actor actor)
{
var effectiveOwner = actor.EffectiveOwner;
if (effectiveOwner != null && effectiveOwner.Disguised && actor.World.RenderPlayer != null)
return effectiveOwner.Owner.Color;
return actor.Owner.Color;
}
public static string FormatInvariant(this string format, params object[] args)
{
return string.Format(CultureInfo.InvariantCulture, format, args);
}
public static string FormatCurrent(this string format, params object[] args)
{
return string.Format(CultureInfo.CurrentCulture, format, args);
}
public static Lazy<T> Lazy<T>(Func<T> p) { return new Lazy<T>(p); }
public static IEnumerable<string> GetNamespaces(this Assembly a)
{
return a.GetTypes().Select(t => t.Namespace).Distinct().Where(n => n != null);
}
public static bool HasAttribute<TAttribute>(this MemberInfo mi)
where TAttribute : Attribute
{
return Attribute.IsDefined(mi, typeof(TAttribute));
}
public static TAttribute[] GetCustomAttributes<TAttribute>(this MemberInfo mi, bool inherit)
where TAttribute : Attribute
{
return (TAttribute[])mi.GetCustomAttributes(typeof(TAttribute), inherit);
}
public static T Clamp<T>(this T val, T min, T max) where T : IComparable<T>
{
if (val.CompareTo(min) < 0)
return min;
else if (val.CompareTo(max) > 0)
return max;
else
return val;
}
static int WindingDirectionTest(int2 v0, int2 v1, int2 p)
{
return Math.Sign((v1.X - v0.X) * (p.Y - v0.Y) - (p.X - v0.X) * (v1.Y - v0.Y));
}
public static bool PolygonContains(this int2[] polygon, int2 p)
{
var windingNumber = 0;
for (var i = 0; i < polygon.Length; i++)
{
var tv = polygon[i];
var nv = polygon[(i + 1) % polygon.Length];
if (tv.Y <= p.Y && nv.Y > p.Y && WindingDirectionTest(tv, nv, p) > 0)
windingNumber++;
else if (tv.Y > p.Y && nv.Y <= p.Y && WindingDirectionTest(tv, nv, p) < 0)
windingNumber--;
}
return windingNumber != 0;
}
public static bool LinesIntersect(int2 a, int2 b, int2 c, int2 d)
{
// If line segments AB and CD intersect:
// - the triangles ACD and BCD must have opposite sense (clockwise or anticlockwise)
// - the triangles CAB and DAB must have opposite sense
// Segments intersect if the orientation (clockwise or anticlockwise) of the two points in each line segment are opposite with respect to the other
// Assumes that lines are not collinear
return WindingDirectionTest(c, d, a) != WindingDirectionTest(c, d, b) && WindingDirectionTest(a, b, c) != WindingDirectionTest(a, b, d);
}
public static bool HasModifier(this Modifiers k, Modifiers mod)
{
// PERF: Enum.HasFlag is slower and requires allocations.
return (k & mod) == mod;
}
public static V GetOrAdd<K, V>(this Dictionary<K, V> d, K k)
where V : new()
{
return d.GetOrAdd(k, new V());
}
public static V GetOrAdd<K, V>(this Dictionary<K, V> d, K k, V v)
{
#if NET5_0_OR_GREATER
// SAFETY: Dictionary cannot be modified whilst the ref is alive.
ref var value = ref System.Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault(d, k, out var exists);
if (!exists)
value = v;
return value;
#else
if (!d.TryGetValue(k, out var ret))
d.Add(k, ret = v);
return ret;
#endif
}
public static V GetOrAdd<K, V>(this Dictionary<K, V> d, K k, Func<K, V> createFn)
{
// Cannot use CollectionsMarshal.GetValueRefOrAddDefault here,
// the creation function could mutate the dictionary which would invalidate the ref.
if (!d.TryGetValue(k, out var ret))
d.Add(k, ret = createFn(k));
return ret;
}
public static T GetOrAdd<T>(this HashSet<T> set, T value)
{
if (!set.TryGetValue(value, out var ret))
set.Add(ret = value);
return ret;
}
public static T GetOrAdd<T>(this HashSet<T> set, T value, Func<T, T> createFn)
{
if (!set.TryGetValue(value, out var ret))
set.Add(ret = createFn(value));
return ret;
}
public static int IndexOf<T>(this T[] array, T value)
{
return Array.IndexOf(array, value);
}
public static T FirstOrDefault<T>(this T[] array, Predicate<T> match)
{
return Array.Find(array, match);
}
public static T FirstOrDefault<T>(this List<T> list, Predicate<T> match)
{
return list.Find(match);
}
public static T Random<T>(this IEnumerable<T> ts, MersenneTwister r)
{
return Random(ts, r, true);
}
public static T RandomOrDefault<T>(this IEnumerable<T> ts, MersenneTwister r)
{
return Random(ts, r, false);
}
static T Random<T>(IEnumerable<T> ts, MersenneTwister r, bool throws)
{
var xs = ts as IReadOnlyCollection<T>;
xs ??= ts.ToList();
if (xs.Count == 0)
{
if (throws)
throw new ArgumentException("Collection must not be empty.", nameof(ts));
else
return default;
}
else
return xs.ElementAt(r.Next(xs.Count));
}
public static Rectangle Union(this IEnumerable<Rectangle> rects)
{
// PERF: Avoid LINQ.
var first = true;
var result = Rectangle.Empty;
foreach (var rect in rects)
{
if (first)
{
first = false;
result = rect;
continue;
}
result = Rectangle.Union(rect, result);
}
return result;
}
public static float Product(this IEnumerable<float> xs)
{
return xs.Aggregate(1f, (a, x) => a * x);
}
public static IEnumerable<T> SymmetricDifference<T>(this IEnumerable<T> xs, IEnumerable<T> ys)
{
// this is probably a shockingly-slow way to do this, but it's concise.
return xs.Except(ys).Concat(ys.Except(xs));
}
public static IEnumerable<T> Iterate<T>(this T t, Func<T, T> f)
{
while (true)
{
yield return t;
t = f(t);
}
}
public static T MinBy<T, U>(this IEnumerable<T> ts, Func<T, U> selector)
{
return ts.CompareBy(selector, 1, true);
}
public static T MaxBy<T, U>(this IEnumerable<T> ts, Func<T, U> selector)
{
return ts.CompareBy(selector, -1, true);
}
public static T MinByOrDefault<T, U>(this IEnumerable<T> ts, Func<T, U> selector)
{
return ts.CompareBy(selector, 1, false);
}
public static T MaxByOrDefault<T, U>(this IEnumerable<T> ts, Func<T, U> selector)
{
return ts.CompareBy(selector, -1, false);
}
static T CompareBy<T, U>(this IEnumerable<T> ts, Func<T, U> selector, int modifier, bool throws)
{
var comparer = Comparer<U>.Default;
T t;
U u;
using (var e = ts.GetEnumerator())
{
if (!e.MoveNext())
if (throws)
throw new ArgumentException("Collection must not be empty.", nameof(ts));
else
return default;
t = e.Current;
u = selector(t);
while (e.MoveNext())
{
var nextT = e.Current;
var nextU = selector(nextT);
if (comparer.Compare(nextU, u) * modifier < 0)
{
t = nextT;
u = nextU;
}
}
return t;
}
}
public static int NextPowerOf2(int v)
{
--v;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
++v;
return v;
}
public static bool IsPowerOf2(int v)
{
return (v & (v - 1)) == 0;
}
public static Size NextPowerOf2(this Size s) { return new Size(NextPowerOf2(s.Width), NextPowerOf2(s.Height)); }
public enum ISqrtRoundMode { Floor, Nearest, Ceiling }
public static int ISqrt(int number, ISqrtRoundMode round = ISqrtRoundMode.Floor)
{
if (number < 0)
throw new InvalidOperationException($"Attempted to calculate the square root of a negative integer: {number}");
return (int)ISqrt((uint)number, round);
}
public static uint ISqrt(uint number, ISqrtRoundMode round = ISqrtRoundMode.Floor)
{
var divisor = 1U << 30;
var root = 0U;
var remainder = number;
// Find the highest term in the divisor
while (divisor > number)
divisor >>= 2;
// Evaluate the root, two bits at a time
while (divisor != 0)
{
if (root + divisor <= remainder)
{
remainder -= root + divisor;
root += 2 * divisor;
}
root >>= 1;
divisor >>= 2;
}
// Adjust for other rounding modes
if (round == ISqrtRoundMode.Nearest && remainder > root)
root++;
else if (round == ISqrtRoundMode.Ceiling && root * root < number)
root++;
return root;
}
public static long ISqrt(long number, ISqrtRoundMode round = ISqrtRoundMode.Floor)
{
if (number < 0)
throw new InvalidOperationException($"Attempted to calculate the square root of a negative integer: {number}");
return (long)ISqrt((ulong)number, round);
}
public static ulong ISqrt(ulong number, ISqrtRoundMode round = ISqrtRoundMode.Floor)
{
var divisor = 1UL << 62;
var root = 0UL;
var remainder = number;
// Find the highest term in the divisor
while (divisor > number)
divisor >>= 2;
// Evaluate the root, two bits at a time
while (divisor != 0)
{
if (root + divisor <= remainder)
{
remainder -= root + divisor;
root += 2 * divisor;
}
root >>= 1;
divisor >>= 2;
}
// Adjust for other rounding modes
if (round == ISqrtRoundMode.Nearest && remainder > root)
root++;
else if (round == ISqrtRoundMode.Ceiling && root * root < number)
root++;
return root;
}
public static int MultiplyBySqrtTwo(short number)
{
return number * 46341 / 32768;
}
public static int MultiplyBySqrtTwoOverTwo(int number)
{
return (int)(number * 23170L / 32768L);
}
public static int IntegerDivisionRoundingAwayFromZero(int dividend, int divisor)
{
var quotient = Math.DivRem(dividend, divisor, out var remainder);
if (remainder == 0)
return quotient;
return quotient + (Math.Sign(dividend) == Math.Sign(divisor) ? 1 : -1);
}
public static string JoinWith<T>(this IEnumerable<T> ts, string j)
{
return string.Join(j, ts);
}
public static IEnumerable<T> Append<T>(this IEnumerable<T> ts, params T[] moreTs)
{
return ts.Concat(moreTs);
}
public static IEnumerable<T> Exclude<T>(this IEnumerable<T> ts, params T[] exclusions)
{
return ts.Except(exclusions);
}
public static HashSet<T> ToHashSet<T>(this IEnumerable<T> source)
{
return new HashSet<T>(source);
}
public static Dictionary<TKey, TSource> ToDictionaryWithConflictLog<TSource, TKey>(
this IEnumerable<TSource> source, Func<TSource, TKey> keySelector,
string debugName, Func<TKey, string> logKey, Func<TSource, string> logValue)
{
return ToDictionaryWithConflictLog(source, keySelector, x => x, debugName, logKey, logValue);
}
public static Dictionary<TKey, TElement> ToDictionaryWithConflictLog<TSource, TKey, TElement>(
this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector,
string debugName, Func<TKey, string> logKey = null, Func<TElement, string> logValue = null)
{
var output = new Dictionary<TKey, TElement>();
IntoDictionaryWithConflictLog(source, keySelector, elementSelector, debugName, output, logKey, logValue);
return output;
}
public static void IntoDictionaryWithConflictLog<TSource, TKey, TElement>(
this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector,
string debugName, Dictionary<TKey, TElement> output,
Func<TKey, string> logKey = null, Func<TElement, string> logValue = null)
{
// Fall back on ToString() if null functions are provided:
logKey ??= s => s.ToString();
logValue ??= s => s.ToString();
// Try to build a dictionary and log all duplicates found (if any):
Dictionary<TKey, List<string>> dupKeys = null;
var capacity = source is ICollection<TSource> collection ? collection.Count : 0;
output.Clear();
output.EnsureCapacity(capacity);
foreach (var item in source)
{
var key = keySelector(item);
var element = elementSelector(item);
// Discard elements with null keys
if (!typeof(TKey).IsValueType && key == null)
continue;
// Check for a key conflict:
if (!output.TryAdd(key, element))
{
dupKeys ??= new Dictionary<TKey, List<string>>();
if (!dupKeys.TryGetValue(key, out var dupKeyMessages))
{
// Log the initial conflicting value already inserted:
dupKeyMessages = new List<string>
{
logValue(output[key])
};
dupKeys.Add(key, dupKeyMessages);
}
// Log this conflicting value:
dupKeyMessages.Add(logValue(element));
}
}
// If any duplicates were found, throw a descriptive error
if (dupKeys != null)
{
var badKeysFormatted = new StringBuilder(
$"{debugName}, duplicate values found for the following keys: ");
foreach (var p in dupKeys)
badKeysFormatted.Append($"{logKey(p.Key)}: [{string.Join(",", p.Value)}]");
throw new ArgumentException(badKeysFormatted.ToString());
}
}
public static Color ColorLerp(float t, Color c1, Color c2)
{
return Color.FromArgb(
(int)(t * c2.A + (1 - t) * c1.A),
(int)(t * c2.R + (1 - t) * c1.R),
(int)(t * c2.G + (1 - t) * c1.G),
(int)(t * c2.B + (1 - t) * c1.B));
}
public static T[] MakeArray<T>(int count, Func<int, T> f)
{
var result = new T[count];
for (var i = 0; i < count; i++)
result[i] = f(i);
return result;
}
public static T[,] ResizeArray<T>(T[,] ts, T t, int width, int height)
{
var result = new T[width, height];
for (var i = 0; i < width; i++)
{
for (var j = 0; j < height; j++)
{
// Workaround for broken ternary operators in certain versions of mono
// (3.10 and certain versions of the 3.8 series): https://bugzilla.xamarin.com/show_bug.cgi?id=23319
if (i <= ts.GetUpperBound(0) && j <= ts.GetUpperBound(1))
result[i, j] = ts[i, j];
else
result[i, j] = t;
}
}
return result;
}
public static int ToBits(this IEnumerable<bool> bits)
{
var i = 0;
var result = 0;
foreach (var b in bits)
if (b)
result |= 1 << i++;
else
i++;
if (i > 33)
throw new InvalidOperationException("ToBits only accepts up to 32 values.");
return result;
}
public static byte ParseByteInvariant(string s)
{
return byte.Parse(s, NumberStyles.Integer, NumberFormatInfo.InvariantInfo);
}
public static short ParseInt16Invariant(string s)
{
return short.Parse(s, NumberStyles.Integer, NumberFormatInfo.InvariantInfo);
}
public static int ParseInt32Invariant(string s)
{
return int.Parse(s, NumberStyles.Integer, NumberFormatInfo.InvariantInfo);
}
public static bool TryParseInt32Invariant(string s, out int i)
{
return int.TryParse(s, NumberStyles.Integer, NumberFormatInfo.InvariantInfo, out i);
}
public static bool TryParseInt64Invariant(string s, out long i)
{
return long.TryParse(s, NumberStyles.Integer, NumberFormatInfo.InvariantInfo, out i);
}
public static string ToStringInvariant(this byte i)
{
return i.ToString(NumberFormatInfo.InvariantInfo);
}
public static string ToStringInvariant(this byte i, string format)
{
return i.ToString(format, NumberFormatInfo.InvariantInfo);
}
public static string ToStringInvariant(this int i)
{
return i.ToString(NumberFormatInfo.InvariantInfo);
}
public static string ToStringInvariant(this int i, string format)
{
return i.ToString(format, NumberFormatInfo.InvariantInfo);
}
public static bool IsTraitEnabled<T>(this T trait)
{
return trait is not IDisabledTrait disabledTrait || !disabledTrait.IsTraitDisabled;
}
public static T FirstEnabledTraitOrDefault<T>(this IEnumerable<T> ts)
{
// PERF: Avoid LINQ.
foreach (var t in ts)
if (t.IsTraitEnabled())
return t;
return default;
}
public static T FirstEnabledTraitOrDefault<T>(this T[] ts)
{
// PERF: Avoid LINQ.
foreach (var t in ts)
if (t.IsTraitEnabled())
return t;
return default;
}
public static T FirstEnabledConditionalTraitOrDefault<T>(this IEnumerable<T> ts) where T : IDisabledTrait
{
// PERF: Avoid LINQ.
foreach (var t in ts)
if (!t.IsTraitDisabled)
return t;
return default;
}
public static T FirstEnabledConditionalTraitOrDefault<T>(this T[] ts) where T : IDisabledTrait
{
// PERF: Avoid LINQ.
foreach (var t in ts)
if (!t.IsTraitDisabled)
return t;
return default;
}
public static LineSplitEnumerator SplitLines(this string str, char separator)
{
return new LineSplitEnumerator(str.AsSpan(), separator);
}
}
public ref struct LineSplitEnumerator
{
ReadOnlySpan<char> str;
readonly char separator;
public LineSplitEnumerator(ReadOnlySpan<char> str, char separator)
{
this.str = str;
this.separator = separator;
Current = default;
}
public readonly LineSplitEnumerator GetEnumerator() => this;
public bool MoveNext()
{
var span = str;
// Reach the end of the string
if (span.Length == 0)
return false;
var index = span.IndexOf(separator);
if (index == -1)
{
// The remaining string is an empty string
str = ReadOnlySpan<char>.Empty;
Current = span;
return true;
}
Current = span[..index];
str = span[(index + 1)..];
return true;
}
public ReadOnlySpan<char> Current { get; private set; }
}
public static class Enum<T>
{
public static T Parse(string s) { return (T)Enum.Parse(typeof(T), s); }
public static T[] GetValues() { return (T[])Enum.GetValues(typeof(T)); }
public static bool TryParse(string s, bool ignoreCase, out T value)
{
// The string may be a comma delimited list of values
var names = ignoreCase ? Enum.GetNames(typeof(T)).Select(x => x.ToLowerInvariant()) : Enum.GetNames(typeof(T));
var values = ignoreCase ? s.Split(',').Select(x => x.Trim().ToLowerInvariant()) : s.Split(',').Select(x => x.Trim());
if (values.Any(x => !names.Contains(x)))
{
value = default;
return false;
}
value = (T)Enum.Parse(typeof(T), s, ignoreCase);
return true;
}
}
}