Diwa
Lightweight implementation of Artificial Neural Network for resource-constrained environments
Loading...
Searching...
No Matches
diwa.h
Go to the documentation of this file.
1/*
2 * This file is part of the Diwa library.
3 * Copyright (c) 2024 Nathanne Isip
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a copy
6 * of this software and associated documentation files (the "Software"), to deal
7 * in the Software without restriction, including without limitation the rights
8 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 * copies of the Software, and to permit persons to whom the Software is
10 * furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be included in
13 * all copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 * THE SOFTWARE.
22 */
23
41#ifndef DIWA_H
42#define DIWA_H
43
44#ifdef ARDUINO
45# include <Arduino.h>
46# include <FS.h>
47#else
48# include <fstream>
49# include <stdint.h>
50#endif
51
52#include <diwa_activations.h>
53#include <math.h>
54
61typedef enum {
69} DiwaError;
70
89class Diwa final {
90private:
91 int inputNeurons;
92 int hiddenNeurons;
93 int hiddenLayers;
94 int outputNeurons;
96 int weightCount;
97 int neuronCount;
99 double *weights;
100 double *outputs;
101 double *deltas;
103 diwa_activation activation;
113 void randomizeWeights();
114
127 DiwaError initializeWeights();
128
140 bool testInference(double *testInput, double *testExpectedOutput);
141
142public:
149 Diwa();
150
157 ~Diwa();
158
176 int inputNeurons,
177 int hiddenLayers,
178 int hiddenNeurons,
179 int outputNeurons,
180 bool randomizeWeights = true
181 );
182
195 double* inference(double *inputs);
196
210 void train(
211 double learningRate,
212 double *inputNeurons,
213 double *outputNeurons
214 );
215
216 #ifdef ARDUINO
217
228 DiwaError loadFromFile(File annFile);
229
240 DiwaError saveToFile(File annFile);
241
242 #else
243
254 DiwaError loadFromFile(std::ifstream& annFile);
255
266 DiwaError saveToFile(std::ofstream& annFile);
267
268 #endif
269
283 double calculateAccuracy(double *testInput, double *testExpectedOutput, int epoch);
284
298 double calculateLoss(double *testInput, double *testExpectedOutput, int epoch);
299
311 void setActivationFunction(diwa_activation activation);
312
326
339
355 int recommendedHiddenLayerCount(int numSamples, int alpha);
356};
357
358#endif // DIWA_H
Lightweight Feedforward Artificial Neural Network (ANN) library tailored for microcontrollers.
Definition diwa.h:89
double * inference(double *inputs)
Perform inference on the neural network.
Definition diwa.cpp:130
double calculateAccuracy(double *testInput, double *testExpectedOutput, int epoch)
Calculates the accuracy of the neural network on test data.
Definition diwa.cpp:466
DiwaError loadFromFile(std::ifstream &annFile)
Load neural network model from file in non-Arduino environment.
Definition diwa.cpp:378
diwa_activation getActivationFunction() const
Retrieves the current activation function used by the neural network.
Definition diwa.cpp:488
void train(double learningRate, double *inputNeurons, double *outputNeurons)
Train the neural network using backpropagation.
Definition diwa.cpp:182
int recommendedHiddenLayerCount(int numSamples, int alpha)
Calculates the recommended number of hidden layers based on the dataset size and complexity.
Definition diwa.cpp:499
double calculateLoss(double *testInput, double *testExpectedOutput, int epoch)
Calculates the loss of the neural network on test data.
Definition diwa.cpp:475
DiwaError initialize(int inputNeurons, int hiddenLayers, int hiddenNeurons, int outputNeurons, bool randomizeWeights=true)
Initializes the Diwa neural network with specified parameters.
Definition diwa.cpp:65
~Diwa()
Destructor for the Diwa class.
Definition diwa.cpp:52
Diwa()
Default constructor for the Diwa class.
Definition diwa.cpp:47
DiwaError saveToFile(std::ofstream &annFile)
Save neural network model to file in non-Arduino environment.
Definition diwa.cpp:432
void setActivationFunction(diwa_activation activation)
Sets the activation function for the neural network.
Definition diwa.cpp:484
int recommendedHiddenNeuronCount()
Calculates the recommended number of hidden neurons based on the input and output neurons.
Definition diwa.cpp:492
DiwaError
Enumeration representing various error codes that may occur during the operation of the Diwa library.
Definition diwa.h:61
@ MALLOC_FAILED
Definition diwa.h:68
@ MODEL_SAVE_ERROR
Definition diwa.h:65
@ STREAM_NOT_OPEN
Definition diwa.h:67
@ MODEL_READ_ERROR
Definition diwa.h:64
@ NO_ERROR
Definition diwa.h:62
@ INVALID_PARAM_VALUES
Definition diwa.h:63
@ INVALID_MAGIC_NUMBER
Definition diwa.h:66
Defines activation functions for use in the Diwa neural network.
double(* diwa_activation)(double)
Typedef for activation function pointer.
Definition diwa_activations.h:59