001/*- 002 * Copyright 2017 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.Arrays; 013import java.util.List; 014 015/** 016 * Class to run over a single dataset with NumPy broadcasting to promote shapes 017 * which have lower rank and outputs to a second dataset 018 * @since 2.1 019 */ 020public class BooleanNullIterator extends BooleanIteratorBase { 021 022 /** 023 * @param a 024 * @param o (can be null for new dataset, or a) 025 */ 026 public BooleanNullIterator(Dataset a, Dataset o) { 027 this(a, o, false); 028 } 029 030 /** 031 * @param a 032 * @param o (can be null for new dataset, or a) 033 * @param createIfNull if true create the output dataset if that is null 034 * (by default, can create float or complex datasets) 035 */ 036 public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull) { 037 this(a, o, createIfNull, false, true); 038 } 039 040 /** 041 * @param a 042 * @param o (can be null for new dataset, or a) 043 * @param createIfNull if true create the output dataset if that is null 044 * @param allowInteger if true, can create integer datasets 045 * @param allowComplex if true, can create complex datasets 046 */ 047 public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) { 048 super(true, a, null, o); 049 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef()); 050 051 BroadcastUtils.checkItemSize(a, o); 052 053 maxShape = fullShapes.remove(0); 054 055 oStride = null; 056 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 057 throw new IllegalArgumentException("Output does not match broadcasted shape"); 058 } 059 060 aShape = fullShapes.remove(0); 061 062 int rank = maxShape.length; 063 endrank = rank - 1; 064 065 aDataset = a.reshape(aShape); 066 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 067 if (outputA) { 068 oStride = aStride; 069 oDelta = null; 070 oStep = 0; 071 } else if (o != null) { 072 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 073 oDelta = new int[rank]; 074 oStep = o.getElementsPerItem(); 075 } else if (createIfNull) { 076 int is = aDataset.getElementsPerItem(); 077 Class<? extends Dataset> dc = aDataset.getClass(); 078 if (aDataset.isComplex() && !allowComplex) { 079 is = 1; 080 dc = InterfaceUtils.getBestFloatInterface(dc); 081 } else if (!aDataset.hasFloatingPointElements() && !allowInteger) { 082 dc = InterfaceUtils.getBestFloatInterface(dc); 083 } 084 oDataset = DatasetFactory.zeros(is, dc, maxShape); 085 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 086 oDelta = new int[rank]; 087 oStep = is; 088 } else { 089 oDelta = null; 090 oStep = 0; 091 } 092 093 pos = new int[rank]; 094 aDelta = new int[rank]; 095 for (int j = endrank; j >= 0; j--) { 096 aDelta[j] = aStride[j] * aShape[j]; 097 if (oDelta != null) { 098 oDelta[j] = oStride[j] * maxShape[j]; 099 } 100 } 101 102 aStart = aDataset.getOffset(); 103 aMax = endrank < 0 ? aStep + aStart : Integer.MIN_VALUE; 104 oStart = oDelta == null ? 0 : oDataset.getOffset(); 105 reset(); 106 } 107 108 @Override 109 public boolean hasNext() { 110 int j = endrank; 111 for (; j >= 0; j--) { 112 pos[j]++; 113 index += aStride[j]; 114 if (oDelta != null) { 115 oIndex += oStride[j]; 116 } 117 if (pos[j] >= maxShape[j]) { 118 pos[j] = 0; 119 index -= aDelta[j]; // reset these dimensions 120 if (oDelta != null) { 121 oIndex -= oDelta[j]; 122 } 123 } else { 124 break; 125 } 126 } 127 if (j == -1) { 128 if (endrank >= 0) { 129 return false; 130 } 131 index += aStep; 132 if (oDelta != null) { 133 oIndex += oStep; 134 } 135 } 136 if (outputA) { 137 oIndex = index; 138 } 139 140 if (index == aMax) { 141 return false; 142 } 143 144 return true; 145 } 146 147 @Override 148 public void reset() { 149 for (int i = 0; i <= endrank; i++) { 150 pos[i] = 0; 151 } 152 153 if (endrank >= 0) { 154 pos[endrank] = -1; 155 index = aStart - aStride[endrank]; 156 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 157 } else { 158 index = aStart - aStep; 159 oIndex = oStart - oStep; 160 } 161 } 162}