Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integer midpoint #293

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions Sources/IntegerUtilities/Midpoint.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===--- Midpoint.swift ---------------------------------------*- swift -*-===//
//
// This source file is part of the Swift Numerics open source project
//
// Copyright (c) 2024 Apple Inc. and the Swift Numerics project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
//
//===----------------------------------------------------------------------===//

/// The average of `a` and `b`, rounded to an integer according to `rule`.
///
/// Unlike commonly seen expressions such as `(a+b)/2` or `(a+b) >> 1` or
/// `a + (b-a)/2` (all of which may overflow), this function never overflows,
/// and the result is guaranteed to be representable in the result type.
///
/// The default rounding rule is `.down`, which matches the behavior of
/// `(a + b) >> 1` when that expression does not overflow. Rounding
/// `.towardZero` matches the behavior of `(a + b)/2` when that expression
/// does not overflow. All other rounding modes are supported.
///
/// Rounding `.down` is generally most efficient; if you do not have a
/// reason to chose a specific other rounding rule, you should use the
/// default.
@inlinable
public func midpoint<T: FixedWidthInteger>(
_ a: T,
_ b: T,
rounding rule: RoundingRule = .down
) -> T {
// Isolate bits in a + b with weight 2, and those with weight 1
let twos = a & b
let ones = a ^ b
let floor = twos &+ ones >> 1
let frac = ones & 1
switch rule {
case .toNearestOrDown:
fallthrough
case .down:
return floor
case .toNearestOrUp:
fallthrough
case .up:
return floor &+ frac
case .toNearestOrZero:
fallthrough
case .towardZero:
return floor &+ (floor < 0 ? frac : 0)
case .toNearestOrAway:
fallthrough
case .awayFromZero:
return floor &+ (floor >= 0 ? frac : 0)
case .toNearestOrEven:
return floor &+ (floor & frac)
case .toOdd:
return floor &+ (~floor & frac)
case .stochastically:
return floor &+ (Bool.random() ? frac : 0)
case .requireExact:
precondition(frac == 0)
return floor
}
}
2 changes: 1 addition & 1 deletion Sources/IntegerUtilities/RoundingRule.swift
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,5 @@ extension RoundingRule {
/// > Deprecated: Use `.toNearestOrAway` instead.
@inlinable
@available(*, deprecated, renamed: "toNearestOrAway")
static var toNearestOrAwayFromZero: Self { .toNearestOrAway }
public static var toNearestOrAwayFromZero: Self { .toNearestOrAway }
}
42 changes: 42 additions & 0 deletions Tests/IntegerUtilitiesTests/MidpointTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===--- MidpointTests.swift ----------------------------------*- swift -*-===//
//
// This source file is part of the Swift Numerics open source project
//
// Copyright (c) 2024 Apple Inc. and the Swift Numerics project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import IntegerUtilities
import XCTest

final class IntegerUtilitiesMidpointTests: XCTestCase {
func testMidpoint() {
for rule in [
RoundingRule.down,
.up,
.towardZero,
.awayFromZero,
.toNearestOrDown,
.toNearestOrUp,
.toNearestOrZero,
.toNearestOrAway,
.toNearestOrEven,
.toOdd
] {
for a in -128 ... 127 {
for b in -128 ... 127 {
let ref = (a + b).shifted(rightBy: 1, rounding: rule)
let tst = midpoint(Int8(a), Int8(b), rounding: rule)
if ref != tst {
print(rule, a, b, ref, tst, separator: "\t")
return
}
}
}
}
}
}