diff --git a/OpenRA.FileFormats/Verifier.cs b/OpenRA.FileFormats/Verifier.cs index 532e4d95ed..0300bad1d6 100644 --- a/OpenRA.FileFormats/Verifier.cs +++ b/OpenRA.FileFormats/Verifier.cs @@ -18,13 +18,169 @@ */ #endregion +using System; +using System.Reflection; +using System.Collections.Generic; +using System.Reflection.Emit; +using System.Linq; + namespace OpenRA.FileFormats { public static class Verifier { - public static bool IsSafe(string filename) + static readonly string[] AllowedPatterns = { + /* todo */ + }; + + public static bool IsSafe(string filename, List failures) { - return false; // todo + AppDomain.CurrentDomain.ReflectionOnlyAssemblyResolve += (s, a) => Assembly.ReflectionOnlyLoad(a.Name); + var flags = BindingFlags.Instance | BindingFlags.Static | + BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly; + + var assembly = Assembly.ReflectionOnlyLoadFrom(filename); + + var pinvokes = assembly.GetTypes() + .SelectMany(t => t.GetMethods(flags)) + .Where(m => (m.Attributes & MethodAttributes.PinvokeImpl) != 0) + .Select(m => m.Name).ToArray(); + + foreach (var pi in pinvokes) + failures.Add("P/Invoke: {0}".F(pi)); + + foreach (var fn in assembly + .GetTypes() + .SelectMany(x => x.GetMembers(flags)) + .SelectMany(x => FunctionsUsedBy(x)) + .Where(x => x.DeclaringType.Assembly != assembly) + .Select(x => string.Format("{0}:{1}", x.DeclaringType.FullName, x)) + .OrderBy(x => x) + .Distinct()) + if (!IsAllowed(fn)) + failures.Add("Unsafe function: {0}".F(fn)); + + return failures.Count > 0; + } + + static bool IsAllowed(string fn) + { + foreach (var p in AllowedPatterns) + if (p.EndsWith("*")) + { + if (fn.StartsWith(p.Substring(0, p.Length - 1))) return true; + } + else + { + if (fn == p) return true; + } + + return false; + } + + static IEnumerable FunctionsUsedBy( MemberInfo x ) + { + if( x is MethodInfo ) + { + var method = x as MethodInfo; + if (method.GetMethodBody() != null) + foreach( var fn in CheckIL( method.GetMethodBody().GetILAsByteArray(), x.Module, x.DeclaringType.GetGenericArguments(), method.GetGenericArguments() ) ) + yield return fn; + } + else if( x is ConstructorInfo ) + { + var method = x as ConstructorInfo; + if (method.GetMethodBody() != null) + foreach( var fn in CheckIL( method.GetMethodBody().GetILAsByteArray(), x.Module, x.DeclaringType.GetGenericArguments(), new Type[ 0 ] ) ) + yield return fn; + } + else if( x is FieldInfo ) + { + // ignore it + } + else if( x is PropertyInfo ) + { + var prop = x as PropertyInfo; + foreach( var method in prop.GetAccessors() ) + if (method.GetMethodBody() != null) + foreach( var fn in CheckIL( method.GetMethodBody().GetILAsByteArray(), x.Module, x.DeclaringType.GetGenericArguments(), new Type[ 0 ] ) ) + yield return fn; + } + else if( x is Type ) + { + // ... shouldn't happen, but does.... :O + } + else + throw new NotImplementedException(); + } + + static IEnumerable CheckIL( byte[] p, Module a, Type[] classGenerics, Type[] functionGenerics ) + { + var position = 0; + var ret = new List(); + while( position < p.Length ) + { + var opcode = OpCodeMap.GetOpCode( p, position ); + position += opcode.Size; + if( opcode.OperandType == OperandType.InlineMethod ) + ret.Add( BitConverter.ToInt32( p, position )); + position += OperandSize( opcode, p, position ); + } + return ret.Select( t => a.ResolveMethod( t, classGenerics, functionGenerics ) ); + } + + static int OperandSize( OpCode opcode, byte[] p, int position ) + { + switch( opcode.OperandType ) + { + case OperandType.InlineNone: + return 0; + case OperandType.InlineMethod: + case OperandType.InlineField: + case OperandType.InlineType: + case OperandType.InlineTok: + case OperandType.InlineString: + case OperandType.InlineI: + case OperandType.InlineBrTarget: + return 4; + case OperandType.ShortInlineBrTarget: + case OperandType.ShortInlineI: + case OperandType.ShortInlineVar: + return 1; + case OperandType.InlineSwitch: + var numSwitchArgs = BitConverter.ToUInt32( p, position ); + return (int)( 4 + 4 * numSwitchArgs ); + case OperandType.ShortInlineR: + return 4; + default: + throw new NotImplementedException("Unsupported: {0}".F(opcode.OperandType)); + } + } + } + + static class OpCodeMap + { + static readonly Dictionary simpleOps = new Dictionary(); + static readonly Dictionary feOps = new Dictionary(); + + static OpCodeMap() + { + foreach( var o in typeof( OpCodes ).GetFields( BindingFlags.Static | BindingFlags.Public ).Select( f => (OpCode)f.GetValue( null ) ) ) + { + if( o.Size == 1 ) + simpleOps.Add( (byte)o.Value, o ); + else if( o.Size == 2 ) + feOps.Add( (byte)( o.Value & 0xFF ), o ); + else + throw new NotImplementedException(); + } + } + + public static OpCode GetOpCode( byte[] input, int position ) + { + if( input[ position ] != 0xFE ) + return simpleOps[ input[ position ] ]; + else + return feOps[ input[ position + 1 ] ]; } } } diff --git a/OpenRA.Game/Game.cs b/OpenRA.Game/Game.cs index c62296c236..1a0a288bc3 100644 --- a/OpenRA.Game/Game.cs +++ b/OpenRA.Game/Game.cs @@ -31,6 +31,7 @@ using OpenRA.Network; using OpenRA.Support; using OpenRA.Traits; using Timer = OpenRA.Support.Timer; +using System.Collections.Generic; namespace OpenRA { @@ -75,10 +76,19 @@ namespace OpenRA // Mod assemblies assumed to contain a single namespace foreach (var a in m.Assemblies) - if (Verifier.IsSafe( Path.GetFullPath(a))) + { + var failures = new List(); + if (Verifier.IsSafe(Path.GetFullPath(a), failures)) asms.Add(Pair.New( - Assembly.LoadFile(Path.GetFullPath(a)), + Assembly.LoadFile(Path.GetFullPath(a)), Path.GetFileNameWithoutExtension(a))); + else + { + Log.Write("Assembly `{0}` cannot be verified. Failures:", a); + foreach (var f in failures) + Log.Write("\t{0}", f); + } + } ModAssemblies = asms.ToArray(); }