From bf8680938d25f05113cd5da3aafe38753e7aa12b Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 28 Aug 2018 08:02:17 -0400 Subject: [PATCH 01/13] docfix -- qr to namespace linalg --- src/ops/linalg_ops.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 4ea6886e2a..2d8314b9ca 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -117,7 +117,11 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { * - `R` has a shape of `[..., M, N]`. * @throws If the rank of `x` is less than 2. */ -/** @doc {heading: 'Operations', subheading: 'Linear Algebra'} */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { if (x.rank < 2) { throw new Error( From 2e862cc732dbd1758f47aa275ea7582be0131cfc Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 28 Aug 2018 09:02:07 -0400 Subject: [PATCH 02/13] sample code fixup --- src/ops/linalg_ops.ts | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 2d8314b9ca..fa7d9551f1 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -32,6 +32,15 @@ import {tensor2d} from './tensor_ops'; /** * Gram-Schmidt orthogonalization. * + * ```js + * const x = tf.tensor2d([[1, 2], [3, 4]]); + * let y = tf.linalg.gramSchmidt(x); + * console.log('Othogonalized:'); + * y.dot(y.transpose()).print(); // should be nearly the identity matrix. + * console.log('First row direction maintained:'); + * console.log(y.get(0, 1) / y.get(0, 0)); // should be nearly 2. + * ``` + * * @param xs The vectors to be orthogonalized, in one of the two following * formats: * - An Array of `Tensor1D`. @@ -44,7 +53,11 @@ import {tensor2d} from './tensor_ops'; * are orthogonal (zero inner products). Normalization means that each * vector or each row of the matrix has an L2 norm that equals `1`. */ -/** @doc {heading: 'Operations', subheading: 'Linear Algebra'} */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { let inputIsTensor2D: boolean; if (Array.isArray(xs)) { @@ -98,12 +111,21 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) * + * ```js + * const a = tf.tensor2d([1, 2], [3, 4]); + * let [q, r] = tf.linalg.qr(a); + * console.log('Orthogonalized'); + * q.dot(q.transpose()).print() // should be nearly the identity matrix. + * console.log('Reconstructed'); + * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; + * ``` + * * @param x The `Tensor` to be QR-decomposed. Must have rank >= 2. Suppose * it has the shape `[..., M, N]`. * @param fullMatrices An optional boolean parameter. Defaults to `false`. * If `true`, compute full-sized `Q`. If `false` (the default), * compute only the leading N columns of `Q` and `R`. - * @return An `Array` of two `Tensor`s: `[Q, R]`. `Q` is a unitary matrix, + * @returns An `Array` of two `Tensor`s: `[Q, R]`. `Q` is a unitary matrix, * i.e., its columns all have unit norm and are mutually orthogonal. * If `M >= N`, * If `fullMatrices` is `false` (default), From e92840d27f910a660391b52c6c9d6ac4fa734be8 Mon Sep 17 00:00:00 2001 From: Stanley Bileschi Date: Tue, 28 Aug 2018 10:20:01 -0400 Subject: [PATCH 03/13] fixup add additional output (Q & R) --- src/ops/linalg_ops.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index fa7d9551f1..8d631b6efd 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -35,6 +35,7 @@ import {tensor2d} from './tensor_ops'; * ```js * const x = tf.tensor2d([[1, 2], [3, 4]]); * let y = tf.linalg.gramSchmidt(x); + * y.print(); * console.log('Othogonalized:'); * y.dot(y.transpose()).print(); // should be nearly the identity matrix. * console.log('First row direction maintained:'); @@ -112,8 +113,12 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) * * ```js - * const a = tf.tensor2d([1, 2], [3, 4]); + * const a = tf.tensor2d([[1, 2], [3, 4]]); * let [q, r] = tf.linalg.qr(a); + * console.log('Q'); + * q.print(); + * console.log('R'); + * r.print(); * console.log('Orthogonalized'); * q.dot(q.transpose()).print() // should be nearly the identity matrix. * console.log('Reconstructed'); From 72faf3f4090316c6785075370105123b533fefdd Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 10:55:53 +0100 Subject: [PATCH 04/13] "Low-Level" QR Decomposition, bandPart, triangularSolve The QR Decomposition implementation uses Givens Rotations for elimination. There's a few reasons, why the Givens Rotation was chosen over the Householder method: * During a quick trial, there seemed to be no practical performance difference in JS between Householder and Givens. JS itself seems to introduce enough overhead to make the fewe extra FLOPS irrelevant. * Givens Method is easier to implement numerically stable (e.g. there is no underflow-safe norm necessary). * Givens Method is easier to backpropagate. * Givens Method ensure det(Q) = 1 * Givens Rotations seem to be smoother when it comes to Pertubation, resulting in smoother gradients. * Givens Method is easier to parallelize As long as R is non-singular there is always a way to produce a canonical representation from a QR Decompostion (e.g. make the diagonal of R all positive). That also means that there is no compatibilty issue with whichever QR implementation Tensorflow for Python/C/C++ chooses. Both `bandPart` and `triangularSolve` were necessary to implement the symbolic backpropagation of the QR Decomposition. --- src/ops/linalg_ops.ts | 870 ++++++++++++++++++++++++++++++++----- src/ops/linalg_ops_test.ts | 519 +++++++++++++++++----- 2 files changed, 1163 insertions(+), 226 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 51daf3ffa6..0db914a6e7 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -20,15 +20,21 @@ */ import {ENV} from '../environment'; -import {dispose} from '../globals'; +import {range, scalar} from './tensor_ops'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {TensorLike, TypedArray} from '../types'; +import {add, mul, sub} from './binary_ops'; +import {logicalAnd} from './logical_ops'; +import {complex, real, imag} from './complex_ops'; import {assert} from '../util'; -import {eye, squeeze, stack, unstack} from './array_ops'; +import {convertToTensor} from '../tensor_util_env'; +import {squeeze, stack} from './array_ops'; import {split} from './concat_split'; +import {matMul} from './matmul'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {upcastType} from '../types'; /** * Gram-Schmidt orthogonalization. @@ -106,12 +112,694 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { } } +/** + * Conjugates a tensor of matrices and then transposes the last two dimensions. + * The adjoint is also commonly known as the Hermitian Transpose. Does not yet + * work for complex data types. + * + * @param a Tensor of shape [...,M,N]. The tensor of matrices that is to be + * tranposed. + * + * @returns Tensor of shape [...,N,M]. The transpose of `a`. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function adjoint_( a: T|TensorLike ): T +{ + let $a = convertToTensor(a,'a','bandPart'); + + const axes = Array.from( $a.shape, (_,i) => i ); + axes[axes.length-2] = axes.length-1; + axes[axes.length-1] = axes.length-2; + + if( $a.dtype.startsWith('complex') ) { + $a = complex( real($a), imag($a).neg() ); // <- TODO: implement tf.conj + } + + return $a.transpose(axes); +} + +/** + * Copies a tensor of matrices, setting everything outside a central band + * in each matrix to zero. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_( + a: T|TensorLike, numLower: number, numUpper: number +): T +{ + if( numLower%1 !== 0 ){ + throw new Error(`bandPart(): numLower=${numLower} not an integer.`); + } + if( numUpper%1 !== 0 ){ + throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`); + } + + const $a = convertToTensor(a,'a','bandPart'); + + const [M,N] = $a.shape.slice(-2); + + if( !(numLower <= M) ) { + throw new Error(`bandPart() check failed: numLower <= #rows.` ); + } + if( !(numUpper <= N) ) { + throw new Error(`bandPart() check failed: numUpper <= #columns.`); + } + + if( numLower < 0 ) { numLower = M; } + if( numUpper < 0 ) { numUpper = N; } + + const i = range(0,M, 1, 'int32').reshape([-1,1]), + j = range(0,N, 1, 'int32'); + + const inBand = logicalAnd( + sub(i,j).lessEqual( scalar(numLower,'int32') ), + sub(j,i).lessEqual( scalar(numUpper,'int32') ) + ).cast($a.dtype); + + return mul($a,inBand); +} + +function triangularSolveKernel( + l: Tensor, y: Tensor, lower: boolean, adjoint: boolean +): Tensor +{ + if( ! l.dtype.startsWith('float') ) { + throw new Error(`triangularSolve(): l.dtype=${l.dtype} not supported.`); + } + if( ! y.dtype.startsWith('float') ) { + throw new Error(`triangularSolve(): y.dtype=${y.dtype} not supported.`); + } + if( l.rank < 2 ) { + throw new Error('triangularSolve(): l must be at least 2D.'); + } + if( y.rank < 2 ) { + throw new Error('triangularSolve(): y must be at least 2D.'); + } + if( l.rank !== y.rank ) { + throw new Error('triangularSolve(): l and y must have same rank.'); + } + for( let i=l.rank-2; i-- > 0; ) { + if( l.shape[i] !== y.shape[i] ) { + throw new Error('triangularSolve(): leading dimensions do not match.'); + } + } + + const [N,M] = l.shape.slice(-2), + [I,J] = y.shape.slice(-2); + if( N !== M ) { + throw new Error('triangularSolve(): Last two axes of L not square.'); + } + if( I !== M ) { + throw new Error('triangularSolve(): L and y do not match.'); + } + + const + rank = Math.max(l.rank, y.rank), + xShape = Array.from(l.shape); + xShape[rank-2] = I; + xShape[rank-1] = J; + + // GENERATE RESULT DATA + const + dtype = 'float32', +// dtype = ( l.dtype === 'float64' || +// y.dtype === 'float64' ) ? 'float64' : 'float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable +// DTypeArray = dtype === 'float32' ? Float32Array +// : Float64Array, + L = l.dataSync(), + X = DTypeArray.from( y.dataSync() ) as TypedArray; + l = undefined; + y = undefined; + + for( let lOff = 0, + xOff = 0; xOff < X.length; xOff += N*J, + lOff += N*N ) + { + if( ! adjoint ) + { + if(lower) + { // FORWARD SUBSTITUTION + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) { + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] -= L[lOff + N*i+k] * X[xOff + J*k+j]; + }} + + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + } + } + else + { // BACKWARD SUBSTITUTION + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + + for( let k=i; k-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*k+j] -= L[lOff + N*k+i] * X[xOff + J*i+j]; + }} + } + } + } + else + { + if(lower) + { // BACKWARD SUBSTITUTION (TRANSPOSED) + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + + for( let k=i; k-- > 0; ) { + for( let j=J; j-- > 0; ) { + X[xOff + J*k+j] -= L[lOff + N*i+k] * X[xOff + J*i+j]; + }} + } + } + else + { // FORWARD SUBSTITUTION (TRANSPOSED) + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) { + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] -= L[lOff + N*k+i] * X[xOff + J*k+j]; + }} + + for( let j=0; j < J; j++ ) { + X[xOff + J*i+j] /= L[lOff + N*i+i]; + } + } + } + } + } + + return Tensor.make(xShape,{values: X},dtype); +} + +/** + * Solves a triangular linear equation system (LES). + * + * @param l The triangular matrix of the . + * @param y The right-hand-side of the LES. + * @param lower If set to `true`, `l` is interpreted as lower triangular + * matrix. The strict upper triangular entries are ignore. + * If set to `false`, `l` is interpreted as upper triangular + * matrix and the strict lower triangular entries are ignored. + * @param adjoint If set to `true`, the hermitian transpose of `l` is used in + * the LES. + * + * @returns The solution of one of the following LES: + *
+ *
lower=false, adjoint=false
tril(l) ∙x == y + *
lower=true, adjoint=false
triu(l) ∙x == y + *
lower=false, adjoint=true
tril(l)ᴴ∙x == y + *
lower=true, adjoint=true
triu(l)ᴴ∙x == y + *
+ */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function triangularSolve_( + l: Tensor|TensorLike, y: Tensor|TensorLike, lower=true, adjoint=false +): Tensor +{ + // FIXME: if `l` is singular the right hand side could be + // checked for 0 and then some/any solution could be used + +// let [$l,$y] = broadcastMatrices( +// convertToTensor(l,'l','triangularSolve'), +// convertToTensor(y,'y','triangularSolve') +// ); + let $l = convertToTensor(l,'l','triangularSolve'), + $y = convertToTensor(y,'y','triangularSolve'); + l=undefined; + y=undefined; + if( $l.rank < 2 ){ + throw new Error(`triangularSolve(): l.rank must be at least 2.`); + } + if( $y.rank < 2 ){ + throw new Error(`triangularSolve(): y.rank must be at least 2.`); + } + + const dtype = upcastType($l.dtype, $y.dtype); + if( $l.dtype !== dtype ) { $l = $l.cast(dtype); } + if( $y.dtype !== dtype ) { $y = $y.cast(dtype); } + + // WHERE THE BACKPROP COMES FROM: + // x = L⁻¹∙y + // => dx = d(L⁻¹)∙y + L⁻¹∙dy = L⁻¹∙dy - L⁻¹∙dL∙L⁻¹∙y = L⁻¹∙dy - L⁻¹∙dL∙x + // => df = tr( (∂f/∂x)∙dxᵀ ) + // = tr( (∂f/∂x)∙dyᵀ∙L⁻ᵀ ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙dLᵀ∙L⁻ᵀ ) + // = tr( (∂f/∂x)ᵀ∙L⁻¹∙dy ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙(L⁻¹∙dL)ᵀ ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻¹∙y∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( x∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻ᵀ ∙(∂f/∂x) ∙ xᵀ ∙dLᵀ ) + // => ∂f/∂y = L⁻ᵀ∙(∂f/∂x) + // ∂f/∂L = -L⁻ᵀ∙(∂f/∂x)∙xᵀ = ∂f/∂L = -(∂f/∂y)∙xᵀ + + // tslint:disable + // SEE: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L218 + // tslint:enable + return ENV.engine.runKernel( + (backend,saveFn) => { + const x = triangularSolveKernel($l,$y,lower,adjoint); + saveFn(x); + return x; + }, + {$l,$y}, + (dx,[x]) => { + const dy = triangularSolve($l, dx, lower, !adjoint); + return { + $l: () => { + let dl = adjoint ? matMul( x, dy, false, true) + : matMul(dy, x, false, true); + dl = dl.neg(); + dl = lower ? bandPart(dl,-1, 0) + : bandPart(dl, 0,-1); + return dl; + }, + $y: () => dy + }; + } + ); +} + +/** Computes the economic QR Decomposition. + */ +function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor] +{ + assert( + a.rank >= 2, + `qr(): input must have rank >= 2, got rank ${a.rank}.` + ); + assert( + ! a.dtype.startsWith('complex'), + `qr(): complex dtype not supported.` + ); + assert( + a.shape[a.rank-2] >= a.shape[a.rank-1], + `qr(): a.shape[-2] = ${a.shape[a.rank-2]}` + + ` < ${a.shape[a.rank-1]} = a.shape[-1].` + ); + + const dtype = 'float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + qShape = Array.from( a.shape ), + rShape = Array.from( qShape ), + [N,M] = qShape.slice(-2); + rShape[rShape.length-2] = M; + Object.freeze(qShape); + Object.freeze(rShape); + + const Q = DTypeArray.from( a.dataSync() ); a = undefined; + const R = new DTypeArray(Q.length/N*M), + cs = new DTypeArray(M*2),// <- APPLY M ROTATIONS TO Q AT ONCE + r = (() => { + try { return cs.subarray(M); } + catch(e) { return new DTypeArray(M); } + })(); // <- space to temp. store rows of R not contained in result + + for( + let rOff=0, + qOff=0; qOff < Q.length; qOff += N*M, + rOff += M*M + ) + { + // HANDLE ENTRIES CONTAINED IN THE RESULT + for( let i=0; i < M; i++ ) + { + // COPY FROM Q TO R AND INIT Q + for( let j=0; j < M; j++ ) { + R[rOff+M*i+j] = Q[qOff+M*i+j]; + Q[qOff+M*i+j] = i !== j ? 0.0 : 1.0; + } + + for( let j=0; j < i; j++ ) + { // USE GIVENS ROTATION TO ELIMINATE ELEMENT R_ji + const rIJ = R[rOff+M*i+j]; if( rIJ === 0.0 ){cs[2*j+0]=1.0; + cs[2*j+1]=0.0; continue;} + const rJJ = R[rOff+M*j+j], + norm = Math.hypot(rJJ,rIJ), + c = rJJ / norm, + s = rIJ / norm; + cs[2*j+0] = c; + cs[2*j+1] = s; + R[rOff + M*i+j] = 0.0; + R[rOff + M*j+j] = norm; + // ROTATE ROW i AND j IN R + for( let k=j; ++k < M; ) { + const ik = rOff + M*i+k, rIK = R[ik], + jk = rOff + M*j+k, rJK = R[jk]; + R[ik] = c*rIK - s*rJK; + R[jk] = s*rIK + c*rJK; + } + } + + // ROTATE COLUMNS IN Q (BUNDLED FOR BETTER CACHE LOCALITY) + for( let k=0; k <= i; k++ ) { + for( let j=0; j < i; j++ ) { + const c = cs[2*j+0], + s = cs[2*j+1], + ki = qOff + M*k+i, qKI = Q[ki], + kj = qOff + M*k+j, qKJ = Q[kj]; + Q[ki] = c*qKI - s*qKJ; + Q[kj] = s*qKI + c*qKJ; + }} + } + // HANDLE REMAINING ENTRIES NOT CONTAINED IN THE RESULT + for( let i=M; i < N; i++ ) + { + // INIT r + for( let j=0; j < M; j++ ) { + r[j] = Q[qOff+M*i+j]; Q[qOff+M*i+j] = 0.0; + } + + // USE GIVENS ROTATIONS TO ELIMINATE ELEMENT r completely + for( let j=0; j < M; j++ ) + { + const rJ = r[j]; if( rJ === 0.0 ) { cs[2*j+0]=1.0; + cs[2*j+1]=0.0; continue; } + const rJJ = R[rOff+M*j+j], + norm = Math.hypot(rJJ,rJ), + c = rJJ / norm, + s = rJ / norm; + R[rOff+M*j+j] = norm; + // ROTATE ROW i AND j IN R + for( let k=j; ++k < M; ) { + const jk = rOff + M*j+k, rJK = R[jk]; + R[jk] = s*r[k] + c*rJK; + r[ k] = c*r[k] - s*rJK; + } + cs[2*j+0] = c; + cs[2*j+1] = s; + } + + // ROTATE COLUMNS IN Q + for( let k=0; k <= i; k++ ) { let QK = i !== k ? 0.0 : 1.0; + for( let j=0; j < M; j++ ) { + const c = cs[2*j+0], + s = cs[2*j+1], qK = QK, + kj = qOff + M*k+j, qKJ = Q[kj]; + QK = c*qK - s*qKJ; + Q[kj]= s*qK + c*qKJ; + }} + } + } + + { + const q = Tensor.make(qShape, { values: Q }, dtype); + const r = Tensor.make(rShape, { values: R }, dtype); + + return [q,r]; + } +} + +/** Computes the full QR Decomposition an memoizes the + * Givens rotation angles in the process. + */ +function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor] +{ + assert( + a.rank >= 2, + `Error in linalg.qr: input must have rank >= 2, got rank ${a.rank}.` + ); + assert( + ! a.dtype.startsWith('complex'), + `Error in linalg.qr: complex dtype not supported.` + ); + + const dtype ='float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + rShape = Array.from( a.shape ), + qShape = Array.from( a.shape ), + [M,N] = a.shape.slice(-2), + R = DTypeArray.from( a.dataSync() ); + a = undefined; + const L = Math.min(M,N), + Q = new DTypeArray( R.length/N*M ), + CS = new DTypeArray( R.length/N/M * 2 * ( + (L*(L-1) >>> 1) + Math.max(0,M-N)*N + )); + qShape[qShape.length-1] = M; + Object.freeze(qShape); + Object.freeze(rShape); + + let l = 0; + for( let qOff=0, + rOff=0; qOff < Q.length; qOff += M*M, + rOff += M*N ) + { + // INIT Q TO IDENTITY + for( let i=0; i < M; i++ ) { Q[qOff + M*i+i] = 1; } + + // BEGIN QR DECOMPOSITION + for( let i=1; i < M; i++ ) { const J = Math.min(i,N); + for( let j=0; j < J; j++ ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const rIJ = R[rOff + N*i+j]; if( 0.0 === rIJ ) { CS[l++]=1.0; + CS[l++]=0.0; continue; } + const rJJ = R[rOff + N*j+j]; + let norm = Math.hypot(rJJ,rIJ), + c = rJJ / norm, + s = rIJ / norm; + CS[l++] = c; + CS[l++] = s; + R[rOff + N*j+j] = norm; + R[rOff + N*i+j] = 0; + // ROTATE ROWS IN R + for( let k=j; ++k < N; ) + { const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + R[rOff + N*j+k] = s*rIK + c*rJK; + R[rOff + N*i+k] = c*rIK - s*rJK; + } + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + Q[qOff + M*j+k] = s*qIK + c*qJK; + Q[qOff + M*i+k] = c*qIK - s*qJK; + } + }} // END QR DECOMPOSITION + + // TRANSPOSE Q (was transposed for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const qIJ = Q[qOff + M*i+j]; + Q[qOff + M*i+j] = Q[qOff + M*j+i]; + Q[qOff + M*j+i] = qIJ; + }} + } + assert( l === CS.length, `WTF: ${l} != ${CS.length}` ); + + const q = Tensor.make(qShape, {values: Q}, dtype); + const r = Tensor.make(rShape, {values: R}, dtype); + const cs = Tensor.make([CS.length], {values: CS}, dtype); + + return [q,r,cs]; +} + +/** Computes the backpropagation full QR Decomposition using + * memoized Givens rotation angles in the process. + */ +function qrFullBackpropKernel( + q: Tensor, dq: Tensor, r: Tensor, dr: Tensor, cs: Tensor +): Tensor +{ + assert( q.rank === dq.rank, `q.rank == ${q.rank} != ${dq.rank} == dq.rank` ); + assert( q.rank === dr.rank, `q.rank == ${q.rank} != ${dr.rank} == dr.rank` ); + assert( q.rank === r.rank, `q.rank == ${q.rank} != ${ r.rank} == r.rank` ); + + assert( cs.rank === 1, `cs.rank == ${cs.rank} != 1` ); + + for( let i=q.rank-2; i-- > 0; ) + { + assert( + q.shape[i] === dq.shape[i], + `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]` + ); + assert( + q.shape[i] === dr.shape[i], + `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]` + ); + assert( + q.shape[i] === r.shape[i], + `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]` + ); + } + const rank = q.rank; + assert( + q.shape[rank-2] === q.shape[rank-1], + `q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` + ); + assert( + q.shape[rank-2] === dq.shape[rank-1], + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]` + ); + assert( + q.shape[rank-2] === dq.shape[rank-2], + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]` + ); + + assert( + r.shape[rank-2] === q.shape[rank-1], + `r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` + ); + assert( + r.shape[rank-1] === dr.shape[rank-1], + `r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]` + ); + assert( + r.shape[rank-2] === dr.shape[rank-2], + `r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]` + ); + + assert( + q.dtype === dq.dtype, `q.dtype == ${q.dtype} == ${ dq.dtype} == dq.dtype` + ); + assert( + q.dtype === dr.dtype, `q.dtype == ${q.dtype} == ${ dr.dtype} == dr.dtype` + ); + assert( + q.dtype === r.dtype, `q.dtype == ${q.dtype} == ${ r.dtype} == r.dtype` + ); + assert( + q.dtype === cs.dtype, `q.dtype == ${q.dtype} == ${cs.dtype} == cs.dtype` + ); + + assert( ! q.dtype.startsWith('complex'), `Complex dtype not supported.`); + + const dtype ='float32', + // tslint:disable + DTypeArray = Float32Array, + // tslint:enable + dAShape = Array.from( r.shape ), + [M,N] = dAShape.slice(-2); + const Q = DTypeArray.from( q.dataSync() ); q = undefined; + const dQ = DTypeArray.from( dq.dataSync() ); dq = undefined; + const R = DTypeArray.from( r.dataSync() ); r = undefined; + const dR = DTypeArray.from( dr.dataSync() ); dr = undefined; + const CS = cs.dataSync(); + Object.freeze(dAShape); + + let l = CS.length; + for( let rOff=R.length, + qOff=Q.length; qOff > 0; ) + { + qOff -= M*M; + rOff -= M*N; + + // TRANSPOSE Q (for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const qIJ = Q[qOff + M*i+j]; + Q[qOff + M*i+j] = Q[qOff + M*j+i]; + Q[qOff + M*j+i] = qIJ; + }} + + // TRANSPOSE dQ (for cache locality) + for( let i=0; i < M; i++ ) { + for( let j=0; j < i; j++ ) { + const dQij = dQ[qOff + M*i+j]; + dQ[qOff + M*i+j] = dQ[qOff + M*j+i]; + dQ[qOff + M*j+i] = dQij; + }} + + // BEGIN QR DECOMPOSITION + for( let i=M; --i > 0; ) { const J = Math.min(i,N); + for( let j=J; j-- > 0; ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const s = CS[--l]; if( 0 === s ) { continue; } + const c = CS[--l], + norm = R[rOff + N*j+j]; + + // ROTATE ROWS IN R + for( let k=j; k < N; k++ ) + { const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + R[rOff + N*j+k] = c*rJK - s*rIK; + R[rOff + N*i+k] = s*rJK + c*rIK; + } + + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + Q[qOff + M*j+k] = c*qJK - s*qIK; + Q[qOff + M*i+k] = s*qJK + c*qIK; + } + + const rIJ = R[rOff + N*i+j], + rJJ = R[rOff + N*j+j], + dCdJ = + rIJ / norm * rIJ / norm**2, + dCdI = - rIJ / norm * rJJ / norm**2, + dSdJ = - rJJ / norm * rIJ / norm**2, + dSdI = + rJJ / norm * rJJ / norm**2; + let dj = 0.0, + di = 0.0; + + // ROTATE ROWS IN dR + for( let k=j; k < N; k++ ) + { const dRjk = dR[rOff + N*j+k], + dRik = dR[rOff + N*i+k]; + dR[rOff + N*j+k] = c*dRjk - s*dRik; + dR[rOff + N*i+k] = s*dRjk + c*dRik; + + const rJK = R[rOff + N*j+k], + rIK = R[rOff + N*i+k]; + + dj += dRjk*(rIK*dSdJ + rJK*dCdJ) + dRik*(rIK*dCdJ - rJK*dSdJ); + di += dRjk*(rIK*dSdI + rJK*dCdI) + dRik*(rIK*dCdI - rJK*dSdI); + } + + // ROTATE ROWS IN dQᵀ + for( let k=0; k <= i; k++ ) + { const dQjk = dQ[qOff + M*j+k], + dQik = dQ[qOff + M*i+k]; + dQ[qOff + M*j+k] = c*dQjk - s*dQik; + dQ[qOff + M*i+k] = s*dQjk + c*dQik; + + const qJK = Q[qOff + M*j+k], + qIK = Q[qOff + M*i+k]; + + dj += dQjk*(qIK*dSdJ + qJK*dCdJ) + dQik*(qIK*dCdJ - qJK*dSdJ); + di += dQjk*(qIK*dSdI + qJK*dCdI) + dQik*(qIK*dCdI - qJK*dSdI); + } + + dR[rOff + N*j+j] += dj; + dR[rOff + N*i+j] += di; + }} // END QR DECOMPOSITION + } + assert( 0 === l, `WTF: ${l} != 0` ); + + return Tensor.make(dAShape,{values: dR},dtype); +} + /** - * Compute QR decomposition of m-by-n matrix using Householder transformation. + * Compute QR decomposition of m-by-n matrix using Givens rotations. * - * Implementation based on - * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] - * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + * See: http://www.math.usm.edu/lambers/mat610/sum10/lecture9.pdf * * ```js * const a = tf.tensor2d([[1, 2], [3, 4]]); @@ -150,114 +838,82 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { * subheading:'Linear Algebra', * namespace:'linalg'} */ -function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { - if (x.rank < 2) { +function qr_( a: Tensor, fullMatrices = false ): [Tensor, Tensor] { + if( a.rank < 2 ) { throw new Error( - `qr() requires input tensor to have a rank >= 2, but got rank ${ - x.rank}`); - } else if (x.rank === 2) { - return qr2d(x as Tensor2D, fullMatrices); - } else { - // Rank > 2. - // TODO(cais): Below we split the input into individual 2D tensors, - // perform QR decomposition on them and then stack the results back - // together. We should explore whether this can be parallelized. - const outerDimsProd = x.shape.slice(0, x.shape.length - 2) - .reduce((value, prev) => value * prev); - const x2ds = unstack( - x.reshape([ - outerDimsProd, x.shape[x.shape.length - 2], - x.shape[x.shape.length - 1] - ]), - 0); - const q2ds: Tensor2D[] = []; - const r2ds: Tensor2D[] = []; - x2ds.forEach(x2d => { - const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices); - q2ds.push(q2d); - r2ds.push(r2d); - }); - const q = stack(q2ds, 0).reshape(x.shape); - const r = stack(r2ds, 0).reshape(x.shape); - return [q, r]; + `qr() requires input tensor to have a rank >= 2, but got rank ${a.rank}` + ); + } + if( a.dtype.startsWith('complex') ) { + throw new Error(`qr() not yet supported for complex tensors.`); } -} -function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { - return ENV.engine.tidy(() => { - if (x.shape.length !== 2) { - throw new Error( - `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`); - } + const [m,n] = a.shape.slice(-2); - const m = x.shape[0]; - const n = x.shape[1]; - - let q = eye(m) as Tensor2D; // Orthogonal transform so far. - let r = x.clone(); // Transformed matrix so far. - - const one2D = tensor2d([[1]], [1, 1]); - let w: Tensor2D = one2D.clone(); - - const iters = m >= n ? n : m; - for (let j = 0; j < iters; ++j) { - // This tidy within the for-loop ensures we clean up temporary - // tensors as soon as they are no longer needed. - const rTemp = r; - const wTemp = w; - const qTemp = q; - [w, r, q] = ENV.engine.tidy((): [Tensor2D, Tensor2D, Tensor2D] => { - // Find H = I - tau * w * w', to put zeros below R(j, j). - const rjEnd1 = r.slice([j, j], [m - j, 1]); - const normX = rjEnd1.norm(); - const rjj = r.slice([j, j], [1, 1]); - const s = rjj.sign().neg() as Tensor2D; - const u1 = rjj.sub(s.mul(normX)) as Tensor2D; - const wPre = rjEnd1.div(u1); - if (wPre.shape[0] === 1) { - w = one2D.clone(); - } else { - w = one2D.concat( - wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]), 0) as - Tensor2D; - } - const tau = s.matMul(u1).div(normX).neg() as Tensor2D; - - // -- R := HR, Q := QH. - const rjEndAll = r.slice([j, 0], [m - j, n]); - const tauTimesW = tau.mul(w) as Tensor2D; - if (j === 0) { - r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); - } else { - r = r.slice([0, 0], [j, n]) - .concat( - rjEndAll.sub( - tauTimesW.matMul(w.transpose().matMul(rjEndAll))), - 0) as Tensor2D; - } - const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); - if (j === 0) { - q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); - } else { - q = q.slice([0, 0], [m, j]) - .concat( - qAllJEnd.sub( - qAllJEnd.matMul(w).matMul(tauTimesW.transpose())), - 1) as Tensor2D; + if( m === n || m > n && !fullMatrices ) + { + // FIXME: What if R is (nearly) singular? + return ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r] = qrEcoDecompKernel(a); + saveFunc(q); + saveFunc(r); + return [q,r]; + }, + {a}, + ([dq,dr], [q,r]) => ({ + a: () => { + // TODO: is tidy required here? + // tslint:disable + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L160 + // tslint:enable + const qdq = matMul(q,dq, true, false), + rdr = matMul(r,dr, false, true), + qdq_ = qdq.sub( adjoint(qdq) ), + rdr_ = rdr.sub( adjoint(rdr) ), + tril = bandPart( add(qdq_,rdr_), -1, 0 ); + + const triSolv = (x: Tensor,r: Tensor) => adjoint( + triangularSolve(r, adjoint(x), /*lower=*/false, /*adjoint_r*/false) + ); + + const gradA = matMul( q, dr.add( triSolv(tril,r) ) ), + gradB = triSolv( dq.sub( matMul(q,qdq) ), r ); + + return add(gradA,gradB); } - return [w, r, q]; - }); - dispose([rTemp, wTemp, qTemp]); - } + }) + ) as [Tensor, Tensor]; + } - if (!fullMatrices && m > n) { - q = q.slice([0, 0], [m, n]); - r = r.slice([0, 0], [n, n]); - } + let [q,r] = ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r,cs] = qrFullDecompKernel(a); + saveFunc(q); + saveFunc(r); + saveFunc(cs); + return [q,r]; + }, + {a}, + ([dq,dr], [q,r,cs]) => ({ + a: () => ENV.engine.runKernel( + (backend,saveFunc) => qrFullBackpropKernel(q,dq, r,dr, cs), + { $dq: dq, $dr: dr } + ) + }) + ); + + if( ! fullMatrices && m > n ) { + const end = a.shape.slice(); + q = q.slice([0, 0], end); end[end.length-2] = n; + r = r.slice([0, 0], end); + } - return [q, r]; - }) as [Tensor2D, Tensor2D]; + return [q,r]; } +export const adjoint = op({adjoint_}); +export const bandPart = op({bandPart_}); export const gramSchmidt = op({gramSchmidt_}); export const qr = op({qr_}); +export const triangularSolve = op({triangularSolve_}); diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 7480d7fcee..f26078b6cc 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -16,12 +16,71 @@ */ import * as tf from '../index'; +import {ENV} from '../environment'; import {describeWithFlags} from '../jasmine_util'; -import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; +/** Returns a random integer in the range of [from,until). + */ +const randInt = (from: number, until: number) => { + return Math.floor(Math.random()*(until-from)) + from; +}; + +/** + * Computes the gradients using finite differences. + * + * SEE: https://en.wikipedia.org/wiki/Finite_difference + * + * FIXME this is terribly imprecise... wish there was + * double precision support *hint hint*. + */ +const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { + if( a.dtype !== 'float32' ) { + throw new Error(`numDiff(): dtype=${a.dtype} not supported.`); + } + + const aData = Float32Array.from( a.dataSync() ); + + const fVal = () => { + const scalar = f(a); + if( scalar.rank !== 0 ) { + throw new Error('f() returned a non-scalar value.'); + } + return scalar.dataSync()[0]; + }; + + return ENV.engine.tidy(() => { + a = Tensor.make(a.shape, {values: aData}); + + const dA = new Float32Array( aData.length ); + + for( let i=0; i < aData.length; i++ ) + { // use central difference + const aI = aData[i], + delta = Math.max( Math.abs(aI) * 2**-12, 2**-12 ), + aHi = aI + delta, + aLo = aI - delta; + + // DISPOSAL (HOPEFULLY) REMOVES DATA FROM GPU AND FORCES REUPLOAD + aData[i] = aLo; a.dispose(); + a = Tensor.make(a.shape, {values: aData}); + const fLo = fVal(); + + aData[i] = aHi; a.dispose(); + a = Tensor.make(a.shape, {values: aData}); + const fHi = fVal(); + + dA[i] = (fHi - fLo) / (aHi - aLo); + aData[i] = aI; + } + + return Tensor.make(a.shape,{values: dA}); + }); +}; + describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', () => { const xs: Tensor1D[] = [ @@ -94,137 +153,359 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => { }); }); -describeWithFlags('qr', ALL_ENVS, () => { - it('1x1', () => { - const x = tensor2d([[10]], [1, 1]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(q, tensor2d([[-1]], [1, 1])); - expectArraysClose(r, tensor2d([[-10]], [1, 1])); +describeWithFlags('adjoint', ALL_ENVS, () => { + it('2x3', () => { + const a = tf.tensor2d([[1,2,3], + [4,5,6]], [2,3]), + aT = tf.tensor2d([[1,4], + [2,5], + [3,6]],[3,2]); + expectArraysEqual( tf.linalg.adjoint(a), aT ); }); - - it('2x2', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, tensor2d([[-0.4472, -0.8944], [0.8944, -0.4472]], [2, 2])); - expectArraysClose(r, tensor2d([[-2.2361, -4.9193], [0, -0.8944]], [2, 2])); + it('3x2x1', () => { + const a = tf.tensor3d([[[1],[2]], + [[3],[4]], + [[5],[6]]], [3,2,1]), + aT = tf.tensor3d([[[1,2]], + [[3,4]], + [[5,6]]], [3,1,2]); + expectArraysEqual( tf.linalg.adjoint(a), aT ); }); +}); - it('2x2x2', () => { - const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor3d( - [ - [[-0.4472, -0.8944], [0.8944, -0.4472]], - [[-0.4472, -0.8944], [0.8944, -0.4472]] - ], - [2, 2, 2])); - expectArraysClose( - r, - tensor3d( - [ - [[2.2361, 4.9193], [0, 0.8944]], - [[-2.2361, -4.9193], [0, -0.8944]] - ], - [2, 2, 2])); - }); +describeWithFlags('bandPart', ALL_ENVS, () => { + const la = tf.linalg; - it('2x1x2x2', () => { - const x = - tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor4d( - [ - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - ], - [2, 1, 2, 2])); - expectArraysClose( - r, - tensor4d( - [ - [[[2.2361, 4.9193], [0, 0.8944]]], - [[[-2.2361, -4.9193], [0, -0.8944]]] - ], - [2, 1, 2, 2])); - }); + it('3x4', () => { + const a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + expectArraysEqual( + la.bandPart(a,0,0), + tf.tensor2d([[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,0,1), + tf.tensor2d([[1, 2, 0, 0], + [0, 6, 7, 0], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,0,2), + tf.tensor2d([[1, 2, 3, 0], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,0,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0,11,12]]) + ); + } - it('3x3', () => { - const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [ - [-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669], - [-0.9631, -0.2182, 0.1576] - ], - [3, 3])); - expectArraysClose( - r, - tensor2d( - [[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]], - [3, 3])); - }); + expectArraysEqual( + la.bandPart(a,1,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [0,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,1,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,1,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,1,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) + ); + } - it('3x2, fullMatrices = default false', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]], - [3, 2])); - expectArraysClose(r, tensor2d([[-3.7417, 2.4054], [0, 2.8661]], [2, 2])); - }); + for( const numLower of [2,3,-1,-2]) + { + expectArraysEqual( + la.bandPart(a,numLower,0), + tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,1), + tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + expectArraysEqual( + la.bandPart(a,numLower,2), + tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + for( const numUpper of [3,4,-1,-2] ) { + expectArraysEqual( + la.bandPart(a,numLower,numUpper), + tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) + ); + } + } - it('3x2, fullMatrices = true', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d( - [ - [-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663], - [0.5345, -0.0997, 0.8393] - ], - [3, 3])); - expectArraysClose( - r, tensor2d([[-3.7417, 2.4054], [0, 2.8661], [0, 0]], [3, 2])); + for( const numUpper of [0,1,2,3,4,-1,-2] ) { + for( const numLower of [0,1,2,3, -1,-2] ) { + const w = tf.randomUniform(a.shape), + f = (x: Tensor) => { + return la.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar; + }, + g = numDiff(f), + h = tf.grad(f); + expectArraysClose( g(a), h(a) ); + }} }); +}); - it('2x3, fullMatrices = default false', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), - ); - }); +describeWithFlags('triangularSolve', ALL_ENVS, () => { + const la = tf.linalg; - it('2x3, fullMatrices = true', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), + const testWith = (L: Tensor, y: Tensor) => { + const test = (adjoint: boolean) => + { + let tril = la.bandPart(L,-1, 0), + triu = la.bandPart(L, 0,-1); + if( adjoint ) { + tril = la.adjoint(tril); + triu = la.adjoint(triu); + } + for( const lower of [true,undefined] ) + { + const x = la.triangularSolve(L,y, lower, adjoint); + const [a,b] = [y,tril.matMul(x)]; +// const [a,b] = broadcastMatrices( y, tril.matMul(x) ); + expectArraysClose(a,b); + } + const x = la.triangularSolve(L,y, /*lower=*/false, adjoint); + const [a,b] = [y,triu.matMul(x)];//broadcastMatrices( y, triu.matMul(x) ); + expectArraysClose(a,b); + + for( const lower of [false,true,undefined] ) + { + const w = tf.randomUniform(y.shape,-1,+1), + f = (L: Tensor, y: Tensor) => { + return la.triangularSolve(L,y,lower).mul(w).mean() as Scalar; + }, + [g1,g2] = tf.grads(f)([L,y]), + h1 = numDiff( (L: Tensor) => f(L,y) )(L), + h2 = numDiff( (y: Tensor) => f(L,y) )(y); + expectArraysClose(g1,h1); + expectArraysClose(g2,h2); + } + }; + test(undefined); + test(false); + test(true); + }; + + it('3x3', () => testWith( + tf.tensor2d([[1,2,3], + [4,5,6], + [7,8,9]]), + tf.tensor2d([[10,11], + [12,13], + [14,15]]) + )); + + for( let run=0; run < 16; run++ ) + { + const lShape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ), + yShape = lShape.slice(); + lShape[lShape.length-1] = lShape[lShape.length-2]; + + // RUN TEST + it(`random#${run}_${lShape.join('x')}_${yShape.join('x')}`, () => { + const ONE = tf.scalar(1), + TWO = tf.scalar(2); + const y = tf.randomUniform(yShape,-1,+1); + let L: Tensor = tf.randomUniform(lShape,-1,+1); + // SET THE DIAGONAL TO BE FAR FROM ZERO + const i = tf.range(0,lShape[lShape.length-2]).reshape([-1,1]), + j = tf.range(0,lShape[lShape.length-1]), + diag = tf.equal(i,j).cast('float32'), + magn = tf.randomNormal (lShape, /*mean=*/1,/*stdDev=*/0.1), + sign = tf.randomUniform(lShape, 0,2, 'int32') + .cast('float32').mul(TWO).sub(ONE); + L = tf.add( + diag.sub(ONE).mul(L), // <- off-diagonal + diag.mul(sign).mul(magn) // <- diagonal + ); + L = tf.clone(L); + testWith(L,y); + }); + } +}); + +describeWithFlags('qr', ALL_ENVS, () => { + const testWith = (a: Tensor) => { + const [m,n] = a.shape.slice(-2), + l = Math.min(m,n), + T = Array.from({ length: a.rank }, (_,i) => i ); + T[T.length-2] = T.length-1; + T[T.length-1] = T.length-2; + + for( const fullMatrices of [undefined,false,true] ) + { + const tril = (() => { + const [p,q] = fullMatrices ? [m,n] : [l,n], + i = tf.range(0,p).reshape([p,1]), + j = tf.range(0,q).reshape([1,q]); + return i.greater(j).cast('float32'); + })(); + const EYE = (() => { + const d = fullMatrices ? m : l; + return tf.stack( + Array.from( + { length: a.shape.slice(0,-2).reduce( (x,y) => x*y, 1 ) }, + () => tf.eye(d) + ) + ).reshape([...a.shape.slice(0,-2),d,d]); + })(); + const [q,r] = tf.linalg.qr(a,fullMatrices); + + // TEST SHAPE OF Q + expectArraysEqual( q.shape.slice(0,-1), a.shape.slice(0,-1) ); + expectArraysClose( q.shape.slice( -1), fullMatrices ? [m ] : [l ] ); + + // TEST SHAPE OF R + expectArraysEqual( r.shape.slice(0,-2), a.shape.slice(0,-2) ); + expectArraysClose( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); + + // TEST DECOMPOSITION (Q @ R == A) + expectArraysClose( q.matMul(r), a ); + + const qT = q.transpose(T); + + // TEST ORTHOGONALITY OF Q + if( fullMatrices || n >= m ) { + expectArraysClose( tf.matMul(q,qT), EYE ); + } + expectArraysClose( tf.matMul(qT,q), EYE ); + + // TEST TRIANGULARITY OF R + expectArraysEqual( tril.mul(r), tf.zeros(r.shape) ); + + // TEST GRADIENTS + const wQ = tf.randomUniform(q.shape,-1,+1), + wR = tf.randomUniform(r.shape,-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a,fullMatrices); + return tf.add( + q.mul(wQ).mean(), + r.mul(wR).mean() + ) as Scalar; + }; + const g = numDiff(f); + const h = tf.grad(f); + try { + expectArraysClose( g(a), h(a) ); + } + catch(err) { + console.log('fullMatrices:', fullMatrices); +// const [q,r] = tf.linalg.qr(a,fullMatrices); + console.log('A:'); a .print(); +// console.log('Q:'); q .print(); +// console.log('R:'); r .print(); +// console.log('G:'); g(a).print(); +// console.log('H:'); h(a).print(); + throw err; + } + } + }; + + it('1x1', () => testWith( tensor2d([[10]], [1, 1]) ) ); + + it('2x2', () => testWith( tensor2d([[ 1, 3], + [-2,-4]], [2, 2]) ) ); + + it('2x2x2', () => testWith( tensor3d([[[-1,-3], + [ 2, 4]], + [[ 1, 3], + [-2,-4]]], [2, 2, 2]) ) ); + + it('2x1x2x2', () => testWith( tensor4d([[[[-1,-3], + [ 2, 4]]], + [[[ 1, 3], + [-2,-4]]]], [2, 1, 2, 2]) ) ); + + it('3x3', () => testWith( tensor2d([[ 1, 3, 2], + [-2, 0, 7], + [ 8,-9, 4]], [3, 3]) ) ); + + it('3x2', () => testWith( tensor2d([[ 1, 2], + [ 3,-3], + [-2, 1]], [3, 2]) ) ); + + it('2x3', () => testWith( tensor2d([[ 1, 2, 3], + [-3,-2, 1]], [2, 3]) ) ); + + for( let run=0; run < 128; run++ ) + { + const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ); + it( + `random#${run}_${shape.join('x')}`, + () => testWith( tf.randomUniform(shape,-1,+1) ) ); + } + + it('Is reasonably fast', () => { + // TODO is there a better way to test this with a timeout? + const N = 128, + A = tf.randomUniform([N,N],-1,+1), + wQ = tf.randomUniform([N,N],-1,+1), + wR = tf.randomUniform([N,N],-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a); + return q.mul(wQ).mean().add( r.mul(wR).mean() ); + }; + const g = tf.grad(f); + // following hopefully prevents g(A) from being JITes/Optimized away... + expectArraysClose( g(A), g(A) ); }); it('Does not leak memory', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); + const x = tensor2d([[ 1, 3], + [-2,-4]], [2, 2]); // The first call to qr creates and keeps internal singleton tensors. // Subsequent calls should always create exactly two tensors. tf.linalg.qr(x); From c36a5aa3fe52a18f9fb23f4fec8279715882e0af Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 14:10:08 +0100 Subject: [PATCH 05/13] Fixed linting error ("use const instead of let"). --- src/ops/linalg_ops.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index acb78b2d4b..40f0e6d446 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -578,8 +578,8 @@ function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor] // DETERMINE GIVENS ROTATION cos AND sin const rIJ = R[rOff + N*i+j]; if( 0.0 === rIJ ) { CS[l++]=1.0; CS[l++]=0.0; continue; } - const rJJ = R[rOff + N*j+j]; - let norm = Math.hypot(rJJ,rIJ), + const rJJ = R[rOff + N*j+j], + norm = Math.hypot(rJJ,rIJ), c = rJJ / norm, s = rIJ / norm; CS[l++] = c; From dc0c54a9ca3f3bd4789248fa6a365d6787ce85dd Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 14:34:33 +0100 Subject: [PATCH 06/13] Made adjoint() and bandPart() tests more lenient (for WebGL) --- src/ops/linalg_ops_test.ts | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index f26078b6cc..4117819aea 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -160,7 +160,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor2d([[1,4], [2,5], [3,6]],[3,2]); - expectArraysEqual( tf.linalg.adjoint(a), aT ); + expectArraysClose( tf.linalg.adjoint(a), aT ); }); it('3x2x1', () => { const a = tf.tensor3d([[[1],[2]], @@ -169,7 +169,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor3d([[[1,2]], [[3,4]], [[5,6]]], [3,1,2]); - expectArraysEqual( tf.linalg.adjoint(a), aT ); + expectArraysClose( tf.linalg.adjoint(a), aT ); }); }); @@ -182,32 +182,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => { [5, 6, 7, 8], [9,10,11,12] ]); - expectArraysEqual( + expectArraysClose( la.bandPart(a,0,0), tf.tensor2d([[1, 0, 0, 0], [0, 6, 0, 0], [0, 0,11, 0]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,0,1), tf.tensor2d([[1, 2, 0, 0], [0, 6, 7, 0], [0, 0,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArraysClose( la.bandPart(a,0,numUpper), tf.tensor2d([[1, 2, 3, 4], [0, 6, 7, 8], @@ -215,32 +215,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => { ); } - expectArraysEqual( + expectArraysClose( la.bandPart(a,1,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [0,10,11, 0]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,1,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [0,10,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArraysClose( la.bandPart(a,1,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], @@ -250,32 +250,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => { for( const numLower of [2,3,-1,-2]) { - expectArraysEqual( + expectArraysClose( la.bandPart(a,numLower,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [9,10,11, 0]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,numLower,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [9,10,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); - expectArraysEqual( + expectArraysClose( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysEqual( + expectArraysClose( la.bandPart(a,numLower,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], @@ -405,11 +405,11 @@ describeWithFlags('qr', ALL_ENVS, () => { // TEST SHAPE OF Q expectArraysEqual( q.shape.slice(0,-1), a.shape.slice(0,-1) ); - expectArraysClose( q.shape.slice( -1), fullMatrices ? [m ] : [l ] ); + expectArraysEqual( q.shape.slice( -1), fullMatrices ? [m ] : [l ] ); // TEST SHAPE OF R expectArraysEqual( r.shape.slice(0,-2), a.shape.slice(0,-2) ); - expectArraysClose( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); + expectArraysEqual( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); // TEST DECOMPOSITION (Q @ R == A) expectArraysClose( q.matMul(r), a ); From 3d52a6d42beda132171b96836083ba6b2b79f98b Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 15:21:03 +0100 Subject: [PATCH 07/13] Switched testing if `qr` and `triangularSolve` to CPU only --- src/ops/linalg_ops_test.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 4117819aea..11be65b5c9 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -19,7 +19,7 @@ import * as tf from '../index'; import {ENV} from '../environment'; import {describeWithFlags} from '../jasmine_util'; import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; +import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; @@ -297,7 +297,7 @@ describeWithFlags('bandPart', ALL_ENVS, () => { }); }); -describeWithFlags('triangularSolve', ALL_ENVS, () => { +describeWithFlags('triangularSolve', CPU_ENVS, () => { const la = tf.linalg; const testWith = (L: Tensor, y: Tensor) => { @@ -376,7 +376,7 @@ describeWithFlags('triangularSolve', ALL_ENVS, () => { } }); -describeWithFlags('qr', ALL_ENVS, () => { +describeWithFlags('qr', CPU_ENVS, () => { const testWith = (a: Tensor) => { const [m,n] = a.shape.slice(-2), l = Math.min(m,n), From f2d42073b05a4f0b5fbbf3cebea5862efc2f3d52 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 15:38:11 +0100 Subject: [PATCH 08/13] Switced `bandPart` testing to CPU only --- src/ops/linalg_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 11be65b5c9..b0d1a84c84 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -173,7 +173,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { }); }); -describeWithFlags('bandPart', ALL_ENVS, () => { +describeWithFlags('bandPart', CPU_ENVS, () => { const la = tf.linalg; it('3x4', () => { From bae0780386b28f5f8d1b97e1333628e3676cb081 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Wed, 31 Oct 2018 18:58:25 +0100 Subject: [PATCH 09/13] Implemented custom relative-tolerance comparison for testing `qr()`. --- src/ops/linalg_ops_test.ts | 46 +++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index b0d1a84c84..822b97c546 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -81,6 +81,48 @@ const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { }); }; +function expectTensorsRelativelyClose( + actual: Tensor, expected: Tensor, rtol?: number, atol?: number +): void +{ + if( expected.shape.some( (s,i) => s !== actual.shape[i] ) ) { + throw new Error( + `Shapes [${actual.shape}] and [${expected.shape}] do not match.` + ); + } + + if( null == atol ) { atol = ENV.get('TEST_EPSILON'); } + if( null == rtol ) { rtol = ENV.get('TEST_EPSILON'); } + + const act = actual.dataSync(), + exp = expected.dataSync(); + + const isClose = (x: number, y: number) => { + x = Math.abs(x); + y = Math.abs(y); + return Math.abs(x-y) <= atol + rtol/2*(x+y); + }; + + for( let i=act.length; i-- > 0; ) { + if( ! isClose(act[i],exp[i]) ) + { + console.log( 'actual:'); actual.print(); + console.log('expected:'); expected.print(); + const idx = [], + shape = actual.shape; + for( let j=i, d=shape.length; d-- > 0; ) + { + const size = shape[d]; + idx.unshift(j % size); + j = Math.trunc(j / size); + } + throw new Error( + `actual[${idx}] = ${act[i]} != ${exp[i]} = expected[${idx}]` + ); + } + } +} + describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', () => { const xs: Tensor1D[] = [ @@ -438,7 +480,9 @@ describeWithFlags('qr', CPU_ENVS, () => { const g = numDiff(f); const h = tf.grad(f); try { - expectArraysClose( g(a), h(a) ); + // since we're already losing precision on numDiff, + // this is the best we can do + expectTensorsRelativelyClose(g(a), h(a), /*rtol=*/1e-2, /*atol=*/1e-2); } catch(err) { console.log('fullMatrices:', fullMatrices); From a52ef84f2f56dd2270cee0e530050911afaf61d7 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Fri, 2 Nov 2018 08:59:54 +0100 Subject: [PATCH 10/13] =?UTF-8?q?linalg=5Fops::qrEcoDecompKernel=20simplif?= =?UTF-8?q?ied=20and=20sped=20up=20O(m=C2=B3)=20to=20O(m=C2=B2n)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ops/linalg_ops.ts | 189 +++++++++++++++---------------------- src/ops/linalg_ops_test.ts | 95 ++++++++++--------- 2 files changed, 125 insertions(+), 159 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 40f0e6d446..7a916d99c3 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -162,29 +162,31 @@ function bandPart_( throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`); } - const $a = convertToTensor(a,'a','bandPart'); + return ENV.engine.tidy( () => { + const $a = convertToTensor(a,'a','bandPart'); - const [M,N] = $a.shape.slice(-2); + const [M,N] = $a.shape.slice(-2); - if( !(numLower <= M) ) { - throw new Error(`bandPart() check failed: numLower <= #rows.` ); - } - if( !(numUpper <= N) ) { - throw new Error(`bandPart() check failed: numUpper <= #columns.`); - } + if( !(numLower <= M) ) { + throw new Error(`bandPart() check failed: numLower <= #rows.` ); + } + if( !(numUpper <= N) ) { + throw new Error(`bandPart() check failed: numUpper <= #columns.`); + } - if( numLower < 0 ) { numLower = M; } - if( numUpper < 0 ) { numUpper = N; } + if( numLower < 0 ) { numLower = M; } + if( numUpper < 0 ) { numUpper = N; } - const i = range(0,M, 1, 'int32').reshape([-1,1]), - j = range(0,N, 1, 'int32'); + const i = range(0,M, 1, 'int32').reshape([-1,1]), + j = range(0,N, 1, 'int32'); - const inBand = logicalAnd( - sub(i,j).lessEqual( scalar(numLower,'int32') ), - sub(j,i).lessEqual( scalar(numUpper,'int32') ) - ).cast($a.dtype); + const inBand = logicalAnd( + sub(i,j).lessEqual( scalar(numLower,'int32') ), + sub(j,i).lessEqual( scalar(numUpper,'int32') ) + ).cast($a.dtype); - return mul($a,inBand); + return mul($a,inBand); + }); } function triangularSolveKernel( @@ -423,112 +425,75 @@ function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor] // tslint:enable qShape = Array.from( a.shape ), rShape = Array.from( qShape ), - [N,M] = qShape.slice(-2); - rShape[rShape.length-2] = M; + [M,N] = qShape.slice(-2); + rShape[rShape.length-2] = N; Object.freeze(qShape); Object.freeze(rShape); const Q = DTypeArray.from( a.dataSync() ); a = undefined; - const R = new DTypeArray(Q.length/N*M), - cs = new DTypeArray(M*2),// <- APPLY M ROTATIONS TO Q AT ONCE - r = (() => { - try { return cs.subarray(M); } - catch(e) { return new DTypeArray(M); } - })(); // <- space to temp. store rows of R not contained in result + const R = new DTypeArray( Q.length/M*N ), + cs = new DTypeArray( 2*M*N - N*(N+1) );// <- MEMOIZE ROTATIONS for( let rOff=0, - qOff=0; qOff < Q.length; qOff += N*M, - rOff += M*M + qOff=0; qOff < Q.length; qOff += M*N, + rOff += N*N ) { - // HANDLE ENTRIES CONTAINED IN THE RESULT - for( let i=0; i < M; i++ ) - { - // COPY FROM Q TO R AND INIT Q - for( let j=0; j < M; j++ ) { - R[rOff+M*i+j] = Q[qOff+M*i+j]; - Q[qOff+M*i+j] = i !== j ? 0.0 : 1.0; - } + let csi = 0; - for( let j=0; j < i; j++ ) - { // USE GIVENS ROTATION TO ELIMINATE ELEMENT R_ji - const rIJ = R[rOff+M*i+j]; if( rIJ === 0.0 ){cs[2*j+0]=1.0; - cs[2*j+1]=0.0; continue;} - const rJJ = R[rOff+M*j+j], - norm = Math.hypot(rJJ,rIJ), - c = rJJ / norm, - s = rIJ / norm; - cs[2*j+0] = c; - cs[2*j+1] = s; - R[rOff + M*i+j] = 0.0; - R[rOff + M*j+j] = norm; - // ROTATE ROW i AND j IN R - for( let k=j; ++k < M; ) { - const ik = rOff + M*i+k, rIK = R[ik], - jk = rOff + M*j+k, rJK = R[jk]; - R[ik] = c*rIK - s*rJK; - R[jk] = s*rIK + c*rJK; - } + for( let i=1; i < M; i++ ) { const J = Math.min(i,N); + for( let j=0; j < J; j++ ) + { // DETERMINE GIVENS ROTATION cos AND sin + const rIJ = Q[qOff + N*i+j]; if( 0.0 === rIJ ) {cs[csi++]=1.0; + cs[csi++]=0.0; continue;} + const rJJ = Q[qOff + N*j+j], + norm = Math.hypot(rJJ,rIJ), + c = rJJ / norm, + s = rIJ / norm; + cs[csi++] = c; + cs[csi++] = s; + Q[qOff + N*j+j] = norm; + Q[qOff + N*i+j] = 0; + // ROTATE ROWS IN R (WHICH IS CURRENTLY STORED IN Q) + for( let k=j; ++k < N; ) + { const rJK = Q[qOff + N*j+k], + rIK = Q[qOff + N*i+k]; + Q[qOff + N*j+k] = s*rIK + c*rJK; + Q[qOff + N*i+k] = c*rIK - s*rJK; } + }} - // ROTATE COLUMNS IN Q (BUNDLED FOR BETTER CACHE LOCALITY) - for( let k=0; k <= i; k++ ) { - for( let j=0; j < i; j++ ) { - const c = cs[2*j+0], - s = cs[2*j+1], - ki = qOff + M*k+i, qKI = Q[ki], - kj = qOff + M*k+j, qKJ = Q[kj]; - Q[ki] = c*qKI - s*qKJ; - Q[kj] = s*qKI + c*qKJ; - }} - } - // HANDLE REMAINING ENTRIES NOT CONTAINED IN THE RESULT - for( let i=M; i < N; i++ ) - { - // INIT r - for( let j=0; j < M; j++ ) { - r[j] = Q[qOff+M*i+j]; Q[qOff+M*i+j] = 0.0; - } + assert( csi === cs.length, `WTF: ${csi} !== ${cs.length}` ); - // USE GIVENS ROTATIONS TO ELIMINATE ELEMENT r completely - for( let j=0; j < M; j++ ) - { - const rJ = r[j]; if( rJ === 0.0 ) { cs[2*j+0]=1.0; - cs[2*j+1]=0.0; continue; } - const rJJ = R[rOff+M*j+j], - norm = Math.hypot(rJJ,rJ), - c = rJJ / norm, - s = rJ / norm; - R[rOff+M*j+j] = norm; - // ROTATE ROW i AND j IN R - for( let k=j; ++k < M; ) { - const jk = rOff + M*j+k, rJK = R[jk]; - R[jk] = s*r[k] + c*rJK; - r[ k] = c*r[k] - s*rJK; - } - cs[2*j+0] = c; - cs[2*j+1] = s; + // COPY R FROM Q -> R + for( let i=0; i < N; i++ ) { + for( let j=i; j < N; j++ ) { + R[rOff + N*i+j] = Q[qOff + N*i+j]; + Q[qOff + N*i+j] = i !== j ? 0.0 : 1.0; + }} + + // COMPUTE Q + for( let i=M; --i > 0; ) { const J = Math.min(i,N); + for( let j=J; j-- > 0; ) + { const s = cs[--csi], + c = cs[--csi]; + // ROTATE ROWS IN Q + for( let k=N; k-- > 0; ) + { const qJK = Q[qOff + N*j+k], + qIK = Q[qOff + N*i+k]; + Q[qOff + N*j+k] = c*qJK - s*qIK; + Q[qOff + N*i+k] = s*qJK + c*qIK; } + }} - // ROTATE COLUMNS IN Q - for( let k=0; k <= i; k++ ) { let QK = i !== k ? 0.0 : 1.0; - for( let j=0; j < M; j++ ) { - const c = cs[2*j+0], - s = cs[2*j+1], qK = QK, - kj = qOff + M*k+j, qKJ = Q[kj]; - QK = c*qK - s*qKJ; - Q[kj]= s*qK + c*qKJ; - }} - } + assert( csi === 0, `WTF: ${csi} !== 0` ); } - { - const q = Tensor.make(qShape, { values: Q }, dtype); - const r = Tensor.make(rShape, { values: R }, dtype); + const q = Tensor.make(qShape, { values: Q }, dtype); + const r = Tensor.make(rShape, { values: R }, dtype); - return [q,r]; - } + return [q,r]; } /** Computes the full QR Decomposition an memoizes the @@ -545,7 +510,7 @@ function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor] `Error in linalg.qr: complex dtype not supported.` ); - const dtype ='float32', + const dtype = 'float32', // tslint:disable DTypeArray = Float32Array, // tslint:enable @@ -750,12 +715,12 @@ function qrFullBackpropKernel( Q[qOff + M*i+k] = s*qJK + c*qIK; } - const rIJ = R[rOff + N*i+j], - rJJ = R[rOff + N*j+j], - dCdJ = + rIJ / norm * rIJ / norm**2, - dCdI = - rIJ / norm * rJJ / norm**2, - dSdJ = - rJJ / norm * rIJ / norm**2, - dSdI = + rJJ / norm * rJJ / norm**2; + const rIJ = R[rOff + N*i+j] / norm, + rJJ = R[rOff + N*j+j] / norm, + dCdJ = +rIJ*rIJ / norm, + dCdI = -rIJ*rJJ / norm, + dSdJ = -rJJ*rIJ / norm, + dSdI = +rJJ*rJJ / norm; let dj = 0.0, di = 0.0; diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 822b97c546..d7d0918412 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -44,37 +44,33 @@ const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { const aData = Float32Array.from( a.dataSync() ); - const fVal = () => { - const scalar = f(a); - if( scalar.rank !== 0 ) { - throw new Error('f() returned a non-scalar value.'); - } - return scalar.dataSync()[0]; - }; + const eps = Math.sqrt( ENV.get('EPSILON') ); return ENV.engine.tidy(() => { - a = Tensor.make(a.shape, {values: aData}); const dA = new Float32Array( aData.length ); for( let i=0; i < aData.length; i++ ) { // use central difference - const aI = aData[i], - delta = Math.max( Math.abs(aI) * 2**-12, 2**-12 ), - aHi = aI + delta, - aLo = aI - delta; - - // DISPOSAL (HOPEFULLY) REMOVES DATA FROM GPU AND FORCES REUPLOAD - aData[i] = aLo; a.dispose(); - a = Tensor.make(a.shape, {values: aData}); - const fLo = fVal(); - - aData[i] = aHi; a.dispose(); - a = Tensor.make(a.shape, {values: aData}); - const fHi = fVal(); - - dA[i] = (fHi - fLo) / (aHi - aLo); - aData[i] = aI; + const x = aData[i], + h = Math.max( Math.abs(x)*eps, eps ); + + const g = ( x: number ) => ENV.engine.tidy( () => { + aData[i] = x; + + const b = Tensor.make(a.shape, {values: aData}); + const scalar = f(b); + + if( scalar.rank !== 0 ) { + throw new Error('f() returned a non-scalar value.'); + } + + return scalar.dataSync()[0]; + }); + + // https://www.geometrictools.com/Documentation/FiniteDifferences.pdf + dA[i] = (-g(x+2*h) + 8*g(x+h) - 8*g(x-h) + g(x-2*h) ) / (12*h); + aData[i] = x; // <- undo modifications } return Tensor.make(a.shape,{values: dA}); @@ -202,7 +198,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor2d([[1,4], [2,5], [3,6]],[3,2]); - expectArraysClose( tf.linalg.adjoint(a), aT ); + expectArraysEqual( tf.linalg.adjoint(a), aT ); }); it('3x2x1', () => { const a = tf.tensor3d([[[1],[2]], @@ -211,7 +207,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor3d([[[1,2]], [[3,4]], [[5,6]]], [3,1,2]); - expectArraysClose( tf.linalg.adjoint(a), aT ); + expectArraysEqual( tf.linalg.adjoint(a), aT ); }); }); @@ -224,32 +220,32 @@ describeWithFlags('bandPart', CPU_ENVS, () => { [5, 6, 7, 8], [9,10,11,12] ]); - expectArraysClose( + expectArraysEqual( la.bandPart(a,0,0), tf.tensor2d([[1, 0, 0, 0], [0, 6, 0, 0], [0, 0,11, 0]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,0,1), tf.tensor2d([[1, 2, 0, 0], [0, 6, 7, 0], [0, 0,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,0,2), tf.tensor2d([[1, 2, 3, 0], [0, 6, 7, 8], [0, 0,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysClose( + expectArraysEqual( la.bandPart(a,0,numUpper), tf.tensor2d([[1, 2, 3, 4], [0, 6, 7, 8], @@ -257,32 +253,32 @@ describeWithFlags('bandPart', CPU_ENVS, () => { ); } - expectArraysClose( + expectArraysEqual( la.bandPart(a,1,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [0,10,11, 0]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,1,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [0,10,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [0,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysClose( + expectArraysEqual( la.bandPart(a,1,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], @@ -292,32 +288,32 @@ describeWithFlags('bandPart', CPU_ENVS, () => { for( const numLower of [2,3,-1,-2]) { - expectArraysClose( + expectArraysEqual( la.bandPart(a,numLower,0), tf.tensor2d([[1, 0, 0, 0], [5, 6, 0, 0], [9,10,11, 0]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,numLower,1), tf.tensor2d([[1, 2, 0, 0], [5, 6, 7, 0], [9,10,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); - expectArraysClose( + expectArraysEqual( la.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], [5, 6, 7, 8], [9,10,11,12]]) ); for( const numUpper of [3,4,-1,-2] ) { - expectArraysClose( + expectArraysEqual( la.bandPart(a,numLower,numUpper), tf.tensor2d([[1, 2, 3, 4], [5, 6, 7, 8], @@ -389,9 +385,9 @@ describeWithFlags('triangularSolve', CPU_ENVS, () => { [14,15]]) )); - for( let run=0; run < 16; run++ ) + for( let run=0; run < 128; run++ ) { - const lShape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ), + const lShape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ), yShape = lShape.slice(); lShape[lShape.length-1] = lShape[lShape.length-2]; @@ -454,7 +450,14 @@ describeWithFlags('qr', CPU_ENVS, () => { expectArraysEqual( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); // TEST DECOMPOSITION (Q @ R == A) - expectArraysClose( q.matMul(r), a ); + try { + expectArraysClose( q.matMul(r), a ); + } catch(err) { + console.log('A'); a.print(); + console.log('Q'); q.print(); + console.log('R'); r.print(); + throw err; + } const qT = q.transpose(T); @@ -480,8 +483,6 @@ describeWithFlags('qr', CPU_ENVS, () => { const g = numDiff(f); const h = tf.grad(f); try { - // since we're already losing precision on numDiff, - // this is the best we can do expectTensorsRelativelyClose(g(a), h(a), /*rtol=*/1e-2, /*atol=*/1e-2); } catch(err) { @@ -525,7 +526,7 @@ describeWithFlags('qr', CPU_ENVS, () => { for( let run=0; run < 128; run++ ) { - const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ); + const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ); it( `random#${run}_${shape.join('x')}`, () => testWith( tf.randomUniform(shape,-1,+1) ) From a8e605c91fffc10fd61fd0e004c29f9269fb7715 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Fri, 2 Nov 2018 10:06:18 +0100 Subject: [PATCH 11/13] Added check NaN/Infinity check to bandPart. --- src/ops/linalg_ops.ts | 12 ++++++++++++ src/ops/linalg_ops_test.ts | 34 ++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 7a916d99c3..97d3debdab 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -131,6 +131,10 @@ function adjoint_( a: T|TensorLike ): T { let $a = convertToTensor(a,'a','bandPart'); + if( $a.rank < 2 ) { + throw new Error(`adjoint(): a.rank = ${$a.rank} < 2.`); + } + const axes = Array.from( $a.shape, (_,i) => i ); axes[axes.length-2] = axes.length-1; axes[axes.length-1] = axes.length-2; @@ -165,6 +169,14 @@ function bandPart_( return ENV.engine.tidy( () => { const $a = convertToTensor(a,'a','bandPart'); + if( $a.rank < 2 ) { + throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`); + } + + if( ! isFinite($a.abs().max().dataSync()[0]) ) { + throw new Error(`bandPart(): NaN and Infinity not yet supported.`); + } + const [M,N] = $a.shape.slice(-2); if( !(numLower <= M) ) { diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index d7d0918412..50f4f5870a 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -198,7 +198,8 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor2d([[1,4], [2,5], [3,6]],[3,2]); - expectArraysEqual( tf.linalg.adjoint(a), aT ); + // FIXME: shouldn't tf.transpose be lossless? + expectArraysClose( tf.linalg.adjoint(a), aT ); }); it('3x2x1', () => { const a = tf.tensor3d([[[1],[2]], @@ -207,13 +208,17 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor3d([[[1,2]], [[3,4]], [[5,6]]], [3,1,2]); - expectArraysEqual( tf.linalg.adjoint(a), aT ); + // FIXME: shouldn't tf.transpose be lossless? + expectArraysClose( tf.linalg.adjoint(a), aT ); }); }); -describeWithFlags('bandPart', CPU_ENVS, () => { +describeWithFlags('bandPart', ALL_ENVS, () => { const la = tf.linalg; + // FIXME: shouldn't 1*x be lossless? It's even in the IEEE spec somewhere... + const expectArraysEqual = expectArraysClose; + it('3x4', () => { const a = tf.tensor2d([ [1, 2, 3, 4], @@ -321,17 +326,18 @@ describeWithFlags('bandPart', CPU_ENVS, () => { ); } } - - for( const numUpper of [0,1,2,3,4,-1,-2] ) { - for( const numLower of [0,1,2,3, -1,-2] ) { - const w = tf.randomUniform(a.shape), - f = (x: Tensor) => { - return la.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar; - }, - g = numDiff(f), - h = tf.grad(f); - expectArraysClose( g(a), h(a) ); - }} +// following test is only required for custom backend implementations +// +// for( const numUpper of [0,1,2,3,4,-1,-2] ) { +// for( const numLower of [0,1,2,3, -1,-2] ) { +// const w = tf.randomUniform(a.shape), +// f = (x: Tensor) => { +// return la.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar; +// }, +// g = numDiff(f), +// h = tf.grad(f); +// expectArraysClose( g(a), h(a) ); +// }} }); }); From fb1c0d0dc1a085624eda216e6c761016509b44de Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Fri, 9 Nov 2018 17:19:42 +0100 Subject: [PATCH 12/13] Turned important assertions into errors, added code samples. --- src/ops/linalg_ops.ts | 244 +++++++++++++++++++++++++------------ src/ops/linalg_ops_test.ts | 24 ++-- 2 files changed, 181 insertions(+), 87 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 87deda6e82..c401a2dd14 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -114,8 +114,21 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { /** * Conjugates a tensor of matrices and then transposes the last two dimensions. - * The adjoint is also commonly known as the Hermitian Transpose. Does not yet - * work for complex data types. + * The adjoint is also commonly known as the Hermitian Transpose. + * + * ```js + * const a = tf.tensor3d([[[1, 2], + * [3, 4]], + * [[5, 6], + * [7, 8]]]); + * const aT = tf.linalg.adjoint(a); + * aT.print(); + * // Output: + * // [[[1, 3], + * // [2, 4]], + * // [[5, 7], + * // [6, 8]]] + * ``` * * @param a Tensor of shape [...,M,N]. The tensor of matrices that is to be * tranposed. @@ -148,7 +161,35 @@ function adjoint_( a: T|TensorLike ): T /** * Copies a tensor of matrices, setting everything outside a central band - * in each matrix to zero. + * in each matrix to zero. Does not yet support Infinity or NaN entries. + * + * ```js + * const a = tf.tensor2d([[11, 12, 13, 14], + * [21, 22, 23, 24], + * [31, 32, 33, 34], + * [41, 42, 43, 44]]); + * tf.linalg.bandPart(a,0,2); + * // Output: + * // [[11, 12, 13, 0], + * // [ 0, 22, 23, 24], + * // [ 0, 0, 33, 34], + * // [ 0, 0, 0, 44]] + * + * tf.linalg.bandPart(a,1,-1); + * // Output: + * // [[11, 12, 13, 14], + * // [21, 22, 23, 24], + * // [ 0, 32, 33, 34], + * // [ 0, 0, 43, 44]] + * ``` + * + * @param a Tensor of matrices from which the band part is extracted. + * @param numLower The number of subdiagonal lines to be copied. + * If set to `-1`, all entries below the diagonal are + * copied. + * @param numUpper The number of superdiagonal lines to be copied. + * If set to `-1`, all entries above the diagonal are + * copied. */ /** * @doc {heading:'Operations', @@ -236,7 +277,7 @@ function triangularSolveKernel( } const - rank = Math.max(l.rank, y.rank), + rank = l.rank, xShape = Array.from(l.shape); xShape[rank-2] = I; xShape[rank-1] = J; @@ -326,7 +367,7 @@ function triangularSolveKernel( /** * Solves a triangular linear equation system (LES). * - * @param l The triangular matrix of the . + * @param l The triangular matrix of the LES. * @param y The right-hand-side of the LES. * @param lower If set to `true`, `l` is interpreted as lower triangular * matrix. The strict upper triangular entries are ignore. @@ -417,19 +458,15 @@ function triangularSolve_( */ function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor] { - assert( - a.rank >= 2, - `qr(): input must have rank >= 2, got rank ${a.rank}.` - ); - assert( - ! a.dtype.startsWith('complex'), - `qr(): complex dtype not supported.` - ); - assert( - a.shape[a.rank-2] >= a.shape[a.rank-1], - `qr(): a.shape[-2] = ${a.shape[a.rank-2]}` - + ` < ${a.shape[a.rank-1]} = a.shape[-1].` - ); + if( a.rank < 2 ) { + throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`); + } + if( a.dtype !== 'float32' ) { + throw new Error(`qrEco(): only float32 currently supported as dtype.`); + } + if( a.shape[a.rank-2] < a.shape[a.rank-1] ) { + throw new Error(`qrEco(): a must have at least as many rows as columns`); + } const dtype = 'float32', // tslint:disable @@ -513,14 +550,12 @@ function qrEcoDecompKernel( a: Tensor ): [Tensor,Tensor] */ function qrFullDecompKernel( a: Tensor ): [Tensor,Tensor,Tensor] { - assert( - a.rank >= 2, - `Error in linalg.qr: input must have rank >= 2, got rank ${a.rank}.` - ); - assert( - ! a.dtype.startsWith('complex'), - `Error in linalg.qr: complex dtype not supported.` - ); + if( a.rank < 2 ) { + throw new Error(`qrEco(): input must have rank >= 2, got rank ${a.rank}.`); + } + if( a.dtype !== 'float32' ) { + throw new Error(`qrEco(): only float32 currently supported as dtype.`); + } const dtype = 'float32', // tslint:disable @@ -603,68 +638,119 @@ function qrFullBackpropKernel( q: Tensor, dq: Tensor, r: Tensor, dr: Tensor, cs: Tensor ): Tensor { - assert( q.rank === dq.rank, `q.rank == ${q.rank} != ${dq.rank} == dq.rank` ); - assert( q.rank === dr.rank, `q.rank == ${q.rank} != ${dr.rank} == dr.rank` ); - assert( q.rank === r.rank, `q.rank == ${q.rank} != ${ r.rank} == r.rank` ); + if( q.rank !== dq.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${dq.rank} == dq.rank` + ); + } + if( q.rank !== dr.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${dr.rank} == dr.rank` + ); + } + if( q.rank !== r.rank ) { + throw new Error( + `qrFullBackprop(): q.rank == ${q.rank} != ${ r.rank} == r.rank` + ); + } + + if( cs.rank !== 1 ) { + throw new Error(`qrFullBackprop(): cs.rank == ${cs.rank} != 1`); + } - assert( cs.rank === 1, `cs.rank == ${cs.rank} != 1` ); + const rank = q.rank; + + if( rank < 2 ) { + throw new Error( + `qrFullBackprop(): input must have rank >= 2, got rank ${rank}.` + ); + } - for( let i=q.rank-2; i-- > 0; ) + for( let i=rank-2; i-- > 0; ) { - assert( - q.shape[i] === dq.shape[i], - `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]` + if( q.shape[i] !== dq.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]` + ); + } + if( q.shape[i] !== dr.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]` + ); + } + if( q.shape[i] !== r.shape[i] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]` + ); + } + } + + if( q.shape[rank-2] !== q.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` ); - assert( - q.shape[i] === dr.shape[i], - `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]` + } + if( q.shape[rank-2] !== dq.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]` ); - assert( - q.shape[i] === r.shape[i], - `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]` + } + if( q.shape[rank-2] !== dq.shape[rank-2] ) { + throw new Error( + 'qrFullBackprop(): ' + + `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]` + ); + } + if( r.shape[rank-2] !== q.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` + ); + } + if( r.shape[rank-1] !== dr.shape[rank-1] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]` + ); + } + if( r.shape[rank-2] !== dr.shape[rank-2] ) { + throw new Error( + 'qrFullBackprop(): ' + + `r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]` ); } - const rank = q.rank; - assert( - q.shape[rank-2] === q.shape[rank-1], - `q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` - ); - assert( - q.shape[rank-2] === dq.shape[rank-1], - `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]` - ); - assert( - q.shape[rank-2] === dq.shape[rank-2], - `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]` - ); - - assert( - r.shape[rank-2] === q.shape[rank-1], - `r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` - ); - assert( - r.shape[rank-1] === dr.shape[rank-1], - `r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]` - ); - assert( - r.shape[rank-2] === dr.shape[rank-2], - `r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]` - ); - assert( - q.dtype === dq.dtype, `q.dtype == ${q.dtype} == ${ dq.dtype} == dq.dtype` - ); - assert( - q.dtype === dr.dtype, `q.dtype == ${q.dtype} == ${ dr.dtype} == dr.dtype` - ); - assert( - q.dtype === r.dtype, `q.dtype == ${q.dtype} == ${ r.dtype} == r.dtype` - ); - assert( - q.dtype === cs.dtype, `q.dtype == ${q.dtype} == ${cs.dtype} == cs.dtype` - ); + if( q.dtype !== dq.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${dq.dtype} == dq.dtype` + ); + } + if( q.dtype !== dr.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${dr.dtype} == dr.dtype` + ); + } + if( q.dtype !== r.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${r.dtype} == r.dtype` + ); + } + if( q.dtype !== cs.dtype ) { + throw new Error( + `qrFullBackprop(): q.dtype == ${q.dtype} != ${cs.dtype} == cs.dtype` + ); + } - assert( ! q.dtype.startsWith('complex'), `Complex dtype not supported.`); + if( q.dtype !== 'float32' ) { + throw new Error( + `qrFullBackprop(): only float32 currently supported as dtype.` + ); + } const dtype ='float32', // tslint:disable diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 50f4f5870a..af78aee1f7 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -30,14 +30,15 @@ const randInt = (from: number, until: number) => { }; /** - * Computes the gradients using finite differences. + * Computes the gradients using finite differences. Current + * implmentation uses an O(h⁴) central difference. * * SEE: https://en.wikipedia.org/wiki/Finite_difference * * FIXME this is terribly imprecise... wish there was * double precision support *hint hint*. */ -const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { +const numDiff = (f: (x: Tensor) => Scalar) => (a: Tensor) => { if( a.dtype !== 'float32' ) { throw new Error(`numDiff(): dtype=${a.dtype} not supported.`); } @@ -77,6 +78,10 @@ const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { }); }; +/** + * An tensor equivalency assertion that uses a comparison operator + * that is very similar to NumPy's `is_close()` function. + */ function expectTensorsRelativelyClose( actual: Tensor, expected: Tensor, rtol?: number, atol?: number ): void @@ -199,6 +204,7 @@ describeWithFlags('adjoint', ALL_ENVS, () => { [2,5], [3,6]],[3,2]); // FIXME: shouldn't tf.transpose be lossless? + // Yet this fails on Travis with `expectArraysEqual`... expectArraysClose( tf.linalg.adjoint(a), aT ); }); it('3x2x1', () => { @@ -208,7 +214,6 @@ describeWithFlags('adjoint', ALL_ENVS, () => { aT = tf.tensor3d([[[1,2]], [[3,4]], [[5,6]]], [3,1,2]); - // FIXME: shouldn't tf.transpose be lossless? expectArraysClose( tf.linalg.adjoint(a), aT ); }); }); @@ -216,7 +221,9 @@ describeWithFlags('adjoint', ALL_ENVS, () => { describeWithFlags('bandPart', ALL_ENVS, () => { const la = tf.linalg; - // FIXME: shouldn't 1*x be lossless? It's even in the IEEE spec somewhere... + // FIXME: shouldn't 1*x be lossless? + // It's even in the IEEE spec somewhere... + // Yet this fails on Travis with `expectArraysEqual`... const expectArraysEqual = expectArraysClose; it('3x4', () => { @@ -357,11 +364,11 @@ describeWithFlags('triangularSolve', CPU_ENVS, () => { { const x = la.triangularSolve(L,y, lower, adjoint); const [a,b] = [y,tril.matMul(x)]; -// const [a,b] = broadcastMatrices( y, tril.matMul(x) ); expectArraysClose(a,b); } const x = la.triangularSolve(L,y, /*lower=*/false, adjoint); - const [a,b] = [y,triu.matMul(x)];//broadcastMatrices( y, triu.matMul(x) ); + const [a,b] = [y,triu.matMul(x)]; +// const [a,b] = broadcastMatrices( y, triu.matMul(x) ); expectArraysClose(a,b); for( const lower of [false,true,undefined] ) @@ -424,6 +431,7 @@ describeWithFlags('qr', CPU_ENVS, () => { const testWith = (a: Tensor) => { const [m,n] = a.shape.slice(-2), l = Math.min(m,n), + // Indices of matrix transpose. T = Array.from({ length: a.rank }, (_,i) => i ); T[T.length-2] = T.length-1; T[T.length-1] = T.length-2; @@ -493,8 +501,8 @@ describeWithFlags('qr', CPU_ENVS, () => { } catch(err) { console.log('fullMatrices:', fullMatrices); -// const [q,r] = tf.linalg.qr(a,fullMatrices); console.log('A:'); a .print(); +// const [q,r] = tf.linalg.qr(a,fullMatrices); // console.log('Q:'); q .print(); // console.log('R:'); r .print(); // console.log('G:'); g(a).print(); @@ -530,7 +538,7 @@ describeWithFlags('qr', CPU_ENVS, () => { it('2x3', () => testWith( tensor2d([[ 1, 2, 3], [-3,-2, 1]], [2, 3]) ) ); - for( let run=0; run < 128; run++ ) + for( let run=0; run < 128*1024; run++ ) { const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ); it( From 1bf683d21200d02ba75bfed00b98021c3c6aede1 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Fri, 9 Nov 2018 17:24:16 +0100 Subject: [PATCH 13/13] Reduced the number of tests. --- src/ops/linalg_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index af78aee1f7..a852c95439 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -538,7 +538,7 @@ describeWithFlags('qr', CPU_ENVS, () => { it('2x3', () => testWith( tensor2d([[ 1, 2, 3], [-3,-2, 1]], [2, 3]) ) ); - for( let run=0; run < 128*1024; run++ ) + for( let run=0; run < 128; run++ ) { const shape = Array.from({ length: randInt(2,5) }, () => randInt(1,7) ); it(