Implement Equatable protocol in a class hierarchy

I have the following simple class hierarchy: A conforms to Hashable and has a single subclass B. Adding distinct instances to a Set (or a Dictionary) results in a weird behaviour because the generic element of Set is internally calling static func == (lhs: A, rhs: A) -> Bool (instead of the one defined on B) while the hashValue is correctly evaluated on the overridden implementation of B. As a result the set will contain unexpected elements. Things get worst with the new non-deterministing value of hashValue in Swift 4.2, causing the set to contain either 1 or 2 elements depending on the seed value.

How does one properly handle class hierarchy and conformance to Equatable?

class A: Hashable {
    let a: Int
    
    init(a: Int) { self.a = a }

    static func == (lhs: A, rhs: A) -> Bool { return lhs.a == rhs.a }
    
    func hash(into hasher: inout Hasher) { hasher.combine(a) }
}

class B: A {
    let b: Int
    
    init(a: Int, b: Int) {
        self.b = b
        super.init(a: a)
    }
    
    static func == (lhs: B, rhs: B) -> Bool { return lhs.a == rhs.a && lhs.b == rhs.b }
    
    override func hash(into hasher: inout Hasher) {
        hasher.combine(a)
        hasher.combine(b)
    }
}


let b1 = B(a: 1, b: 2)
let b2 = B(a: 1, b: 3)

var set = Set<B>()

set.insert(b1)
set.insert(b2)

print(set.count) // sometimes 1, sometimes 2
1 Like

There are sort of two answers here:

  1. The Liskov Substitution Principle says that any subclass needs to be able to fill in where a superclass is expected. That means that if a superclass's == only tests a, the subclass can't start also making b significant. What would happen when you test an A against a B in that case? Or a B against another subclass C? This isn't just about == either; if someone is trying to make a set of all distinct A values, and then they find that the value 3 is in there five times, they'd be really confused.

  2. What you can do in some cases is have subclasses be more equal than superclasses. That is, you can say "all A values are only equal when a is equal, but B values can be equal whenever a is equal or when b is equal". (I have no idea why you'd want to do this, but it doesn't break any rules.) In this case, you have to be able to override equality checking, but since == must be declared static and static implies final you can't do that directly. Instead, you can have == call out to an instance method, perhaps equals(_:), and then override that. This is what Foundation's NSObject type does.

Note that if you don't want to get this tricky, a safe option is always to make == compare class identity using ===, and have the hash value be based on the hash of ObjectIdentifier(self). That doesn't unique based on any property values, but you may not actually need that.

2 Likes

To make this more concrete, there's a problem with your code even if == were somehow dispatched dynamically exactly as you wished. Consider:

let a = A(a: 42)
let b0 = B(a: 42, b: 42)
a == b // true

This is, of course, unlikely to be what you actually want. Calling out to an instance method both (a) does what you want; and (b) automagically helps you avoid the problem above:

class A: Hashable {
  let a: Int

  init(a: Int) {
    self.a = a
  }

  static func == (lhs: A, rhs: A) -> Bool {
    return lhs.isEqual(to: rhs)
  }

  func isEqual(to other: A) -> Bool {
    return self.a == other.a
  }

  func hash(into hasher: inout Hasher) {
    hasher.combine(a)
  }
}

class B: A {
  let b: Int

  init(a: Int, b: Int) {
    self.b = b
    super.init(a: a)
  }

  override func isEqual(to other: A) -> Bool {
    // Note how we don't have a choice here:
    // Swift _requires_ us to cast or else we can't compare the property `b`.
    //
    // By doing so, it prevents us from the problem outlined above (provided,
    // of course, that any subclass of B does the same thing in turn).
    guard let other = other as? B else { return false }
    return self.a == other.a && self.b == other.b
  }

  override func hash(into hasher: inout Hasher) {
    hasher.combine(a)
    hasher.combine(b)
  }
}
1 Like

Oh, and don't forget the rule about hash values: a == b implies a.hashValue == b.hashValue. (Not necessarily the other way around, though.) Otherwise when you go to look something up in a Set you'll be looking in the wrong place. So if you make a subclass more equal-able than the superclass, you'll have to override the hash value too, and in such a way that doesn't mess with how the superclass defined it. (This means my example about "either a or b being equal" doesn't actually work if you want to conform to Hashable and not just Equatable, because you may end up comparing a base class to a subclass.)

IIRC this is a pretty good article on the problem in general: https://www.artima.com/lejava/articles/equality.html

(caveats: it's about Java and I haven't re-read it to see how well it applies to Swift)

2 Likes

You can do something like the following (as @Ben_Cohen pointed out):

import Foundation

protocol AnyEquatable {
    func equals(rhs: AnyEquatable) -> Bool
    func canEqualReverseDispatch(lhs: AnyEquatable) -> Bool
}

/*final*/ func ==(lhs: AnyEquatable, rhs: AnyEquatable) -> Bool {
    return lhs.equals(rhs: rhs) // Fix the type of the LHS using dynamic dispatch.
}
/*final*/ func !=(lhs: AnyEquatable, rhs: AnyEquatable) -> Bool {
    return !lhs.equals(rhs: rhs) // Fix the type of the LHS using dynamic dispatch.
}

class Point2D: AnyEquatable {
    let x: Double
    let y: Double
    init(x: Double, y: Double) {
        self.x = x
        self.y = y
    }
    func equals(rhs: Point2D) -> Bool {
        return x == rhs.x && y == rhs.y
    }
    func equals(rhs: AnyEquatable) -> Bool {
        guard rhs.canEqualReverseDispatch(lhs: self), let r = rhs as? Point2D else { // Fix type of RHS via a failable cast.
            return false // or fatalError("Coding Error: unequatable types; lhs: \(self), rhs: \(rhs).")
        }
        return equals(rhs: r) // LHS and RHS both Point2Ds.
    }
    func canEqualReverseDispatch(lhs: AnyEquatable) -> Bool {
        return lhs is Point2D // By default derrived types may be equal.
    }
}

let p20 = Point2D(x: 0, y: 0)
let p21 = Point2D(x: 1, y: 1)
p20 == p20 // T
p20 == p21 // F

// PointPolar can be added retrospectively (main point of technique!) and can be equal to a Point2D.
class PointPolar: Point2D {
    init(rho: Double, theta: Double) {
        super.init(x: rho * cos(theta), y: rho * sin(theta))
    }
}

let pp0 = PointPolar(rho: 0, theta: 0)
pp0 == p20 // T
p20 == pp0 // T
pp0 == p21 // F
p21 == pp0 // F

// Point3D can be added retrospectively (main point of technique!), but must be always unequal to a Point2D.
class Point3D: Point2D {
    let z: Double
    init(x: Double, y: Double, z: Double) {
        self.z = z
        super.init(x: x, y: y)
    }
    func equals(rhs: Point3D) -> Bool {
        return x == rhs.x && y == rhs.y && z == rhs.z
    }
    override func equals(rhs: AnyEquatable) -> Bool {
        guard rhs.canEqualReverseDispatch(lhs: self), let r = rhs as? Point3D else { // Fix type of RHS via a failable cast.
            return false // or fatalError("Coding Error: unequatable types; lhs: \(self), rhs: \(rhs).")
        }
        return equals(rhs: r) // LHS and RHS both Point3Ds.
    }
    override func canEqualReverseDispatch(lhs: AnyEquatable) -> Bool {
        return lhs is Point3D // Make Point3D unequal to Point2D.
    }
}

let p30 = Point3D(x: 0, y: 0, z: 0)
let p31 = Point3D(x: 1, y: 1, z: 1)
p30 == p30 // T
p30 == p31 // F

p20 == p30 // F
p30 == p20 // F
p21 == p30 // F
p30 == p21 // F
p20 == p31 // F
p31 == p20 // F
p21 == p31 // F
p31 == p21 // F

var result = ""
let ps: [AnyEquatable] = [p20, p21, pp0, p30, p31]
for po in ps {
    for pi in ps {
        result += ", \(po == pi)"
    }
}
result // TFTFF FTFFF TFTFF FFFTF FFFFT
1 Like

Thanks for all the feedback! It's really interesting how all this is simple and complex at the same time.

Having tried it in various languages over the years, the recipe I follow goes like this:

  • Static == function on the base class that calls a normal instance method, usually called equals(_ other: BaseClass)
  • Each subclass implements equals in a sensible way per its own implementation but the rule is that the other object must have exactly the same type as self in order to compare equal.

So your class A would have

static func == (a: A, b: A) -> Bool { return a.equals(b) }

func equals(other: A) -> Bool 
{
    guard type(of: other) == type(of: self) else { return false }
    return self.a == other.a
}

And B would have:

func equals(other: A) -> Bool 
{
    guard let otherAsB = other as? B else { return false }
    return self.a == otherAsB.a && self.b == otherAsB.b
}

This is the only way I have found to guarantee symmetry (i.e. a == b implies b == a) and transitivity (a == b && b == c implies a == c).

2 Likes

As a minor improvement, we could move the type check type(of: self) == type(of: other) to the static == function.
This would allow us, to override equals in the subclass and call super.equals.
By doing this, we don't need to check all the properties of the superclass in our Subclass.equals function.

The next problem with this approach is though, that we now have an isEqual function, that cannot be made private (otherwise the subclass would not be able to override it).
The function does not check the type itself and therefore potentially crashes, when being called with a parameter, that does not have the right type.

But since we, the programmer, are the only person that may be using the internal isEqual function, that feels like a smaller drawback than copy-pasting the equals checks, in my opinion.

Alternatively we could define all super- and subclasses in the same file and make the isEquals function fileprivate. (see the link below)

Example: see below or here for the updated version using fileprivate and extensions.

class Vehicle: Equatable {
    let name: String
    
    init(name: String) {
        self.name = name
    }
    
    func isEqual(to other: Vehicle) -> Bool {
        return self.name == other.name
    }
    
    static func == (lhs: Vehicle, rhs: Vehicle) -> Bool {
        return type(of: lhs) == type(of: rhs) && lhs.isEqual(to: rhs) // If the types mismatch, lhs.isEqual(to:) will never be called.
    }
}

class Car: Vehicle {
    let seats: Int
    
    init(name: String, seats: Int) {
        self.seats = seats
        super.init(name: name)
    }
    
    // This function will only be called, after checking, that self and other have the same type, so we can assume, that other is of type Car too.
    override func isEqual(to other: Vehicle) -> Bool {
        assert(type(of: other) == Car.self, "`other` does not have the same type as `self`. The static function `==` should have prevented that.")
        let otherCar = other as! Car
        return super.isEqual(to: other) && self.seats == otherCar.seats
    }
}

class Bicycle: Vehicle {
    let color: String
    
    init(name: String, color: String) {
        self.color = color
        super.init(name: name)
    }
    
    // This function will only be called, after checking, that self and other have the same type, so we can assume, that other is of type Car too.
    override func isEqual(to other: Vehicle) -> Bool {
        assert(type(of: other) == Bicycle.self, "`other` does not have the same type as `self`. The static function `==` should have prevented that.")
        let otherBicycle = other as! Bicycle
        return super.isEqual(to: other) && self.color == otherBicycle.color
    }
}
Terms of Service

Privacy Policy

Cookie Policy