From 7a444f77fcdde2862776c54649ed484481e3ed95 Mon Sep 17 00:00:00 2001
From: David Calavera <david.calavera@gmail.com>
Date: Tue, 22 Sep 2009 10:46:14 +0200
Subject: [PATCH] [1.9] Array improvements to solve several specs

---
 src/org/jruby/RubyArray.java      |  131 +++++++++++++++++++++++++++++++------
 src/org/jruby/RubyEnumerable.java |   20 +++++-
 2 files changed, 129 insertions(+), 22 deletions(-)

diff --git a/src/org/jruby/RubyArray.java b/src/org/jruby/RubyArray.java
index e192d0b..7f0ee7a 100644
--- a/src/org/jruby/RubyArray.java
+++ b/src/org/jruby/RubyArray.java
@@ -963,12 +963,18 @@ public class RubyArray extends RubyObject implements List {
     /** rb_ary_insert
      * 
      */
-    @JRubyMethod(name = "insert")
+    @JRubyMethod(name = "insert", compat = CompatVersion.RUBY1_8)
     public IRubyObject insert(IRubyObject arg) {
         return this;
     }
 
-    @JRubyMethod(name = "insert")
+    @JRubyMethod(name = "insert", compat = CompatVersion.RUBY1_9)
+    public IRubyObject insert1_9(IRubyObject arg) {
+        modifyCheck();
+        return insert(arg);
+    }
+
+    @JRubyMethod(name = "insert", compat = CompatVersion.RUBY1_8)
     public IRubyObject insert(IRubyObject arg1, IRubyObject arg2) {
         long pos = RubyNumeric.num2long(arg1);
 
@@ -980,7 +986,13 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
-    @JRubyMethod(name = "insert", required = 1, rest = true)
+    @JRubyMethod(name = "insert", compat = CompatVersion.RUBY1_9)
+    public IRubyObject insert1_9(IRubyObject arg1, IRubyObject arg2) {
+        modifyCheck();
+        return insert(arg1, arg2);
+    }
+
+    @JRubyMethod(name = "insert", required = 1, rest = true, compat = CompatVersion.RUBY1_8)
     public IRubyObject insert(IRubyObject[] args) {
         if (args.length == 1) return this;
 
@@ -999,12 +1011,19 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
+    @JRubyMethod(name = "insert", required = 1, rest = true, compat = CompatVersion.RUBY1_9)
+    public IRubyObject insert1_9(IRubyObject[] args) {
+        modifyCheck();
+        return insert(args);
+    }
+
     /** rb_ary_dup
      * 
      */
     public final RubyArray aryDup() {
         RubyArray dup = new RubyArray(getRuntime(), getMetaClass(), this);
         dup.flags |= flags & TAINTED_F; // from DUP_SETUP
+        dup.flags |= flags & UNTRUSTED_F;
         // rb_copy_generic_ivar from DUP_SETUP here ...unlikely..
         return dup;
     }
@@ -1153,7 +1172,7 @@ public class RubyArray extends RubyObject implements List {
     /** rb_ary_push_m
      * FIXME: Whis is this named "push_m"?
      */
-    @JRubyMethod(name = "push", rest = true)
+    @JRubyMethod(name = "push", rest = true, compat = CompatVersion.RUBY1_8)
     public RubyArray push_m(IRubyObject[] items) {
         for (int i = 0; i < items.length; i++) {
             append(items[i]);
@@ -1162,6 +1181,12 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
+    @JRubyMethod(name = "push", rest = true, compat = CompatVersion.RUBY1_9)
+    public RubyArray push_m1_9(IRubyObject[] items) {
+        modifyCheck();
+        return push_m(items);
+    }
+
     /** rb_ary_pop
      *
      */
@@ -1233,15 +1258,21 @@ public class RubyArray extends RubyObject implements List {
         return result;
     }
 
-    @JRubyMethod(name = "unshift")
+    @JRubyMethod(name = "unshift", compat = CompatVersion.RUBY1_8)
     public IRubyObject unshift() {
         return this;
     }
 
+    @JRubyMethod(name = "unshift", compat = CompatVersion.RUBY1_9)
+    public IRubyObject unshift1_9() {
+        modifyCheck();
+        return unshift();
+    }
+
     /** rb_ary_unshift
      *
      */
-    @JRubyMethod(name = "unshift")
+    @JRubyMethod(name = "unshift", compat = CompatVersion.RUBY1_8)
     public IRubyObject unshift(IRubyObject item) {
         if (begin == 0 || isShared) {
             modify();
@@ -1276,7 +1307,13 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
-    @JRubyMethod(name = "unshift", rest = true)
+    @JRubyMethod(name = "unshift", compat = CompatVersion.RUBY1_9)
+    public IRubyObject unshift1_9(IRubyObject item) {
+        modifyCheck();
+        return unshift(item);
+    }
+
+    @JRubyMethod(name = "unshift", rest = true, compat = CompatVersion.RUBY1_8)
     public IRubyObject unshift(IRubyObject[] items) {
         long len = realLength;
         if (items.length == 0) return this;
@@ -1293,6 +1330,12 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
+    @JRubyMethod(name = "unshift", rest = true, compat = CompatVersion.RUBY1_9)
+    public IRubyObject unshift1_9(IRubyObject[] items) {
+        modifyCheck();
+        return unshift(items);
+    }
+
     /** rb_ary_includes
      * 
      */
@@ -1440,7 +1483,7 @@ public class RubyArray extends RubyObject implements List {
 	/** rb_ary_concat
      *
      */
-    @JRubyMethod(name = "concat", required = 1)
+    @JRubyMethod(name = "concat", required = 1, compat = CompatVersion.RUBY1_8)
     public RubyArray concat(IRubyObject obj) {
         RubyArray ary = obj.convertToArray();
         
@@ -1449,6 +1492,12 @@ public class RubyArray extends RubyObject implements List {
         return this;
     }
 
+    @JRubyMethod(name = "concat", required = 1, compat = CompatVersion.RUBY1_9)
+    public RubyArray concat1_9(IRubyObject obj) {
+        modifyCheck();
+        return concat(obj);
+    }
+
     /** inspect_ary
      * 
      */
@@ -1650,12 +1699,16 @@ public class RubyArray extends RubyObject implements List {
         if (realLength == 0) return RubyString.newEmptyString(runtime);
 
         boolean taint = isTaint() || sep.isTaint();
+        boolean untrusted = isUntrusted() || sep.isUntrusted();
 
         int len = 1;
         for (int i = begin; i < begin + realLength; i++) {            
             IRubyObject value;
             try {
                 value = values[i];
+                if (runtime.is1_9() && value.equals(this)) {
+                    throw runtime.newArgumentError("");
+                }
             } catch (ArrayIndexOutOfBoundsException e) {
                 concurrentModification();
                 return RubyString.newEmptyString(runtime);
@@ -1695,10 +1748,13 @@ public class RubyArray extends RubyObject implements List {
 
             buf.append(tmp.asString().getByteList());
             if (tmp.isTaint()) taint = true;
+            if (tmp.isUntrusted()) untrusted = true;
         }
 
         RubyString result = runtime.newString(buf); 
         if (taint) result.setTaint(true);
+        if (untrusted) result.untrust(context);
+        
         return result;
     }
 
@@ -2124,9 +2180,12 @@ public class RubyArray extends RubyObject implements List {
     /** rb_ary_collect
      *
      */
-    public RubyArray collect(ThreadContext context, Block block) {
+    @JRubyMethod(name = "collect", frame = true, compat = CompatVersion.RUBY1_8)
+    public IRubyObject collect(ThreadContext context, Block block) {
         Ruby runtime = context.getRuntime();
-        if (!block.isGiven()) return new RubyArray(runtime, runtime.getArray(), this);
+        if (!block.isGiven()) {
+            return new RubyArray(runtime, runtime.getArray(), this);
+        }
 
         RubyArray collect = new RubyArray(runtime, realLength);
 
@@ -2137,6 +2196,11 @@ public class RubyArray extends RubyObject implements List {
         return collect;
     }
 
+    @JRubyMethod(name = "collect", frame = true, compat = CompatVersion.RUBY1_9)
+    public IRubyObject collect1_9(ThreadContext context, Block block) {
+        return block.isGiven() ? collect(context, block) : enumeratorize(context.getRuntime(), this, "collect");
+    }
+
     /** rb_ary_collect_bang
      *
      */
@@ -2583,8 +2647,8 @@ public class RubyArray extends RubyObject implements List {
         return modified;
     }
 
-    @JRubyMethod(name = "flatten!")
-    public IRubyObject flatten_bang19(ThreadContext context) {
+    @JRubyMethod(name = "flatten!", compat = CompatVersion.RUBY1_8)
+    public IRubyObject flatten_bang1_8(ThreadContext context) {
         Ruby runtime = context.getRuntime();
 
         RubyArray result = new RubyArray(runtime, getMetaClass(), realLength);
@@ -2598,9 +2662,16 @@ public class RubyArray extends RubyObject implements List {
         return runtime.getNil();
     }
 
-    @JRubyMethod(name = "flatten!")
-    public IRubyObject flatten_bang19(ThreadContext context, IRubyObject arg) {
+    @JRubyMethod(name = "flatten!", compat = CompatVersion.RUBY1_9)
+    public IRubyObject flatten_bang1_9(ThreadContext context) {
+        modifyCheck();
+        return flatten_bang1_8(context);
+    }
+
+    @JRubyMethod(name = "flatten!", compat = CompatVersion.RUBY1_8)
+    public IRubyObject flatten_bang1_8(ThreadContext context, IRubyObject arg) {
         Ruby runtime = context.getRuntime();
+        
         int level = RubyNumeric.num2int(arg);
         if (level == 0) return this;
 
@@ -2614,6 +2685,12 @@ public class RubyArray extends RubyObject implements List {
         return runtime.getNil();
     }
 
+    @JRubyMethod(name = "flatten!", compat = CompatVersion.RUBY1_9)
+    public IRubyObject flatten_bang1_9(ThreadContext context, IRubyObject arg) {
+        modifyCheck();
+        return flatten_bang1_8(context, arg);
+    }
+
     @JRubyMethod(name = "flatten")
     public IRubyObject flatten19(ThreadContext context) {
         Ruby runtime = context.getRuntime();
@@ -2786,6 +2863,8 @@ public class RubyArray extends RubyObject implements List {
 
     @JRubyMethod(name = "uniq!", compat = CompatVersion.RUBY1_9)
     public IRubyObject uniq_bang19(ThreadContext context, Block block) {
+        modifyCheck();
+        
         if (!block.isGiven()) return uniq_bang(context);
         RubyHash hash = makeHash(context, block);
         if (realLength == hash.size()) return context.getRuntime().getNil();
@@ -2911,7 +2990,9 @@ public class RubyArray extends RubyObject implements List {
     public IRubyObject sort_bang(ThreadContext context, Block block) {
         modify();
         if (realLength > 1) {
-            flags |= TMPLOCK_ARR_F;
+            if (!context.getRuntime().is1_9()) {
+                flags |= TMPLOCK_ARR_F;
+            }
             try {
                 return block.isGiven() ? sortInternal(context, block): sortInternal(context);
             } finally {
@@ -3215,14 +3296,22 @@ public class RubyArray extends RubyObject implements List {
     /** rb_ary_choice
      * 
      */
-    @JRubyMethod(name = "choice")
+    @JRubyMethod(name = "choice",  compat = CompatVersion.RUBY1_8)
     public IRubyObject choice(ThreadContext context) {
-        Random random = context.getRuntime().getRandom();
-        int i = realLength;
-        if(i == 0) {
-            return context.getRuntime().getNil();
+        return (realLength == 0) ? context.getRuntime().getNil() : choiceCommon(context);
+    }
+    
+    @JRubyMethod(name = "choice", compat = CompatVersion.RUBY1_9)
+    public IRubyObject choice1_9(ThreadContext context) {
+        if (realLength == 0) {
+            throw context.getRuntime().newNoMethodError("undefined method 'choice' for []:Array", null, context.getRuntime().getNil());
         }
-        int r = random.nextInt(i);
+        return choiceCommon(context);
+    }
+    
+    public IRubyObject choiceCommon(ThreadContext context) {
+        Random random = context.getRuntime().getRandom();
+        int r = random.nextInt(realLength);
         return values[begin + r];
     }
 
diff --git a/src/org/jruby/RubyEnumerable.java b/src/org/jruby/RubyEnumerable.java
index 5910ec2..9c0f601 100644
--- a/src/org/jruby/RubyEnumerable.java
+++ b/src/org/jruby/RubyEnumerable.java
@@ -615,8 +615,26 @@ public class RubyEnumerable {
         return block.isGiven() ? reject(context, self, block) : enumeratorize(context.getRuntime(), self, "reject");
     }
 
-    @JRubyMethod(name = {"collect", "map"}, frame = true)
+    @JRubyMethod(name = {"collect", "map"}, frame = true, compat = CompatVersion.RUBY1_8)
     public static IRubyObject collect(ThreadContext context, IRubyObject self, final Block block) {
+        return collectCommon(context, self, block);
+    }
+    
+    @JRubyMethod(name = "collect", frame = true, compat = CompatVersion.RUBY1_9)
+    public static IRubyObject collect1_9(ThreadContext context, IRubyObject self, final Block block) {
+        return mapOrCollect("collect", context, self, block);        
+    }
+    
+    @JRubyMethod(name = "map", frame = true, compat = CompatVersion.RUBY1_9)
+    public static IRubyObject map1_9(ThreadContext context, IRubyObject self, final Block block) {
+        return mapOrCollect("map", context, self, block);
+    }
+    
+    private static IRubyObject mapOrCollect(String name, ThreadContext context, IRubyObject self, final Block block) {
+        return block.isGiven() ? collectCommon(context, self, block) : enumeratorize(context.getRuntime(), self, name);
+    }
+    
+    private static IRubyObject collectCommon(ThreadContext context, IRubyObject self, final Block block) {
         final Ruby runtime = context.getRuntime();
         final RubyArray result = runtime.newArray();
 
-- 
1.6.0.4


