Performant byte-based String wrapper

I looked into how String is handling things, and I don’t think I can (and maybe also don’t want to) work as low-level. Instead, on creation, I now check whether the string is ASCII-only, and if so, use the String implementations of ==, hash(into:), and < since the characters of an all-ASCII String fulfill my byte-based equality as well.

Checking for all-ASCII is done using code I copied from Swift’s internal standard library:

This function has a comment pointing out possible optimizations via SIMD, but it is not trivial as-is either, reading and masking eight bytes at a time. That should be good enough for me, too.

With the is-ASCII check added, ByteString creation is twice as slow as before. However, creation was not a bottleneck for me before and still is not now. I could also offer a second initializer that skips the check (downgrading the “is ASCII” to a “is known to be ASCII” property) if I need more short-lived ByteString values in the future where creation time is more critical.

The new implementation matches or beats Swift’s String time for == and < in a mixed ASCII/Unicode corpus with a heavy ASCII-only bias. (Ignoring creation time.) Hashing is still a bit slower, but only marginally so.

Here is the updated ByteString:

import Darwin

public struct ByteString: Codable, Sendable {
    public let value: String
    public let isASCII: Bool
    
    /// Copied from the [internal Swift standard library](https://github.com/apple/swift/blob/main/stdlib/public/core/StringCreate.swift), licensed under the [Apache License 2.0](https://www.apache.org/licenses/)
    @usableFromInline
    static func isAllASCII(_ input: UnsafeBufferPointer<UInt8>) -> Bool {
        if input.isEmpty { return true }
        
        let count = input.count
        var pointer = UnsafeRawPointer(input.baseAddress.unsafelyUnwrapped)
        
        let asciiMask64 = 0x8080_8080_8080_8080 as UInt64
        let asciiMask32 = UInt32(truncatingIfNeeded: asciiMask64)
        let asciiMask16 = UInt16(truncatingIfNeeded: asciiMask64)
        let asciiMask8 = UInt8(truncatingIfNeeded: asciiMask64)
        
        let end128 = pointer + count & ~(MemoryLayout<(UInt64, UInt64)>.stride &- 1)
        let end64 = pointer + count & ~(MemoryLayout<UInt64>.stride &- 1)
        let end32 = pointer + count & ~(MemoryLayout<UInt32>.stride &- 1)
        let end16 = pointer + count & ~(MemoryLayout<UInt16>.stride &- 1)
        let end = pointer + count
        
        while pointer < end128 {
            let pair = pointer.loadUnaligned(as: (UInt64, UInt64).self)
            let result = (pair.0 | pair.1) & asciiMask64
            guard result == 0 else { return false }
            pointer = pointer + MemoryLayout<(UInt64, UInt64)>.stride
        }
        
        if pointer < end64 {
            let value = pointer.loadUnaligned(as: UInt64.self)
            guard value & asciiMask64 == 0 else { return false }
            pointer = pointer + MemoryLayout<UInt64>.stride
        }
        
        if pointer < end32 {
            let value = pointer.loadUnaligned(as: UInt32.self)
            guard value & asciiMask32 == 0 else { return false }
            pointer = pointer + MemoryLayout<UInt32>.stride
        }
        
        if pointer < end16 {
            let value = pointer.loadUnaligned(as: UInt16.self)
            guard value & asciiMask16 == 0 else { return false }
            pointer = pointer + MemoryLayout<UInt16>.stride
        }
        
        if pointer < end {
            let value = pointer.loadUnaligned(fromByteOffset: 0, as: UInt8.self)
            guard value & asciiMask8 == 0 else { return false }
        }
        
        return true
    }
    
    @inlinable
    public init(_ value: String) {
        var value = value
        self.isASCII = value.withUTF8(ByteString.isAllASCII(_:))
        self.value = value
    }
}

extension ByteString: Hashable {
    @inlinable
    @_effects(readonly)
    public static func == (lhs: Self, rhs: Self) -> Bool {
        if lhs.isASCII {
            guard rhs.isASCII else {
                return false
            }
            return lhs.value == rhs.value
        }
        else if rhs.isASCII {
            return false
        }
        
        var lhs = lhs.value
        var rhs = rhs.value
        return lhs.withUTF8 { lhsUTF8 in
            rhs.withUTF8 { rhsUTF8 in
                guard lhsUTF8.count == rhsUTF8.count else {
                    return false
                }
                let lhsBaseAddress = lhsUTF8.baseAddress.unsafelyUnwrapped
                let rhsBaseAddress = rhsUTF8.baseAddress.unsafelyUnwrapped
                if lhsBaseAddress == rhsBaseAddress {
                    return true
                }
                return memcmp(lhsBaseAddress, rhsBaseAddress, lhsUTF8.count) == 0
            }
        }
    }
    
    @inlinable
    public func hash(into hasher: inout Hasher) {
        if isASCII {
            hasher.combine(value)
        }
        else {
            var value = value
            value.withUTF8 { utf8 in
                hasher.combine(bytes: UnsafeRawBufferPointer(utf8))
            }
        }
    }
}

extension ByteString: Comparable {
    @inlinable
    @_effects(readonly)
    public static func < (lhs: ByteString, rhs: ByteString) -> Bool {
        if lhs.isASCII && rhs.isASCII {
            return lhs.value < rhs.value
        }
        
        var lhs = lhs.value
        var rhs = rhs.value
        return lhs.withUTF8 { lhsUTF8 in
            rhs.withUTF8 { rhsUTF8 in
                let lhsBaseAddress = lhsUTF8.baseAddress.unsafelyUnwrapped
                let rhsBaseAddress = rhsUTF8.baseAddress.unsafelyUnwrapped
                if lhsBaseAddress == rhsBaseAddress && lhsUTF8.count == rhsUTF8.count {
                    return false
                }
                let count = min(lhsUTF8.count, rhsUTF8.count)
                return memcmp(lhsBaseAddress, rhsBaseAddress, count) < 0
            }
        }
    }
}

Thanks for your feedback!

2 Likes