Skip to content

Commit

Permalink
try impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanaasagi committed Jun 15, 2023
1 parent 2e0b548 commit 3013b47
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/lib.zig
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
const std = @import("std");
const builtin = @import("builtin");
const testing = std.testing;

export fn add(a: i32, b: i32) i32 {
return a + b;
pub fn OnceCell(comptime T: type) type {
if (builtin.single_threaded) {
return @import("./singlethread.zig").OnceCell(T);
} else {
return @import("./multithread.zig").OnceCell(T);
}
}

test "basic add functionality" {
try testing.expect(add(3, 7) == 10);
// pub fn Lazy(comptime T: type, comptime f: fn () T) type {
// _ = f;
// return struct {
// cell: OnceCell(T),
// };
// }

test {
// Import All files
_ = @import("./singlethread.zig");
_ = @import("./multithread.zig");

std.testing.refAllDecls(@This());
}
252 changes: 252 additions & 0 deletions src/multithread.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
const std = @import("std");
const testing = std.testing;

pub fn Lazy(comptime T: type, comptime f: fn () T) type {
return struct {
cell: OnceCell(T),

const Self = @This();

pub fn init() Self {
return Self{ .cell = OnceCell(T).init() };
}

pub fn get(self: *Self) *T {
return self.cell.getOrInit(f);
}

pub fn getConst(self: *Self) *const T {
return self.cell.getOrInit(f);
}
};
}

pub fn OnceCell(comptime T: type) type {
return struct {
cell: T = undefined,
mutex: std.Thread.Mutex = std.Thread.Mutex{},
done: bool = false,

const Self = @This();

pub fn init() Self {
return Self{};
}

pub fn get(self: *Self) ?*T {
if (self.isInitialize()) {
return &self.cell;
}
return null;
}

pub fn getUnchecked(self: *Self) *T {
std.debug.assert(self.isInitialize());

return &self.cell;
}

pub fn getOrInit(self: *Self, comptime f: fn () T) *T {
// Fast path check
if (self.get()) |value| {
return value;
}

self.initialize(f);
std.debug.assert(self.isInitialize());

return self.getUnchecked();
}

// --------------------------------------------------------------------------------
// Core API
// --------------------------------------------------------------------------------

fn isInitialize(self: Self) bool {
return @atomicLoad(bool, &self.done, .Acquire);
}

fn initialize(self: *Self, comptime f: fn () T) void {
@setCold(true);

self.mutex.lock();
defer self.mutex.unlock();

// The first thread to acquire the mutex gets to run the initializer

if (!self.done) {
self.cell = f();
defer @atomicStore(bool, &self.done, true, .Release);
}
}
};
}

// --------------------------------------------------------------------------------
// Testing
// --------------------------------------------------------------------------------

const allocator = testing.allocator;
fn return_1() i32 {
return 1;
}

fn return_2() i32 {
return 2;
}

fn returnMap() std.StringHashMap(i32) {
var map = std.StringHashMap(i32).init(allocator);
map.put("b", 2) catch @panic("unable to put b");
return map;
}

var globalMap = OnceCell(std.StringHashMap(i32)).init();

test "test global map" {
_ = globalMap.getOrInit(returnMap);
var r1 = globalMap.get().?;
try r1.*.put("a", 1);

try testing.expect(r1.*.get("b") != null);
try testing.expect(r1.*.get("b").? == 2);

// must be same hashmap
_ = globalMap.getOrInit(returnMap);
var r2 = globalMap.get().?;

try testing.expect(r2.*.get("a") != null);
try testing.expect(r2.*.get("a").? == 1);

defer r2.*.deinit();
}

test "test assume init" {
var cell1 = OnceCell(i32).init();
const r1 = cell1.getOrInit(return_1);
const r2 = cell1.getUnchecked();

try testing.expect(r1.* == 1);
try testing.expect(r2.* == 1);
}

test "test cell multi init" {
var cell1 = OnceCell(i32).init();
var cell2 = OnceCell(i32).init();

const r1 = cell1.getOrInit(return_1);
const r2 = cell1.getOrInit(return_1);
const r3 = cell1.getOrInit(return_1);

try testing.expect(r1.* == 1);
try testing.expectEqual(r1, r2);
try testing.expectEqual(r2, r3);

const a1 = cell2.getOrInit(return_2);
const a2 = cell2.getOrInit(return_2);
const a3 = cell2.getOrInit(return_2);

try testing.expect(a1.* == 2);
try testing.expectEqual(a1, a2);
try testing.expectEqual(a2, a3);
}

var shared: i32 = 0;

fn incrShared() i32 {
shared += 1;
return shared;
}

var cell3 = OnceCell(i32).init();

test "test multithread shared value" {
var threads: [10]std.Thread = undefined;
defer for (threads) |handle| handle.join();

for (&threads) |*handle| {
handle.* = try std.Thread.spawn(.{}, struct {
fn thread_fn(x: u8) void {
_ = x;
_ = cell3.getOrInit(incrShared);
}
}.thread_fn, .{0});
}

try testing.expectEqual(@as(i32, 1), shared);
}

var LazyMap = Lazy(std.StringHashMap(i32), returnMap).init();

test "test lazy" {
var map = LazyMap.get();
defer map.*.deinit();
try map.*.put("c", 3);

var map2 = LazyMap.get();

try testing.expect(map2.*.get("c") != null);
try testing.expect(map2.*.get("c").? == 3);
}

const MutexStringHashMap = struct {
mutex: std.Thread.Mutex,
str_map: std.StringHashMap(i32),

const Self = @This();
pub fn init() Self {
return Self{
.mutex = std.Thread.Mutex{},
.str_map = std.StringHashMap(i32).init(allocator),
};
}

pub fn deinit(self: *Self) void {
self.str_map.deinit();
}

pub fn borrow(self: *Self) *std.StringHashMap(i32) {
self.mutex.lock();
return &self.str_map;
}

pub fn restore(self: *Self) void {
self.mutex.unlock();
}
};

fn returnMutexMap() MutexStringHashMap {
return MutexStringHashMap.init();
}

var LazyMutexMap = Lazy(MutexStringHashMap, returnMutexMap).init();

test "test lazy mutex map" {
var obj = LazyMutexMap.get();
defer obj.*.deinit();

var map = obj.*.borrow();
try map.*.put("v", 0);
obj.*.restore();

var threads: [10]std.Thread = undefined;

for (&threads) |*handle| {
handle.* = try std.Thread.spawn(.{}, struct {
fn thread_fn(x: u8) !void {
_ = x;
var o = LazyMutexMap.get();
var m = o.*.borrow();

const v = m.get("v").?;
try m.put("v", v + 1);
o.*.restore();
}
}.thread_fn, .{0});
}

for (threads) |handle| handle.join();

const map2 = obj.*.borrow();
try testing.expectEqual(@as(i32, 10), map2.*.get("v").?);
}
4 changes: 4 additions & 0 deletions src/singlethread.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub fn OnceCell(comptime T: type) type {
_ = T;
return struct {};
}

0 comments on commit 3013b47

Please sign in to comment.