Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Fixed division by zero in QR decomposition. Issue #1058 (#1473)
Browse files Browse the repository at this point in the history
tensorflow/tfjs#1058
The sign() function returns 0 on 0, which causes a division by zero in the QR decomposition function qr() if there is a zero on the diagonal.

BUG
  • Loading branch information
jarno-r authored and dsmilkov committed Aug 9, 2019
1 parent b484b28 commit 2ae5a8a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 30 deletions.
9 changes: 4 additions & 5 deletions src/io/passthrough_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ describeWithFlags('Passthrough Saver', BROWSER_ENVS, () => {

describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
it('load topology and weights: legacy signature', async () => {
const passthroughHandler = tf.io.fromMemory(
modelTopology1, weightSpecs1, weightData1);
const passthroughHandler =
tf.io.fromMemory(modelTopology1, weightSpecs1, weightData1);
const modelArtifacts = await passthroughHandler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
Expand Down Expand Up @@ -147,9 +147,8 @@ describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
});

it('load model topology only', async () => {
const passthroughHandler = tf.io.fromMemory({
modelTopology: modelTopology1
});
const passthroughHandler =
tf.io.fromMemory({modelTopology: modelTopology1});
const modelArtifacts = await passthroughHandler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(undefined);
Expand Down
6 changes: 3 additions & 3 deletions src/jasmine_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ export const SYNC_BACKEND_ENVS: Constraints = {
};

export const HAS_WORKER = {
predicate: () => typeof(Worker) !== 'undefined'
&& typeof(Blob) !== 'undefined' && typeof(URL) !== 'undefined'
predicate: () => typeof (Worker) !== 'undefined' &&
typeof (Blob) !== 'undefined' && typeof (URL) !== 'undefined'
};

export const HAS_NODE_WORKER = {
Expand All @@ -52,7 +52,7 @@ export const HAS_NODE_WORKER = {
} catch {
hasWorker = false;
}
return typeof(process) !== 'undefined' && hasWorker;
return typeof (process) !== 'undefined' && hasWorker;
}
};

Expand Down
48 changes: 27 additions & 21 deletions src/ops/concat_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('concat complex input', async() => {
it('concat complex input', async () => {
// [1+1j, 2+2j]
const c1 = tf.complex([1, 2], [1, 2]);
// [3+3j, 4+4j]
Expand Down Expand Up @@ -234,7 +234,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
expectArraysEqual(await res2.data(), []);
});

it('concat complex input axis=0', async() => {
it('concat complex input axis=0', async () => {
// [[1+1j, 2+2j], [3+3j, 4+4j]]
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
// [[5+5j, 6+6j], [7+7j, 8+8j]]
Expand All @@ -247,7 +247,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
it('concat complex input axis=1', async () => {
// [[1+1j, 2+2j], [3+3j, 4+4j]]
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
// [[5+5j, 6+6j], [7+7j, 8+8j]]
Expand Down Expand Up @@ -500,50 +500,56 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]);
});

it('concat complex input axis=0', async() => {
it('concat complex input axis=0', async () => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
const c1 =
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 0;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
const expected = [
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
it('concat complex input axis=1', async () => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
const c1 =
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 1;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
const expected = [
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});

it('concat complex input axis=1', async() => {
it('concat complex input axis=1', async () => {
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
const c1 = tf.complex(
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
const c1 =
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
const c2 = tf.complex(
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);

const axis = 2;
const result = tf.concat([c1, c2], axis);
const expected = [1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12];
const expected = [
1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12
];
expect(result.dtype).toEqual('complex64');
expectArraysClose(await result.data(), expected);
});
Expand Down
5 changes: 4 additions & 1 deletion src/ops/linalg_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
const rjEnd1 = r.slice([j, j], [m - j, 1]);
const normX = rjEnd1.norm();
const rjj = r.slice([j, j], [1, 1]);
const s = rjj.sign().neg() as Tensor2D;

// The sign() function returns 0 on 0, which causes division by zero.
const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]]));

const u1 = rjj.sub(s.mul(normX)) as Tensor2D;
const wPre = rjEnd1.div(u1);
if (wPre.shape[0] === 1) {
Expand Down
11 changes: 11 additions & 0 deletions src/ops/linalg_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ describeWithFlags('qr', ALL_ENVS, () => {
[[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]]);
});

it('3x3, zero on diagonal', async () => {
const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]);
const [q, r] = tf.linalg.qr(x);
expectArraysClose(await q.data(), [
[0., -0.89442719, 0.4472136], [1., 0., 0.], [0., -0.4472136, -0.89442719]
]);
expectArraysClose(
await r.data(),
[[1., 1., 1.], [0., -2.23606798, -2.68328157], [0., 0., -0.89442719]]);
});

it('3x2, fullMatrices = default false', async () => {
const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]);
const [q, r] = tf.linalg.qr(x);
Expand Down

0 comments on commit 2ae5a8a

Please sign in to comment.