diff --git a/src/lib.zig b/src/lib.zig index ecfeade..969bab6 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1,10 +1,26 @@ const std = @import("std"); +const builtin = @import("builtin"); const testing = std.testing; -export fn add(a: i32, b: i32) i32 { - return a + b; +pub fn OnceCell(comptime T: type) type { + if (builtin.single_threaded) { + return @import("./singlethread.zig").OnceCell(T); + } else { + return @import("./multithread.zig").OnceCell(T); + } } -test "basic add functionality" { - try testing.expect(add(3, 7) == 10); +// pub fn Lazy(comptime T: type, comptime f: fn () T) type { +// _ = f; +// return struct { +// cell: OnceCell(T), +// }; +// } + +test { + // Import All files + _ = @import("./singlethread.zig"); + _ = @import("./multithread.zig"); + + std.testing.refAllDecls(@This()); } diff --git a/src/multithread.zig b/src/multithread.zig new file mode 100644 index 0000000..3280935 --- /dev/null +++ b/src/multithread.zig @@ -0,0 +1,252 @@ +const std = @import("std"); +const testing = std.testing; + +pub fn Lazy(comptime T: type, comptime f: fn () T) type { + return struct { + cell: OnceCell(T), + + const Self = @This(); + + pub fn init() Self { + return Self{ .cell = OnceCell(T).init() }; + } + + pub fn get(self: *Self) *T { + return self.cell.getOrInit(f); + } + + pub fn getConst(self: *Self) *const T { + return self.cell.getOrInit(f); + } + }; +} + +pub fn OnceCell(comptime T: type) type { + return struct { + cell: T = undefined, + mutex: std.Thread.Mutex = std.Thread.Mutex{}, + done: bool = false, + + const Self = @This(); + + pub fn init() Self { + return Self{}; + } + + pub fn get(self: *Self) ?*T { + if (self.isInitialize()) { + return &self.cell; + } + return null; + } + + pub fn getUnchecked(self: *Self) *T { + std.debug.assert(self.isInitialize()); + + return &self.cell; + } + + pub fn getOrInit(self: *Self, comptime f: fn () T) *T { + // Fast path check + if (self.get()) |value| { + return value; + } + + self.initialize(f); + std.debug.assert(self.isInitialize()); + + return self.getUnchecked(); + } + + // -------------------------------------------------------------------------------- + // Core API + // -------------------------------------------------------------------------------- + + fn isInitialize(self: Self) bool { + return @atomicLoad(bool, &self.done, .Acquire); + } + + fn initialize(self: *Self, comptime f: fn () T) void { + @setCold(true); + + self.mutex.lock(); + defer self.mutex.unlock(); + + // The first thread to acquire the mutex gets to run the initializer + + if (!self.done) { + self.cell = f(); + defer @atomicStore(bool, &self.done, true, .Release); + } + } + }; +} + +// -------------------------------------------------------------------------------- +// Testing +// -------------------------------------------------------------------------------- + +const allocator = testing.allocator; +fn return_1() i32 { + return 1; +} + +fn return_2() i32 { + return 2; +} + +fn returnMap() std.StringHashMap(i32) { + var map = std.StringHashMap(i32).init(allocator); + map.put("b", 2) catch @panic("unable to put b"); + return map; +} + +var globalMap = OnceCell(std.StringHashMap(i32)).init(); + +test "test global map" { + _ = globalMap.getOrInit(returnMap); + var r1 = globalMap.get().?; + try r1.*.put("a", 1); + + try testing.expect(r1.*.get("b") != null); + try testing.expect(r1.*.get("b").? == 2); + + // must be same hashmap + _ = globalMap.getOrInit(returnMap); + var r2 = globalMap.get().?; + + try testing.expect(r2.*.get("a") != null); + try testing.expect(r2.*.get("a").? == 1); + + defer r2.*.deinit(); +} + +test "test assume init" { + var cell1 = OnceCell(i32).init(); + const r1 = cell1.getOrInit(return_1); + const r2 = cell1.getUnchecked(); + + try testing.expect(r1.* == 1); + try testing.expect(r2.* == 1); +} + +test "test cell multi init" { + var cell1 = OnceCell(i32).init(); + var cell2 = OnceCell(i32).init(); + + const r1 = cell1.getOrInit(return_1); + const r2 = cell1.getOrInit(return_1); + const r3 = cell1.getOrInit(return_1); + + try testing.expect(r1.* == 1); + try testing.expectEqual(r1, r2); + try testing.expectEqual(r2, r3); + + const a1 = cell2.getOrInit(return_2); + const a2 = cell2.getOrInit(return_2); + const a3 = cell2.getOrInit(return_2); + + try testing.expect(a1.* == 2); + try testing.expectEqual(a1, a2); + try testing.expectEqual(a2, a3); +} + +var shared: i32 = 0; + +fn incrShared() i32 { + shared += 1; + return shared; +} + +var cell3 = OnceCell(i32).init(); + +test "test multithread shared value" { + var threads: [10]std.Thread = undefined; + defer for (threads) |handle| handle.join(); + + for (&threads) |*handle| { + handle.* = try std.Thread.spawn(.{}, struct { + fn thread_fn(x: u8) void { + _ = x; + _ = cell3.getOrInit(incrShared); + } + }.thread_fn, .{0}); + } + + try testing.expectEqual(@as(i32, 1), shared); +} + +var LazyMap = Lazy(std.StringHashMap(i32), returnMap).init(); + +test "test lazy" { + var map = LazyMap.get(); + defer map.*.deinit(); + try map.*.put("c", 3); + + var map2 = LazyMap.get(); + + try testing.expect(map2.*.get("c") != null); + try testing.expect(map2.*.get("c").? == 3); +} + +const MutexStringHashMap = struct { + mutex: std.Thread.Mutex, + str_map: std.StringHashMap(i32), + + const Self = @This(); + pub fn init() Self { + return Self{ + .mutex = std.Thread.Mutex{}, + .str_map = std.StringHashMap(i32).init(allocator), + }; + } + + pub fn deinit(self: *Self) void { + self.str_map.deinit(); + } + + pub fn borrow(self: *Self) *std.StringHashMap(i32) { + self.mutex.lock(); + return &self.str_map; + } + + pub fn restore(self: *Self) void { + self.mutex.unlock(); + } +}; + +fn returnMutexMap() MutexStringHashMap { + return MutexStringHashMap.init(); +} + +var LazyMutexMap = Lazy(MutexStringHashMap, returnMutexMap).init(); + +test "test lazy mutex map" { + var obj = LazyMutexMap.get(); + defer obj.*.deinit(); + + var map = obj.*.borrow(); + try map.*.put("v", 0); + obj.*.restore(); + + var threads: [10]std.Thread = undefined; + + for (&threads) |*handle| { + handle.* = try std.Thread.spawn(.{}, struct { + fn thread_fn(x: u8) !void { + _ = x; + var o = LazyMutexMap.get(); + var m = o.*.borrow(); + + const v = m.get("v").?; + try m.put("v", v + 1); + o.*.restore(); + } + }.thread_fn, .{0}); + } + + for (threads) |handle| handle.join(); + + const map2 = obj.*.borrow(); + try testing.expectEqual(@as(i32, 10), map2.*.get("v").?); +} diff --git a/src/singlethread.zig b/src/singlethread.zig new file mode 100644 index 0000000..2738d16 --- /dev/null +++ b/src/singlethread.zig @@ -0,0 +1,4 @@ +pub fn OnceCell(comptime T: type) type { + _ = T; + return struct {}; +}