diff --git a/OpenRA.Game/Support/ConditionExpression.cs b/OpenRA.Game/Support/ConditionExpression.cs index aafb972cce..260bd42d3e 100644 --- a/OpenRA.Game/Support/ConditionExpression.cs +++ b/OpenRA.Game/Support/ConditionExpression.cs @@ -13,6 +13,8 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Linq.Expressions; +using Expressions = System.Linq.Expressions; namespace OpenRA.Support { @@ -22,7 +24,7 @@ namespace OpenRA.Support readonly HashSet variables = new HashSet(); public IEnumerable Variables { get { return variables; } } - readonly Token[] postfix; + readonly Func, int> asFunction; enum CharClass { Whitespace, Operator, Mixed, Id, Digit } @@ -447,30 +449,16 @@ namespace OpenRA.Support if (currentOpeners.Count > 0) throw new InvalidDataException("Unclosed opening parenthesis at index {0}".F(currentOpeners.Peek().Index)); - // Convert to postfix (discarding parentheses) ready for evaluation - postfix = ToPostfix(tokens).ToArray(); + asFunction = new Compiler().Compile(ToPostfix(tokens).ToArray()); } - static int ParseSymbol(VariableToken t, IReadOnlyDictionary symbols) + static int ParseSymbol(string symbol, IReadOnlyDictionary symbols) { int value; - symbols.TryGetValue(t.Symbol, out value); + symbols.TryGetValue(symbol, out value); return value; } - static void ApplyBinaryOperation(Stack s, Func f) - { - var x = s.Pop(); - var y = s.Pop(); - s.Push(f(x, y)); - } - - static void ApplyUnaryOperation(Stack s, Func f) - { - var x = s.Pop(); - s.Push(f(x)); - } - static IEnumerable ToPostfix(IEnumerable tokens) { var s = new Stack(); @@ -500,41 +488,169 @@ namespace OpenRA.Support yield return s.Pop(); } - public int Evaluate(IReadOnlyDictionary symbols) + enum ExpressionType { Int, Bool } + + static readonly ParameterExpression SymbolsParam = + Expressions.Expression.Parameter(typeof(IReadOnlyDictionary), "symbols"); + static readonly ConstantExpression Zero = Expressions.Expression.Constant(0); + static readonly ConstantExpression One = Expressions.Expression.Constant(1); + static readonly ConstantExpression False = Expressions.Expression.Constant(false); + static readonly ConstantExpression True = Expressions.Expression.Constant(true); + + static Expression AsBool(Expression expression) { - var s = new Stack(); - foreach (var t in postfix) + return Expressions.Expression.GreaterThan(expression, Zero); + } + + static Expression AsNegBool(Expression expression) + { + return Expressions.Expression.LessThanOrEqual(expression, Zero); + } + + static Expression IfThenElse(Expression test, Expression ifTrue, Expression ifFalse) + { + return Expressions.Expression.Condition(test, ifTrue, ifFalse); + } + + class AstStack + { + readonly List expressions = new List(); + readonly List types = new List(); + + public ExpressionType PeekType() { return types[types.Count - 1]; } + + public Expression Peek(ExpressionType toType) { - switch (t.Type) + var fromType = types[types.Count - 1]; + var expression = expressions[expressions.Count - 1]; + if (toType == fromType) + return expression; + + switch (toType) { - case TokenType.And: - ApplyBinaryOperation(s, (x, y) => y > 0 ? x : y); - continue; - case TokenType.NotEquals: - ApplyBinaryOperation(s, (x, y) => (y != x) ? 1 : 0); - continue; - case TokenType.Or: - ApplyBinaryOperation(s, (x, y) => y > 0 ? y : x); - continue; - case TokenType.Equals: - ApplyBinaryOperation(s, (x, y) => (y == x) ? 1 : 0); - continue; - case TokenType.Not: - ApplyUnaryOperation(s, x => (x > 0) ? 0 : 1); - continue; - case TokenType.Number: - s.Push(((NumberToken)t).Value); - continue; - case TokenType.Variable: - s.Push(ParseSymbol((VariableToken)t, symbols)); - continue; - default: - throw new InvalidProgramException("Evaluate is missing an evaluator for TokenType.{0}".F( - Enum.GetValues()[(int)t.Type])); + case ExpressionType.Bool: + return IfThenElse(AsBool(expression), True, False); + case ExpressionType.Int: + return IfThenElse(expression, One, Zero); } + + throw new InvalidProgramException("Unable to convert ExpressionType.{0} to ExpressionType.{1}".F( + Enum.GetValues()[(int)fromType], Enum.GetValues()[(int)toType])); } - return s.Pop(); + public Expression Pop(ExpressionType type) + { + var expression = Peek(type); + expressions.RemoveAt(expressions.Count - 1); + types.RemoveAt(types.Count - 1); + return expression; + } + + public void Push(Expression expression, ExpressionType type) + { + expressions.Add(expression); + if (type == ExpressionType.Int) + if (expression.Type != typeof(int)) + throw new InvalidOperationException("Expected System.Int type instead of {0} for {1}".F(expression.Type, expression)); + + if (type == ExpressionType.Bool) + if (expression.Type != typeof(bool)) + throw new InvalidOperationException("Expected System.Boolean type instead of {0} for {1}".F(expression.Type, expression)); + types.Add(type); + } + + public void Push(Expression expression) + { + expressions.Add(expression); + if (expression.Type == typeof(int)) + types.Add(ExpressionType.Int); + else if (expression.Type == typeof(bool)) + types.Add(ExpressionType.Bool); + else + throw new InvalidOperationException("Unhandled result type {0} for {1}".F(expression.Type, expression)); + } + } + + class Compiler + { + readonly AstStack ast = new AstStack(); + + public Func, int> Compile(Token[] postfix) + { + foreach (var t in postfix) + { + switch (t.Type) + { + case TokenType.And: + { + var y = ast.Pop(ExpressionType.Bool); + var x = ast.Pop(ExpressionType.Bool); + ast.Push(Expressions.Expression.And(x, y)); + continue; + } + + case TokenType.Or: + { + var y = ast.Pop(ExpressionType.Bool); + var x = ast.Pop(ExpressionType.Bool); + ast.Push(Expressions.Expression.Or(x, y)); + continue; + } + + case TokenType.NotEquals: + { + var y = ast.Pop(ExpressionType.Int); + var x = ast.Pop(ExpressionType.Int); + ast.Push(Expressions.Expression.NotEqual(x, y)); + continue; + } + + case TokenType.Equals: + { + var y = ast.Pop(ExpressionType.Int); + var x = ast.Pop(ExpressionType.Int); + ast.Push(Expressions.Expression.Equal(x, y)); + continue; + } + + case TokenType.Not: + { + if (ast.PeekType() == ExpressionType.Bool) + ast.Push(Expressions.Expression.Not(ast.Pop(ExpressionType.Bool))); + else + ast.Push(AsNegBool(ast.Pop(ExpressionType.Int))); + continue; + } + + case TokenType.Number: + { + ast.Push(Expressions.Expression.Constant(((NumberToken)t).Value)); + continue; + } + + case TokenType.Variable: + { + var symbol = Expressions.Expression.Constant(((VariableToken)t).Symbol); + Func, int> parseSymbol = ParseSymbol; + ast.Push(Expressions.Expression.Call(parseSymbol.Method, symbol, SymbolsParam)); + continue; + } + + default: + throw new InvalidProgramException( + "ConditionExpression.Compiler.Compile() is missing an expression builder for TokenType.{0}".F( + Enum.GetValues()[(int)t.Type])); + } + } + + return Expressions.Expression.Lambda, int>>( + ast.Pop(ExpressionType.Int), SymbolsParam).Compile(); + } + } + + public int Evaluate(IReadOnlyDictionary symbols) + { + return asFunction(symbols); } } } diff --git a/OpenRA.Test/OpenRA.Game/ConditionExpressionTest.cs b/OpenRA.Test/OpenRA.Game/ConditionExpressionTest.cs index 0eb9118fbd..71394a1050 100644 --- a/OpenRA.Test/OpenRA.Game/ConditionExpressionTest.cs +++ b/OpenRA.Test/OpenRA.Game/ConditionExpressionTest.cs @@ -66,6 +66,13 @@ namespace OpenRA.Test AssertValue("-12", -12); } + [TestCase(TestName = "Booleans")] + public void TestBooleans() + { + AssertValue("false", 0); + AssertValue("true", 1); + } + [TestCase(TestName = "AND operation")] public void TestAnd() { @@ -75,8 +82,8 @@ namespace OpenRA.Test AssertFalse("false && true"); AssertValue("2 && false", 0); AssertValue("false && 2", 0); - AssertValue("3 && 2", 2); - AssertValue("2 && 3", 3); + AssertValue("3 && 2", 1); + AssertValue("2 && 3", 1); } [TestCase(TestName = "OR operation")] @@ -86,10 +93,10 @@ namespace OpenRA.Test AssertFalse("false || false"); AssertTrue("true || false"); AssertTrue("false || true"); - AssertValue("2 || false", 2); - AssertValue("false || 2", 2); - AssertValue("3 || 2", 3); - AssertValue("2 || 3", 2); + AssertValue("2 || false", 1); + AssertValue("false || 2", 1); + AssertValue("3 || 2", 1); + AssertValue("2 || 3", 1); } [TestCase(TestName = "Equals operation")]