00001
00002
00003
00004
00005
00006
00007
00008
00009 #ifndef NEURAL_NET_H
00010 #define NEURAL_NET_H
00011
00012 #include <string>
00013 #include <vector>
00014 #include "neuron.h"
00015 #include "input_file_buffer.h"
00016
00017 namespace tesseract {
00018
00019
00020 static const float kMinInputRange = 1e-6f;
00021
00022 class NeuralNet {
00023 public:
00024 NeuralNet();
00025 virtual ~NeuralNet();
00026
00027 static NeuralNet *FromFile(const string file_name);
00028
00029 static NeuralNet *FromInputBuffer(InputFileBuffer *ib);
00030
00031 template <typename Type> bool FeedForward(const Type *inputs,
00032 Type *outputs);
00033
00034
00035
00036 template <typename Type> bool GetNetOutput(const Type *inputs,
00037 int output_id,
00038 Type *output);
00039
00040 int in_cnt() const { return in_cnt_; }
00041 int out_cnt() const { return out_cnt_; }
00042
00043 protected:
00044 struct Node;
00045
00046 struct WeightedNode {
00047 Node *input_node;
00048 float input_weight;
00049 };
00050
00051
00052 struct Node {
00053 float out;
00054 float bias;
00055 int fan_in_cnt;
00056 WeightedNode *inputs;
00057 };
00058
00059
00060
00061 bool read_only_;
00062
00063 int in_cnt_;
00064
00065 int out_cnt_;
00066
00067 int neuron_cnt_;
00068
00069 int wts_cnt_;
00070
00071 Neuron *neurons_;
00072
00073
00074
00075
00076 static const int kWgtChunkSize = 0x10000;
00077
00078
00079 static const unsigned int kNetSignature = 0xFEFEABD0;
00080
00081 int alloc_wgt_cnt_;
00082
00083 vector<vector<float> *>wts_vec_;
00084
00085 bool auto_encoder_;
00086
00087 vector<float> inputs_max_;
00088
00089 vector<float> inputs_min_;
00090
00091 vector<float> inputs_mean_;
00092
00093 vector<float> inputs_std_dev_;
00094
00095
00096 vector<Node> fast_nodes_;
00097
00098 void Init();
00099
00100 void Clear() {
00101 for (int node = 0; node < neuron_cnt_; node++) {
00102 neurons_[node].Clear();
00103 }
00104 }
00105
00106 template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) {
00107
00108 Init();
00109
00110 unsigned int read_val;
00111 unsigned int auto_encode;
00112
00113 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00114 return false;
00115 }
00116 if (read_val != kNetSignature) {
00117 return false;
00118 }
00119 if (input_buff->Read(&auto_encode, sizeof(auto_encode)) !=
00120 sizeof(auto_encode)) {
00121 return false;
00122 }
00123 auto_encoder_ = auto_encode;
00124
00125 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00126 return false;
00127 }
00128 neuron_cnt_ = read_val;
00129 if (neuron_cnt_ <= 0) {
00130 return false;
00131 }
00132
00133 neurons_ = new Neuron[neuron_cnt_];
00134 if (neurons_ == NULL) {
00135 return false;
00136 }
00137
00138 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00139 return false;
00140 }
00141 in_cnt_ = read_val;
00142 if (in_cnt_ <= 0) {
00143 return false;
00144 }
00145
00146 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00147 return false;
00148 }
00149 out_cnt_ = read_val;
00150 if (out_cnt_ <= 0) {
00151 return false;
00152 }
00153
00154 for (int idx = 0; idx < neuron_cnt_; idx++) {
00155 neurons_[idx].set_id(idx);
00156
00157 if (idx < in_cnt_) {
00158 neurons_[idx].set_node_type(Neuron::Input);
00159 } else if (idx >= (neuron_cnt_ - out_cnt_)) {
00160 neurons_[idx].set_node_type(Neuron::Output);
00161 } else {
00162 neurons_[idx].set_node_type(Neuron::Hidden);
00163 }
00164 }
00165
00166 for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
00167
00168 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00169 return false;
00170 }
00171
00172 int fan_out_cnt = read_val;
00173 for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) {
00174
00175 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
00176 return false;
00177 }
00178
00179 if (!SetConnection(node_idx, read_val)) {
00180 return false;
00181 }
00182 }
00183 }
00184
00185 for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
00186
00187 if (!neurons_[node_idx].ReadBinary(input_buff)) {
00188 return false;
00189 }
00190 }
00191
00192 inputs_mean_.resize(in_cnt_);
00193 inputs_std_dev_.resize(in_cnt_);
00194 inputs_min_.resize(in_cnt_);
00195 inputs_max_.resize(in_cnt_);
00196
00197 if (input_buff->Read(&(inputs_mean_.front()),
00198 sizeof(inputs_mean_[0]) * in_cnt_) !=
00199 sizeof(inputs_mean_[0]) * in_cnt_) {
00200 return false;
00201 }
00202 if (input_buff->Read(&(inputs_std_dev_.front()),
00203 sizeof(inputs_std_dev_[0]) * in_cnt_) !=
00204 sizeof(inputs_std_dev_[0]) * in_cnt_) {
00205 return false;
00206 }
00207 if (input_buff->Read(&(inputs_min_.front()),
00208 sizeof(inputs_min_[0]) * in_cnt_) !=
00209 sizeof(inputs_min_[0]) * in_cnt_) {
00210 return false;
00211 }
00212 if (input_buff->Read(&(inputs_max_.front()),
00213 sizeof(inputs_max_[0]) * in_cnt_) !=
00214 sizeof(inputs_max_[0]) * in_cnt_) {
00215 return false;
00216 }
00217
00218 if (read_only_) {
00219 return CreateFastNet();
00220 }
00221 return true;
00222 }
00223
00224
00225 bool SetConnection(int from, int to);
00226
00227
00228 bool CreateFastNet();
00229
00230
00231
00232 float *AllocWgt(int wgt_cnt);
00233
00234 template <typename Type> bool FastFeedForward(const Type *inputs,
00235 Type *outputs);
00236
00237
00238
00239
00240 template <typename Type> bool FastGetNetOutput(const Type *inputs,
00241 int output_id,
00242 Type *output);
00243 };
00244 }
00245
00246 #endif // NEURAL_NET_H__