diff --git a/rxjava-core/src/main/java/rx/observers/Subscribers.java b/rxjava-core/src/main/java/rx/observers/Subscribers.java index bc9496e3e3..4fb8f9dd48 100644 --- a/rxjava-core/src/main/java/rx/observers/Subscribers.java +++ b/rxjava-core/src/main/java/rx/observers/Subscribers.java @@ -7,29 +7,9 @@ import rx.functions.Action1; public class Subscribers { - - private static final Subscriber EMPTY = new Subscriber() { - - @Override - public final void onCompleted() { - // do nothing - } - - @Override - public final void onError(Throwable e) { - throw new OnErrorNotImplementedException(e); - } - - @Override - public final void onNext(Object args) { - // do nothing - } - - }; - @SuppressWarnings("unchecked") public static Subscriber empty() { - return (Subscriber) EMPTY; + return from(Observers.empty()); } public static Subscriber from(final Observer o) { diff --git a/rxjava-core/src/main/java/rx/operators/OperatorMulticast.java b/rxjava-core/src/main/java/rx/operators/OperatorMulticast.java index 4a5016b717..01d647316c 100644 --- a/rxjava-core/src/main/java/rx/operators/OperatorMulticast.java +++ b/rxjava-core/src/main/java/rx/operators/OperatorMulticast.java @@ -15,6 +15,7 @@ */ package rx.operators; +import java.util.BitSet; import rx.Observable; import rx.Subscriber; import rx.Subscription; @@ -49,9 +50,10 @@ public void call(Subscriber subscriber) { @Override public Subscription connect() { + Subscriber s = null; synchronized (guard) { if (subscription == null) { - subscription = source.unsafeSubscribe(new Subscriber() { + s = new Subscriber() { @Override public void onCompleted() { subject.onCompleted(); @@ -66,18 +68,24 @@ public void onError(Throwable e) { public void onNext(T args) { subject.onNext(args); } - }); + }; + subscription = s; } } + if (s != null) { + source.unsafeSubscribe(s); + } return Subscriptions.create(new Action0() { @Override public void call() { + Subscription s; synchronized (guard) { - if (subscription != null) { - subscription.unsubscribe(); - subscription = null; - } + s = subscription; + subscription = null; + } + if (s != null) { + s.unsubscribe(); } } }); diff --git a/rxjava-core/src/main/java/rx/operators/OperatorRefCount.java b/rxjava-core/src/main/java/rx/operators/OperatorRefCount.java index bfcc157951..3831dd63b0 100644 --- a/rxjava-core/src/main/java/rx/operators/OperatorRefCount.java +++ b/rxjava-core/src/main/java/rx/operators/OperatorRefCount.java @@ -15,6 +15,10 @@ */ package rx.operators; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.WeakHashMap; import rx.Observable.OnSubscribe; import rx.Subscriber; import rx.Subscription; @@ -31,32 +35,175 @@ public final class OperatorRefCount implements OnSubscribe { final ConnectableObservable source; final Object guard; /** Guarded by guard. */ - int count; + int index; /** Guarded by guard. */ + boolean emitting; + /** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */ + List queue; + /** Manipulated while in the serialized section. */ + int count; + /** Manipulated while in the serialized section. */ Subscription connection; + /** Manipulated while in the serialized section. */ + final Map connectionStatus; + /** Occupied indicator. */ + private static final Object OCCUPIED = new Object(); public OperatorRefCount(ConnectableObservable source) { this.source = source; this.guard = new Object(); + this.connectionStatus = new WeakHashMap(); } @Override public void call(Subscriber t1) { + int id; + synchronized (guard) { + id = ++index; + } + final Token t = new Token(id); t1.add(Subscriptions.create(new Action0() { @Override public void call() { - synchronized (guard) { - if (--count == 0) { - connection.unsubscribe(); - connection = null; - } - } + disconnect(t); } })); source.unsafeSubscribe(t1); + connect(t); + } + private void connect(Token id) { + List localQueue; + synchronized (guard) { + if (emitting) { + if (queue == null) { + queue = new ArrayList(); + } + queue.add(id); + return; + } + + localQueue = queue; + queue = null; + emitting = true; + } + boolean once = true; + do { + drain(localQueue); + if (once) { + once = false; + doConnect(id); + } + synchronized (guard) { + localQueue = queue; + queue = null; + if (localQueue == null) { + emitting = false; + return; + } + } + } while (true); + } + private void disconnect(Token id) { + List localQueue; synchronized (guard) { + if (emitting) { + if (queue == null) { + queue = new ArrayList(); + } + queue.add(id.toDisconnect()); // negative value indicates disconnect + return; + } + + localQueue = queue; + queue = null; + emitting = true; + } + boolean once = true; + do { + drain(localQueue); + if (once) { + once = false; + doDisconnect(id); + } + synchronized (guard) { + localQueue = queue; + queue = null; + if (localQueue == null) { + emitting = false; + return; + } + } + } while (true); + } + private void drain(List localQueue) { + if (localQueue == null) { + return; + } + int n = localQueue.size(); + for (int i = 0; i < n; i++) { + Token id = localQueue.get(i); + if (id.isDisconnect()) { + doDisconnect(id); + } else { + doConnect(id); + } + } + } + private void doConnect(Token id) { + // this method is called only once per id + // if add succeeds, id was not yet disconnected + if (connectionStatus.put(id, OCCUPIED) == null) { if (count++ == 0) { connection = source.connect(); } + } else { + // connection exists due to disconnect, just remove + connectionStatus.remove(id); + } + } + private void doDisconnect(Token id) { + // this method is called only once per id + // if remove succeeds, id was connected + if (connectionStatus.remove(id) != null) { + if (--count == 0) { + connection.unsubscribe(); + connection = null; + } + } else { + // mark id as if connected + connectionStatus.put(id, OCCUPIED); + } + } + /** Token that represens a connection request or a disconnection request. */ + private static final class Token { + final int id; + public Token(int id) { + this.id = id; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (obj.getClass() != getClass()) { + return false; + } + int other = ((Token)obj).id; + return id == other || -id == other; + } + + @Override + public int hashCode() { + return id < 0 ? -id : id; + } + public boolean isDisconnect() { + return id < 0; + } + public Token toDisconnect() { + if (id < 0) { + return this; + } + return new Token(-id); } } } diff --git a/rxjava-core/src/main/java/rx/operators/OperatorSampleWithTime.java b/rxjava-core/src/main/java/rx/operators/OperatorSampleWithTime.java index 2b92b9792e..31642ea439 100644 --- a/rxjava-core/src/main/java/rx/operators/OperatorSampleWithTime.java +++ b/rxjava-core/src/main/java/rx/operators/OperatorSampleWithTime.java @@ -80,6 +80,7 @@ public void onError(Throwable e) { @Override public void onCompleted() { subscriber.onCompleted(); + unsubscribe(); } @Override diff --git a/rxjava-core/src/test/java/rx/RefCountTests.java b/rxjava-core/src/test/java/rx/RefCountTests.java index 95a7ec2030..43726456dd 100644 --- a/rxjava-core/src/test/java/rx/RefCountTests.java +++ b/rxjava-core/src/test/java/rx/RefCountTests.java @@ -16,7 +16,7 @@ package rx; import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.List; @@ -25,11 +25,14 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.InOrder; import org.mockito.MockitoAnnotations; import rx.functions.Action0; import rx.functions.Action1; +import rx.observers.Subscribers; import rx.schedulers.TestScheduler; +import rx.subjects.ReplaySubject; import rx.subscriptions.Subscriptions; public class RefCountTests { @@ -152,4 +155,50 @@ public void call(Long t1) { assertEquals(1L, list3.get(1).longValue()); } + + @Test + public void testAlreadyUnsubscribedClient() { + Subscriber done = Subscribers.empty(); + done.unsubscribe(); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + + Observable result = Observable.just(1).publish().refCount(); + + result.subscribe(done); + + result.subscribe(o); + + verify(o).onNext(1); + verify(o).onCompleted(); + verify(o, never()).onError(any(Throwable.class)); + } + @Test + public void testAlreadyUnsubscribedInterleavesWithClient() { + ReplaySubject source = ReplaySubject.create(); + + Subscriber done = Subscribers.empty(); + done.unsubscribe(); + + @SuppressWarnings("unchecked") + Observer o = mock(Observer.class); + InOrder inOrder = inOrder(o); + + Observable result = source.publish().refCount(); + + result.subscribe(o); + + source.onNext(1); + + result.subscribe(done); + + source.onNext(2); + source.onCompleted(); + + inOrder.verify(o).onNext(1); + inOrder.verify(o).onNext(2); + inOrder.verify(o).onCompleted(); + verify(o, never()).onError(any(Throwable.class)); + } }