001/*-
002 * Copyright 2017 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.Arrays;
013import java.util.List;
014
015/**
016 * Class to run over a single dataset with NumPy broadcasting to promote shapes
017 * which have lower rank and outputs to a second dataset
018 * @since 2.1
019 */
020public class BooleanNullIterator extends BooleanIteratorBase {
021
022        /**
023         * @param a
024         * @param o (can be null for new dataset, or a)
025         */
026        public BooleanNullIterator(Dataset a, Dataset o) {
027                this(a, o, false);
028        }
029
030        /**
031         * @param a
032         * @param o (can be null for new dataset, or a)
033         * @param createIfNull if true create the output dataset if that is null
034         * (by default, can create float or complex datasets)
035         */
036        public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull) {
037                this(a, o, createIfNull, false, true);
038        }
039
040        /**
041         * @param a
042         * @param o (can be null for new dataset, or a)
043         * @param createIfNull if true create the output dataset if that is null
044         * @param allowInteger if true, can create integer datasets
045         * @param allowComplex if true, can create complex datasets
046         */
047        public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) {
048                super(true, a, null, o);
049                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef());
050
051                BroadcastUtils.checkItemSize(a, o);
052
053                maxShape = fullShapes.remove(0);
054
055                oStride = null;
056                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
057                        throw new IllegalArgumentException("Output does not match broadcasted shape");
058                }
059
060                aShape = fullShapes.remove(0);
061
062                int rank = maxShape.length;
063                endrank = rank - 1;
064
065                aDataset = a.reshape(aShape);
066                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
067                if (outputA) {
068                        oStride = aStride;
069                        oDelta = null;
070                        oStep = 0;
071                } else if (o != null) {
072                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
073                        oDelta = new int[rank];
074                        oStep = o.getElementsPerItem();
075                } else if (createIfNull) {
076                        int is = aDataset.getElementsPerItem();
077                        Class<? extends Dataset> dc = aDataset.getClass();
078                        if (aDataset.isComplex() && !allowComplex) {
079                                is = 1;
080                                dc = InterfaceUtils.getBestFloatInterface(dc);
081                        } else if (!aDataset.hasFloatingPointElements() && !allowInteger) {
082                                dc = InterfaceUtils.getBestFloatInterface(dc);
083                        }
084                        oDataset = DatasetFactory.zeros(is, dc, maxShape);
085                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
086                        oDelta = new int[rank];
087                        oStep = is;
088                } else {
089                        oDelta = null;
090                        oStep = 0;
091                }
092
093                pos = new int[rank];
094                aDelta = new int[rank];
095                for (int j = endrank; j >= 0; j--) {
096                        aDelta[j] = aStride[j] * aShape[j];
097                        if (oDelta != null) {
098                                oDelta[j] = oStride[j] * maxShape[j];
099                        }
100                }
101
102                aStart = aDataset.getOffset();
103                aMax = endrank < 0 ? aStep + aStart : Integer.MIN_VALUE;
104                oStart = oDelta == null ? 0 : oDataset.getOffset();
105                reset();
106        }
107
108        @Override
109        public boolean hasNext() {
110                int j = endrank;
111                for (; j >= 0; j--) {
112                        pos[j]++;
113                        index += aStride[j];
114                        if (oDelta != null) {
115                                oIndex += oStride[j];
116                        }
117                        if (pos[j] >= maxShape[j]) {
118                                pos[j] = 0;
119                                index -= aDelta[j]; // reset these dimensions
120                                if (oDelta != null) {
121                                        oIndex -= oDelta[j];
122                                }
123                        } else {
124                                break;
125                        }
126                }
127                if (j == -1) {
128                        if (endrank >= 0) {
129                                return false;
130                        }
131                        index += aStep;
132                        if (oDelta != null) {
133                                oIndex += oStep;
134                        }
135                }
136                if (outputA) {
137                        oIndex = index;
138                }
139
140                if (index == aMax) {
141                        return false;
142                }
143
144                return true;
145        }
146
147        @Override
148        public void reset() {
149                for (int i = 0; i <= endrank; i++) {
150                        pos[i] = 0;
151                }
152
153                if (endrank >= 0) {
154                        pos[endrank] = -1;
155                        index = aStart - aStride[endrank];
156                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
157                } else {
158                        index = aStart - aStep;
159                        oIndex = oStart - oStep;
160                }
161        }
162}