Diwa
Lightweight implementation of Artificial Neural Network for resource-constrained environments
Loading...
Searching...
No Matches
diwa.h
Go to the documentation of this file.
1/*
2 * This file is part of the Diwa library.
3 * Copyright (c) 2024 Nathanne Isip
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a copy
6 * of this software and associated documentation files (the "Software"), to deal
7 * in the Software without restriction, including without limitation the rights
8 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 * copies of the Software, and to permit persons to whom the Software is
10 * furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be included in
13 * all copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 * THE SOFTWARE.
22 */
23
41#ifndef DIWA_H
42#define DIWA_H
43
44#ifdef ARDUINO
45# include <Arduino.h>
46# include <FS.h>
47#else
48# include <fstream>
49# include <stdint.h>
50#endif
51
52#include <diwa_activations.h>
53#include <math.h>
54
61typedef enum {
70 #ifdef ARDUINO
71 NO_ESP_PSRAM
72 #endif
73} DiwaError;
74
93class Diwa final {
94private:
95 int inputNeurons;
96 int hiddenNeurons;
97 int hiddenLayers;
98 int outputNeurons;
100 int weightCount;
101 int neuronCount;
103 double *weights;
104 double *outputs;
105 double *deltas;
107 diwa_activation activation;
117 void randomizeWeights();
118
131 DiwaError initializeWeights();
132
144 bool testInference(double *testInput, double *testExpectedOutput);
145
146public:
154
162
180 int inputNeurons,
181 int hiddenLayers,
182 int hiddenNeurons,
183 int outputNeurons,
184 bool randomizeWeights = true
185 );
186
199 double* inference(double *inputs);
200
214 void train(
215 double learningRate,
216 double *inputNeurons,
217 double *outputNeurons
218 );
219
220 #ifdef ARDUINO
221
232 DiwaError loadFromFile(File annFile);
233
244 DiwaError saveToFile(File annFile);
245
246 #else
247
258 DiwaError loadFromFile(std::ifstream& annFile);
259
270 DiwaError saveToFile(std::ofstream& annFile);
271
272 #endif
273
287 double calculateAccuracy(double *testInput, double *testExpectedOutput, int epoch);
288
302 double calculateLoss(double *testInput, double *testExpectedOutput, int epoch);
303
316
330
343
359 int recommendedHiddenLayerCount(int numSamples, int alpha);
360};
361
362#endif // DIWA_H
Lightweight Feedforward Artificial Neural Network (ANN) library tailored for microcontrollers.
Definition diwa.h:93
double * inference(double *inputs)
Perform inference on the neural network.
double calculateAccuracy(double *testInput, double *testExpectedOutput, int epoch)
Calculates the accuracy of the neural network on test data.
DiwaError loadFromFile(std::ifstream &annFile)
Load neural network model from file in non-Arduino environment.
diwa_activation getActivationFunction() const
Retrieves the current activation function used by the neural network.
void train(double learningRate, double *inputNeurons, double *outputNeurons)
Train the neural network using backpropagation.
int recommendedHiddenLayerCount(int numSamples, int alpha)
Calculates the recommended number of hidden layers based on the dataset size and complexity.
double calculateLoss(double *testInput, double *testExpectedOutput, int epoch)
Calculates the loss of the neural network on test data.
DiwaError initialize(int inputNeurons, int hiddenLayers, int hiddenNeurons, int outputNeurons, bool randomizeWeights=true)
Initializes the Diwa neural network with specified parameters.
~Diwa()
Destructor for the Diwa class.
Diwa()
Default constructor for the Diwa class.
DiwaError saveToFile(std::ofstream &annFile)
Save neural network model to file in non-Arduino environment.
void setActivationFunction(diwa_activation activation)
Sets the activation function for the neural network.
int recommendedHiddenNeuronCount()
Calculates the recommended number of hidden neurons based on the input and output neurons.
DiwaError
Enumeration representing various error codes that may occur during the operation of the Diwa library.
Definition diwa.h:61
@ MALLOC_FAILED
Definition diwa.h:68
@ MODEL_SAVE_ERROR
Definition diwa.h:65
@ STREAM_NOT_OPEN
Definition diwa.h:67
@ MODEL_READ_ERROR
Definition diwa.h:64
@ NO_ERROR
Definition diwa.h:62
@ INVALID_PARAM_VALUES
Definition diwa.h:63
@ INVALID_MAGIC_NUMBER
Definition diwa.h:66
Defines activation functions for use in the Diwa neural network.
double(* diwa_activation)(double)
Typedef for activation function pointer.
Definition diwa_activations.h:58