Skip to content

Commit

Permalink
Merge pull request #174 from leondavi/leondavi/tests
Browse files Browse the repository at this point in the history
[nerlNIF] nerltensor_conversion fix an issue
  • Loading branch information
leondavi committed May 12, 2023
2 parents cba66f2 + eec83d1 commit 3a40718
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
40 changes: 29 additions & 11 deletions src_erl/erlBridge/nerlNIF.erl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
-module(nerlNIF).
-include_lib("kernel/include/logger.hrl").
-include("nerlTensor.hrl").

-import(nerl,[tic/0, toc/1]).

-export([init/0,create_nif/6,train_nif/5,call_to_train/6,predict_nif/2,call_to_predict/5,get_weights_nif/1,printTensor/2]).
-export([call_to_get_weights/1,call_to_set_weights/2]).
-export([decode_nif/2, nerltensor_binary_decode/2]).
-export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0]).
-export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]).

-define(FILE_IDENTIFIER,"[NERLNIF] ").
-define(NERLNET_LIB,"libnerlnet").
Expand All @@ -22,7 +23,6 @@

%nerltensor
-define(NUMOF_DIMS,3).
-include("nerlTensor.hrl").

-export([nerltensor_sum_nif/3]).
-export([nerltensor_sum_erl/2]).
Expand Down Expand Up @@ -124,21 +124,39 @@ nerltensor_binary_decode(Binary, Type) when erlang:is_binary(Binary) and erlang:

% return the merged list of all supported binary types
get_all_binary_types() -> ?LIST_BINARY_FLOAT_NERLTENSOR_TYPE ++ ?LIST_BINARY_INT_NERLTENSOR_TYPE.

get_all_nerltensor_list_types() -> ?LIST_GROUP_NERLTENSOR_TYPE.
% nerltensor_conversion:
% Type is Binary then: Binary (Compressed Form) --> Erlang List
% Type is list then: Erlang List --> Binary
nerltensor_conversion({NerlTensor, Type}, ResType) ->
BinaryGroup = lists:member(Type, get_all_binary_types()), % compressed type
ListGroup = lists:member(Type, ?LIST_GROUP_NERLTENSOR_TYPE), % non compressed, list type
case ResType of
ResType when BinaryGroup -> decode_nif(NerlTensor,Type); % returns {Binary, Type}
ResType when ListGroup -> encode_nif(NerlTensor,Type); % returns {Binary, Type}
_ERROR -> error % TODO add log here
nerltensor_conversion({NerlTensor, Type}, ResType) ->
TypeListGroup = lists:member(Type, get_all_nerltensor_list_types()),
ResTypeListGroup = lists:member(ResType, get_all_nerltensor_list_types()),

{Operation, ErlType, BinType} =
case {TypeListGroup, ResTypeListGroup} of
{true, false} -> {encode, Type, ResType};
{false, true} -> {decode, ResType, Type};
_ -> throw("invalid types combination")
end,

BinTypeInteger = lists:member(BinType, ?LIST_BINARY_INT_NERLTENSOR_TYPE),
BinTypeFloat = lists:member(BinType, ?LIST_BINARY_FLOAT_NERLTENSOR_TYPE),

% Wrong combination guard
case ErlType of
erl_float when BinTypeFloat-> ok;
erl_int when BinTypeInteger -> ok;
_ -> throw("invalid types combination")
end,

case Operation of
encode -> encode_nif(NerlTensor, BinType);
decode -> decode_nif(NerlTensor, BinType);
_ -> throw("wrong operation")
end.

nerltensor_sum_erl({NerlTensorErlA, Type}, {NerlTensorErlB, Type}) ->
ListGroup = lists:member(Type, ?LIST_GROUP_NERLTENSOR_TYPE),
ListGroup = lists:member(Type, get_all_nerltensor_list_types()),
if ListGroup ->
DIMS = lists:sublist(NerlTensorErlA, 1, ?NUMOF_DIMS),
NerlTensorErlA_NODIMS = lists:sublist(NerlTensorErlA, ?NUMOF_DIMS + 1, length(NerlTensorErlA) - ?NUMOF_DIMS),
Expand Down
52 changes: 41 additions & 11 deletions src_erl/erlBridge/nerlTests.erl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
-import(nerlNIF,[init/0,create_nif/6,train_nif/5,call_to_train/6,predict_nif/2,call_to_predict/5,get_weights_nif/1,printTensor/2]).
-import(nerlNIF,[call_to_get_weights/1,call_to_set_weights/2]).
-import(nerlNIF,[decode_nif/2, nerltensor_binary_decode/2]).
-import(nerlNIF,[encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0]).
-import(nerlNIF,[encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]).
-import(nerlNIF,[nerltensor_sum_nif/3]).
-import(nerl,[compare_floats_L/3, string_format/2, logger_settings/1]).

Expand All @@ -18,18 +18,19 @@ nerltest_print(String) ->
logger:notice(?NERLTEST_PRINT_STR++String).

% encode_decode test macros
-define(ENCODE_DECODE_ROUNDS, 100).
-define(DIMX_RAND_MAX, 200).
-define(DIMY_RAND_MAX, 200).
-define(DIMX_RAND_MAX, 2).
-define(DIMY_RAND_MAX, 2).
-define(SUM_NIF_ROUNDS, 100).
-define(ENCODE_DECODE_ROUNDS, 100).
-define(NERLTENSOR_CONVERSION_ROUNDS, 50).

run_tests()->
nerl:logger_settings(nerlTests),
DimXStr = integer_to_list(?DIMX_RAND_MAX),
DimYStr = integer_to_list(?DIMY_RAND_MAX),
nerltest_print("encode decode test starts "++integer_to_list(?ENCODE_DECODE_ROUNDS)++" tests up to ("++DimXStr++","++DimYStr++")"),
Tic_niftest_encode_decode = nerl:tic(),
niftest_encode_decode(?ENCODE_DECODE_ROUNDS,[]), % throws if a test fails
encode_decode_nifs_test(?ENCODE_DECODE_ROUNDS,[]), % throws if a test fails
{TDiff_niftest_encode_decode, TimeUnit} = nerl:toc(Tic_niftest_encode_decode),
nerltest_print(nerl:string_format("Elapsed: ~p~p",[TDiff_niftest_encode_decode,TimeUnit])),

Expand All @@ -52,6 +53,10 @@ run_tests()->
{TDiff_nerltensor_sum_nif_test_double, _} = nerl:toc(Tic_nerltensor_sum_nif_test_double),
nerltest_print(nerl:string_format("Elapsed: ~p~p, Avg nif operations: ~.4f~p",[TDiff_nerltensor_sum_nif_test_double,TimeUnit,PerformanceSumNifDouble,TimeUnit])),

nerltest_print("nerltensor_conversion_test starts "++integer_to_list(?NERLTENSOR_CONVERSION_ROUNDS)++" tests"),
nerltensor_conversion_test(?NERLTENSOR_CONVERSION_ROUNDS),

nerltest_print("Tests Completed"),
ok.

random_pick_nerltensor_type()->
Expand Down Expand Up @@ -97,18 +102,18 @@ nerltensor_sum_nif_test(Type, N, Performance) ->
{TocRes, _} = nerl:toc(Tic),
PerformanceNew = TocRes + Performance,
% io:format("ResultTensorCEnc ~p Type ~p~n",[ResultTensorCEnc, Type]),

{ResultTensorCEncDec, erl_float} = nerlNIF:nerltensor_conversion({ResultTensorCEnc, Type}, erl_float),
CompareFloats = nerl:compare_floats_L(ResultTensorCEncDec, ExpectedResult, 4), % Erlang accuracy is double
% io:format("ResultTensorCEncDec ~p~n",[ResultTensorCEncDec]),

% io:format("ResultTensorCEncDec ~p~n",[ResultTensorCEncDec]),
if
CompareFloats -> nerltensor_sum_nif_test(Type, N-1, PerformanceNew);
true -> throw(ner:string_format("test failed - not equal ~n ExpectedResult: ~p ~n ResultTensorCEncDec: ~p",[ExpectedResult, ResultTensorCEncDec]))
end.


niftest_encode_decode(0, _Res) -> ok ;
niftest_encode_decode(N, Res) ->
encode_decode_nifs_test(0, _Res) -> ok ;
encode_decode_nifs_test(N, Res) ->
EncodeType = random_pick_nerltensor_type(),
NerlTensor = generate_nerltensor_rand_dims(EncodeType),
{EncodedNerlTensor, NerlTensorType} = nerlNIF:encode_nif(NerlTensor, EncodeType),
Expand All @@ -119,8 +124,33 @@ niftest_encode_decode(N, Res) ->
FloatCase = EncodeType == float,
CompareFloats = nerl:compare_floats_L(NerlTensor, DecodedTensor, 6),
if
FloatCase and CompareFloats-> niftest_encode_decode(N-1, Res ++ []);
NerlTensor == DecodedTensor -> niftest_encode_decode(N-1, Res ++ []);
FloatCase and CompareFloats-> encode_decode_nifs_test(N-1, Res ++ []);
NerlTensor == DecodedTensor -> encode_decode_nifs_test(N-1, Res ++ []);
true -> throw(ner:string_format("test failed - not equal ~n Origin: ~p ~n EncDec: ~p",[{NerlTensor, EncodeType},{DecodedTensor, DecodedType}]))
end.

nerltensor_conversion_test(0) -> ok;
nerltensor_conversion_test(Rounds) ->
BinType = random_pick_nerltensor_type(),
BinFloatType = lists:member(BinType, ?LIST_BINARY_FLOAT_NERLTENSOR_TYPE),
RandomIndex = rand:uniform(length(nerlNIF:get_all_nerltensor_list_types())),
ErlType = lists:nth(RandomIndex, nerlNIF:get_all_nerltensor_list_types()),
NerlTensorErl = generate_nerltensor_rand_dims(BinType),
try
{NerlTensorEnc, _} = nerlNIF:nerltensor_conversion({NerlTensorErl,ErlType},BinType),
{NerlTensorEncDecErl, _ } = nerlNIF:nerltensor_conversion({NerlTensorEnc,BinType},ErlType),
CompareFloats = nerl:compare_floats_L(NerlTensorErl, NerlTensorEncDecErl, 6),
% io:format("test failed - not equal ~n Origin: ~p ~n EncDec: ~p",[NerlTensorErl,NerlTensorEncDecErl]),
if
BinFloatType and CompareFloats-> nerltensor_conversion_test(Rounds - 1);
NerlTensorErl == NerlTensorEncDecErl -> nerltensor_conversion_test(Rounds - 1);
true -> throw(nerl:string_format("test failed - not equal ~n Origin: ~p ~n EncDec: ~p",[NerlTensorErl,NerlTensorEncDecErl]))
end
catch
throw:Reason -> BinTypeInt = lists:member(BinType, ?LIST_BINARY_INT_NERLTENSOR_TYPE),
case {ErlType,BinTypeInt,BinFloatType} of
{erl_int, false, true} -> nerltensor_conversion_test(Rounds - 1); % continues normal
{erl_float, true, false} -> nerltensor_conversion_test(Rounds - 1); % continues normal
_ -> throw(nerl:string_format("unknown nerltensor conversion exception Reason: ~p",[Reason]))
end
end.

0 comments on commit 3a40718

Please sign in to comment.