How to make combine's prefix operator inclusive?

I'm probably missing something very simple, but I can't seem to figure it out.

My goal here is to prefix until I see "end" value. I want this end value to get published and then I want the publisher to finish. So in the following code I want to know how to modify such that "end" is the last thing printed, instead of the element before "end".

_ = Publishers.Sequence<[String], Never>(sequence: ["a", "b", "end", "c"])
    .prefix(while: { $0 != "end" })
    .sink { each in
        print(each)
    }

Thanks, that gets what I asked for, but it isn't the full behavior that I really need.

In particular my real world sequence will have time delays between items. I want the publisher to complete immediately after seeing/publishing "end" ... I don't want it to have to wait to see the next element after "end" to realize that it should complete.

I assume your prefix(while:) predicate isn't a simple comparison with a constant, but something where you don't know the exact final value in advance.

We can solve your problem by using the scan operator to pair each output with a flag indicating whether the prior output should have been the final output. Then we can use prefix to end the sequence when the flag is true. Finally, we can use compactMap to discard the flag and keep just the original output value.

import Combine

let sequence: AnyPublisher<String, Never> = [
    "hello",
    "world",
    "end-transmission",
    "garbage",
].publisher.eraseToAnyPublisher()

let pub = sequence
    .scan((current: String?.none, priorWasFinal: false)) { state, new in
        // state.current is only nil for the first output from upstream,
        // in which case there was no prior element to be final.
        let priorWasFinal = state.current?.hasPrefix("end-") ?? false
        return (current: new, priorWasFinal: priorWasFinal)
    }
    .prefix(while: { !$0.priorWasFinal })
    .compactMap { $0.0 }
pub.sink { print($0) }

Output:

hello
world
end-transmission

We can wrap the logic into a Publisher extension method like this:

import Combine

extension Publisher {
    func cut(after predicate: @escaping (Output) -> Bool) -> AnyPublisher<Output, Failure> {
        return self
            .scan((current: Output?.none, priorWasFinal: false)) { state, new in
                // state.current is only nil for the first output from upstream,
                // in which case there was no prior element to be final.
                let priorWasFinal = state.current.map(predicate) ?? false
                return (current: new, priorWasFinal: priorWasFinal)
            }
            .prefix(while: { !$0.priorWasFinal })
            .compactMap { $0.0 }
            .eraseToAnyPublisher()
    }
}

[
    "hello",
    "world",
    "end-transmission",
    "garbage",
].publisher
    .cut(after: { $0.hasPrefix("end-") })
    .sink { print($0) }

(Same output as the first example.)

1 Like

Thanks for your help and that's the right API for sure.

But I think it still doesn't send .finish immediately after the predicate fails. So in the above example if "garbage" doesn't arrive for 1 hour after "end-transmission" then the .finish will also not be sent until 1 hour after "end-transmission" ... I "think" that's what I'm seeing in debugger.

Do I need a custom publisher for this behavior?

Jesse

You're right. Do it with flatMap instead.

import Combine

enum Cut<Value> {
    case output(Value)
    case end(Bool)

    var isEnd: Bool {
        if case .end(_) = self { return true }
        return false
    }

    var output: Value? {
        if case .output(let output) = self { return output }
        return nil
    }
}

extension Publisher {
    func cut(after predicate: @escaping (Output) -> Bool) -> AnyPublisher<Output, Failure> {
        return self.flatMap { (input: Output) -> AnyPublisher<Cut<Output>, Failure> in
            if predicate(input) {
                return [.output(input), .end(true)].publisher.setFailureType(to: Failure.self).eraseToAnyPublisher()
            } else {
                return [.output(input)].publisher.setFailureType(to: Failure.self).eraseToAnyPublisher()
            }
        }
        .prefix { !$0.isEnd }
        .compactMap { $0.output }
        .eraseToAnyPublisher()
    }
}

let subject = PassthroughSubject<String, Never>()

let ticket = subject.cut(after: { $0.hasPrefix("end-") })
    .sink(
        receiveCompletion: { print("completion: \($0)") },
        receiveValue: { print($0) }
    )

subject.send("hello")
subject.send("world")
subject.send("end-transmission")

But there's a bug that makes it crash if the Cut.end case has no associated value. Hence the dummy Bool value.

1 Like

That's it, thanks again for your help. My mind need to get more flexible around what flatMap can do.

Jesse

Here's a custom Publisher approach.

It's a copy-past-modify job from GitHub - cx-org/CombineX: Open source implementation of Apple's Combine 's TryPrefixWhile implementation.

Lots more code, but much is dependencies... the actual publisher logic isn't TOO complex. Performance is about 5x of the composed flatMap solution. Depending on use case it might not matter... for 1000 elements the timings were 0.001 vrs 0.005.

Probably flatMap solution is best for most everyone... but was fun to make a publisher!

import Combine

extension Publisher {
    public func cut(after predicate: @escaping (Output) -> Bool) -> Publishers.CutAfter<Self> {
        .init(upstream: self, predicate: predicate)
    }
}

extension Publishers {
    public struct CutAfter<Upstream: Publisher>: Publisher {
        public typealias Output = Upstream.Output
        public typealias Failure = Upstream.Failure

        public let upstream: Upstream
        public let predicate: (Upstream.Output) -> Bool

        public init(upstream: Upstream, predicate: @escaping (Publishers.CutAfter<Upstream>.Output) -> Bool) {
            self.upstream = upstream
            self.predicate = predicate
        }

        public func receive<S: Subscriber>(subscriber: S) where Upstream.Failure == S.Failure, Upstream.Output == S.Input {
            upstream
                .tryCut(after: predicate)
                .mapError {
                    $0 as! Failure
                }
                .receive(subscriber: subscriber)
        }
    }
}

extension Publisher {
    public func tryCut(after predicate: @escaping (Output) throws -> Bool) -> Publishers.TryCutAfter<Self> {
        .init(upstream: self, predicate: predicate)
    }
}

extension Publishers {
    public struct TryCutAfter<Upstream: Publisher>: Publisher {
        public typealias Output = Upstream.Output
        public typealias Failure = Error

        public let upstream: Upstream
        public let predicate: (Upstream.Output) throws -> Bool

        public init(upstream: Upstream, predicate: @escaping (Publishers.TryCutAfter<Upstream>.Output) throws -> Bool) {
            self.upstream = upstream
            self.predicate = predicate
        }

        public func receive<S: Subscriber>(subscriber: S) where Upstream.Output == S.Input, S.Failure == Publishers.TryCutAfter<Upstream>.Failure {
            let s = Inner(pub: self, sub: subscriber)
            upstream.subscribe(s)
        }
    }
}

extension Publishers.TryCutAfter {
    private final class Inner<S>:
        Subscription,
        Subscriber,
        CustomStringConvertible,
        CustomDebugStringConvertible
        where
        S: Subscriber,
        S.Input == Output,
        S.Failure == Failure {
        typealias Input = Upstream.Output
        typealias Failure = Upstream.Failure

        typealias Pub = Publishers.TryCutAfter<Upstream>
        typealias Sub = S
        typealias Predicate = (Upstream.Output) throws -> Bool

        let lock = Lock()
        let predicate: Predicate
        let sub: Sub

        var state = RelayState.waiting

        init(pub: Pub, sub: Sub) {
            predicate = pub.predicate
            self.sub = sub
        }

        deinit {
            lock.cleanupLock()
        }

        func request(_ demand: Subscribers.Demand) {
            lock.withLockGet(state.subscription)?.request(demand)
        }

        func cancel() {
            lock.withLockGet(state.complete())?.cancel()
        }

        func receive(subscription: Subscription) {
            guard lock.withLockGet(state.relay(subscription)) else {
                subscription.cancel()
                return
            }

            sub.receive(subscription: self)
        }

        func receive(_ input: Input) -> Subscribers.Demand {
            lock.lock()
            guard state.isRelaying else {
                lock.unlock()
                return .none
            }

            do {
                if try predicate(input) {
                    let subscription = state.complete()
                    lock.unlock()

                    subscription?.cancel()
                    _ = sub.receive(input)
                    sub.receive(completion: .finished)
                    return .none
                } else {
                    lock.unlock()
                    return sub.receive(input)
                }
            } catch {
                let subscription = state.complete()
                lock.unlock()

                subscription?.cancel()
                sub.receive(completion: .failure(error))
                return .none
            }
        }

        func receive(completion: Subscribers.Completion<Failure>) {
            complete(completion.mapError { $0 })
        }

        private func complete(_ completion: Subscribers.Completion<Error>) {
            guard let subscription = lock.withLockGet(state.complete()) else {
                return
            }

            subscription.cancel()
            sub.receive(completion: completion.mapError { $0 })
        }

        var description: String {
            "TryCutAfter"
        }

        var debugDescription: String {
            "TryCutAfter"
        }
    }
}

enum RelayState {
    case waiting
    case relaying(Subscription)
    case completed
}

extension RelayState {
    var isWaiting: Bool {
        switch self {
        case .waiting:
            return true
        default:
            return false
        }
    }

    var isRelaying: Bool {
        switch self {
        case .relaying:
            return true
        default:
            return false
        }
    }

    var isCompleted: Bool {
        switch self {
        case .completed:
            return true
        default:
            return false
        }
    }

    var subscription: Subscription? {
        switch self {
        case let .relaying(s):
            return s
        default:
            return nil
        }
    }
}

extension RelayState {
    func preconditionValue(file: StaticString = #file, line: UInt = #line) {
        if isWaiting {
            fatalError("Received value before receiving subscription", file: file, line: line)
        }
    }

    func preconditionCompletion(file: StaticString = #file, line: UInt = #line) {
        if isWaiting {
            fatalError("Received completion before receiving subscription", file: file, line: line)
        }
    }
}

extension RelayState {
    mutating func relay(_ subscription: Subscription) -> Bool {
        guard isWaiting else { return false }
        self = .relaying(subscription)
        return true
    }

    mutating func complete() -> Subscription? {
        defer {
            self = .completed
        }
        return subscription
    }
}

extension Subscribers.Completion {
    func mapError<NewFailure: Error>(_ transform: (Failure) -> NewFailure) -> Subscribers.Completion<NewFailure> {
        switch self {
        case .finished:
            return .finished
        case let .failure(error):
            return .failure(transform(error))
        }
    }
}

public protocol Locking {
    func lock()
    func tryLock() -> Bool
    func unlock()
}

extension Locking {
    public func withLock<T>(_ body: () throws -> T) rethrows -> T {
        lock(); defer { self.unlock() }
        return try body()
    }

    public func withLockGet<T>(_ body: @autoclosure () throws -> T) rethrows -> T {
        lock(); defer { self.unlock() }
        return try body()
    }
}

// MARK: - Lock

public struct Lock: Locking {
    private let _lock: UnsafeMutableRawPointer

    public init() {
        #if canImport(Darwin)
            if #available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) {
                _lock = OSUnfairLock().raw
                return
            }
        #endif
        _lock = PThreadMutex(recursive: false).raw
    }

    public func cleanupLock() {
        #if canImport(Darwin)
            if #available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) {
                _lock.as(OSUnfairLock.self).cleanupLock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).cleanupLock()
    }

    public func lock() {
        #if canImport(Darwin)
            if #available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) {
                _lock.as(OSUnfairLock.self).lock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).lock()
    }

    public func tryLock() -> Bool {
        #if canImport(Darwin)
            if #available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) {
                return _lock.as(OSUnfairLock.self).tryLock()
            }
        #endif
        return _lock.as(PThreadMutex.self).tryLock()
    }

    public func unlock() {
        #if canImport(Darwin)
            if #available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *) {
                _lock.as(OSUnfairLock.self).unlock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).unlock()
    }
}

// MARK: - RecursiveLock

public struct RecursiveLock: Locking {
    private let _lock: UnsafeMutableRawPointer

    public init() {
        #if canImport(DarwinPrivate)
            if #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) {
                _lock = OSUnfairRecursiveLock().raw
                return
            }
        #endif
        _lock = PThreadMutex(recursive: true).raw
    }

    public func cleanupLock() {
        #if canImport(DarwinPrivate)
            if #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) {
                _lock.as(OSUnfairRecursiveLock.self).cleanupLock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).cleanupLock()
    }

    public func lock() {
        #if canImport(DarwinPrivate)
            if #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) {
                _lock.as(OSUnfairRecursiveLock.self).lock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).lock()
    }

    public func tryLock() -> Bool {
        #if canImport(DarwinPrivate)
            if #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) {
                return _lock.as(OSUnfairRecursiveLock.self).tryLock()
            }
        #endif
        return _lock.as(PThreadMutex.self).tryLock()
    }

    public func unlock() {
        #if canImport(DarwinPrivate)
            if #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) {
                _lock.as(OSUnfairRecursiveLock.self).unlock()
                return
            }
        #endif
        _lock.as(PThreadMutex.self).unlock()
    }
}

#if canImport(Darwin)

    // MARK: - OSUnfairLock

    private typealias OSUnfairLock = UnsafeMutablePointer<os_unfair_lock_s>

    @available(macOS 10.12, iOS 10.0, tvOS 10.0, watchOS 3.0, *)
    private extension UnsafeMutablePointer where Pointee == os_unfair_lock_s {
        init() {
            let l = UnsafeMutablePointer.allocate(capacity: 1)
            l.initialize(to: os_unfair_lock_s())
            self = l
        }

        func cleanupLock() {
            deinitialize(count: 1)
            deallocate()
        }

        func lock() {
            os_unfair_lock_lock(self)
        }

        func tryLock() -> Bool {
            os_unfair_lock_trylock(self)
        }

        func unlock() {
            os_unfair_lock_unlock(self)
        }
    }

// MARK: - OSUnfairRecursiveLock

    // TODO: Use os_unfair_recursive_lock_s
    #if canImport(DarwinPrivate)

        private typealias OSUnfairRecursiveLock = UnsafeMutablePointer<os_unfair_recursive_lock_s>

        @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *)
        private extension UnsafeMutablePointer where Pointee == os_unfair_recursive_lock_s {
            init() {
                let l = UnsafeMutablePointer.allocate(capacity: 1)
                l.initialize(to: os_unfair_recursive_lock_s())
                self = l
            }

            func cleanupLock() {
                deinitialize(count: 1)
                deallocate()
            }

            func lock() {
                os_unfair_recursive_lock_lock(self)
            }

            func tryLock() -> Bool {
                let result = os_unfair_recursive_lock_trylock(self)
                return result
            }

            func unlock() {
                os_unfair_recursive_lock_unlock(self)
            }
        }

    #endif // canImport(DarwinPrivate)

#endif // canImport(Darwin)

// MARK: - PThreadMutex

private typealias PThreadMutex = UnsafeMutablePointer<pthread_mutex_t>

private extension UnsafeMutablePointer where Pointee == pthread_mutex_t {
    init(recursive: Bool) {
        let l = UnsafeMutablePointer<pthread_mutex_t>.allocate(capacity: 1)
        if recursive {
            var attr = pthread_mutexattr_t()
            pthread_mutexattr_init(&attr)
            pthread_mutexattr_settype(&attr, Int32(PTHREAD_MUTEX_RECURSIVE)).assertZero()
            pthread_mutex_init(l, &attr).assertZero()
        } else {
            pthread_mutex_init(l, nil).assertZero()
        }
        self = l
    }

    func cleanupLock() {
        pthread_mutex_destroy(self).assertZero()
        deinitialize(count: 1)
        deallocate()
    }

    func lock() {
        pthread_mutex_lock(self).assertZero()
    }

    func tryLock() -> Bool {
        pthread_mutex_trylock(self) == 0
    }

    func unlock() {
        pthread_mutex_unlock(self).assertZero()
    }
}

// MARK: Helpers

private extension UnsafeMutablePointer {
    @inline(__always)
    var raw: UnsafeMutableRawPointer {
        UnsafeMutableRawPointer(self)
    }
}

private extension UnsafeMutableRawPointer {
    @inline(__always)
    func `as`<T>(_: UnsafeMutablePointer<T>.Type) -> UnsafeMutablePointer<T> {
        assumingMemoryBound(to: T.self)
    }
}

private extension Int32 {
    @inline(__always)
    func assertZero() {
        // assert or precondition?
        assert(self == 0)
    }
}

Thanks for your input!

I was struggling with this myself, and found this old thread.
Thanks for pointing to .flatMap and using it to pass on the value, as well as a marker.

I think your Cut type only has two real states, a value or an "end" marker. You can reuse Optional for that, so the implementation can be simplified to the following:

extension Publisher {
    func cut(after isLastElement: @escaping (Output) -> Bool) -> AnyPublisher<Output, Failure> {
        self
            .flatMap { isLastElement($0) ? [$0, nil].publisher : [$0].publisher }
            .prefix { $0 != nil } // ← this is functionally the same as .isEnd
            .compactMap { $0 }    // ← this is functionally equivalent to .output
            .eraseToAnyPublisher()
    }
}
1 Like