Index: src/examples/groovyShell/ArithmeticShell.groovy =================================================================== --- src/examples/groovyShell/ArithmeticShell.groovy (revision 0) +++ src/examples/groovyShell/ArithmeticShell.groovy (revision 0) @@ -0,0 +1,448 @@ + +import java.security.CodeSource +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.GroovyCodeVisitor +import org.codehaus.groovy.ast.ModuleNode +import org.codehaus.groovy.ast.expr.BinaryExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.classgen.GeneratorContext +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilationUnit.PrimaryClassNodeOperation +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.Phases +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.syntax.Types +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.expr.ArgumentListExpression +import org.codehaus.groovy.ast.expr.PropertyExpression +import org.codehaus.groovy.ast.expr.UnaryMinusExpression +import org.codehaus.groovy.ast.expr.UnaryPlusExpression +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.WhileStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.IfStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.AssertStatement +import org.codehaus.groovy.ast.stmt.TryCatchStatement +import org.codehaus.groovy.ast.stmt.SwitchStatement +import org.codehaus.groovy.ast.stmt.CaseStatement +import org.codehaus.groovy.ast.stmt.BreakStatement +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.ThrowStatement +import org.codehaus.groovy.ast.stmt.SynchronizedStatement +import org.codehaus.groovy.ast.stmt.CatchStatement +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.ConstructorCallExpression +import org.codehaus.groovy.ast.expr.TernaryExpression +import org.codehaus.groovy.ast.expr.ElvisOperatorExpression +import org.codehaus.groovy.ast.expr.PrefixExpression +import org.codehaus.groovy.ast.expr.PostfixExpression +import org.codehaus.groovy.ast.expr.BooleanExpression +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.TupleExpression +import org.codehaus.groovy.ast.expr.MapExpression +import org.codehaus.groovy.ast.expr.MapEntryExpression +import org.codehaus.groovy.ast.expr.ListExpression +import org.codehaus.groovy.ast.expr.RangeExpression +import org.codehaus.groovy.ast.expr.AttributeExpression +import org.codehaus.groovy.ast.expr.FieldExpression +import org.codehaus.groovy.ast.expr.MethodPointerExpression +import org.codehaus.groovy.ast.expr.ClassExpression +import org.codehaus.groovy.ast.expr.DeclarationExpression +import org.codehaus.groovy.ast.expr.RegexExpression +import org.codehaus.groovy.ast.expr.GStringExpression +import org.codehaus.groovy.ast.expr.ArrayExpression +import org.codehaus.groovy.ast.expr.SpreadExpression +import org.codehaus.groovy.ast.expr.SpreadMapExpression +import org.codehaus.groovy.ast.expr.NotExpression +import org.codehaus.groovy.ast.expr.BitwiseNegationExpression +import org.codehaus.groovy.ast.expr.CastExpression +import org.codehaus.groovy.ast.expr.ClosureListExpression +import org.codehaus.groovy.classgen.BytecodeExpression +import org.codehaus.groovy.control.MultipleCompilationErrorsException +import org.codehaus.groovy.control.messages.ExceptionMessage + +/** +* The arithmetic shell is similar to a GroovyShell in that it can evaluate text as +* code and return a result. It is not a subclass of GroovyShell because it does not widen +* the contract of GroovyShell, instead it narrows it. Using one of these shells like a +* GroovyShell would result in many runtime errors. +* +* @author Hamlet D'Arcy (hamletdrc@gmail.com) +*/ +public class ArithmeticShell { + + /** + * Compiles the text into a Groovy object and then executes it, returning the result. + * @param text + * the script to evaluate typed as a string + * @throws SecurityException + * most likely the script is doing something other that arithmetic + * @throws IllegalStateException + * if the script returns something other than a number + */ + public Number evaluate(String text) { + try { + SecureArithmeticClassLoader loader = new SecureArithmeticClassLoader() + Class clazz = loader.parseClass(addStaticImports(text)) + Script script = (Script)clazz.newInstance(); + Object result = script.run() + if (!(result instanceof Number)) throw new IllegalStateException("Script returned a non-number: $result"); + return (Number)result + } catch (SecurityException ex) { + throw new SecurityException("Could not evaluate script: $text", ex) + } catch (MultipleCompilationErrorsException mce) { + //this allows compilation errors to be seen by the user + mce.errorCollector.errors.each { + if (it instanceof ExceptionMessage && it.cause instanceof SecurityException) { + throw it.cause + } + } + throw mce + } + } + + /** + * Updates a script text with any desired static imports. + */ + private def addStaticImports(String text) { + return "import static java.lang.Math.*\n" + text + } +} + + +/** +* This classloader hooks the security enforcer into the compilation process. +*/ +class SecureArithmeticClassLoader extends GroovyClassLoader { + + protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource codeSource) { + + CompilationUnit cu = super.createCompilationUnit(config, codeSource) + // wiring into the SEMANTIC_ANALYSIS phase will provide more type information + // that the CONVERSION phase. + cu.addPhaseOperation(new SecurityFilteringNodeOperation(), Phases.SEMANTIC_ANALYSIS) + return cu + } +} + +/** + * This operation will force only arithmetic operations to be compiled. + */ +private class SecurityFilteringNodeOperation extends PrimaryClassNodeOperation { + + public void call(SourceUnit source, GeneratorContext context, ClassNode classNode) { + + ModuleNode ast = source.getAST() + + if (ast.getImportPackages()) {throw new SecurityException("Package import statements are not allowed.")} + if (ast.getImports()) {throw new SecurityException("Import statements are not allowed.")} + if (ast.getStaticImportAliases()) {throw new SecurityException("Static import aliases are not allowed.")} + if (ast.getStaticImportFields()) {throw new SecurityException("Static field import statements are not allowed.")} + + //do not allow package names + if (source.getAST().getPackageName()) {throw new SecurityException("Package names are not allowed.")} + + //do not allow method definitions + if (ast.getMethods()) {throw new SecurityException("Method definition is not allowed.")} + + //enforce arithmetic only expressions + ast.getStatementBlock().visit(new ArithmeticExpressionEnforcer()) + } +} + +/** +* This code visitor throws a SecurityException if anything but an arithmetic expression is found. +* Normally, it would be easier to extend CodeVisitorSupport because that provides all the base +* methods to perform visits on the syntax tree and would make upgrading to newer versions of +* Groovy easier. However, that would mean that any new syntax in Groovy would be supported by +* this shell by default which is undesireable in this case. For instance, if a new metaprogramming +* trick gets introduced, this shell should _not_ allow it to be accessed without considerationg +* from the developer. +*/ +private class ArithmeticExpressionEnforcer implements GroovyCodeVisitor { + + private static final allowedTokens = [ + Types.PLUS, + Types.MINUS, + Types.MULTIPLY, + Types.DIVIDE, + Types.MOD, + Types.POWER, + Types.PLUS_PLUS, + Types.MINUS_MINUS, + Types.COMPARE_EQUAL, + Types.COMPARE_NOT_EQUAL, + Types.COMPARE_LESS_THAN, + Types.COMPARE_LESS_THAN_EQUAL, + Types.COMPARE_GREATER_THAN, + Types.COMPARE_GREATER_THAN_EQUAL, + ].asImmutable() + + private static final allowedConstantTypes = [ + Integer, + Float, + Long, + Double, + BigDecimal + ].asImmutable() + + private static final allowedReceivers = [ + Math, + Integer, + Float, + Double, + Long, + BigDecimal + ].asImmutable() + + private static final allowedStaticImports = [ + Math, + ].asImmutable() + + /** + * Block statements are allowed and traversal continues. + */ + public void visitBlockStatement(BlockStatement statement) { + //keep walking... + statement.getStatements().each { ASTNode child -> + child.visit(this) + } + } + + /** + * Expression statements must continue traversal. + */ + public void visitExpressionStatement(ExpressionStatement statement) { + Expression exp = statement.getExpression() + exp.visit(this) //keep walking... + } + + /** + * Binary expressions must have a numeric token (+, -, /, etc) and continue traversal. + */ + public void visitBinaryExpression(BinaryExpression expression) { + if (!allowedTokens.contains(expression.getOperation().getType())) { + throw new SecurityException("Unsupported token: ${expression.getOperation().getText() }") + } + expression.getLeftExpression().visit(this) + expression.getRightExpression().visit(this) + } + + /** + * Constants may be of only numeric core types. + */ + public void visitConstantExpression(ConstantExpression expression) { + Object value = expression.getValue() + if (!(allowedConstantTypes.contains(value.getClass()))) { + throw new SecurityException("""Unsupported constant type: ${ value.getClass() }, value: $value""") + } + } + + /** + * Method calls may only be invoked on a few core types + */ + public void visitMethodCallExpression(MethodCallExpression expression) { + Expression receiver = expression.getObjectExpression() + if (!(receiver instanceof ClassExpression)) { + throw new SecurityException("Unsupported method call: $receiver") + } + if (!allowedReceivers.contains(receiver.getType().getTypeClass())) { + throw new SecurityException("Unsupported method receiver: ${receiver.getText()}") + } + expression.getArguments().visit(this) //enforce arguments + } + + /** + * Argument expressions must continue to be processed + */ + public void visitArgumentlistExpression(ArgumentListExpression expression) { + expression.getExpressions().each { it.visit(this) } + } + + /** + * Property access allowed only on a few core Java types. + */ + public void visitPropertyExpression(PropertyExpression expression) { + Expression receiver = expression.getObjectExpression() + if (!(receiver instanceof ClassExpression)) { + throw new SecurityException("Unsupported method call: $receiver") + } + if (!allowedReceivers.contains(receiver.getType().getTypeClass())) { + throw new SecurityException("Unsupported method receiver: ${receiver.getText()}") + } + } + + /** + * The unary minus is allowed. + */ + public void visitUnaryMinusExpression(UnaryMinusExpression expression) { + expression.getExpression().visit(this) + } + + /** + * The unary plus operation is allowed. + */ + public void visitUnaryPlusExpression(UnaryPlusExpression expression) { + expression.getExpression().visit(this) + } + + /** + * Prefix operations like ++ and -- are allowed. + */ + public void visitPrefixExpression(PrefixExpression expression) { + if (!allowedTokens.contains(expression.getOperation().getType())) { + throw new SecurityException("Unsupported token: ${expression.getOperation().getText() }") + } + expression.getExpression().visit(this) + } + + /** + * Postfix operations like ++ and -- are allowed. + */ + public void visitPostfixExpression(PostfixExpression expression) { + if (!allowedTokens.contains(expression.getOperation().getType())) { + throw new SecurityException("Unsupported token: ${expression.getOperation().getText() }") + } + expression.getExpression().visit(this) + } + + /** + * Ternary expressions are allowed as long as they are arithmetic + */ + public void visitTernaryExpression(TernaryExpression expression) { + expression.getBooleanExpression().visit(this) + expression.getTrueExpression().visit(this) + expression.getFalseExpression().visit(this) + } + + /** + * Boolean expressions are allowed. + */ + public void visitBooleanExpression(BooleanExpression expression) { + expression.getExpression().visit(this) + } + + public void visitForLoop(ForStatement forStatement) { + throw new SecurityException("For statements forbidden in arithmetic shell.") + } + public void visitWhileLoop(WhileStatement whileStatement) { + throw new SecurityException("While statements forbidden in arithmetic shell.") + } + public void visitDoWhileLoop(DoWhileStatement doWhileStatement) { + throw new SecurityException("Do/while statements forbidden in arithmetic shell.") + } + public void visitIfElse(IfStatement ifStatement) { + throw new SecurityException("If statements forbidden in arithmetic shell.") + } + public void visitReturnStatement(ReturnStatement returnStatement) { + throw new SecurityException("Return statements forbidden in arithmetic shell.") + } + public void visitAssertStatement(AssertStatement assertStatement) { + throw new SecurityException("Assert statements forbidden in arithmetic shell.") + } + public void visitTryCatchFinally(TryCatchStatement tryCatchStatement) { + throw new SecurityException("Try/Catch statements forbidden in arithmetic shell.") + } + public void visitSwitch(SwitchStatement switchStatement) { + throw new SecurityException("Switch statements forbidden in arithmetic shell.") + } + public void visitCaseStatement(CaseStatement caseStatement) { + throw new SecurityException("Case statements forbidden in arithmetic shell.") + } + public void visitBreakStatement(BreakStatement breakStatement) { + throw new SecurityException("Break statements forbidden in arithmetic shell.") + } + public void visitContinueStatement(ContinueStatement continueStatement) { + throw new SecurityException("Continue statements forbidden in arithmetic shell.") + } + public void visitThrowStatement(ThrowStatement throwStatement) { + throw new SecurityException("Throw statements forbidden in arithmetic shell.") + } + public void visitSynchronizedStatement(SynchronizedStatement synchronizedStatement) { + throw new SecurityException("Synchronized statements forbidden in arithmetic shell.") + } + public void visitCatchStatement(CatchStatement catchStatement) { + throw new SecurityException("Catch statements forbidden in arithmetic shell.") + } + public void visitStaticMethodCallExpression(StaticMethodCallExpression staticMethodCallExpression) { + throw new SecurityException("Static method call expressions forbidden in arithmetic shell.") + } + public void visitConstructorCallExpression(ConstructorCallExpression constructorCallExpression) { + throw new SecurityException("Constructor call expressions forbidden in arithmetic shell.") + } + public void visitShortTernaryExpression(ElvisOperatorExpression elvisOperatorExpression) { + throw new SecurityException("Elvis operator expressions forbidden in arithmetic shell.") + } + public void visitClosureExpression(ClosureExpression closureExpression) { + throw new SecurityException("Closure expressions forbidden in arithmetic shell.") + } + public void visitTupleExpression(TupleExpression tupleExpression) { + throw new SecurityException("Tuple expressions forbidden in arithmetic shell.") + } + public void visitMapExpression(MapExpression mapExpression) { + throw new SecurityException("Map expressions forbidden in arithmetic shell.") + } + public void visitMapEntryExpression(MapEntryExpression mapEntryExpression) { + throw new SecurityException("Map entry expressions forbidden in arithmetic shell.") + } + public void visitListExpression(ListExpression listExpression) { + throw new SecurityException("List expressions forbidden in arithmetic shell.") + } + public void visitRangeExpression(RangeExpression rangeExpression) { + throw new SecurityException("Range expressions forbidden in arithmetic shell.") + } + public void visitAttributeExpression(AttributeExpression attributeExpression) { + throw new SecurityException("Attribute expressions forbidden in arithmetic shell.") + } + public void visitFieldExpression(FieldExpression fieldExpression) { + throw new SecurityException("Field expressions forbidden in arithmetic shell.") + } + public void visitMethodPointerExpression(MethodPointerExpression methodPointerExpression) { + throw new SecurityException("Method pointer expressions forbidden in arithmetic shell.") + } + public void visitVariableExpression(VariableExpression variableExpression) { + throw new SecurityException("Variable expressions forbidden in arithmetic shell.") + } + public void visitDeclarationExpression(DeclarationExpression declarationExpression) { + throw new SecurityException("Declaraion expressions forbidden in arithmetic shell.") + } + public void visitRegexExpression(RegexExpression regexExpression) { + throw new SecurityException("Regex expressions forbidden in arithmetic shell.") + } + public void visitGStringExpression(GStringExpression gStringExpression) { + throw new SecurityException("Groovy String expressions forbidden in arithmetic shell.") + } + public void visitArrayExpression(ArrayExpression arrayExpression) { + throw new SecurityException("Array expressions forbidden in arithmetic shell.") + } + public void visitSpreadExpression(SpreadExpression spreadExpression) { + throw new SecurityException("Spread expressions forbidden in arithmetic shell.") + } + public void visitSpreadMapExpression(SpreadMapExpression spreadMapExpression) { + throw new SecurityException("Spread map expressions forbidden in arithmetic shell.") + } + public void visitNotExpression(NotExpression notExpression) { + throw new SecurityException("Not expressions forbidden in arithmetic shell.") + } + public void visitBitwiseNegationExpression(BitwiseNegationExpression bitwiseNegationExpression) { + throw new SecurityException("Bitwise Negation expressions forbidden in arithmetic shell.") + } + public void visitCastExpression(CastExpression castExpression) { + throw new SecurityException("Cast expressions forbidden in arithmetic shell.") + } + public void visitClosureListExpression(ClosureListExpression closureListExpression) { + throw new SecurityException("Closure expressions forbidden in arithmetic shell.") + } + public void visitBytecodeExpression(BytecodeExpression bytecodeExpression) { + throw new SecurityException("Bytecode expressions forbidden in arithmetic shell.") + } + public void visitClassExpression(ClassExpression classExpression) { + throw new SecurityException("Class expressions forbidden in arithmetic shell.") + } +} Index: src/examples/groovyShell/ArithmeticShellTest.groovy =================================================================== --- src/examples/groovyShell/ArithmeticShellTest.groovy (revision 0) +++ src/examples/groovyShell/ArithmeticShellTest.groovy (revision 0) @@ -0,0 +1,40 @@ + +/** +* Unit test for ArithmeticShell. +* Requires JUnit to be in path, just like any other GroovyTestCase. +* +* @author Hamlet D'Arcy +*/ +class ArithmeticShellTest extends GroovyTestCase { + + public void testEvaluate_SuccessfulPaths() { + ArithmeticShell shell = new ArithmeticShell() + assertEquals(2.9073548971824276E135, shell.evaluate("((6L / 2f) - 1) ** 4.5e2")) + assertEquals(-6.816387600233341, shell.evaluate("10 * Math.sin(15/-20)")) + assertEquals(74.17310622494026, shell.evaluate("80*Math.E**(-(+(11++/40)**2))")) + assertEquals(2147483646, shell.evaluate("Integer.MAX_VALUE - ++2%2")) + assertEquals(6, shell.evaluate("++(5)")) + assertEquals(0, shell.evaluate("5 < 4 ? 1 : 0")) + assertEquals(0, shell.evaluate("5 != 4 ? 0 : 1 ")) + } + + public void testEvaluate_Failures() { + ArithmeticShell shell = new ArithmeticShell() + + shouldFail(SecurityException) { + shell.evaluate("Double.valueOf(\"5\")") + } + + shouldFail(SecurityException) { + shell.evaluate("import javax.swing.JLabel;5") + } + + shouldFail(SecurityException) { + shell.evaluate("def x = 5+3;x.toString()") + } + + shouldFail(SecurityException) { + shell.evaluate("new File();Double.valueOf(\"5\")") + } + } +}